### TUGDA on Multi-task Learning settings
In this notebook we present how TUGDA on MTL settings (Figure 1, preprint) can be trained using the GDSC dataset (*in vitro* settings). Here we trained and evaluated TUGDA on set of best hyperparameters (previously found on validation data) using 3-fold cross-validation (Figure 2, TUGDA model).

In [18]:
#install required libs
!pip install pandas
!pip install torch==1.5.1+cu101  -f https://download.pytorch.org/whl/torch_stable.html
#in this version we require pytorch-lightning==0.9.0
!pip install pytorch-lightning==0.9.0

In [2]:
import pandas as pd
import numpy as np
import random
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
#call pytorch lightning functions
import pytorch_lightning as pl
from pytorch_lightning import Callback
from pytorch_lightning import Trainer, seed_everything

In [3]:
#get list of drugs to be trained and predicted
folder = 'data/'
drug_list = pd.read_csv('{}/cl_y_test_o_k1.csv'.format(folder), index_col=0 )
drug_list = drug_list.columns

In [4]:
#3-fold training and test data;
train_data_report = {}
test_data_report = {}

for k in range(1,4):
    train_data_report['x_k_fold{}'.format(k)] = pd.read_csv('{}/cl_x_train_o_k{}.csv'.format(folder, k), index_col=0)
    train_data_report['y_k_fold{}'.format(k)] = pd.read_csv('{}/cl_y_train_o_k{}.csv'.format(folder, k), index_col=0)
    
    test_data_report['x_k_fold{}'.format(k)] = pd.read_csv('{}/cl_x_test_o_k{}.csv'.format(folder, k), index_col=0)
    test_data_report['y_k_fold{}'.format(k)] = pd.read_csv('{}/cl_y_test_o_k{}.csv'.format(folder, k), index_col=0)

In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


In [6]:
class MetricsCallback(Callback):
    """PyTorch Lightning metric callback."""

    def __init__(self):
        super().__init__()
        self.metrics = []

    def on_test_end(self, trainer, pl_module):
        self.metrics.append(trainer.callback_metrics)

