In [1]:
import anndata
import anndata as ad
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import pickle
import pytorch_lightning as pl
import scanpy as sc
import scprep as scp
import seaborn as sns
import sys
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from collections import defaultdict as dfd
from copy import deepcopy as dcp
from pathlib import Path, PurePath
from pytorch_lightning.loggers import TensorBoardLogger
from scipy import stats
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import adjusted_rand_score as ari_score
from sklearn.metrics.cluster import normalized_mutual_info_score as nmi_score
from tqdm import tqdm
from PIL import ImageFile, Image
from matplotlib.image import imread
from scanpy import read_10x_mtx, read_visium
from torch.utils.data import DataLoader
from window_adata import *


In [2]:
# !pip install squidpy
# !pip install pytorch_lightning
# !pip install scprep

In [3]:
""" Integrate two visium datasets """
data_dir1 = "./Alex_NatGen_6BreastCancer/"
data_dir2 = "./breast_cancer_10x_visium/"

samps1 = ["1142243F", "CID4290", "CID4465", "CID44971", "CID4535", "1160920F"]
samps2 = ["block1", "block2", "FFPE",]

sampsall = samps1 + samps2
samples1 = {i:data_dir1 + i for i in samps1}
samples2 = {i:data_dir2 + i for i in samps2}

# Marker gene list
gene_list = ["COX6C","TTLL12", "HSP90AB1", "TFF3", "ATP1A1", "B2M", "FASN", "SPARC", "CD74", "CD63", "CD24", "CD81"]

# # Load windowed dataset
import pickle
with open('10x_visium_dataset_without_window.pickle', 'rb') as f:
    adata_dict0 = pickle.load(f)
    
# Define the gridding size
sizes = [4000 for i in range(len(adata_dict0))]

# Split tiles into smaller patches according to gridding size
adata_dict = window_adata(adata_dict0, sizes)

Windowing 1142243F
Num spots:  4784
246
216
222
185
77
247
246
255
247
93
246
247
255
245
94
247
246
255
238
88
130
135
132
140
55
Total:  4787
Windowing CID4290
Num spots:  2714
793
576
1001
344
Total:  2714
Windowing CID4465
Num spots:  1310
345
258
149
558
Total:  1310
Windowing CID44971
Num spots:  1322
491
462
339
30
Total:  1322
Windowing CID4535
Num spots:  1431
564
232
632
3
Total:  1431
Windowing 1160920F
Num spots:  4895
210
251
251
239
83
226
255
232
240
102
231
230
246
247
99
238
246
247
255
93
144
147
164
160
60
Total:  4896
Windowing block1
Num spots:  3798
139
205
219
185
10
169
246
247
255
16
189
230
205
233
0
197
156
241
228
0
72
106
129
124
0
Total:  3801
Windowing block2
Num spots:  3987
224
247
246
229
208
246
247
231
243
207
211
196
221
195
254
205
81
97
108
92
Total:  3988
Windowing FFPE
Num spots:  2518
50
190
188
79
0
169
219
216
189
0
182
215
201
192
0
68
138
159
63
0
0
1
0
0
0
Total:  2519
Windowing 1168993F
Num spots:  4898
244
249
248
244
68
246
247
255
246


In [4]:
# For training
from data_vit import ViT_Anndata

def dataset_wrap(fold = 0, train=True, dataloader=True):
    test_sample = sampsall[fold]
    test_sample_orig = sampsall[fold] # Split one sample as test sample
    val_sample = list(set(sampsall)-set(sampsall[fold]))[:3] # Split 3 samples as validation samples
    train_sample = list(set(sampsall)-set(test_sample)-set(val_sample)) # Other samples are training samples

    tr_name = list(set([i for i in list(adata_dict.keys()) for tr in train_sample if tr in i]))
    val_name = list(set([i for i in list(adata_dict.keys()) for val in val_sample if val in i]))
    te_name = list(set([i for i in list(adata_dict.keys()) if test_sample in i]))
    if train:
        print("LOADED TRAINSET")
        trainset = ViT_Anndata(adata_dict = adata_dict, train_set = tr_name, gene_list = gene_list, train=True, flatten=False, ori=True, prune='NA', neighs=4, )
        valset = ViT_Anndata(adata_dict = adata_dict, train_set = val_name, gene_list = gene_list, train=True, flatten=False, ori=True, prune='NA', neighs=4, )
        train_loader = DataLoader(trainset, batch_size=1, num_workers=0, shuffle=True)
        val_loader = DataLoader(valset, batch_size=1, num_workers=0, shuffle=False)
        return train_loader, val_loader
    
    else:
        print("LOADED TESTSET")
        testset = ViT_Anndata(adata_dict = adata_dict, train_set = te_name, gene_list = gene_list, train=True, flatten=False, ori=True, prune='NA', neighs=4, )
        test_loader = DataLoader(testset, batch_size=1, num_workers=0, shuffle=False)
        return test_loader
            


