### TUGDA on Domain Adaptation settings (PDX)
In this notebook we present how TUGDA on Domain Adaptation settings (figure 1, preprint) can be trained using the GDSC (*in vitro* settings) and PDX (*in vivo*) datasets. Here we trained and evaluated TUGDA on set of best hyperparameters (previously found on validation data). Results reproduced (figure 3) using Pytorch (1.5.1, cuda 10.1) on GPU settings.

In [23]:
#install required libs
!pip install pandas
!pip install torch==1.5.1+cu101  -f https://download.pytorch.org/whl/torch_stable.html
#in this version we require pytorch-lightning==0.9.0
!pip install pytorch-lightning==0.9.0

In [14]:
import pandas as pd
import numpy as np
import random
from itertools import cycle

import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from torch.autograd import Function
import pytorch_lightning as pl
from pytorch_lightning import Callback
from pytorch_lightning import seed_everything

In [15]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


In [16]:
#flip discriminator gradient;
class ReverseLayerF(Function):

    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha

        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        output = grad_output.neg() * ctx.alpha

        return output, None

In [17]:
class MetricsCallback(Callback):
    """PyTorch Lightning metric callback."""

    def __init__(self):
        super().__init__()
        self.metrics = []

    def on_test_end(self, trainer, pl_module):
        self.metrics.append(trainer.callback_metrics)

    def on_train_end(self, trainer, pl_module):
        self.metrics.append(trainer.callback_metrics)


