# Example code for training process

In [None]:
import scprep
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import graphtools as gt
import datetime
import scanpy as sc
import sklearn.preprocessing as preprocessing
import loompy as lp
import umap.umap_ as umap
from sklearn.utils import shuffle

import torch.autograd
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
from torchvision import transforms
from torchvision import datasets
import torch.utils.data as Data  #Data是用来批训练的模块
from torchvision.utils import save_image
import numpy as np
import os
import pandas as pd

import sklearn
from sklearn import linear_model
from sklearn.metrics import r2_score

np.random.seed(999)
torch.manual_seed(999)
torch.cuda.manual_seed_all(999)

In [None]:
adata_new = sc.read_h5ad('PBMC10_5.h5ad') # Prepared dataset with two results from different techniques.

In [None]:
def KNN_Matching(data1, data2, label_list):
  celltype1 = label_list[0]
  celltype2 = label_list[1]

  id_list1 = [i for i in range(len(data1))]
  id_list2 = [i for i in range(len(data2))]

  result_pair = []

  while id_list2 != []:
    item = id_list2[0]
    temp = [i for i in range(len(id_list1)) if celltype1[i]==celltype2[item]]
    k = np.random.choice(temp)
    result_pair.append((k, item))
    id_list2.remove(item)

  return [result_pair,id_list1]

# Prepare for training

In [None]:
adata_new.obsm['protein_expression']

pro_test = adata_new.obsm['protein_expression'][['CD3','CD4','CD8a','CD14','CD16','CD19']]

pro_test

adata_0 = adata_new[:,['CD3D','CD8B','CD8A','CYBB','CHL1','CCL5']]

adata_0.obsm['protein_expression'] = pro_test

adata_0.obs_names_make_unique()

adata_0.write_h5ad('correlationfind_pbmc.h5ad')

adata_1 = adata_new[0:2000,:]

adata_2 = adata_new[2000:5000,:]

train_data_b1 = adata_1.X 
train_label_b1 = adata_1.obsm['protein_expression']

train_data_b2 = adata_2.X 
train_label_b2 = adata_2.obsm['protein_expression']

adata_0 = adata_1.concatenate(adata_2, batch_categories=['batch1','batch2'])

label = [np.array(adata_1.obs['celltype']), np.array(adata_2.obs['celltype'])]


pair_info, res_id = KNN_Matching(adata_1.X, adata_2.X, label)


rna_train_index = [i[0] for i in pair_info]
pro_train_index = [i[1] for i in pair_info]

train_data = adata_1.X[rna_train_index]
train_label = adata_2.obsm['protein_expression'].values[pro_train_index]

# Training process

In [None]:

class Mish(nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self,x):
    return x*torch.tanh(F.softplus(x))

# NN model
# use supervisied learning method
# target: transfer rna data into protein data

class generator_r2p(nn.Module):
    def __init__(self):
        super(generator_r2p, self).__init__()
        self.relu_l = nn.ReLU(True)
        self.gen = nn.Sequential(

            nn.Linear(2000, 1024),  
            nn.BatchNorm1d(1024),
            Mish(),

            nn.Linear(1024, 512),  
            nn.BatchNorm1d(512),
            Mish(),

            nn.Linear(512, 14)
           
        )

        self.lin = nn.Linear(2000, 14)


    def forward(self, x):
        ge = self.gen(x)
        
        return ge



class generator_p2r(nn.Module):
    def __init__(self):
        super(generator_p2r, self).__init__()
        self.relu_l = nn.ReLU(True)
        self.gen = nn.Sequential(
            nn.Linear(14,128),  
            nn.BatchNorm1d(128),
            Mish(),

            nn.Linear(128, 256),  
            nn.BatchNorm1d(256),
            Mish(),

            nn.Linear(256, 512),  
            nn.BatchNorm1d(512),
            Mish(),

            nn.Linear(512, 1024),  
            nn.BatchNorm1d(1024),
            Mish(),

            nn.Linear(1024, 2000),  
           
        )

        self.lin = nn.Linear(14,2000)

    def forward(self, x):
        x = self.relu_l(self.gen(x) + self.lin(x))
        return x

class discriminator_r2p(nn.Module):
    def __init__(self):
        super(discriminator_r2p, self).__init__()
        self.disc = nn.Sequential(
            nn.Linear(14,6),  
            nn.BatchNorm1d(6),
            Mish(),
            nn.Linear(6,1),  
            nn.ReLU(True)
           
        )

    def forward(self, x):
        x = self.disc(x)
        return x


class discriminator_p2r(nn.Module):
    def __init__(self):
        super(discriminator_p2r, self).__init__()
        self.disc = nn.Sequential(
            nn.Linear(4000,2048),  
            # nn.BatchNorm1d(2048),
            Mish(),

            nn.Linear(2048, 1024),  
            # nn.BatchNorm1d(1024),
            Mish(),

            nn.Linear(1024, 512),  
            # nn.BatchNorm1d(512),
            Mish(),

            nn.Linear(512, 256),  
            # nn.BatchNorm1d(256),
            Mish(),

            nn.Linear(256, 128),  
            # nn.BatchNorm1d(128),
            Mish(),

            nn.Linear(128, 64),  
            # nn.BatchNorm1d(64),
            Mish(),

            nn.Linear(64, 1)  
            # nn.ReLU(True)
           
        )

    def forward(self, x):
        x=self.disc(x)
        return x

