In [1]:
import anndata
import anndata as ad
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import pickle
import pytorch_lightning as pl
import scanpy as sc
import scprep as scp
import seaborn as sns
import sys
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from collections import defaultdict as dfd
from copy import deepcopy as dcp
from pathlib import Path, PurePath
from pytorch_lightning.loggers import TensorBoardLogger
from scipy import stats
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import adjusted_rand_score as ari_score
from sklearn.metrics.cluster import normalized_mutual_info_score as nmi_score
from tqdm import tqdm
from PIL import ImageFile, Image
from matplotlib.image import imread
from scanpy import read_10x_mtx, read_visium
from torch.utils.data import DataLoader
from window_adata import *


In [2]:
""" Integrate two visium datasets """
data_dir1 = "./Alex_NatGen_6BreastCancer/"
data_dir2 = "./breast_cancer_10x_visium/"

samps1 = ["1142243F", "CID4290", "CID4465", "CID44971", "CID4535", "1160920F"]
samps2 = ["block1", "block2", "FFPE"]

sampsall = samps1 + samps2
samples1 = {i:data_dir1 + i for i in samps1}
samples2 = {i:data_dir2 + i for i in samps2}

# Marker gene list
gene_list = ["COX6C","TTLL12", "HSP90AB1", "TFF3", "ATP1A1", "B2M", "FASN", "SPARC", "CD74", "CD63", "CD24", "CD81"]

# # Load windowed dataset
import pickle
with open('10x_visium_dataset_without_window.pickle', 'rb') as f:
    adata_dict0 = pickle.load(f)
    
# Define the gridding size
sizes = [4000 for i in range(len(adata_dict0))]

# Split tiles into smaller patches according to gridding size
adata_dict = window_adata(adata_dict0, sizes)


Windowing 1142243F
Num spots:  4784
246
216
222
185
77
247
246
255
247
93
246
247
255
245
94
247
246
255
238
88
130
135
132
140
55
Total:  4787
Windowing CID4290
Num spots:  2714
793
576
1001
344
Total:  2714
Windowing CID4465
Num spots:  1310
345
258
149
558
Total:  1310
Windowing CID44971
Num spots:  1322
491
462
339
30
Total:  1322
Windowing CID4535
Num spots:  1431
564
232
632
3
Total:  1431
Windowing 1160920F
Num spots:  4895
210
251
251
239
83
226
255
232
240
102
231
230
246
247
99
238
246
247
255
93
144
147
164
160
60
Total:  4896
Windowing block1
Num spots:  3798
139
205
219
185
10
169
246
247
255
16
189
230
205
233
0
197
156
241
228
0
72
106
129
124
0
Total:  3801
Windowing block2
Num spots:  3987
224
247
246
229
208
246
247
231
243
207
211
196
221
195
254
205
81
97
108
92
Total:  3988
Windowing FFPE
Num spots:  2518
50
190
188
79
0
169
219
216
189
0
182
215
201
192
0
68
138
159
63
0
0
1
0
0
0
Total:  2519


In [3]:
# For training
from data_vit import ViT_Anndata

def dataset_wrap(fold = 0, train=True, dataloader=True):
    test_sample = sampsall[fold]
    test_sample_orig = sampsall[fold] # Split one sample as test sample
    val_sample = list(set(sampsall)-set(sampsall[fold]))[:3] # Split 3 samples as validation samples
    train_sample = list(set(sampsall)-set(test_sample)-set(val_sample)) # Other samples are training samples

    tr_name = list(set([i for i in list(adata_dict.keys()) for tr in train_sample if tr in i]))
    val_name = list(set([i for i in list(adata_dict.keys()) for val in val_sample if val in i]))
