In [None]:
##分别转化为灰度图和热图两种格式
from pyDeepInsight import ImageTransformer, LogScaler
from sklearn.model_selection import train_test_split
import numpy as np
from matplotlib import pyplot as plt
import matplotlib.ticker as ticker
import feather
import datatable as dt
import scanpy as sc
import cv2
import matplotlib.cm as cm
import pandas as pd
import seaborn as sns
from PIL import Image
from matplotlib import ticker
from sklearn.manifold import TSNE
from umap import UMAP
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import MinMaxScaler

##读取预处理的基因表达量矩阵
features_norm = pd.read_csv("/home/WuHX/gbx/Study.xlsx", header=None,skiprows=[0])
positive_norm = features_norm

##通过UMAP查看数据的整体特征分布

umap_model = UMAP(n_neighbors=50,
    min_dist=0.5,
    n_components=2,
    metric='cosine',
)

scaler = StandardScaler()
positive_norm_scaled = scaler.fit_transform(positive_norm)

it = ImageTransformer(feature_extractor=umap_model, pixels=32)

positive_img = it.fit_transform(positive_norm_scaled)


##通过TSNE查看数据的整体特征分布

tsne = TSNE(n_components=2, perplexity=30, metric='cosine',
            random_state=1701, n_jobs=-1)

it = ImageTransformer(feature_extractor=tsne, 
                      pixels=50, random_state=1701, 
                      n_jobs=-1)

plt.figure(figsize=(10, 10))
_ = it.fit(positive_norm, plot=True)


##查看特征分布热图

fdm = it.feature_density_matrix()
fdm[fdm == 0] = np.nan

plt.figure(figsize=(5, 5))

ax = sns.heatmap(fdm, cmap="viridis", linewidths=0.01, 
                 linecolor="lightgrey", square=True)
ax.xaxis.set_major_locator(ticker.MultipleLocator(5))
ax.yaxis.set_major_locator(ticker.MultipleLocator(5))
for _, spine in ax.spines.items():
    spine.set_visible(True)
_ = plt.title("Genes per pixel")


##查看不同分辨率的特征分布图

px_sizes = [25, (25, 50), 50, 100]

fig, ax = plt.subplots(1, len(px_sizes), figsize=(25, 7))
for ix, px in enumerate(px_sizes):
    it.pixels = px
    fdm = it.feature_density_matrix()
    fdm[fdm == 0] = np.nan
    cax = sns.heatmap(fdm, cmap="viridis", linewidth=0.01, 
                      linecolor="lightgrey", square=True, 
                      ax=ax[ix], cbar=False)
    cax.set_title('Dim {} x {}'.format(*it.pixels))
    for _, spine in cax.spines.items():
        spine.set_visible(True)
    cax.xaxis.set_major_locator(ticker.MultipleLocator(5))
    cax.yaxis.set_major_locator(ticker.MultipleLocator(5))
plt.tight_layout()    
    
it.pixels = 32


##转化图片

positive_img = it.fit_transform(positive_norm)
positive_img.shape


##查看并导出转化后的灰度图

image = positive_img[1]
image.shape
img = Image.fromarray(np.uint8(image*255)) 
fig = plt.figure(figsize=(6.2,6.2))
plt.axis('off')
plt.xticks([])
plt.yticks([])
plt.imshow(img)

for i in range(positive_img.shape[0]):
    img = Image.fromarray(np.uint8(positive_img[i] * 255)) 
    fig = plt.figure(figsize=(6.2,6.2))
    plt.axis('off')
    plt.xticks([])
    plt.yticks([])
    plt.imshow(img)
    plt.savefig("/home/GaoBX/work/Study/{}.png".format(i + 1), dpi=50, bbox_inches='tight', pad_inches=-0.1)
    plt.clf()  

    
##查看并导出转化后的热图
image = positive_img[1]
fdm = np.sum(image, axis=2)
plt.figure(figsize=(5, 5))

ax = sns.heatmap(fdm, cmap="viridis", linewidths=0, 
                 square=True, cbar=False)
ax.xaxis.set_major_locator(ticker.MultipleLocator(5))
ax.yaxis.set_major_locator(ticker.MultipleLocator(5))
ax.tick_params(left=False, bottom=False)
for _, spine in ax.spines.items():
    spine.set_visible(False)
plt.savefig('heatmap.png', dpi=300, bbox_inches='tight', pad_inches=0.1)

for i, feature in enumerate(positive_img):
    fdm = np.sum(feature, axis=2)
    plt.figure(figsize=(5, 5))
    ax = sns.heatmap(fdm, cmap="viridis", linewidths=0, 
                     square=True, cbar=False)
    ax.xaxis.set_major_locator(ticker.MultipleLocator(5))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(5))
    ax.tick_params(left=False, bottom=False)
    for _, spine in ax.spines.items():
        spine.set_visible(False)
    plt.savefig("/home/GaoBX/work/Study/{}.png".format(i + 1), dpi=50, bbox_inches='tight', pad_inches=-0.1)
    plt.clf()  