In [18]:
class tugda_da(pl.LightningModule):
    def __init__(self, params, train_data, y_train, x_target_unl,
                 test_data, y_test, y_test_da
                ):
        super(tugda_da, self).__init__()
        
        self.learning_rate = params['lr']
        self.batch_size = params['bs_source']
        self.mu = params['mu']
        self.lambda_ = params['lambda_']
        self.gamma = params['gamma']
        self.lambda_disc = params['lambda_disc']
        self.bs_disc = params['bs_disc']
        self.n_epoch = params['n_epochs']
        self.passes = params['passes']
        self.num_tasks = params['num_tasks']

        self.train_data = train_data
        self.y_train = y_train
        self.x_target_unl = x_target_unl
        
        input_dim = self.train_data.shape[1]
        feature_extractor = [nn.Linear(input_dim, params['hidden_units_1']), 
                             nn.Dropout(p=params['dropout']),
                             nn.ReLU()]
        self.feature_extractor = nn.Sequential(*feature_extractor)

        latent_basis =  [nn.Linear(params['hidden_units_1'], params['latent_space']),
                         nn.Dropout(p=params['dropout']),
                         nn.ReLU()]
        self.latent_basis = nn.Sequential(*latent_basis)
        
        #task-specific weights
        self.S = nn.Linear(params['latent_space'], self.num_tasks)
        
        #decoder weights
        A = [nn.Linear( self.num_tasks , params['latent_space']), nn.ReLU()]
        self.A = nn.Sequential(*A)
        
        #uncertainty (aleatoric)
        self.log_vars = torch.zeros(self.num_tasks, requires_grad=True, device=device)
        
        #discriminator
        domain_classifier = [nn.Linear(params["latent_space"], params["n_units_disc"]), 
                             nn.Dropout(p=params['dropout']), 
                             nn.ReLU(),
                             nn.Linear(params["n_units_disc"], 1), 
                             nn.Sigmoid()
                            ]
        
        self.domain_classifier = nn.Sequential(*domain_classifier)
        
    def forward(self, input_data, alpha):
        
        x = self.feature_extractor(input_data)
        h = self.latent_basis(x)
        preds = self.S(h)
        h_hat = self.A(preds)
        reverse_feature = ReverseLayerF.apply(h, alpha)
        domain_output = self.domain_classifier(reverse_feature)
        return preds, h, h_hat, domain_output

    def prepare_data(self):
        train_dataset = TensorDataset(torch.FloatTensor(self.train_data),
                                      torch.FloatTensor(self.y_train))

        
        target_unl_dataset = TensorDataset(torch.FloatTensor(self.x_target_unl))
        
        self.train_dataset = train_dataset
        self.target_unl_dataset = target_unl_dataset

    def train_dataloader(self):
        dataloader1 = DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)
        dataloader2 = DataLoader(self.target_unl_dataset, batch_size=self.bs_disc, shuffle=True)
        self.len_dataloader = min(len(dataloader1), len(dataloader2))
        return zip(dataloader1, cycle(dataloader2))

    def test_dataloader(self):
        dataloader = DataLoader(self.target_unl_dataset, batch_size=len(self.target_unl_dataset), shuffle=False)
        return dataloader

    def configure_optimizers(self):
        
        opt_cla = torch.optim.Adagrad([
            {'params': self.feature_extractor.parameters()},
            {'params': self.latent_basis.parameters()},
            {'params': self.S.parameters()},
            {'params': self.A.parameters()},
            {'params': self.domain_classifier.parameters()},
            {'params': self.log_vars}
        ], lr=self.learning_rate)

        return opt_cla
    
    def binary_classification_loss(self, preds, labels):
        bin_loss = torch.nn.BCELoss()
        return bin_loss(preds, labels)

        
    def mse_ignore_nan(self, preds, labels):
        mse_loss = torch.nn.MSELoss(reduction='none')
        per_task_loss = torch.zeros(labels.size(1), device=device)

        for k in range(labels.size(1)):
            precision = torch.exp(-self.log_vars[k])
            diff = mse_loss(preds[~torch.isnan(labels[:,k]), k], labels[~torch.isnan(labels[:,k]), k])
            per_task_loss[k] = torch.mean(precision * diff + self.log_vars[k])
            
            
        return torch.mean(per_task_loss[~torch.isnan(per_task_loss)]), per_task_loss  
    
    def mse_ignore_nan_test(self, preds, labels):
        mse_loss = torch.nn.MSELoss(reduction='mean')
        per_task_loss = torch.zeros(labels.size(1), device=device)
        per_sample_loss = torch.zeros(labels.size(0), device=device)

        
        for k in range(labels.size(1)):
            per_task_loss[k] = mse_loss(preds[~torch.isnan(labels[:,k]), k], labels[~torch.isnan(labels[:,k]), k])
        
        #per class loss
        for k in range(labels.size(0)):
            per_sample_loss[k] = mse_loss(preds[k, ~torch.isnan(labels[k,:])], labels[k, ~torch.isnan(labels[k, :])])
        
        
        return torch.mean(per_task_loss[~torch.isnan(per_task_loss)]), per_task_loss,per_sample_loss 

    def MSE_loss(self, x, x_hat):
        mse_loss = torch.nn.MSELoss()
        return mse_loss(x, x_hat)
    
    def forward_pass(self, fw_batch, batch_idx):
        x, y = fw_batch[0]
        #unlabelled data (target)
        unl = fw_batch[1][0]
        
        #warm-up
        p = float(batch_idx +  self.current_epoch * self.len_dataloader) / self.n_epoch / self.len_dataloader
        alpha = 2. / (1. + np.exp(-10 * p)) - 1
        
        preds_simulation = torch.torch.zeros(y.size(0),y.size(1), self.passes, device=device)
        for simulation in range(self.passes):
            preds, h, h_hat, domain_out_source = self.forward(x, alpha)
            preds_simulation[:,:, simulation]=preds

        preds_mean = torch.mean(preds_simulation, axis=2)
        preds_var = torch.var(preds_simulation, axis=2)
        total_unc = torch.mean(preds_var, axis=0)

        m_loss, task_loss = self.mse_ignore_nan(preds_mean, y)
        recon_loss = self.gamma * self.MSE_loss(h, h_hat)
        a = 1 + (total_unc + torch.sum(torch.abs(self.A[0].weight.T),1))
        loss_weight = ( a[~torch.isnan(task_loss)] ) * task_loss[~torch.isnan(task_loss)]
        loss_weight = torch.sum(loss_weight)
        l1_S = self.mu * self.S.weight.norm(1)
        L = self.latent_basis[0].weight.norm(2) + self.feature_extractor[0].weight.norm(2)
        l2_L = self.lambda_ * L

        #domain discriminator source
        zeros = torch.zeros(y.size(0), device=self.device)
        d_loss_source = self.binary_classification_loss(domain_out_source, zeros)
        
        #domain discriminator target
        preds, h, h_hat, domain_out_target = self.forward(unl, alpha)
        ones = torch.ones(unl.size(0), device=self.device)
        d_loss_target = self.binary_classification_loss(domain_out_target, ones)
        
        d_loss = d_loss_source + d_loss_target

        #total loss
        loss = loss_weight + recon_loss + l1_S + l2_L + (self.lambda_disc *d_loss)
        return loss, m_loss, d_loss

    def training_step(self, train_batch, batch_idx):
        
        loss, task_loss, disc_loss = self.forward_pass(train_batch, batch_idx)
        logs = {'total_loss': loss, 'source_loss': task_loss, 'disc_loss': disc_loss}
        return {'loss': loss, 'log': logs}
    
    def test_step(self, test_batch, batch_idx):
    
        x_unl = test_batch[0]

        #dropout on
        self.feature_extractor[1].train()
        self.latent_basis[1].train()
        
        #TARGET STEPS
        #get model preds 
        preds_simulation = torch.torch.zeros(x_unl.size(0),self.num_tasks, self.passes, device=device)
    
        for simulation in range(self.passes):
            
            seed = simulation
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False
            
            preds, h_target_ls, _, _  = self.forward(x_unl, 0)
            preds_simulation[:,:, simulation]=preds
            
        preds_mean = torch.mean(preds_simulation, axis=2)
        
        #disable dropouts: train() # resets dropout to train
        self.feature_extractor[1].eval()
        self.latent_basis[1].eval()
        
        return {'preds': preds_mean.detach().cpu().numpy()
                }
    