#     te_name = list(set([i for i in list(adata_dict.keys()) if test_sample in i]))
    te_name = test_sample
    
    if train:
        print("LOADED TRAINSET")
        trainset = ViT_Anndata(adata_dict = adata_dict, train_set = tr_name, gene_list = gene_list, train=True, flatten=False, ori=True, prune='NA', neighs=4, )
        valset = ViT_Anndata(adata_dict = adata_dict, train_set = val_name, gene_list = gene_list, train=True, flatten=False, ori=True, prune='NA', neighs=4, )
        train_loader = DataLoader(trainset, batch_size=1, num_workers=0, shuffle=True)
        val_loader = DataLoader(valset, batch_size=1, num_workers=0, shuffle=False)
        return train_loader, val_loader
    
    else:
        print("LOADED TESTSET")
        print("Test sample", te_name)
        testset = ViT_Anndata(adata_dict = adata_dict0, train_set = [te_name], gene_list = gene_list, train=False, flatten=False, ori=True, prune='NA', neighs=4, )
        test_loader = DataLoader(testset, batch_size=1, num_workers=0, shuffle=False)
        return test_loader
            

In [4]:
# fold = 0
# name = "VICReg"
# train_loader, val_loader = dataset_wrap(fold = fold, train=True, dataloader= True)
# test_loader = dataset_wrap(fold = fold, train=False, dataloader= True)


In [5]:
import warnings
warnings.filterwarnings("ignore")
import torch
import random
import numpy as np
import pytorch_lightning as pl
import torch.nn as nn

from gcn import *
from transformer import *
from NB_module import *
from scipy.stats import pearsonr
from torch.utils.data import DataLoader
from copy import deepcopy as dcp
from collections import defaultdict as dfd
from sklearn.metrics import adjusted_rand_score as ari_score
from sklearn.metrics.cluster import normalized_mutual_info_score as nmi_score

from torchvision import models, transforms

# set random seed
def setup_seed(seed=12000):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
    
def get_R(data1,data2,dim=1,func=pearsonr):
    adata1=data1.X
    adata2=data2.X
    r1,p1=[],[]
    for g in range(data1.shape[dim]):
        if dim==1:
            r,pv=func(adata1[:,g],adata2[:,g])
        elif dim==0:
            r,pv=func(adata1[g,:],adata2[g,:])
        r1.append(r)
        p1.append(pv)
    r1=np.array(r1)
    p1=np.array(p1)
    return r1,p1


class GATLayer(nn.Module):
    def __init__(self, c_in, c_out, num_heads=2, concat_heads=True, alpha=0.2):
        """
        Args:
            c_in: Dimensionality of input features
            c_out: Dimensionality of output features
            num_heads: Number of heads, i.e. attention mechanisms to apply in parallel. The
                        output features are equally split up over the heads if concat_heads=True.
            concat_heads: If True, the output of the different heads is concatenated instead of averaged.
            alpha: Negative slope of the LeakyReLU activation.
        """
        super().__init__()
        self.num_heads = num_heads
        self.concat_heads = concat_heads
        if self.concat_heads:
            assert c_out % num_heads == 0, "Number of output features must be a multiple of the count of heads."
            c_out = c_out // num_heads

        # Sub-modules and parameters needed in the layer
        self.projection = nn.Linear(c_in, c_out * num_heads)
        self.a = nn.Parameter(torch.Tensor(num_heads, 2 * c_out))  # One per head
        self.leakyrelu = nn.LeakyReLU(alpha)

        # Initialization from the original implementation
        nn.init.xavier_uniform_(self.projection.weight.data, gain=1.414)
        nn.init.xavier_uniform_(self.a.data, gain=1.414)

    def forward(self, node_feats, adj_matrix, print_attn_probs=False):
        """
        Args:
            node_feats: Input features of the node. Shape: [batch_size, c_in]
            adj_matrix: Adjacency matrix including self-connections. Shape: [batch_size, num_nodes, num_nodes]
            print_attn_probs: If True, the attention weights are printed during the forward pass
                               (for debugging purposes)
        """
        batch_size, num_nodes = node_feats.size(0), node_feats.size(1)

        # Apply linear layer and sort nodes by head
        node_feats = self.projection(node_feats)
        node_feats = node_feats.view(batch_size, num_nodes, self.num_heads, -1)

        # We need to calculate the attention logits for every edge in the adjacency matrix
        # Doing this on all possible combinations of nodes is very expensive
        # => Create a tensor of [W*h_i||W*h_j] with i and j being the indices of all edges
        # Returns indices where the adjacency matrix is not 0 => edges
        edges = adj_matrix.nonzero(as_tuple=False)
        node_feats_flat = node_feats.view(batch_size * num_nodes, self.num_heads, -1)
        edge_indices_row = edges[:, 0] * num_nodes + edges[:, 1]
        edge_indices_col = edges[:, 0] * num_nodes + edges[:, 2]
        a_input = torch.cat(
            [
                torch.index_select(input=node_feats_flat, index=edge_indices_row, dim=0),
                torch.index_select(input=node_feats_flat, index=edge_indices_col, dim=0),
            ],
            dim=-1,
        )  # Index select returns a tensor with node_feats_flat being indexed at the desired positions

        # Calculate attention MLP output (independent for each head)
        attn_logits = torch.einsum("bhc,hc->bh", a_input, self.a)
        attn_logits = self.leakyrelu(attn_logits)

        # Map list of attention values back into a matrix
        attn_matrix = attn_logits.new_zeros(adj_matrix.shape + (self.num_heads,)).fill_(-9e15)
        attn_matrix[adj_matrix[..., None].repeat(1, 1, 1, self.num_heads) == 1] = attn_logits.reshape(-1)

        # Weighted average of attention
        attn_probs = F.softmax(attn_matrix, dim=2)
        if print_attn_probs:
            print("Attention probs\n", attn_probs.permute(0, 3, 1, 2))
        atten = attn_probs.permute(0, 3, 1, 2)
        node_feats = torch.einsum("bijh,bjhc->bihc", attn_probs, node_feats)

        # If heads should be concatenated, we can do this by reshaping. Otherwise, take mean
        if self.concat_heads:
            node_feats = node_feats.reshape(batch_size, num_nodes, -1)
        else:
            node_feats = node_feats.mean(dim=2)
        return node_feats, atten

