In [100]:
%notebook code_logs/tugda_da.ipynb

In [45]:
import pandas as pd
import numpy as np
import random
from itertools import cycle

In [46]:
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from torch.autograd import Function

In [47]:
device = 'cpu' if torch.backends.mps.is_available() else 'cpu'
print(device)

cpu


In [48]:
class ReverseLayerF(Function):

    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha

        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        output = grad_output.neg() * ctx.alpha

        return output, None

In [96]:
class Tugda_da(nn.Module):

    def __init__(self,params,train_data,test_data,y_train,y_test):
        super(Tugda_da,self).__init__()
        self.learning_rate = params['lr']
        self.batch_size = params['bs_source']
        self.mu = params['mu']
        self.lambda_ = params['lambda_']
        self.gamma = params['gamma']
        self.lambda_disc = params['lambda_disc']
        self.bs_disc = params['bs_disc']
        self.n_epoch = params['n_epochs']
        self.passes = params['passes']
        self.num_tasks = params['num_tasks']

        self.train_data = train_data
        self.y_train = y_train
        self.test_data = test_data
        self.y_test = y_test

        input_dim=self.train_data.shape[1]

        feature_extractor = [nn.Linear(input_dim, params['hidden_units_1']), 
                             nn.Dropout(p=params['dropout']),
                             nn.ReLU()]
        
        self.feature_extractor = nn.Sequential(*feature_extractor) #L layer in diag

        latent_basis =  [nn.Linear(params['hidden_units_1'], params['latent_space']),
                         nn.Dropout(p=params['dropout']),
                         nn.ReLU()]
        
        self.latent_basis = nn.Sequential(*latent_basis) #Z layer in diag
        
        #task-specific weights
        self.MTL = nn.Linear(params['latent_space'], self.num_tasks) #S Layer in diag| also the prediction outputs
        
        #decoder weights
        decTTF = [nn.Linear( self.num_tasks , params['latent_space']), nn.ReLU()] # A layer in diagram or decoder task to feature layer 
        self.decTTF = nn.Sequential(*decTTF) 

        #uncertainty (aleatoric)
        self.log_vars = torch.zeros(self.num_tasks, requires_grad=True, device=device)
        
        #discriminator
        task_classifier = [nn.Linear(params["latent_space"], params["n_units_disc"]), 
                             nn.Dropout(p=params['dropout']), 
                             nn.ReLU(),
                             nn.Linear(params["n_units_disc"], 1), 
                             nn.Sigmoid()
                            ]
        self.task_classifier = nn.Sequential(*task_classifier)

        self.prepare_data()

    def forward(self, input_data, alpha):
        x = self.feature_extractor(input_data) #usually the input data is treated as x
        x = self.latent_basis(x) # here h 
        preds = self.MTL(x)
        x_hat = self.decTTF(preds) # task transfer to features # here was h_hat
        reverse_feature = ReverseLayerF.apply(x, alpha)
        task_classifier_output = self.task_classifier(reverse_feature)
        return preds, x, x_hat, task_classifier_output
    
    def prepare_data(self):
        train_dataset = TensorDataset(torch.FloatTensor(self.train_data),
                                      torch.FloatTensor(self.y_train))

        
        target_unl_dataset = TensorDataset(torch.FloatTensor(self.test_data))
        
        self.train_dataset = train_dataset
        self.target_unl_dataset = target_unl_dataset
    
    def train_dataloader(self):
        dataloader1 = DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True) #in vitro data
        dataloader2 = DataLoader(self.target_unl_dataset, batch_size=self.bs_disc, shuffle=True) #in vivo data 
        self.len_dataloader = min(len(dataloader1), len(dataloader2))
        x=zip(dataloader1, cycle(dataloader2))
        return x
    
    def test_dataloader(self):
        dataloader = DataLoader(self.target_unl_dataset, batch_size=len(self.target_unl_dataset), shuffle=False)
        return dataloader
    
    def MSE_loss(self, x, x_hat):
        mse_loss = torch.nn.MSELoss()
        return mse_loss(x, x_hat)
    
    def task_mse_ignore_nan(self, preds, labels):
        mse_loss = torch.nn.MSELoss(reduction='none')
        per_task_loss = torch.zeros(labels.size(1), device=device)

        for k in range(labels.size(1)):
            precision = torch.exp(-self.log_vars[k]) #didnt understand this.
            diff = mse_loss(preds[~torch.isnan(labels[:,k]), k], labels[~torch.isnan(labels[:,k]), k])
            per_task_loss[k] = torch.mean(precision * diff + self.log_vars[k])

        return torch.mean(per_task_loss[~torch.isnan(per_task_loss)]), per_task_loss

    def mse_ignore_nan_test(self, preds, labels):
        mse_loss = torch.nn.MSELoss(reduction='mean')
        per_task_loss = torch.zeros(labels.size(1), device=device)
        per_sample_loss = torch.zeros(labels.size(0), device=device)

        
        for k in range(labels.size(1)):
            per_task_loss[k] = mse_loss(preds[~torch.isnan(labels[:,k]), k], labels[~torch.isnan(labels[:,k]), k])
        
        #per class loss
        for k in range(labels.size(0)):
            per_sample_loss[k] = mse_loss(preds[k, ~torch.isnan(labels[k,:])], labels[k, ~torch.isnan(labels[k, :])])
        
        
        return torch.mean(per_task_loss[~torch.isnan(per_task_loss)]), per_task_loss,per_sample_loss
    
    def binary_classification_loss(self, preds, labels):
        bin_loss = torch.nn.BCELoss()
        return bin_loss(preds, labels)
    
    def train(self):

        optimizer = torch.optim.Adagrad([
            {'params': self.feature_extractor.parameters()},
            {'params': self.latent_basis.parameters()},
            {'params': self.MTL.parameters()},
            {'params': self.decTTF.parameters()},
            {'params': self.task_classifier.parameters()},
            {'params': self.log_vars}
        ], lr=self.learning_rate)
        
        for epoch in range(self.n_epoch):
            train_dataloader=self.train_dataloader()
            train_loss=0
            for batch_idx,data in enumerate(train_dataloader):
                x,y=data[0] #in vitro
                x_unl=data[1][0] #invivo

                p=float(batch_idx+epoch*self.len_dataloader)/self.n_epoch/self.len_dataloader
                alpha=2. / (1. + np.exp(-10 * p)) - 1

                preds_simulation=torch.zeros(y.size(0),y.size(1),self.passes)#,device=device) 
                for simulation in range(self.passes):
                    preds, h, h_hat, domain_out_source = self.forward(x, alpha)
                    preds_simulation[:,:, simulation]=preds
                
                preds_mean = torch.mean(preds_simulation, axis=2) #2D predictions after p passes
                preds_var = torch.var(preds_simulation, axis=2) #2D
                total_unc = torch.mean(preds_var, axis=0) # mu_t this is epistemic uncertainity 1D of tasks i.e y.shape[1]


                m_loss, task_loss = self.task_mse_ignore_nan(preds_mean, y)
                recon_loss= self.MSE_loss(h,h_hat)

                a = 1 + (total_unc + torch.sum(torch.abs(self.decTTF[0].weight.T),1)) #multiplier to L_bnn

                loss_weight = ( a[~torch.isnan(task_loss)] ) * task_loss[~torch.isnan(task_loss)]
                loss_weight = torch.sum(loss_weight)

                #Regularizer gamma(Z-sig(ZSA))
                l1_S = self.mu * self.MTL.weight.norm(1)
                L = self.latent_basis[0].weight.norm(2) + self.feature_extractor[0].weight.norm(2)
                l2_L = self.lambda_ * L

                #domain discriminator source
                #normal data pass
                # 0 is one class of invitro
                zeros = torch.zeros(y.size(0),1)#, device=self.device)
                d_loss_source = self.binary_classification_loss(domain_out_source, zeros)
                
                #domain discriminator target
                #unlabelled data pass
                #1 is second class which is invivo 
                preds, h, h_hat, domain_out_target = self.forward(x_unl, alpha)
                ones = torch.ones(x_unl.size(0),1)#, device=self.device) 
                d_loss_target = self.binary_classification_loss(domain_out_target, ones)
                
                d_loss = d_loss_source + d_loss_target #l_adv loss or the adversarial loss 

                #total loss
                loss = loss_weight + recon_loss + l1_S + l2_L + (self.lambda_disc *d_loss)

                loss.backward()
                optimizer.step()

                train_loss+=loss

            train_loss/=self.len_dataloader
            print('Epoch [{}/{}] : Loss {}'.format(epoch,self.n_epoch,train_loss))#, 'source_loss': task_loss, 'disc_loss': d_loss})
        return None #loss, m_loss, d_loss
    
    # def test(self):
        