In [10]:
class tugda_mtl(pl.LightningModule):
    def __init__(self, params, train_data, y_train,
                 test_data, y_test
                ):
        super(tugda_mtl, self).__init__()
        
        self.learning_rate = params['lr']
        self.batch_size = params['bs']
        self.mu = params['mu']
        self.lambda_ = params['lambda_']
        self.gamma = params['gamma']
        self.num_tasks = params['num_tasks']
        self.passes = params['passes']

        self.train_data = train_data
        self.y_train = y_train
        self.test_data = test_data
        self.y_test = y_test
        
        input_dim = self.train_data.shape[1]
        
        feature_extractor = [nn.Linear(input_dim, params['hidden_units_1']), 
                             nn.Dropout(p=params['dropout']),
                             nn.ReLU()]
        
        self.feature_extractor = nn.Sequential(*feature_extractor)

        latent_basis =  [nn.Linear(params['hidden_units_1'], params['latent_space']),
                         nn.Dropout(p=params['dropout']),
                         nn.ReLU()]
        
        self.latent_basis = nn.Sequential(*latent_basis)
        
        #task-specific weights
        self.S = nn.Linear(params['latent_space'], self.num_tasks)
        
        #decoder weights
        A = [nn.Linear( self.num_tasks , params['latent_space']), nn.ReLU()]
        self.A = nn.Sequential(*A)
        
        #uncertainty (aleatoric)
        self.log_vars = torch.zeros(self.num_tasks, requires_grad=True, device=device)
        
    def forward(self, input_data):
        x = self.feature_extractor(input_data)
        h = self.latent_basis(x)
        preds = self.S(h)
        h_hat = self.A(preds)
        return preds, h, h_hat

    def prepare_data(self):
        train_dataset = TensorDataset(torch.FloatTensor(self.train_data),
                                      torch.FloatTensor(self.y_train))

        test_dataset = TensorDataset(torch.FloatTensor(self.test_data),
                                     torch.FloatTensor(self.y_test))

        self.train_dataset = train_dataset
        self.test_dataset = test_dataset

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=8)
    
    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=len(self.test_dataset), shuffle=False, num_workers=8)

    def configure_optimizers(self):
        params = ([p for p in self.parameters()] + [self.log_vars])
        optimizer = torch.optim.Adagrad(params, lr=self.learning_rate)
        return optimizer
    
    def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, 
                       second_order_closure=None, on_tpu=False, using_native_amp=False, using_lbfgs=False):
        # update params
        optimizer.step()
        optimizer.zero_grad()
        
    def mse_ignore_nan(self, preds, labels):
        mse_loss = torch.nn.MSELoss(reduction='none')
        per_task_loss = torch.zeros(labels.size(1), device=device)

        for k in range(labels.size(1)):
            precision = torch.exp(-self.log_vars[k])
            diff = mse_loss(preds[~torch.isnan(labels[:,k]), k], labels[~torch.isnan(labels[:,k]), k])
            per_task_loss[k] = torch.mean(precision * diff + self.log_vars[k])

        return torch.mean(per_task_loss[~torch.isnan(per_task_loss)]), per_task_loss
    
    def mse_ignore_nan_test(self, preds, labels):
        mse_loss = torch.nn.MSELoss(reduction='mean')
        per_task_loss = torch.zeros(labels.size(1), device=device)
        
        for k in range(labels.size(1)):
            per_task_loss[k] = mse_loss(preds[~torch.isnan(labels[:,k]), k], labels[~torch.isnan(labels[:,k]), k])
            
        return torch.mean(per_task_loss[~torch.isnan(per_task_loss)]), per_task_loss 
    
    #autoencoder loss
    def MSE_loss(self, x, x_hat):
        mse_loss = torch.nn.MSELoss()
        return mse_loss(x, x_hat)
    
    def forward_pass(self, fw_batch, batch_idx):
        
        x, y = fw_batch
        
        preds_simulation = torch.torch.zeros(y.size(0),y.size(1), self.passes, device=device)
        for simulation in range(self.passes):
            
            preds, h, h_hat = self.forward(x)
            preds_simulation[:,:, simulation]=preds
        
        preds_mean = torch.mean(preds_simulation, axis=2)
        preds_var = torch.var(preds_simulation, axis=2)
        total_unc = torch.mean(preds_var, axis=0)
        
        #prediction loss
        local_loss, task_loss = self.mse_ignore_nan(preds_mean, y)
        #autoencoder loss
        recon_loss = self.gamma * self.MSE_loss(h, h_hat)
        
        a = 1 + (total_unc + torch.sum(torch.abs(self.A[0].weight.T),1))
        loss_weight = ( a[~torch.isnan(task_loss)]  ) * task_loss[~torch.isnan(task_loss)] 
        loss_weight = torch.sum(loss_weight)
        l1_S = self.mu * self.S.weight.norm(1)
        L = self.latent_basis[0].weight.norm(2) + self.feature_extractor[0].weight.norm(2)
        l2_L = self.lambda_ * L

        #total loss
        total_loss = loss_weight + recon_loss + l1_S + l2_L
        return total_loss, task_loss
        
    def training_step(self, train_batch, batch_idx):
        
        loss, task_loss = self.forward_pass(train_batch, batch_idx)
        
        logs = {'train_loss': loss, 'task_loss': task_loss}
        return {'loss': loss, 'log': logs}
    
    def test_step(self, test_batch, batch_idx):
        x, y = test_batch
        
        #enable dropouts
        self.feature_extractor[1].train()
        self.latent_basis[1].train()
        
        preds_simulation = torch.torch.zeros(y.size(0),y.size(1), self.passes, device=device)
        for simulation in range(self.passes):
            
            seed = simulation
            #to reproduce predictions
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False
            
            preds, _, _  = self.forward(x)
            preds_simulation[:,:, simulation]=preds
        
        preds_mean = torch.mean(preds_simulation, axis=2)
        loss, task_losses_per_class = self.mse_ignore_nan_test(preds_mean, y)
        
        #disable dropouts
        self.feature_extractor[1].eval()
        self.latent_basis[1].eval()
        
        return {'test_loss': loss,
               'test_task_losses_per_class': task_losses_per_class.detach().cpu().numpy(),
               'test_preds': preds_mean.detach().cpu().numpy(),
                }

In [11]:
#best set of hyperparamters found on this dataset setting (GDSC)
net_params = {
 #tunned hyperparameters
 'hidden_units_1': 1024,
 'latent_space': 700,
 'lr': 0.001,
 'dropout': 0.1,
 'mu': 0.01,
 'lambda_': 0.001,
 'gamma': 0.0001,
 'bs': 300,
 'passes': 50,
 'num_tasks': 200,
 'epochs': 100}

In [12]:
#training and testing
error_list = []
pcorr_list = []

metrics_callback = MetricsCallback()

for k in range(1,4):
    
    X_train = train_data_report['x_k_fold{}'.format(k)].values
    X_test = test_data_report['x_k_fold{}'.format(k)].values

    y_train = train_data_report['y_k_fold{}'.format(k)].values
    y_test = test_data_report['y_k_fold{}'.format(k)].values

    trainer = pl.Trainer(
        max_epochs=net_params['epochs'],
        gpus=1 if torch.cuda.is_available() else None,
        callbacks=[metrics_callback],
        deterministic=True,
        reload_dataloaders_every_epoch=True
    )

    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    seed_everything(seed)
    model = tugda_mtl(net_params, X_train, y_train,
                  X_test, y_test)

    trainer.fit(model)
    
    # use model after training or load weights
    results = trainer.test(model)

    #get error per drug
    error_mtl_nn_results = np.concatenate((np.array(drug_list, ndmin=2).T,
                            np.array(results[0]['test_task_losses_per_class'], ndmin=2).T), axis=1)



GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type       | Params
-------------------------------------------------
0 | feature_extractor | Sequential | 1 M   
1 | latent_basis      | Sequential | 717 K 
2 | S                 | Linear     | 140 K 
3 | A                 | Sequential | 140 K 


Training: 0it [00:00, ?it/s]

Saving latest checkpoint..


Testing: 0it [00:00, ?it/s]

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type       | Params
-------------------------------------------------
0 | feature_extractor | Sequential | 1 M   
1 | latent_basis      | Sequential | 717 K 
2 | S                 | Linear     | 140 K 
3 | A                 | Sequential | 140 K 


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': tensor(1.8383, device='cuda:0'),
 'test_preds': array([[ 5.5901303 ,  3.5795941 ,  3.8705695 , ...,  1.2044197 ,
         3.9309645 ,  8.334039  ],
       [ 1.2868376 ,  1.8721877 ,  2.4557369 , ...,  1.2568722 ,
        -0.61222005,  6.030042  ],
       [ 4.191142  ,  2.702798  ,  4.1943035 , ...,  0.48288536,
         2.7910597 ,  7.878966  ],
       ...,
       [ 4.408374  ,  1.7468319 ,  3.9545064 , ...,  1.9170301 ,
         2.4439368 ,  6.893946  ],
       [ 5.5004554 ,  3.5991457 ,  4.3986564 , ...,  1.3953224 ,
         4.053058  ,  9.278416  ],
       [ 4.348958  ,  2.8188868 ,  3.2787073 , ...,  3.7183821 ,
         3.2583432 ,  7.179334  ]], dtype=float32),
 'test_task_losses_per_class': array([1.4914658 , 1.3344768 , 0.88051796, 2.2467782 , 1.5596995 ,
       3.5224605 , 2.424226  , 2.3212197 , 1.6756608 , 2.6915495 ,
       1.6154008 , 1.9730062 , 1.7105

Training: 0it [00:00, ?it/s]

Saving latest checkpoint..


Testing: 0it [00:00, ?it/s]

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type       | Params
-------------------------------------------------
0 | feature_extractor | Sequential | 1 M   
1 | latent_basis      | Sequential | 717 K 
2 | S                 | Linear     | 140 K 
3 | A                 | Sequential | 140 K 


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': tensor(1.7584, device='cuda:0'),
 'test_preds': array([[3.3611708 , 2.166865  , 2.6094744 , ..., 3.636953  , 3.1647692 ,
        8.365946  ],
       [4.6335254 , 2.4563174 , 3.911688  , ..., 1.1904571 , 2.5462885 ,
        6.9986057 ],
       [3.7746596 , 2.6394184 , 3.618499  , ..., 0.30261213, 1.8948536 ,
        6.629425  ],
       ...,
       [5.496294  , 1.7671915 , 5.0049424 , ..., 1.4731125 , 2.834544  ,
        9.644545  ],
       [1.7615435 , 1.7464737 , 2.256626  , ..., 2.788451  , 0.7374495 ,
        5.73632   ],
       [4.406636  , 1.975105  , 3.4456518 , ..., 2.7866516 , 3.4155595 ,
        7.1431456 ]], dtype=float32),
 'test_task_losses_per_class': array([1.411931  , 1.1247501 , 0.65603256, 1.9193974 , 1.4431998 ,
       3.5096447 , 2.6746562 , 2.1912496 , 1.9675112 , 2.7095287 ,
       1.4945976 , 2.19008   , 1.803638  , 1.9738765 , 0.74348086,
      

Training: 0it [00:00, ?it/s]

Saving latest checkpoint..


Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': tensor(1.7833, device='cuda:0'),
 'test_preds': array([[3.8143811 , 1.8225929 , 3.7388253 , ..., 2.9014597 , 2.376441  ,
        7.878894  ],
       [3.5160081 , 2.1327918 , 3.0487273 , ..., 0.37313232, 1.8789779 ,
        6.543551  ],
       [4.3462915 , 2.7900941 , 3.8394225 , ..., 0.5774908 , 2.6929893 ,
        6.9358087 ],
       ...,
       [4.8131986 , 1.4028168 , 3.9412835 , ..., 1.5468093 , 3.3820288 ,
        7.945622  ],
       [1.7894878 , 2.346606  , 2.2630565 , ..., 2.6876516 , 0.93378997,
        5.619823  ],
       [2.663619  , 1.5849105 , 2.5142999 , ..., 2.1392007 , 0.89365065,
        6.041117  ]], dtype=float32),
 'test_task_losses_per_class': array([2.0588343 , 1.33407   , 0.83493567, 2.2446628 , 1.5029362 ,
       3.7772398 , 2.755147  , 2.6765773 , 1.8384305 , 2.6004894 ,
       1.8577863 , 2.2867053 , 1.8857011 , 1.905037  , 0.9244382 ,
      