class Feature_extractor(nn.Module):
    def __init__(self, name="resnet", dim=1024, num_layer=2, ):
        super().__init__()
        self.ft_extra = self.ft_extr(name)
        self.emb_dim = self.return_emb(name)
        self.projection = nn.Sequential(nn.Linear(self.emb_dim, dim))
        self.GATLayer=GATLayer(c_in=dim, c_out=dim)
        self.GAT=nn.ModuleList([self.GATLayer for _ in range(num_layer)])
        self.jknet=nn.Sequential(
            nn.LSTM(dim,dim,2),
            SelectItem(0),
        )
        self.dropout = nn.Dropout(0.2)

    def forward(self,patch,adj):
        # Resize the tiles

        x = transforms.Resize(224)(patch)
        x = self.ft_extra(x)
        x = self.projection(x)
        x = self.dropout(x).unsqueeze(0)
        
        # GAT with layer-aggregation
        jk=[]
        for layer in self.GAT:
            x, attn=layer(x,adj)
            jk.append(x)
        x=torch.cat(jk,0)
        
        # Jumping knowledge-LSTM
        x=self.jknet(x).mean(0)
        return x, attn
    
    # Define a function to fine-tune a pretrained model
    def ft_extr(self, name):
        if name == "resnet":
            model_ft = torchvision.models.resnet50(weights=models.ResNet50_Weights)
            for param in model_ft.parameters():
                param.requires_grad = False
            model_ft.fc = nn.Sequential(nn.Identity())
            dim = 2048
        elif name == "efficient":
            model_ft = torchvision.models.efficientnet_v2_s(weights=models.EfficientNet_V2_S_Weights)
            for param in model_ft.parameters():
                param.requires_grad = False
            model_ft.classifier = nn.Sequential(nn.Identity())
            dim = 1280
        elif name == 'swin':
            model_ft = torchvision.models.swin_s(weights=models.Swin_S_Weights)
            for param in model_ft.parameters():
                param.requires_grad = False
            model_ft.head = nn.Sequential(nn.Identity())
            dim = 768
        return model_ft
    
    def return_emb(self, name):
        dim=2048
        if name=="resnet":
            dim=2048
        elif name=="efficient":
            dim=1280
        elif name=="swin":
            dim=768
        return dim

