In [None]:
import scanpy as sc

In [None]:
from torch.utils.data import TensorDataset, DataLoader
import torch

def make_dataset(tissue, sample, num_extra=0):
    our_data = processed_adata[processed_adata.obs['patient']==sample]
    our_data = our_data[our_data.obs['tissue']==tissue]
    res_numpy = our_data.X
    extra_index = np.random.choice(res_numpy.shape[0], num_extra, replace=True)
    extra_res_numpy = res_numpy[extra_index]
    res_numpy = np.concatenate([res_numpy, extra_res_numpy], axis=0)
    final_res = torch.from_numpy(res_numpy)
    return final_res


patient_to_label = {
    f'BC{i + 1}': i for i in range(8)
}


def make_data_loader(tissue, samples, with_labels=True, batch_size=1<<7):
    our_data = processed_adata[processed_adata.obs['tissue']==tissue]
    sample_to_more_data = {
        sample: our_data[our_data.obs['patient']==sample]
        for sample in samples
    }
    
    res = sample_to_more_data[samples[0]]
    res_labels = np.zeros(res.shape[0]) + patient_to_label[samples[0]]
    res = res.concatenate(*[v for k, v in sample_to_more_data.items() if k != samples[0]])
    res_labels = np.concatenate([res_labels] + [ np.zeros(v.shape[0]) + patient_to_label[k] for k, v in sample_to_more_data.items() if k != samples[0]])
    res_numpy = res.X
    
    from_which = [
        torch.from_numpy(res_numpy),
    ]
    
    if with_labels:
        from_which.append(torch.from_numpy(res_labels))
        
    final_res = DataLoader(
        TensorDataset(
            *from_which,
        ), batch_size=batch_size, shuffle=True,
        num_workers=1, pin_memory=True, drop_last=True,
    )
    return final_res

def make_train_test_split_tensors(tissue, train_samples, test_samples):
    our_data = processed_adata[processed_adata.obs['tissue']==tissue]
    sample_to_more_data = {
        sample: our_data[our_data.obs['patient']==sample]
        for sample in train_samples + test_samples
    }
    
    res_train = sample_to_more_data[train_samples[0]]
    if len(train_samples) > 1:
        res_train = res_train.concatenate(*[v for k, v in sample_to_more_data.items() if k != train_samples[0] and k in train_samples])
    res_test = sample_to_more_data[test_samples[0]]
    if len(test_samples) > 1:
        res_test = res_test.concatenate(*[v for k, v in sample_to_more_data.items() if k != test_samples[0] and k in test_samples])
    return torch.from_numpy(res_train.X).to('cpu'), torch.from_numpy(res_test.X).to('cpu')
    
###### BELOW are more models and more losses

# we now have information regarding the spatiality of the VIM gene

import os


# no tesnorflow, being very uncooperative
import torch
# import torchvision
# import torchsummary

import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils as U
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader


import numpy as np
import scipy as sp
import pandas as pd

import sys
sys.path.extend([".", ".."])

def turn_on_model(model):
    for param in model.parameters():
        param.requires_grad = True
        
def turn_off_model(model):
    for param in model.parameters():
        param.requires_grad = False

        
# better new arch    
class StandardEncoder(nn.Module):
    def __init__(self, input_dim, latent_dim, hidden_dim=512):
        super(StandardEncoder, self).__init__()
        self.part1 = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),  # x tra here
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),  # x tra end 
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
#             nn.Linear(512, 512),
#             nn.BatchNorm1d(512),
#             nn.ReLU(),
        )
        
        self.to_mean = nn.Linear(hidden_dim, latent_dim)
        self.to_logvar = nn.Linear(hidden_dim, latent_dim)
        
        self.latent_dim = latent_dim
    
    def forward(self, x):
        x = self.part1(x)
        return self.to_mean(x), self.to_logvar(x)
    
    
class StandardDecoder(nn.Module):
    def __init__(self, input_dim, latent_dim, hidden_dim=512, no_final_relu=False):
        super(StandardDecoder, self).__init__()
     
        
        # this is for the case of non-zinb
        if no_final_relu:
            self.net = nn.Sequential(
                nn.Linear(latent_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),  # x tra start here
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),  # xtra end
                nn.Linear(hidden_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, input_dim),
#                 nn.ReLU(),  # do the activation here
            )
            
        else:
            self.net = nn.Sequential(
                nn.Linear(latent_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),  # x tra start here
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),  # xtra end
                nn.Linear(hidden_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, input_dim),
                nn.ReLU(),  # do the activation here
            )
        
        self.latent_dim = latent_dim
    
    # returns a tuple regardless
    def forward(self, x):
        res = self.net(x)
        return res