In [19]:
#cell-line dataset;
gdsc_dataset = pd.read_csv('data/GDSCDA_fpkm_AUC_all_drugs.zip', index_col=0)
#gene set range
gene_list = gdsc_dataset.columns[0:1780]
drug_list = gdsc_dataset.columns[1780:]

In [20]:
#pdx novartis dataset;
pdx_dataset = pd.read_csv('data/PDX_MTL_DA.csv', index_col=0)
drugs_pdx = pdx_dataset.columns[1780:]

In [21]:
#best set of hyperparameters found on validation settings;
net_params = {
 'hidden_units_1': 1500,
 'latent_space': 800,
 'lr': 0.001,
 'dropout': 0.1,
 'mu': 1,
 'lambda_': 1,
 'gamma': 0.01,
 'n_units_disc': 500,
 'n_epochs': 50,
 'bs_disc': 64,
 'bs_source': 300,
 'lambda_disc': 0.3,
 'num_tasks': 200,
 'passes': 50}

In [22]:
#training and testing
print(net_params)

metrics_callback = MetricsCallback()

gene_list = gdsc_dataset.columns[0:1780]
drug_list = gdsc_dataset.columns[1780:]

X_train = gdsc_dataset[gene_list].values
y_train = gdsc_dataset[drug_list].values

X_test = pdx_dataset[gene_list].values
y_test = pdx_dataset[drugs_pdx].values

X_train_unl = pdx_dataset[gene_list].values

trainer = pl.Trainer(
    max_epochs=net_params['n_epochs'],
    gpus=1 if torch.cuda.is_available() else None,
    callbacks=[metrics_callback],
    deterministic=True,
    reload_dataloaders_every_epoch=True
)

seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

seed_everything(seed)
model = tugda_da(net_params, X_train, y_train, X_train_unl,
              X_train, y_train, y_test)

trainer.fit(model)

results = trainer.test(model)

preds = results[0]['preds']

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type       | Params
-------------------------------------------------
0 | feature_extractor | Sequential | 2 M   
1 | latent_basis      | Sequential | 1 M   
2 | S                 | Linear     | 160 K 
3 | A                 | Sequential | 160 K 
4 | domain_classifier | Sequential | 401 K 


{'hidden_units_1': 1500, 'latent_space': 800, 'lr': 0.001, 'dropout': 0.1, 'mu': 1, 'lambda_': 1, 'gamma': 0.01, 'n_units_disc': 500, 'n_epochs': 50, 'bs_disc': 64, 'bs_source': 300, 'lambda_disc': 0.3, 'num_tasks': 200, 'passes': 50}


Training: 0it [00:00, ?it/s]

Saving latest checkpoint..


Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'preds': array([[0.8763215 , 0.602008  , 0.8809368 , ..., 0.6359824 , 0.737716  ,
        0.7230629 ],
       [1.0196947 , 0.7056573 , 0.951037  , ..., 0.7362957 , 0.9649322 ,
        0.89155716],
       [0.83114064, 0.5705545 , 0.9167524 , ..., 0.7746541 , 0.83101714,
        0.90812194],
       ...,
       [0.8948665 , 0.76307493, 0.9733892 , ..., 0.83099943, 0.8795322 ,
        0.93821555],
       [0.6325121 , 0.66811043, 0.78564125, ..., 0.6174042 , 0.8663467 ,
        0.7862173 ],
       [0.85000914, 0.5742485 , 0.90454763, ..., 0.4991126 , 0.85457283,
        0.89456856]], dtype=float32)}
--------------------------------------------------------------------------------