class CNN_GAT(pl.LightningModule):
    def __init__(self, learning_rate=1e-4, name="resnet", dim=1024, n_genes=12, num_layer=4, 
                 zinb=0.25, nb=False, policy='mean', bake=5, lamb=0.5, 
                ):
        super().__init__()
#          self.save_hyperparameters()
        self.learning_rate = learning_rate
        self.nb=nb
        self.zinb=zinb
        self.bake=bake
        self.lamb=lamb
        
        
        """ Feature Extractor """
        self.vit = Feature_extractor(name=name, num_layer=num_layer)
        
        self.n_genes=n_genes
        """ ZINB Loss """
        if self.zinb>0:
            if self.nb:
                self.hr=nn.Linear(dim, n_genes)
                self.hp=nn.Linear(dim, n_genes)
            else:
                self.mean = nn.Sequential(nn.Linear(dim, n_genes), MeanAct())
                self.disp = nn.Sequential(nn.Linear(dim, n_genes), DispAct())
                self.pi = nn.Sequential(nn.Linear(dim, n_genes), nn.Sigmoid())

        """Data Augmentation"""
        self.coef=nn.Sequential(
                nn.Linear(dim,dim),
                nn.ReLU(),
                nn.Linear(dim,1),
            )
        self.imaug=transforms.Compose([
            transforms.RandomGrayscale(0.1),
            transforms.RandomRotation(90),
            transforms.RandomHorizontalFlip(0.2),
        ])
        """ Regression Module """
        self.gene_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, n_genes),
        )
        
    def forward(self, patch, adj):
        """ Feature Extraction """
        h, attn = self.vit(patch, adj)
        
        """ Gene expression prediction """
        x = self.gene_head(h)
        
        """ZINB Distribution"""
        extra=None
        if self.zinb>0:
            if self.nb:
                r=self.hr(h)
                p=self.hp(h)
                extra=(r,p)
            else:
                m = self.mean(h)
                d = self.disp(h)
                p = self.pi(h)
                extra=(m,d,p) 
        return x,extra,h,attn


    def training_step(self, batch, batch_idx):
        patch, center, exp, adj, oris, sfs, *_ = batch
        patch, exp = patch.squeeze(0), exp.squeeze(0)

        """ Model inference """
        pred,extra,h,attn = self(patch, adj)

        """ Regression Loss """
        mse_loss = F.mse_loss(pred, exp)
        self.log('mse_loss', mse_loss,on_epoch=True, prog_bar=True, logger=True)

        """ ZINB Loss """
        zinb_loss=0
        if self.zinb>0:
            if self.nb:
                r,p=extra
                zinb_loss = NB_loss(oris.squeeze(0),r,p)
            else:
                m,d,p=extra
                zinb_loss = ZINB_loss(oris.squeeze(0),m,d,p,sfs.squeeze(0))
        self.log('zinb_loss', zinb_loss,on_epoch=True, prog_bar=True, logger=True)
        
        """ Total Loss """
        loss = mse_loss + self.zinb*zinb_loss
        self.log('train_loss', loss,on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        patch, center, exp, adj, oris, sfs, *_ = batch
        patch, exp = patch.squeeze(0), exp.squeeze(0)

        """ Model Inference """
        pred,extra,h,attn = self(patch, adj)

        """ Regression Loss """
        mse_loss = F.mse_loss(pred, exp)
        self.log('val_loss', mse_loss,on_epoch=True, prog_bar=True, logger=True)
        return mse_loss

    
    def test_step(self, batch, batch_idx):
        patch, center, exp, adj, oris, sfs, *_ = batch
        patch, exp = patch.squeeze(0), exp.squeeze(0)

        """ Model Inference """
        pred,extra,h,attn = self(patch, adj)

        """Pearson correlation coeficient"""
        adata1 = ad.AnnData(pred.cpu().detach().numpy())
        adata2 = ad.AnnData(exp.cpu().detach().numpy())
        R=get_R(adata1,adata2)[0]
        mean_pcc=np.nanmean(R)
        self.log('test_mean_PCC', mean_pcc, on_epoch=True, prog_bar=True, logger=True)
        return R
    
    def configure_optimizers(self):
        # self.hparams available because we called self.save_hyperparameters()
        optim=torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        optim_dict = {'optimizer': optim}
        return optim_dict


[easydl] tensorflow not available!


In [6]:
"""For training only"""
import gc
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import time

start_time = time.time()
"""Training loops"""
seed=12000
epochs=1
dim=1024
name = "resnet"   # ["swin", "efficient", "resnet"]
fold=1

"""Load dataset"""
train_loader, val_loader = dataset_wrap(fold=fold, train=True, dataloader=True)
test_loader = dataset_wrap(fold=fold, train=False, dataloader=True)


LOADED TRAINSET
Loading imgs...
Loading imgs...
LOADED TESTSET
Test sample CID4290
Loading imgs...


In [7]:
"""Define model"""
model = CNN_GAT(name=name, dim=dim)
setup_seed(seed)

"""Setup trainer"""
logger = pl.loggers.CSVLogger("logs", name=f"./{name}_fold{fold}")
trainer = pl.Trainer(accelerator='auto',  callbacks=[EarlyStopping(monitor='val_loss',mode='min')], max_epochs=epochs,logger=False)
trainer.fit(model, train_loader, val_loader)
#     torch.save(model.state_dict(),f"./model/{name}-seed{seed}-epochs{epochs}-sampleIndex{fold}.ckpt")
trainer = pl.Trainer(accelerator='cpu', logger=False)
trainer.test(model, test_loader)

"""Save model and clean memory"""
gc.collect()
end_time = time.time()
execution_time = end_time - start_time
print("Training time: ", execution_time/3600, " hours")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type              | Params
------------------------------------------------
0 | vit       | Feature_extractor | 43.5 M
1 | mean      | Sequential        | 12.3 K
2 | disp      | Sequential        | 12.3 K
3 | pi        | Sequential        | 12.3 K
4 | coef      | Sequential        | 1.1 M 
5 | gene_head | Sequential        | 14.3 K
------------------------------------------------
21.0 M    Trainable params
23.5 M    Non-trainable params
44.6 M    Total params
178.213   Total estimated model params size (MB)
SLURM auto-requeueing enabled. Setting signal handlers.


Epoch 0: 100%|██████████| 77/77 [00:29<00:00,  2.58it/s, mse_loss_step=0.770, zinb_loss_step=3.690, train_loss_step=1.690]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/45 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/45 [00:00<?, ?it/s][A
Validation DataLoader 0:   2%|▏         | 1/45 [00:00<00:29,  1.50it/s][A
Validation DataLoader 0:   4%|▍         | 2/45 [00:00<00:17,  2.48it/s][A
Validation DataLoader 0:   7%|▋         | 3/45 [00:00<00:13,  3.16it/s][A
Validation DataLoader 0:   9%|▉         | 4/45 [00:01<00:12,  3.40it/s][A
Validation DataLoader 0:  11%|█         | 5/45 [00:01<00:12,  3.14it/s][A
Validation DataLoader 0:  13%|█▎        | 6/45 [00:02<00:13,  2.85it/s][A
Validation DataLoader 0:  16%|█▌        | 7/45 [00:02<00:12,  2.98it/s][A
Validation DataLoader 0:  18%|█▊        | 8/45 [00:02<00:12,  2.95it/s][A
Validation DataLoader 0:  20%|██        | 9/45 [00:02<00:11,  3.05it/s][A
Validation DataLoader 0:  22%|██▏       | 10/45

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 77/77 [00:46<00:00,  1.67it/s, mse_loss_step=0.770, zinb_loss_step=3.690, train_loss_step=1.690, val_loss=0.817, mse_loss_epoch=0.794, zinb_loss_epoch=8.050, train_loss_epoch=2.810]

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs





SLURM auto-requeueing enabled. Setting signal handlers.


Testing DataLoader 0: 100%|██████████| 1/1 [00:59<00:00, 59.60s/it]


Training time:  0.04993143618106842  hours


###### 