EPOCH1 = 100 #old value : 200
MAX_ITER = train_data.shape[0]
batch = 32
BATCH = 32
b1 = 0.9
b2 = 0.999
lambda_1 = 1/10

Encoder = generator_r2p()
Decoder = generator_p2r()
# use GPU
if torch.cuda.is_available():
  Encoder = Encoder.cuda()
  Decoder = Decoder.cuda()
Encoder.train()
Decoder.train()

criterion = nn.SmoothL1Loss() 
if torch.cuda.is_available():
  criterion = criterion.cuda()
encoder_optimizer = torch.optim.Adam(Encoder.parameters(), lr=0.00001)
decoder_optimizer = torch.optim.Adam(Decoder.parameters(), lr=0.00001)

# train the model
for epoch in range(EPOCH1):
  print(epoch)
  print("###########################Encoder Part#######################")
  for time in range(0,MAX_ITER,BATCH):
    train = torch.FloatTensor(train_data[time:time+BATCH,:]).cuda()
    label = torch.FloatTensor(train_label[time:time+BATCH,:]).cuda()

    #train encoder
    output = Encoder(train)

    err_r2p = criterion(output, label)

    encoder_optimizer.zero_grad()
    err_r2p.backward()
    encoder_optimizer.step()

    if(time%100==0):
      print('encoder part loss', err_r2p)
  print("###########################Decoder Part#######################")

  for time in range(0,MAX_ITER,BATCH):
    train = torch.FloatTensor(train_data[time:time+BATCH,:]).cuda()
    label = torch.FloatTensor(train_label[time:time+BATCH,:]).cuda()

    #train encoder
    output = Decoder(label)

    err_p2r = criterion(output, train)

    decoder_optimizer.zero_grad()
    err_p2r.backward()
    decoder_optimizer.step()

    if(time%100==0):
      print('decoder part loss', err_p2r)
  print("###########################Construct Part#######################")
  for time in range(0,MAX_ITER,BATCH):
    train = torch.FloatTensor(train_data[time:time+BATCH,:]).cuda()
    label = torch.FloatTensor(train_label[time:time+BATCH,:]).cuda()

    #train encoder
    output = Decoder(Encoder(train))

    err_cons = criterion(output, train)

    decoder_optimizer.zero_grad()
    err_cons.backward()
    decoder_optimizer.step()

    if(time%100==0):
      print('construct part loss', err_cons)
  print("###########################Reconstruct Part#######################")
  for time in range(0,MAX_ITER,BATCH):
    train = torch.FloatTensor(train_data[time:time+BATCH,:]).cuda()
    label = torch.FloatTensor(train_label[time:time+BATCH,:]).cuda()

    #train encoder
    output = Encoder(Decoder(label))

    err_cons = criterion(output, label)

    encoder_optimizer.zero_grad()
    err_cons.backward()
    encoder_optimizer.step()

    if(time%100==0):
      print('reconstruct part loss', err_cons)

print("#######################fubusged pre train##########################")

# generate test data
def get_train_result(G,testd):
  G.eval()
  test_data1 = torch.FloatTensor(testd).cuda()
  test_list = G(test_data1).detach().cpu().numpy() 
  return test_list 

rna = get_train_result(Decoder, train_label_b2.values)

rna_t = np.vstack([train_data_b1,rna])

adata_0.X = rna_t
sc.tl.pca(adata_0, 20)


In [None]:
data_umap = umap.UMAP().fit_transform(adata_0.obsm['X_pca'])

scprep.plot.scatter2d(data_umap, c=adata_new.obs['celltype'][0:5000], figsize=(12,8), cmap="Spectral",
                      ticks=False, label_prefix="UMAP", s = 20, title = 'Celltype Specific')

In [None]:
scprep.plot.scatter2d(data_umap, c=adata_0.obs['batch'], figsize=(12,8), cmap="Spectral",
                      ticks=False, label_prefix="UMAP", s = 20, title = 'Celltype Specific')

In [None]:
pro = get_train_result(Encoder, train_data_b1)

pro_t = np.vstack([pro, train_label_b2.values])

In [None]:

data_umap = umap.UMAP().fit_transform(pro_t)

scprep.plot.scatter2d(data_umap, c=adata_new.obs['celltype'][0:5000], figsize=(12,8), cmap="Spectral",
                      ticks=False, label_prefix="UMAP", s = 20, title = 'Celltype Specific')


In [None]:
scprep.plot.scatter2d(data_umap, c=adata_0.obs['batch'], figsize=(12,8), cmap="Spectral",
                      ticks=False, label_prefix="UMAP", s = 20, title = 'Celltype Specific')


# Store output files as adata. 

In [None]:
adata_new0 = adata_new[0:5000]

adata_new0.X = rna_t

adata_new0.obsm['protein_expression'] = pro_t

adata_new0.write_h5ad('Finished_Running_BIDpart1.h5ad')