SyntaxError: incomplete input (2652268761.py, line 192)

In [84]:
#cell-line dataset;
gdsc_dataset = pd.read_csv('data/GDSCDA_fpkm_AUC_all_drugs.zip', index_col=0)
#gene set range
gene_list = gdsc_dataset.columns[0:1780]
drug_list = gdsc_dataset.columns[1780:]

In [85]:
#pdx novartis dataset;
pdx_dataset = pd.read_csv('data/PDX_MTL_DA.csv', index_col=0)
drugs_pdx = pdx_dataset.columns[1780:]

In [98]:
net_params = {
 'hidden_units_1': 1500,
 'latent_space': 800,
 'lr': 0.001,
 'dropout': 0.1,
 'mu': 1,
 'lambda_': 1,
 'gamma': 0.01,
 'n_units_disc': 500,
 'n_epochs': 150, #50
 'bs_disc': 64,
 'bs_source': 300,
 'lambda_disc': 0.3,
 'num_tasks': 200,
 'passes': 20}#50

In [88]:
gene_list = gdsc_dataset.columns[0:1780]
drug_list = gdsc_dataset.columns[1780:]

X_train = gdsc_dataset[gene_list].values #In vitro data
y_train = gdsc_dataset[drug_list].values
 
X_test = pdx_dataset[gene_list].values #In-Vivo data
y_test = pdx_dataset[drugs_pdx].values