class Discriminator(nn.Module):
    def __init__(self, latent_dim, spectral=True, end_dim=2):
        super(Discriminator, self).__init__()
        if spectral:
            self.net = nn.Sequential(
                U.spectral_norm(nn.Linear(latent_dim, 1<<6)),
                nn.ReLU(),
                U.spectral_norm(nn.Linear(1<<6, 1<<5)),
                nn.ReLU(),
                U.spectral_norm(nn.Linear(1<<5, 1<<5)),
                nn.ReLU(),
                U.spectral_norm(nn.Linear(1<<5, end_dim)),
    #             nn.Sigmoid(), just do w logits for now 
            )
        else:
            self.net = nn.Sequential(
                nn.Linear(latent_dim, 1<<6),
                nn.ReLU(),
                nn.Linear(1<<6, 1<<5),
                nn.ReLU(),
                nn.Linear(1<<5, 1<<5),
                nn.ReLU(),
                nn.Linear(1<<5, end_dim),
    #             nn.Sigmoid(), just do w logits for now 
            )
        
    def forward(self, x):
        return self.net(x)
    
class VAE(nn.Module):
    def __init__(self, encoder, decoder, is_vae=True, use_latent_norm=True):
        super(VAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.is_vae = is_vae
        self.latent_normalizer = (
            nn.BatchNorm1d(self.encoder.latent_dim) if 1
            else nn.Sigmoid()
        )
        self.use_latent_norm = use_latent_norm 
        
    def reparam_trick(self, mean, logvar):
        sigma = torch.exp(0.5*logvar)
        eps = torch.randn_like(sigma)
        res = (
            mean + eps*sigma if self.is_vae
            else mean
        )
        return res
        # below is stupid garbo
        # this was a massive BUG
#         return mean + eps*sigma
#         return mean  # for non variational version, uncomment
    
    def get_latent(self, x):
        mean, logvar = self.encoder(x)
            
        if self.use_latent_norm:
            mean = self.latent_normalizer(mean)
            logvar = self.latent_normalizer(logvar)
            
        return self.reparam_trick(mean, logvar)
    
    def forward(self, x, noise_latent_lambda=0.):
        mean, logvar = self.encoder(x)
        
        if self.use_latent_norm:
            if 0:
                mean = self.latent_normalizer(mean)
                logvar = self.latent_normalizer(logvar)
            
            
                latent = self.reparam_trick(mean, logvar)
            else:
                latent = self.reparam_trick(mean, logvar)
                latent = self.latent_normalizer(latent)
        else:
            latent = self.reparam_trick(mean, logvar)
            
        if noise_latent_lambda:
            latent = latent + noise_latent_lambda*torch.randn_like(latent)
            
        
#         m_bar, pi, theta = self.decoder(latent)
        # return everything , last 3 are mean, logvar, latent

        recon_x = self.decoder(latent)
        return recon_x, mean, logvar, latent
       
    
#### LOSS FUNCTIONS
# gives option for VAE type of loss
def old_mse_loss(x, recon_x, weights=None):
    return F.mse_loss(
        recon_x, x, 
    ) * 1e5


def discrim_criter(pred, true):
    return F.binary_cross_entropy_with_logits(
        pred, true,
    ) * 1e5


def weighted_mse(a, b, weights=None):
    return (
        torch.sum(((a-b)**2)*weights) if (weights is not None)
        else F.mse_loss(a, b)
    ) * 1e5

def old_vae_loss(x, recon_x, mean, logvar, weights=None, this_lambda=0.,):
    if weights is None:
        bce = F.mse_loss(
            recon_x, x, 
        ) * 1e5  # poss comment out last part 
    else:
        bce = weighted_mse(recon_x, x, weights=weights)
   
    kl_div = -.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
    
    return bce + this_lambda*kl_div

def discrim_loss(pred, true):
    return F.binary_cross_entropy_with_logits(
        pred, true, 
    ) * 1e5

# don't do any requires_grad stuff in here
def adv_vae_loss(
    x, recon_x, 
    mean, logvar, discrim_preds,
    alpha, beta, weights=None,
):
    vae_part_loss = old_vae_loss(
        x, recon_x, mean, logvar, weights=weights)
    source_label = [1., 0.]
    target_label = [0., 1.]
    discrim_labels = torch.tensor([source_label] * x.shape[0]).to(device)
    total_discrim_loss = F.binary_cross_entropy_with_logits(
        discrim_preds, discrim_labels, 
    ) * 1e5
    
    discrim_part_loss = beta * total_discrim_loss
    return alpha * vae_part_loss + discrim_part_loss, vae_part_loss, total_discrim_loss

In [None]:
# include everything or only the non clonotypes
# do everything (how gte the clonotypes)


device = torch.device("cuda")
train_feature = omics_train.X.todense() / .1  # don't have this

# very small number on this one 
batch_size = 1<<7
epochs = 30  

ref_data_loader = DataLoader(
    TensorDataset(
        torch.from_numpy(train_feature),
    ), batch_size=batch_size,
    shuffle=True, num_workers=1, pin_memory=True,
)

input_cell_dim = 19089


# ODD FINDING - DEEPER MAKES LATENT SPACE LOOK BETTER
# ALSO LARGER HIDDEN DIM BY FACTOR OF 2
ref_vae = VAE(
    StandardEncoder(input_cell_dim, 1<<7, hidden_dim=1<<11),  # hidden was 1<<10
    StandardDecoder(input_cell_dim, 1<<7, hidden_dim=1<<11,),
    is_vae=False,
    use_latent_norm=True,  # was True for all else 
).to(device)

ref_vae_opt = optim.Adam( #5-5 is 181
    ref_vae.parameters(), lr=1e-5, #betas=(.5,.999), 5e get .074, after 10, same is avg .057
)

ref_vae.to(device)

epoch_losses = []
    
need_retrain = 0    
if need_retrain:    
    epochs = 30
    # got to < .115 avg loss after 150 epochs
    # best was ====> Epoch: 1000 Average loss: 0.0979890559

    for epoch in range(1, epochs+1):
        epoch_loss = 0.0
        for _id, [batch,] in enumerate(ref_data_loader):
            batch = batch.to(device)
            ref_vae_opt.zero_grad()

            recon_x, mean, logvar, latent = ref_vae(batch)
            batch_loss = old_vae_loss(
                batch, recon_x, mean, logvar, weights=None,
            )


            batch_loss.backward()
            epoch_loss += batch_loss.item()
            ref_vae_opt.step()
            if not (_id % 500):
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.10f}'.format(
                    epoch, 
                    _id * len(batch), 
                    len(ref_data_loader.dataset),
                    25. * _id / len(ref_data_loader),
                    batch_loss.item() / len(batch),
                ))

        print('====> Epoch: {} Average loss: {:.10f}'.format(
                  epoch, epoch_loss / len(ref_data_loader.dataset)))

        epoch_losses.append(epoch_loss)

