In [None]:
import scgpt as scg
import torch as tc
import numpy as np
import pandas as pd
import scanpy as sc
from pathlib import Path
from torch.autograd import Function
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn 
from torch.optim import Adam
import torch.nn.functional as F
from tqdm import tqdm


In [None]:
ground_truth_genes = pd.read_parquet('../data/MERGED_normalized_5000genesTIERS.parquet').drop(columns = ['Tier_1', 'Tier_2', 'Tier_3', 'Tier_4'])
ground_truth_genes

In [None]:
dat_raw = pd.read_parquet('../data/Metastasized_5000genesTIERS.parquet')#.iloc[:500,:]
dat_raw

In [None]:
df_new_aligned = pd.DataFrame(columns=ground_truth_genes.columns)
df_new_aligned = pd.concat((df_new_aligned, dat_raw), axis=0)
df_new_aligned


In [None]:
dat = df_new_aligned.iloc[:,2:]
dat.head()

In [None]:
class DS(Dataset):
    def __init__(self, df, patients):
        self.df = df 

        self.patients = np.array(patients)
        self.unique_patients = patients.drop_duplicates().to_numpy()

        self.data_tensor = tc.FloatTensor((self.df.values.astype(float)))
        self.data_tensor = tc.nan_to_num(self.data_tensor,0.0)


        self.data_dict = {patient: self.data_tensor[self.patients == patient,:] for patient in tqdm(self.unique_patients[:])}

        self.patient_ids_dict = {patient: self.make_one_hot(self.unique_patients.shape[0], i) for i,patient in enumerate(self.unique_patients)}


    def __len__(self):
        return len(self.data_dict)


    def __getitem__(self, idx):
        current_data_tensor =  self.data_dict[self.unique_patients[idx]]
        nsamples = current_data_tensor.shape[0]


        current_patient_id = self.patient_ids_dict[self.unique_patients[idx]]
        
        if current_data_tensor.shape[0]>2000:
            return current_data_tensor[tc.randperm(nsamples)[:2000],:], current_patient_id * tc.ones(2000,1)

        print(idx)

        return current_data_tensor[tc.randperm(nsamples),:], current_patient_id * tc.ones(nsamples,1)


    def make_one_hot(self, length, idx):
        x = tc.zeros(length)
        x[idx]=1
        return x





In [None]:
class ReverseLayerF(Function):

    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha

        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        output = grad_output.neg() * ctx.alpha

        return output, None
    
    

class Backbone(nn.Module):
    def __init__(self, inp, hidden, outp):
        super().__init__()

        self.layers = nn.Sequential(
                    #nn.Dropout(0.2),          #0.5
                    nn.Linear(inp, hidden), 
                    nn.BatchNorm1d(hidden),
                    nn.LeakyReLU(), 
                    nn.Dropout(0.2),      
                    nn.Linear(hidden, hidden),
                    nn.BatchNorm1d(hidden),
                    nn.LeakyReLU(),
                    nn.Dropout(0.2),      
                    nn.Linear(hidden, hidden),
                    nn.BatchNorm1d(hidden),
                    nn.LeakyReLU(),
                    nn.Dropout(0.2),      
                    nn.Linear(hidden, 1000),
                    nn.LeakyReLU(),
                    nn.Dropout(0.2),      
                    nn.Linear(1000, outp))

    def forward(self,x):
        x = self.layers(x)

        return x
        


class Model(nn.Module):
    def __init__(self, inp, hidden, npatients):
        super().__init__()

        self.npatients = npatients
        self.backbone = Backbone(inp, hidden, 64)
        
        self.projector = nn.Sequential(
            nn.Linear(64, 512), 
            nn.LeakyReLU(),
            nn.Linear(512,512))

        self.classifier = nn.Sequential(
            nn.Linear(64,512),
            nn.LeakyReLU(),
            nn.Linear(512, self.npatients)
        )
    
    def forward(self, x, y):


        def off_diagonal(arr):
            n = len(arr)
            return arr.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

        x = self.projector(self.backbone(x))
        y = self.projector(self.backbone(y))


        repr_loss = F.mse_loss(x, y)  # invariance (2)
        x = x - x.mean(dim=0)
        y = y - y.mean(dim=0)
        

        std_x = tc.sqrt(x.var(dim=0) + 0.0001)  # variance (1)
        std_y = tc.sqrt(y.var(dim=0) + 0.0001)
        std_loss = tc.mean(F.relu(1 - std_x)) / 2 + tc.mean(F.relu(1 - std_y)) / 2

        cov_x = (x.T @ x) / (len(x) - 1) # covariance (3)
        cov_y = (y.T @ y) / (len(y) - 1)
        cov_loss = off_diagonal(cov_x).pow_(2).sum().div(
            256) + off_diagonal(cov_y).pow_(2).sum().div(256)

        
        loss = (25. * repr_loss + 25. * std_loss + 1. * cov_loss)
        return loss

    def dal(self, x,patient_y):

        x = self.backbone(x)
        reverse_x =  ReverseLayerF.apply(x, 1.0)

        y_hat = self.classifier(reverse_x)
                
        criterion = nn.CrossEntropyLoss()
        loss = criterion(y_hat, patient_y,)


        return loss



In [None]:
ds = DS(dat, dat_raw['Pseudo'])
#model = Model(dat.shape[1], 5000, npatients = ds.unique_patients.shape[0])
dl = DataLoader(ds, batch_size = 1, shuffle = True)


In [None]:
model = Model(dat.shape[1], 5000, npatients = 365) #npatients: only from old training approach to include domain adversarial loss

model.load_state_dict(tc.load('./save/vicreg/model_per_patient_batchnorm_highinputdropout' + str('5000') + '.pt', map_location=tc.device('cpu')) )


In [None]:
class DS(Dataset):
    def __init__(self, df):
        self.df = df 

        self.data_tensor = tc.FloatTensor((self.df.values.astype(float)))
        self.data_tensor = tc.nan_to_num(self.data_tensor,0.0)


    def __len__(self):
        return self.data_tensor.shape[0]


    def __getitem__(self, idx):
        return self.data_tensor[idx,:]

ds = DS(dat)


In [None]:
from torch.utils.data import Dataset, DataLoader
device = tc.device('cpu')

In [None]:
embedding_dl = DataLoader(ds, batch_size = 5000, shuffle=False)

In [None]:
model.eval().to(device)
with tc.no_grad():
    embeddings = pd.DataFrame(np.array(tc.cat([model.backbone.forward(X.to(device)).cpu() for X in tqdm(embedding_dl)], axis=0)))

embeddings.columns = ['D' + str(column) for column in embeddings.columns]



In [None]:
embeddings_frame = pd.concat((dat_raw.iloc[:,:2], embeddings), axis=1)
embeddings_frame.to_parquet('./embeddings/VICREG_embedding_validation_external.parquet')
embeddings_frame