In [99]:
model = Tugda_da(net_params, X_train, X_test, y_train, y_test)
model.train()

Epoch [0/150] : Loss 5465.60400390625
Epoch [1/150] : Loss 3281.790283203125
Epoch [2/150] : Loss 3029.475830078125
Epoch [3/150] : Loss 2707.848388671875
Epoch [4/150] : Loss 2533.971923828125
Epoch [5/150] : Loss 2363.306884765625
Epoch [6/150] : Loss 2216.306884765625
Epoch [7/150] : Loss 2055.726318359375
Epoch [8/150] : Loss 1956.2987060546875
Epoch [9/150] : Loss 1844.7503662109375
Epoch [10/150] : Loss 1749.3697509765625
Epoch [11/150] : Loss 1642.3861083984375
Epoch [12/150] : Loss 1558.0196533203125
Epoch [13/150] : Loss 1476.8834228515625
Epoch [14/150] : Loss 1404.9962158203125
Epoch [15/150] : Loss 1337.8436279296875
Epoch [16/150] : Loss 1271.7005615234375
Epoch [17/150] : Loss 1208.3077392578125
Epoch [18/150] : Loss 1151.0533447265625
Epoch [19/150] : Loss 1095.324462890625
Epoch [20/150] : Loss 1048.3060302734375
Epoch [21/150] : Loss 1003.436767578125
Epoch [22/150] : Loss 958.6067504882812
Epoch [23/150] : Loss 918.3923950195312
Epoch [24/150] : Loss 880.2361450195312