else:
    ref_vae.load_state_dict(torch.load('ref_vae.pt'))

In [None]:
latent_dim = 1<<7
ref_vae = ref_vae.to('cpu')

orig_cells_dataset = (
    torch
    .from_numpy(train_feature)
    .float()
    .to('cpu')
)

_, _, _, latent = ref_vae(orig_cells_dataset)

latent = latent.detach().numpy()

annos = np.unique(omics_data.obs['leiden_cell_type'])
anno_to_label = dict(zip(annos, range(len(annos))))
label_to_anno = dict(zip(range(len(annos)), annos))
final_output_shape = len(label_to_anno)

celltype_classifier = Discriminator(latent_dim, end_dim=final_output_shape).to(device)

celltype_classifier_opt = optim.Adam(
    celltype_classifier.parameters(), lr=1e-3,
)

batch_size = 1<<5


celltype_train_list = np.array([
    anno_to_label[ct] 
    for ct in omics_train.obs['leiden_cell_type']
])

celltype_data_loader = DataLoader(
    TensorDataset(
        torch.from_numpy(latent),
        torch.from_numpy(celltype_train_list),
    ), batch_size=batch_size,
    shuffle=True, num_workers=1, pin_memory=True,
)

num_cells = len(celltype_train_list)
class_weights = torch.tensor([
    (float(num_cells) / np.sum(celltype_train_list == class_label)) for class_label in range(final_output_shape)
]).float()

class_weights = class_weights.to(device)
criter = nn.CrossEntropyLoss(weight=class_weights)


epochs = 32
celltype_classifier = celltype_classifier.to(device)
for epoch in range(epochs):
    epoch_loss = 0.0
    for _id, ([this_batch, this_label]) in enumerate(celltype_data_loader):
        this_batch = this_batch.to(device)
        this_label = this_label.to(device)
        celltype_classifier_opt.zero_grad()
        predicted_labels = celltype_classifier(this_batch.float())
        this_batch_loss = criter(
            predicted_labels,
            this_label,
        )
        this_batch_loss.backward()
        epoch_loss += this_batch_loss.item()
        celltype_classifier_opt.step()

    print('====> Epoch: {} Average loss: {:.10f}'.format(
          epoch+1, epoch_loss / len(celltype_data_loader.dataset)))