In [9]:
import warnings
warnings.filterwarnings("ignore")

# Import necessary libraries
import torch
import random
import numpy as np
import pytorch_lightning as pl
import torchvision.transforms as tf
import torch.nn as nn

# Import custom modules
from gcn import *
from NB_module import *
from transformer import *
from scipy.stats import pearsonr
from torch.utils.data import DataLoader
from copy import deepcopy as dcp
from collections import defaultdict as dfd
from sklearn.metrics import adjusted_rand_score as ari_score
from sklearn.metrics.cluster import normalized_mutual_info_score as nmi_score

# Import torchvision models and transforms
from torchvision import models, transforms

# Define a function to set a random seed for reproducibility
def setup_seed(seed=12000):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)

# Define a function to calculate Pearson correlation coefficient
def get_R(data1, data2, dim=1, func=pearsonr):
    # Calculate Pearson correlation coefficient for each gene
    adata1 = data1.X
    adata2 = data2.X
    r1, p1 = [], []
    for g in range(data1.shape[dim]):
        if dim == 1:
            r, pv = func(adata1[:, g], adata2[:, g])
        elif dim == 0:
            r, pv = func(adata1[g, :], adata2[g, :])
        r1.append(r)
        p1.append(pv)
    r1 = np.array(r1)
    p1 = np.array(p1)
    return r1, p1

# Define a function to fine-tune a pretrained model
def ft_extra(name="resnet"):
    if name == "resnet":
        model_ft = torchvision.models.resnet50(weights=models.ResNet50_Weights)
        for param in model_ft.parameters():
            param.requires_grad = False
        model_ft.fc = nn.Sequential(nn.Identity())
        dim = 2048
    elif name == "efficient":
        model_ft = torchvision.models.efficientnet_v2_s(weights=models.EfficientNet_V2_S_Weights)
        for param in model_ft.parameters():
            param.requires_grad = False
        model_ft.classifier = nn.Sequential(nn.Identity())
        dim = 1280
    elif name == 'swin':
        model_ft = torchvision.models.swin_s(weights=models.Swin_S_Weights)
        for param in model_ft.parameters():
            param.requires_grad = False
        model_ft.head = nn.Sequential(nn.Identity())
        dim = 768
    return model_ft, dim

# Define a custom Vision Transformer (ViT) class

class ViT(nn.Module):
    def __init__(self, name="resnet", dim=1024,
                 depth1=2, depth2=8, depth3=4, 
                 heads=8, dim_head=64, mlp_dim=1024,
                 policy='mean', gcn=True
                ):
        super().__init__()

        # Initialize with the given hyperparameters
        self.ft_extractor, ft_dim = ft_extra(name=name)  # Feature extractor and its output dimension
        self.projection = nn.Sequential(nn.Linear(ft_dim, dim))  # Linear projection layer
        self.transformer = nn.Sequential(*[attn_block(dim, heads, dim_head, mlp_dim, 0.2) for i in range(depth2)])  # Transformer blocks
        self.GCN = nn.ModuleList([gs_block(dim, dim, policy, gcn) for i in range(depth3)])  # Graph Convolutional Network blocks
        self.jknet = nn.Sequential(
            nn.LSTM(dim, dim, 2),  # LSTM layer
            SelectItem(0),  # Custom module to select an item from the LSTM output
        )
        self.dropout = nn.Dropout(0.2)  # Dropout layer with a dropout rate of 0.2
        self.tf = transforms.Compose([
            transforms.Resize(256),  # Resize image to 256x256
            transforms.CenterCrop(224),  # Center crop to 224x224
        ])

    def forward(self, patch, ct, adj):
        # Resize and crop the input image patches
        x = self.tf(patch.squeeze())

        # Extract features from the input image
        x = self.ft_extractor(x)
        
        # Project features to the specified embedding dimension
        x = self.projection(x)
        
        # Apply dropout to the projected features
        x = self.dropout(x.squeeze(0)).unsqueeze(0)

        # Pass the features through the transformer
        x = self.transformer(x + ct).squeeze(0)
        
        # Apply Graph Convolutional Networks
        jk = []
        for layer in self.GCN:
            x = layer(x, adj.squeeze(0))
            jk.append(x.unsqueeze(0))
        x = torch.cat(jk, 0)

        # Apply LSTM and compute the mean of the LSTM output
        x = self.jknet(x).mean(0)
        
        # Return the final output
        return x


