## 在Github 找到有人建過 pytorch 版的 Scaden
###### https://github.com/poseidonchan/TAPE/blob/main/Experiments/pytorch_scaden_PBMConly.ipynb

In [1]:
!nvidia-smi

Fri Sep  9 22:08:34 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 418.56       Driver Version: 418.56       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  GeForce GTX 108...  Off  | 00000000:5E:00.0 Off |                  N/A |
|  0%   30C    P8     9W / 280W |  10941MiB / 11178MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   1  GeForce GTX 108...  Off  | 00000000:AF:00.0 Off |                  N/A |
|  0%   35C    P8    16W / 280W |   3471MiB / 11178MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                            

In [2]:
from tqdm import tqdm
import torch
import random
import numpy as np
import pandas as pd
import collections
import anndata
from anndata import read_h5ad
import torch.utils.data as Data
import torch.nn.functional as F
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
import torch.nn as nn
import matplotlib.pyplot as plt
import scanpy as sc
from torchsummary import summary
from scipy.stats import pearsonr
from sklearn import preprocessing

In [3]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
adata = read_h5ad('./0902_CGL2_SCT.h5ad')
adata

In [4]:
adata = read_h5ad('./0902_CGL2_harmony.h5ad')
adata

AnnData object with n_obs × n_vars = 14 × 1971
    obs: 'Scaden results', 'CD8+/CD45RA+ Naive Cytotoxic T Cells', 'CD14_Monocytes', 'CD4+/CD45RA+/CD25-Naive T cells', 'CD4+/CD45RO+ Memory T Cells', 'CD56+ Natural Killer Cells', 'CD4+/CD25+ Regulatory T Cells_and_CD4+ T Helper Cells', 'CD19_B_Cells', 'CD8+ Cytotoxic T cells'

In [5]:
#adata.obs

In [6]:
print(adata.X.shape)
adata.X[0,0:20]

(14, 1971)


array([0.        , 0.        , 0.        , 0.10010195, 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.18947189, 0.3992759 , 0.        , 0.        ,
       0.        , 0.        , 0.56196386, 0.0631573 , 0.        ],
      dtype=float32)

## Scaden

In [7]:
class MLP_no_dropout(nn.Module):
    def __init__(self, gene_size, unit):
        super(MLP_no_dropout, self).__init__()
        self.gene_size = gene_size
        self.unit = unit
        
        self.D1 = nn.Sequential(
            nn.Linear(self.gene_size, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, self.unit),
            nn.Softmax(dim=1)
        )
        self.D2 = nn.Sequential(
            nn.Linear(self.gene_size, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, self.unit),
            nn.Softmax(dim=1)
        )
        self.D3 = nn.Sequential(
            nn.Linear(self.gene_size, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),

            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, self.unit),
            nn.Softmax(dim=1)
        )
    def forward(self, inputs):
        out1 = self.D1(inputs)
        out2 = self.D2(inputs)
        out3 = self.D3(inputs)
        
        out = (out1+out2+out3)/3
        return out