celltype_classifier = celltype_classifier.to('cpu')        
orig_pred_labels = celltype_classifier(torch.from_numpy(latent)).detach().numpy()
orig_pred_labels = np.argmax(orig_pred_labels, axis=1)
num_final_correct = np.sum(orig_pred_labels == celltype_train_list)

print(f'final_accuracy:{ num_final_correct / float(num_cells)}')
# get around 96 % accuracy, which is solid i suppose

In [None]:
# train the adversarial wring of the model 
# first, do cross modal

latent_dim = 1<<7
sample1_beta = 4e-3 if 0 else 0 # for now, just have this as 0


input_cell_dim = 930

ref_vae = ref_vae.to(device)
raman_vae = VAE(
    StandardEncoder(input_cell_dim, 1<<7, hidden_dim=1<<11,),  # hidden was 1<<10
    StandardDecoder(input_cell_dim, 1<<7, hidden_dim=1<<11, no_final_relu=True,),
    is_vae=False,
    use_latent_norm=True,  # was True for all else 
).to(device)

raman_opt = optim.Adam(
    raman_vae.parameters(), lr=5e-5, #betas=(.5,.999), 5e get .074, after 10, same is avg .057
)

raman_vae.to(device)

raman_discrim = Discriminator(latent_dim).to(device)

raman_discrim_opt = optim.Adam(
    raman_discrim.parameters(), lr=4e-3,#5e-4,  # was 1e-5 before, 5e-6, best 1e-4
)


# now actually do the training  

alpha = 1e0  # maybe make this 0 

# 8e-4 was good before add in other sample
beta = 3e-4 if 1 else 8e-4 if 1 else 2e-3 if 1 else 0 #3e-4  #1e-2   # 1e-8, was best


# a good loss for the celltype is .006 ish
# 5 was good from when there was extra day 12 in the training 
raman_beta = 5e1 if 1 else 1e2 # make this big, see if helps


print(f"begin_raman_latent_train")

# variables of interest are:
    # raman_device
    # raman_vae
    # raman_opt
    # raman_discrim
    # raman_discrim_opt
    # raman_discrim_sched

# try this and just repeating the ref_key one many times 
# also make this multip gpus training this is taking a while 

# train_feature = omics_train.X / .1
train_feature = omics_train.X.todense() / .1
train_feature_raman = raman_train.X / .1  # see if this helps 
    
ref_data_loader = DataLoader(
    TensorDataset(
        torch.from_numpy(train_feature),
    ), batch_size=batch_size,
    shuffle=True, num_workers=1, pin_memory=True,
)

raman_celltype_train_list = np.array([
    anno_to_label[ct] 
    for ct in raman_train.obs['tg_celltype']
])

raman_data_loader = DataLoader(
    TensorDataset(
        torch.from_numpy(train_feature_raman),
        torch.from_numpy(raman_celltype_train_list),
    ), batch_size=batch_size,
    shuffle=True, num_workers=1, pin_memory=True,
)

