In [5]:
import os
import gc
import random
import time
import argparse
import pickle as pk
import numpy as np
import pandas as pd
import h5py
import json
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
import pytorch_lightning as pl
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from torch.utils.data import DataLoader
from pytorch_lightning.loggers import TensorBoardLogger
from tqdm import tqdm
from scipy.stats import pearsonr, spearmanr
from scipy import stats
from copy import deepcopy as dcp
from collections import defaultdict as dfd
from sklearn.metrics import adjusted_rand_score as ari_score
from sklearn.metrics.cluster import normalized_mutual_info_score as nmi_score
from pathlib import Path, PurePath
from matplotlib.image import imread
from PIL import ImageFile, Image
from anndata import AnnData, read_csv, read_text, read_excel, read_mtx, read_loom, read_hdf
from anndata import (
    AnnData,
    read_csv,
    read_text,
    read_excel,
    read_mtx,
    read_loom,
    read_hdf,
)
from window_adata import window_adata
from HIST2ST import Hist2ST
from utils import *

# For OOD dataset training
from data_vit import ViT_Anndata

# Define data directories and sample lists
data_dir1 = "./Alex_NatGen_6BreastCancer/"
data_dir2 = "./breast_cancer_10x_visium/"

samps1 = ["1142243F", "CID4290", "CID4465", "CID44971", "CID4535", "1160920F"]
samps2 = ["block1", "block2", "FFPE"]

sampsall = samps1 + samps2
samples1 = {i: data_dir1 + i for i in samps1}
samples2 = {i: data_dir2 + i for i in samps2}

# Marker gene list
gene_list = [
    "COX6C", "TTLL12", "HSP90AB1", "TFF3", "ATP1A1", "B2M", "FASN", "SPARC", "CD74", "CD63", "CD24", "CD81"
]

# Load windowed dataset
with open('../10x_visium_dataset_without_window.pickle', 'rb') as f:
    adata_dict0 = pickle.load(f)

for i in samps2:
    adata_dict0[i].var_names_make_unique()

# Define the gridding size
sizes = [4000 for i in range(len(adata_dict0))]

# Split tiles into smaller patches according to gridding size
adata_dict = window_adata(adata_dict0, sizes)

# Define a function to set random seeds
def setup_seed(seed=12000):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)

# Define a function to calculate correlation between two datasets
def get_R(data1, data2, dim=1, func=pearsonr):
    adata1 = data1.X
    adata2 = data2.X
    r1, p1 = [], []
    for g in range(data1.shape[dim]):
        if dim == 1:
            r, pv = func(adata1[:, g], adata2[:, g])
        elif dim == 0:
            r, pv = func(adata1[g, :], adata2[g, :])
        r1.append(r)
        p1.append(pv)
    r1 = np.array(r1)
    p1 = np.array(p1)
    return r1, p1

# Load and prepare datasets for training
def dataset_wrap(dataloader=True):
    train_sample = list(set(samps1) - set(["1160920F", "CID4290"]))  # Alex visium samples
    val_sample = ["1160920F", "CID4290"] # Alex visium samples
    test_sample = samps2  # 10x visium samples

    tr_name = list(set([i for i in list(adata_dict.keys()) for tr in train_sample if tr in i]))
    val_name = list(set([i for i in list(adata_dict.keys()) for val in val_sample if val in i]))
    te_name = list(set([i for i in list(adata_dict.keys()) for te in test_sample if te in i]))

    trainset = ViT_Anndata(adata_dict=adata_dict, train_set=tr_name, gene_list=gene_list)
    valset = ViT_Anndata(adata_dict=adata_dict, train_set=val_name, gene_list=gene_list)
    testset = ViT_Anndata(adata_dict=adata_dict, train_set=te_name, gene_list=gene_list)

    print("LOADED TRAINSET")
    train_loader = DataLoader(trainset, batch_size=1, num_workers=0, shuffle=True)
    val_loader = DataLoader(valset, batch_size=1, num_workers=0, shuffle=True)
    test_loader = DataLoader(testset, batch_size=1, num_workers=0, shuffle=False)
    if dataloader:
        return train_loader, val_loader, test_loader
    else:
        return trainset, valset, testset