# Define a custom CNN_ST class that extends pytorch_lightning LightningModule
class CNN_ST(pl.LightningModule):
    def __init__(self, learning_rate=1e-5, name="resnet", dim=1024, n_pos=128, n_genes=12,
                 depth1=2, depth2=8, depth3=4, heads=16,
                 zinb=0.25, nb=False, policy='mean', bake=5, lamb=0.5):
        super().__init__()
        self.learning_rate = learning_rate
        self.nb = nb
        self.zinb = zinb
        self.bake = bake
        self.lamb = lamb

        # Position Embedding
        self.x_embed = nn.Embedding(n_pos, dim)
        self.y_embed = nn.Embedding(n_pos, dim)

        # Feature Extractor (ViT)
        self.vit = ViT(
            heads=heads, name=name,
            dim=dim, depth1=depth1, depth2=depth2, depth3=depth3,
            mlp_dim=dim, policy=policy, gcn=True, )

        self.n_genes = n_genes

        # ZINB Loss
        if self.zinb > 0:
            if self.nb:
                self.hr = nn.Linear(dim, n_genes)
                self.hp = nn.Linear(dim, n_genes)
            else:
                self.mean = nn.Sequential(nn.Linear(dim, n_genes), MeanAct())
                self.disp = nn.Sequential(nn.Linear(dim, n_genes), DispAct())
                self.pi = nn.Sequential(nn.Linear(dim, n_genes), nn.Sigmoid())

        # Data augmentation
        self.coef = nn.Sequential(
            nn.Linear(dim, dim),
            nn.ReLU(),
            nn.Linear(dim, 1),
        )
        self.imaug = transforms.Compose([
            transforms.RandomGrayscale(0.1),
            transforms.RandomRotation(90),
            transforms.RandomHorizontalFlip(0.2),
        ])

        # Regression Module
        self.gene_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, n_genes),
        )

    # Forward pass
    def forward(self, patch, centers, adj):
        # Spatial location for Transformer
        centers_x = self.x_embed(centers[:, :, 0].long())
        centers_y = self.y_embed(centers[:, :, 1].long())
        ct = centers_x + centers_y

        # Feature Extraction
        h = self.vit(patch, ct, adj)

        # Gene expression prediction
        x = self.gene_head(h)

        # ZINB Distribution
        extra = None
        if self.zinb > 0:
            if self.nb:
                r, p = self.hr(h)
                extra = (r, p)
            else:
                m = self.mean(h)
                d = self.disp(h)
                p = self.pi(h)
                extra = (m, d, p)

        h = self.coef(h)
        return x, extra, h
    
    def aug(self,patch,center,adj):
        bake_x=[]
        # generate 5 additional image patches
        for i in range(self.bake):
            new_patch=self.imaug(patch.squeeze(0)).unsqueeze(0)
            x,_,h=self(new_patch,center,adj)
            bake_x.append((x.unsqueeze(0),h.unsqueeze(0)))
        return bake_x
    
    def distillation(self,bake_x):
        new_x,coef=zip(*bake_x)
        coef=torch.cat(coef,0)
        new_x=torch.cat(new_x,0)
        coef=F.softmax(coef,dim=0)
        new_x=(new_x*coef).sum(0)
        return new_x
    

    def training_step(self, batch, batch_idx):
        patch, center, exp, adj, oris, sfs, *_ = batch
        adj=adj.squeeze(0)
        exp=exp.squeeze(0)

        """ Model inference """
        pred,extra,h = self(patch, center, adj)

        """ Regression Loss """
        mse_loss = F.mse_loss(pred, exp)
        self.log('mse_loss', mse_loss,on_epoch=True, prog_bar=True, logger=True)

        """ ZINB Loss """
        zinb_loss=0
        if self.zinb>0:
            if self.nb:
                r,p=extra
                zinb_loss = NB_loss(oris.squeeze(0),r,p)
            else:
                m,d,p=extra
                zinb_loss = ZINB_loss(oris.squeeze(0),m,d,p,sfs.squeeze(0))
        self.log('zinb_loss', zinb_loss,on_epoch=True, prog_bar=True, logger=True)
        
        """ Self-distillation loss """
        bake_loss=0
        bake_x=self.aug(patch,center,adj)
        new_pred=self.distillation(bake_x)
        bake_loss+=F.mse_loss(new_pred,pred)
        self.log('bake_loss', bake_loss,on_epoch=True, prog_bar=True, logger=True)

        """ Total Loss """
        loss = mse_loss + self.zinb*zinb_loss+self.lamb*bake_loss
        self.log('train_loss', loss,on_epoch=True, prog_bar=True, logger=True)

        return loss

    def validation_step(self, batch, batch_idx):
        patch, center, exp, adj, oris, sfs, *_ = batch
        adj=adj.squeeze(0)
        exp=exp.squeeze(0)

        """ Model Inference """
        pred,extra,h = self(patch, center, adj)

        """ Regression Loss """
        mse_loss = F.mse_loss(pred, exp)
        self.log('val_loss', mse_loss,on_epoch=True, prog_bar=True, logger=True)

        return mse_loss

    
    def test_step(self, batch, batch_idx):
        patch, center, exp, adj, oris, sfs, *_ = batch
        adj=adj.squeeze(0)
        exp=exp.squeeze(0)

        """ Model Inference """
        gene_exp,extra,h = self(patch, center, adj)

        """Pearson correlation coeficient"""
        adata1 = ad.AnnData(gene_exp.cpu().detach().numpy())
        adata2 = ad.AnnData(exp.cpu().detach().numpy())
        R=get_R(adata1,adata2)[0]
        mean_pcc=np.nanmean(R)
        self.log('test_mean_PCC', mean_pcc, on_epoch=True, prog_bar=True, logger=True)

        

    def configure_optimizers(self):
        # self.hparams available because we called self.save_hyperparameters()
        optim=torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        optim_dict = {'optimizer': optim}
        return optim_dict
    