celltype_classifier = celltype_classifier.to(device)
need_retrain_raman = 0
if need_retrain_raman:
    # maybe do just ten epochs
    epochs = 105 if 1 else 150 if 1 else 90 if 1 else 75 if 0 else 100 if 0 else 30 # (was 150 before)
    for epoch in range(1, epochs + 1):

        discrim_epoch_loss = 0.
        vae_part_epoch_loss = 0.
        raman_vae_epoch_loss = 0.
        celltype_part_epoch_loss = 0.
        print(f"begin epoch {epoch}")
        for _id, ([ref_batch,], [raman_batch, raman_celltypes], ) in enumerate(zip(
            ref_data_loader,
            raman_data_loader,
        )):

            raman_opt.zero_grad()
            raman_discrim_opt.zero_grad()
            ref_batch = ref_batch.to(device)
            raman_batch = raman_batch.to(device)
            raman_celltypes = raman_celltypes.to(device)

            ref_encoded = (
                ref_vae.get_latent(ref_batch)
                .detach()
            )

            raman_encoded = (
                raman_vae.get_latent(raman_batch)
                .detach()
            )

            # put all tensors on the right device 
            source_label, target_label = [1., 0.], [0., 1.]
            encodeds = torch.cat((ref_encoded, raman_encoded), axis=0)
            discrim_labels = torch.tensor(
                [source_label] * ref_encoded.shape[0]
                + [target_label] * raman_encoded.shape[0]
            ).to(device)

            pred_discrim_labels = raman_discrim(encodeds)
            batch_discrim_loss = discrim_loss(
                pred_discrim_labels, discrim_labels,
            )

            batch_discrim_loss.backward()
            discrim_epoch_loss += batch_discrim_loss.item()
            raman_discrim_opt.step()     


            #### second part
            for param in raman_discrim.parameters():
                param.requires_grad = False

            recon_raman_batch, raman_batch_mean, raman_batch_logvar, raman_batch_latent = raman_vae(raman_batch)

            # this was from before, do not use this
    #                 raman_batch_latent = raman_vae.reparam_trick(raman_batch_mean, raman_batch_logvar)
            raman_vae_discrim_preds = raman_discrim(raman_batch_latent)

            # set the discrim requires_grad to False 

            raman_vae_batch_loss, vae_part_batch_loss, _ = adv_vae_loss(
                raman_batch.detach(), recon_raman_batch, # second_batch_latent,
                raman_batch_mean, raman_batch_logvar, raman_vae_discrim_preds,
                alpha, beta,   # was 1e-1, 1e-2 was bad, 1e-1 was best 
            )

            #### next part THIS IS NEW 
            raman_celltype_preds = celltype_classifier(raman_batch_latent)
            raman_celltype_loss =  criter(raman_celltype_preds, raman_celltypes)
            raman_vae_batch_loss = raman_vae_batch_loss + raman_beta * raman_celltype_loss

            raman_vae_batch_loss.backward()
            celltype_part_epoch_loss += raman_celltype_loss.item()
            raman_vae_epoch_loss += raman_vae_batch_loss.item()
            vae_part_epoch_loss += vae_part_batch_loss.item()
            raman_opt.step()

            # undo the above
            for param in raman_discrim.parameters():
                param.requires_grad = True


            if not (_id % 500):
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLosses: {:.6f} {:.6f}'.format(
                epoch, _id * len(raman_batch), len(raman_data_loader.dataset),
                100. * _id / len(raman_data_loader),
                batch_discrim_loss.item() / len(raman_batch),
                raman_vae_batch_loss.item() / len(raman_batch)))

        print('====> Epoch: {} Average adv vae loss: {:.10f}'.format(
              epoch, raman_vae_epoch_loss / len(raman_data_loader.dataset)))

        print('====> Epoch: {} Average vae part loss: {:.10f}'.format(
              epoch, vae_part_epoch_loss / len(raman_data_loader.dataset)))

        print('====> Epoch: {} Average celltype part loss: {:.10f}'.format(
              epoch, celltype_part_epoch_loss / len(raman_data_loader.dataset)))

        print('====> Epoch: {} Average discrim vae loss: {:.10f}'.format(
              epoch, discrim_epoch_loss / len(raman_data_loader.dataset)))

        print(f"end epoch {epoch}")
        if 0 and vae_part_epoch_loss >  .9 * raman_vae_epoch_loss:
            beta *= 2
            print(f"updating at epoch {epoch} to beta {beta}")
    print(f"end_raman_latent_train")

else:
    raman_vae.load_state_dict(torch.load('raman_vae.pt'))    

In [None]:
# see how well transferred and originals align
transfer_vae = VAE(
    raman_vae.encoder,
    ref_vae.decoder,
    is_vae=False,
    use_latent_norm=True,  # was True for all else 
)

transfer_vae = transfer_vae.to('cpu')
ref_vae = ref_vae.to('cpu')


orig_cells_dataset = (
    torch
    .from_numpy(train_feature)
    .float()
    .to('cpu')
)


recon, _, _, _ = ref_vae(orig_cells_dataset)


orig_cells_dataset_raman = (
    torch
    .from_numpy(train_feature_raman)
    .float()
    .to('cpu')
)

recon_raman, _, _, _ = transfer_vae(orig_cells_dataset_raman)


recon_adata = sc.AnnData(recon.detach().numpy())
recon_adata.obs = omics_train.obs
recon_adata_raman = sc.AnnData(recon_raman.detach().numpy())
recon_adata_raman.obs = raman_train.obs
together_recon = recon_adata.concatenate(recon_adata_raman)

sc.pp.pca(together_recon, n_comps=30)
sc.pp.neighbors(together_recon, n_neighbors=30)
sc.tl.umap(together_recon)

sc.pl.umap(together_recon, color='batch')

In [None]:
sc.pl.umap(together_recon, color=['batch', 'leiden_cell_type', 'tg_celltype', 'together_celltype'], wspace=0.4)