In [8]:
class MLP_batch(nn.Module):
    def __init__(self, gene_size, unit):
        super(MLP_batch, self).__init__()
        self.gene_size = gene_size
        self.unit = unit
        
        self.D1 = nn.Sequential(
            nn.Linear(self.gene_size, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Linear(32, self.unit),
            nn.Softmax(dim=1)
        )
        self.D2 = nn.Sequential(
            nn.Linear(self.gene_size, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Linear(64, self.unit),
            nn.Softmax(dim=1)
        )
        self.D3 = nn.Sequential(
            nn.Linear(self.gene_size, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(128, self.unit),
            nn.Softmax(dim=1)
        )
    def forward(self, inputs):
        out1 = self.D1(inputs)
        out2 = self.D2(inputs)
        out3 = self.D3(inputs)
        
        out = (out1+out2+out3)/3
        return out

## Preprocess train, validation

### Training harmony

In [None]:
date = "0902"

In [None]:
dropout_rate = "no_dropout"

In [None]:
val_data = adata

#將 train_data 分開
#sample x gene
v_exp = val_data.X
#v_obs = val_data.obs
test = torch.tensor(v_exp, dtype=torch.float32)
#test_label = torch.tensor(np.array(v_obs), dtype=torch.float32)
celltype = ['CD141+DC','CD4/CD8-C1-CCR7','CD4/CD8-C2-MKI67','CD8-C7-KLRD1','CD8-C9-SLC4A10','Central memory T cells',
            'Circulating NK','Conventional dendritic cells(CD1C DC)','Cytotoxicity CD8T','DC-C4-LAMP3',
            'Effector memory T cells','Exhausted CD8+ T (Tex) cells','ILCs','Liver-resident NK (lrNK) cell',
            'Lymphoid-B','M-C4-GPX3','M1','Mast','Mono','Myeloid-derived suppressor cells','NK','TAM-like',
            'Th0','Th1','Treg']

In [None]:
PATH = './2.model/0902_model/0902_SCT_combine_model_no_dropout.pth'

In [None]:
mlp = MLP_no_dropout(v_exp.shape[1], len(celltype)).to(device)
mlp.load_state_dict(torch.load(PATH))

loss_fn = nn.L1Loss()
optimizer = torch.optim.Adam(mlp.parameters(),lr=0.0001)

test_dataset = Data.TensorDataset(test)
test_dataset = torch.utils.data.DataLoader(test_dataset, batch_size=64)

#print(mlp)

In [9]:
date = "0902"

In [10]:
dropout_rate = "batch_nor"

In [11]:
val_data = adata

#將 train_data 分開
#sample x gene
v_exp = val_data.X
#v_obs = val_data.obs
test = torch.tensor(v_exp, dtype=torch.float32)
#test_label = torch.tensor(np.array(v_obs), dtype=torch.float32)
celltype = ['CD141+DC','CD4/CD8-C1-CCR7','CD4/CD8-C2-MKI67','CD8-C7-KLRD1','CD8-C9-SLC4A10','Central memory T cells',
            'Circulating NK','Conventional dendritic cells(CD1C DC)','Cytotoxicity CD8T','DC-C4-LAMP3',
            'Effector memory T cells','Exhausted CD8+ T (Tex) cells','ILCs','Liver-resident NK (lrNK) cell',
            'Lymphoid-B','M-C4-GPX3','M1','Mast','Mono','Myeloid-derived suppressor cells','NK','TAM-like',
            'Th0','Th1','Treg']

In [12]:
PATH = './2.model/0902_model/0902_harmony_combine_model_batchnor_before.pth'

In [13]:
mlp = MLP_batch(v_exp.shape[1], len(celltype)).to(device)
mlp.load_state_dict(torch.load(PATH))

loss_fn = nn.L1Loss()
optimizer = torch.optim.Adam(mlp.parameters(),lr=0.0001)

test_dataset = Data.TensorDataset(test)
test_dataset = torch.utils.data.DataLoader(test_dataset, batch_size=64)

#print(mlp)

## Combine three submodel

In [14]:
train_loss_list = []
t_loss =[]
val_loss_list = []
v_loss = []

cor = []
cor_list = []
val_cor = []
val_cor_list = []

t_cor_list = []
v_cor_list = []
t_big_cor_list = []
v_big_cor_list = []

epochs = 100
big_type = [1,3,4,5,6,7,8,10,11,13,14,19,20,21,22,24]


In [15]:
#Train
for inte in range(epochs):
    print("--- Start training! ---", end="\r")
    print('      epoch: {}            '.format(inte+1))
    train_loss = 0
    val_loss = 0
    a = 0
    b = 0
    t_big = 0
    v_big = 0
    
#Validation    
    print("--- Now evaluation!!! ---",end="\r")
    mlp.eval()
    
    # Tell torch not to calculate gradients
    with torch.no_grad():
        for n, D in enumerate(test_dataset):
            pred = mlp(D[0].float().cuda())
            #VAL = loss_fn(pred, D[1].float().cuda())
            optimizer.zero_grad()
            optimizer.step()
            
        
print(" ----- Finish!! ----- ", end="\r")

      epoch: 1            
      epoch: 2            
      epoch: 3            
      epoch: 4            
      epoch: 5            
      epoch: 6            
      epoch: 7            
      epoch: 8            
      epoch: 9            
      epoch: 10            
      epoch: 11            
      epoch: 12            
      epoch: 13            
      epoch: 14            
      epoch: 15            
      epoch: 16            
      epoch: 17            
      epoch: 18            
      epoch: 19            
      epoch: 20            
      epoch: 21            
      epoch: 22            
      epoch: 23            
      epoch: 24            
      epoch: 25            
      epoch: 26            
      epoch: 27            
      epoch: 28            
      epoch: 29            
      epoch: 30            
      epoch: 31            
      epoch: 32            
      epoch: 33            
      epoch: 34            
      epoch: 35            
      epoch: 36            
 

In [16]:
model = "combine_model"

In [17]:
df = pd.DataFrame(pred.cpu().detach().numpy())
df = df*100
df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,15,16,17,18,19,20,21,22,23,24
0,0.656015,6.453666,0.768478,4.208766,1.711249,23.937073,1.410413,3.489094,9.807796,0.927695,...,0.669512,4.619187,1.38189,1.947082,4.356538,0.539389,4.087608,3.16534,0.439703,2.732345
1,0.672803,6.568323,0.782105,4.170029,1.731601,23.595064,1.468344,3.591688,9.417401,0.964153,...,0.687672,4.599662,1.422911,1.959408,4.461899,0.564665,4.12184,3.191759,0.462567,2.874877
2,0.642758,6.388806,0.756561,4.143879,1.694888,24.085365,1.331089,3.462164,10.115726,0.890888,...,0.651254,4.589,1.345221,1.9009,4.380582,0.492553,4.166725,3.18274,0.422401,2.609814
3,0.662778,6.432861,0.772553,4.125156,1.721803,23.966825,1.420724,3.472657,9.901963,0.918007,...,0.676007,4.566734,1.3729,1.924317,4.390543,0.527891,4.145574,3.189193,0.445519,2.787483
4,0.638319,6.400237,0.753536,4.225998,1.682347,24.098427,1.318871,3.366527,10.279809,0.868069,...,0.64586,4.581454,1.287504,1.909834,4.285573,0.501148,4.051194,3.168123,0.416203,2.543793
5,0.654502,6.540128,0.766733,4.363331,1.703228,23.737345,1.421851,3.33431,9.805918,0.912278,...,0.666816,4.659933,1.319691,1.953964,4.238883,0.563509,3.934856,3.19067,0.439919,2.725107
6,0.635506,6.352584,0.752524,4.157109,1.682888,24.222301,1.312239,3.450688,10.193728,0.869164,...,0.644318,4.544249,1.284395,1.89088,4.366358,0.496668,4.098042,3.160381,0.415725,2.561428
7,0.641125,6.410493,0.755701,4.136579,1.688721,24.099165,1.318618,3.484531,10.074389,0.886544,...,0.648696,4.583266,1.338185,1.896849,4.400714,0.49013,4.17237,3.179897,0.420714,2.583388
8,0.654554,6.413641,0.764487,4.14672,1.715249,24.02631,1.400277,3.39887,9.919557,0.914965,...,0.665617,4.563376,1.357731,1.911407,4.375976,0.519795,4.118711,3.223119,0.437713,2.732165
9,0.655097,6.463675,0.765322,4.193246,1.714998,23.924068,1.404198,3.403003,9.849689,0.911486,...,0.665524,4.581716,1.336415,1.914054,4.36975,0.529535,4.066313,3.221978,0.43841,2.736626


In [None]:
df.to_csv('./4.Result/0902_CGL2_SCT_reults.csv')

In [18]:
df.to_csv('./4.Result/0902_CGL2_harmony_reults.csv')