In [11]:
import gc
from data_vit import ViT_Anndata
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import time

start_time = time.time()
"""Training loops"""
seed = 12000
epochs = 1 # Num epochs
dim = 1024 # Hidden layer dimension
name = "resnet"   # ["swim_s", "efficientnet", "resnet"]
fold=0 # LOOCV

"""Load dataset"""
train_loader, val_loader = dataset_wrap(fold = fold, train=True, dataloader= True)
test_loader = dataset_wrap(fold = fold, train=False, dataloader= True)



In [10]:
"""Define model"""
model = CNN_ST(name=name, dim=dim)
setup_seed(seed)

"""Setup trainer"""
logger = pl.loggers.CSVLogger("logs", name=f"./CNN_ST/{name}_fold{fold}")
trainer = pl.Trainer(accelerator='auto',  callbacks=[EarlyStopping(monitor='val_loss',mode='min')], max_epochs=epochs,logger=False)
trainer.fit(model, train_loader, val_loader)
trainer.test(model, test_loader)

"""Save model and clean memory"""
#     torch.save(model.state_dict(),f"./model/Earlystop/{name}-seed{seed}-epochs{epochs}-sampleIndex{fold}.ckpt")
gc.collect()
del train_loader, val_loader, test_loader
end_time = time.time()
execution_time = end_time - start_time
print("Training time: ", execution_time/3600, " hours")


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type       | Params
-----------------------------------------
0 | x_embed   | Embedding  | 131 K 
1 | y_embed   | Embedding  | 131 K 
2 | vit       | ViT        | 97.0 M
3 | mean      | Sequential | 12.3 K
4 | disp      | Sequential | 12.3 K
5 | pi        | Sequential | 12.3 K
6 | coef      | Sequential | 1.1 M 
7 | gene_head | Sequential | 14.3 K
-----------------------------------------
74.8 M    Trainable params
23.5 M    Non-trainable params
98.3 M    Total params
393.388   Total estimated model params size (MB)
SLURM auto-requeueing enabled. Setting signal handlers.


Epoch 0: 100%|██████████| 90/90 [02:08<00:00,  1.42s/it, mse_loss_step=0.959, zinb_loss_step=3.440, bake_loss_step=1.5e-5, train_loss_step=1.820] 
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/32 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/32 [00:00<?, ?it/s][A
Validation DataLoader 0:   3%|▎         | 1/32 [00:00<00:00, 39.89it/s][A
Validation DataLoader 0:   6%|▋         | 2/32 [00:00<00:03,  9.13it/s][A
Validation DataLoader 0:   9%|▉         | 3/32 [00:00<00:04,  6.62it/s][A
Validation DataLoader 0:  12%|█▎        | 4/32 [00:00<00:04,  6.17it/s][A
Validation DataLoader 0:  16%|█▌        | 5/32 [00:00<00:04,  5.82it/s][A
Validation DataLoader 0:  19%|█▉        | 6/32 [00:01<00:04,  5.42it/s][A
Validation DataLoader 0:  22%|██▏       | 7/32 [00:01<00:04,  5.07it/s][A
Validation DataLoader 0:  25%|██▌       | 8/32 [00:01<00:04,  4.94it/s][A
Validation DataLoader 0:  28%|██▊       | 9/32 [00:01<00:04,  4.87it/s][A
Validation DataLoader 0

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 90/90 [02:17<00:00,  1.53s/it, mse_loss_step=0.959, zinb_loss_step=3.440, bake_loss_step=1.5e-5, train_loss_step=1.820, val_loss=0.780, mse_loss_epoch=2.350, zinb_loss_epoch=14.60, bake_loss_epoch=3.01e-5, train_loss_epoch=6.000]


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
SLURM auto-requeueing enabled. Setting signal handlers.


Testing DataLoader 0: 100%|██████████| 25/25 [00:06<00:00,  3.64it/s]


Training time:  0.12267369204097324  hours