# Training parameters
seed = 12000
epochs = 350

# Load datasets
# train_loader, val_loader, test_loader = dataset_wrap(dataloader=True)

# Define the Hist2ST model
model = Hist2ST(
    depth1=2, depth2=8, depth3=4, n_pos=128,
    n_genes=len(gene_list), learning_rate=1e-5,
    kernel_size=5, patch_size=7, fig_size=112,
    heads=16, channel=32, dropout=0.2,
    zinb=0.25, nb=False,
    bake=5, lamb=0.5,
    policy='mean',
)

# Set random seed
setup_seed(seed)

start_time = time.time()

# Setup the PyTorch Lightning trainer
trainer = pl.Trainer(accelerator='auto', callbacks=[EarlyStopping(monitor='Train_loss', mode='min')], max_epochs=epochs, logger=False)
trainer.fit(model, train_loader, val_loader)

# Save model and clean memory
gc.collect()
del train_loader, test_loader, model

end_time = time.time()
execution_time = end_time - start_time
print("Training time:", execution_time/3600, "hours")


Windowing 1142243F
Num spots:  4784
246
216
222
185
77
247
246
255
247
93
246
247
255
245
94
247
246
255
238
88
130
135
132
140
55
Total:  4787
Windowing CID4290
Num spots:  2714
793
576
1001
344
Total:  2714
Windowing CID4465
Num spots:  1310
345
258
149
558
Total:  1310
Windowing CID44971
Num spots:  1322
491
462
339
30
Total:  1322
Windowing CID4535
Num spots:  1431
564
232
632
3
Total:  1431
Windowing 1160920F
Num spots:  4895
210
251
251
239
83
226
255
232
240
102
231
230
246
247
99
238
246
247
255
93
144
147
164
160
60
Total:  4896
Windowing block1
Num spots:  3798
139
205
219
185
10
169
246
247
255
16
189
230
205
233
0
197
156
241
228
0
72
106
129
124
0
Total:  3801
Windowing block2
Num spots:  3987
224
247
246
229
208
246
247
231
243
207
211
196
221
195
254
205
81
97
108
92
Total:  3988
Windowing FFPE
Num spots:  2518
50
190
188
79
0
169
219
216
189
0
182
215
201
192
0
68
138
159
63
0
0
1
0
0
0
Total:  2519
Windowing 1168993F
Num spots:  4898
244
249
248
244
68
246
247
255
246


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name            | Type       | Params
-----------------------------------------------
0 | patch_embedding | Conv2d     | 4.7 K 
1 | x_embed         | Embedding  | 131 K 
2 | y_embed         | Embedding  | 131 K 
3 | vit             | ViT        | 71.4 M
4 | mean            | Sequential | 12.3 K
5 | disp            | Sequential | 12.3 K
6 | pi              | Sequential | 12.3 K
7 | coef            | Sequential | 1.1 M 
8 | gene_head       | Sequential | 14.3 K
-----------------------------------------------
72.8 M    Trainable params
0         Non-trainable params
72.8 M    Total params
291.006   Total estimated model params size (MB)
SLURM auto-requeueing enabled. Setting signal handlers.


Epoch 0: 100%|██████████| 36/36 [00:10<00:00,  3.54it/s, mse_loss_step=2.000, bake_loss_step=0.000159, zinb_loss_step=3.290, Train_loss_step=2.820, mse_loss_epoch=4.200, bake_loss_epoch=0.000205, zinb_loss_epoch=5.670, Train_loss_epoch=5.620]

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 36/36 [00:11<00:00,  3.10it/s, mse_loss_step=2.000, bake_loss_step=0.000159, zinb_loss_step=3.290, Train_loss_step=2.820, mse_loss_epoch=4.200, bake_loss_epoch=0.000205, zinb_loss_epoch=5.670, Train_loss_epoch=5.620]
Training time: 0.003986349370744493 hours
