## Train the Transformer Module

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import Image
from IPython.core.debugger import set_trace
from torch import nn as nn
import torch
import pytorch_lightning as pl
from pytorch_lightning import LightningDataModule
from torch.utils.data import Dataset, DataLoader
from sklearn import preprocessing
from torch.nn import TransformerEncoderLayer, TransformerEncoder
from torch.nn import functional as F
from Data.Drosophilla.FlyDataMod import FlyDataModule
import yaml
import os
from pytorch_lightning.callbacks import EarlyStopping

In [None]:
dm = FlyDataModule(cell_line="S2",
                  data_win_radius=5,
                  batch_size=2,
                  label_type="gamma",
                  label_val=10)
dm.setup()

In [None]:
import math
from IPython.core.debugger import set_trace

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=11):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)[:,:-1]
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        r"""Inputs of forward function
        Args:
            x: the sequence fed to the positional encoder model (required).
        Shape:
            x: [sequence length, batch size, embed dim]
            output: [sequence length, batch size, embed dim]
        Examples:
            >>> output = pos_encoder(x)
        """

        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)
    
def weighted_mse(output, target):
    alpha = 11
    N     = output.shape[1]
    wmse  = torch.mean(1/N * torch.sum((target-output)**2* (alpha-target)/alpha, dim=1))
    return wmse

class  TransformerModule(pl.LightningModule):
    def __init__(self,
                ntoken,
                ninp,
                nhead,
                nhid,
                nlayers,
                dropout=0.5,
                optimi="Adam",
                lr=.01,
                loss_type="weighted"):
        print("init")
        super().__init__()
        
        self.ntoken    = ntoken
        self.ninp      = int(ninp)
        self.nhead     = nhead
        self.nhid      = nhid
        self.nlayers   = nlayers
        self.dropout   = dropout
        self.optimi    = optimi
        self.lr        = lr
        self.loss_type = loss_type
        
        self.pos_enc   = PositionalEncoding(ninp, dropout)
        encoder_layer  = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layer, nlayers)
        self.decoder   = nn.Linear(ninp, ntoken)
        self.src_mask  = None
        self.init_weights()
        self.save_hyperparameters()
    
    def init_weights(self):
        initrange = 0.1
        nn.init.zeros_(self.decoder.weight)
        nn.init.uniform_(self.decoder.weight, -initrange, initrange)

    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0,1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask ==1, float(0.0))
        return mask
    
    def forward(self, src, has_mask=True):
        BATCH_SIZE = src.shape[0]
        SEQ_LEN    = src.shape[1]
        EMBED_DIM  = src.shape[2]
        src        = src.permute(1, 0, 2)
        if has_mask:
            device = src.device
            if self.src_mask is None or self.src_mask.size(0) != len(src):
                mask = self._generate_square_subsequent_mask(len(src)).to(device)
                mask = mask.float()
                self.src_mask = mask
        else:
            self.src_mask = None
        src    = src * math.sqrt(self.ninp)
        src    = self.pos_enc(src)
        output = self.transformer_encoder(src, self.src_mask)
        output = self.decoder(output)
        output = output.permute(1,0,2)
        return output
    
    def training_step(self, batch, batch_idx):
        feature, label = batch
        feature        = feature.float()
        label          = label.float()
        output         = self.forward(feature)
        #loss           = F.mse_loss(output, label)
        if self.loss_type =="weighted":
            loss           = weighted_mse(output, label)
        else:
            loss       = F.mse_loss(output, label)
        self.log("train weighted mse loss", loss)
        return loss
    
    def validation_step(self, batch, batch_indx):
        feature, label = batch
        feature        = feature.float()
        output         = self.forward(feature)
        loss           = weighted_mse(output, label)
        self.log("val weighted mse loss", loss)
        return loss
        
    def test_step(self, batch, batch_indx):
        feature, label = batch
        feature        = feature.float()
        output         = self.forward(feature)
        fig, ax = plt.subplots(1)
        ax.plot(label[0,:,0].cpu(), label="label")
        ax.plot(output[0,:,0].cpu(),  label="output")
        plt.legend()
        plt.show()
    
    def configure_optimizers(self):
        #optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        if self.optimi == "Adam":
            optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        if self.optimi == "SGD":
            optimizer = torch.optim.SGD(self.parameters(), lr=self.lr)
        if self.optimi == "RMSprop":
            optimizer = torch.optim.RMSprop(self.parameters(), lr=self.lr)
        return optimizer
        
    

In [None]:
model   = TransformerModule(
                ntoken=1,
                ninp=29,
                nhead=1,
                nhid=2048,
                nlayers=1,
                dropout=0,
                optimi="Adam",
                lr=.001)

trainer = pl.Trainer(gpus=1,
                    max_epochs=200)

trainer.fit(model, dm)


In [None]:
trainer.test()

We build a few helper functions for evaluation

In [None]:
from sklearn.metrics import r2_score
from sklearn.metrics import mean_squared_error
from sklearn.metrics import mean_absolute_error
from scipy.stats import pearsonr
from scipy.stats import spearmanr


def getModelPredictions(model,
                   dm,
                   tvt):
    if tvt=="test":
        dataloader = dm.test_dataloader()
    if tvt=="train":
        dataloader = dm.train_dataloader()
    if tvt=="val":
        dataloader =dm.val_dataloader()
        
    full_label_vec  = []
    full_output_vec = []
    for b, batch in enumerate(dataloader):
        feature, label = batch
        feature = feature.to('cuda:0').float()
        label   = label.to('cuda:0').float()
        output  = model(feature)
        label   = label.squeeze()
        output  = output.squeeze()
        full_label_vec.append(label[int(len(label)/2)].item())
        full_output_vec.append(output[int(len(output)/2)].item())

    return full_label_vec, full_output_vec

def getModelMetrics(model,
                   dm,
                   tvt):
    label_vec, output_vec = getModelPredictions(model,
                                               dm,
                                               tvt)
    scores             = {}
    scores['mse']      = mean_squared_error(label_vec, output_vec)
    scores['mae']      = mean_absolute_error(label_vec, output_vec)
    scores['r2']       = r2_score(label_vec, output_vec)
    scores['pearson']  = pearsonr(label_vec, output_vec)[0]
    scores['spearman'] = spearmanr(label_vec, output_vec)[0]
    return scores

def createPlot(model,
                dm,
                tvt,
               fig_name,
               start=0,
              end=200):
    fig, ax = plt.subplots(1, figsize=(20,7))
    label_vec, output_vec = getModelPredictions(model, dm, tvt)
    ax.plot(label_vec[start:end],  label="label",      color="crimson")
    ax.plot(output_vec[start:end], label="prediction", color="silver")
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_ylabel(dm.label_type)
    ax.set_xlabel(tvt+"data")
    plt.legend()
    plt.savefig(fig_name)


In [None]:
dm = FlyDataModule(cell_line="S2",
                  data_win_radius=5,
                  batch_size=1,
                  label_type="gamma",
                  label_val=10)
dm.setup()
print(getModelMetrics(model, dm, 'val'))
createPlot(model, dm, "val", "val.png")

In [None]:
#pulled from https://github.com/optuna/optuna/issues/1186

import yaml
import os
from pytorch_lightning.loggers import LightningLoggerBase

class DictLogger(LightningLoggerBase):
    """PyTorch Lightning `dict` logger."""

    def __init__(self, version, root_dir):
        super(DictLogger, self).__init__()
        self.metrics = []
        self._version = version
        self.root_dir = root_dir

    def log_metrics(self, metrics, step=None):
        self.metrics.append(metrics)

    @property
    def version(self):
        return self._version

    @property
    def experiment(self):
        """Return the experiment object associated with this logger."""

    def log_hyperparams(self, params):
        """
        Record hyperparameters.
        Args:
            params: :class:`~argparse.Namespace` containing the hyperparameters
        """
        if not os.path.isdir(self.root_dir):
            os.mkdir(self.root_dir)
        if not os.path.isdir(self.root_dir+"/optuna"):
            os.mkdir(self.root_dir+"/optuna")
        dirr = self.root_dir+'/optuna/version_'+str(self._version)
        if not os.path.isdir(dirr):
            os.mkdir(self.root_dir+'/optuna/version_'+str(self._version))
        with open(dirr+"/hparams.yaml", 'w') as outfile:
            print("logging them hyperparams:"+str(self.root_dir))
            yaml.dump(params, outfile)

    @property
    def name(self):
        """Return the experiment name."""
        return 'optuna'   

In [None]:
# Experiment 1
# Determine Hyper Parameters for Transformer
import optuna
import pytorch_lightning as pl
import os

def objective(trial):
    rootdir = "Experiments/Transformer_Hyperparameter_Tuning"
    if not os.path.isdir(rootdir):
        os.mkdir(rootdir)
    logger = DictLogger(trial.number,
                       rootdir)
    trainer = pl.Trainer(
        logger=logger,
        gpus=1,
        max_epochs=200,
        default_root_dir=rootdir
    )
    
    
    #hyper params
    lr      = trial.suggest_categorical("lr", [1e-5, 1e-4, 1e-3, 1e-2, 1e-1])
    dropout    = trial.suggest_categorical("dropout",[0, 0.1, 0.2, 0.3])
    batch_size = trial.suggest_categorical("batch_size", [1,4,16,64])
    optimi     = trial.suggest_categorical("optimi", ["Adam", "SGD"])
    
    dm = FlyDataModule(cell_line="S2",
                  data_win_radius=5,
                  batch_size=batch_size,
                  label_type="gamma",
                  label_val=10)
    dm.setup()
    
    model   = TransformerModule(
                ntoken=1,
                ninp=29,
                nhead=1,
                nhid=2048,
                nlayers=6,
                dropout=dropout,
                optimi=optimi,
                lr=lr)
    
    trainer.fit(model, dm)
    return logger.metrics[-1]['val weighted mse loss']

        
study = optuna.create_study(direction="minimize")
study.optimize(objective, n_trials=10)
print("Number of finished trials: {}".format(len(study.trials)))
print("Best Trial")
trial = study.best_trial
print(trial)


In [None]:
# Experiment 2
# Vary number of layers
import yaml
import os
from pytorch_lightning.callbacks import EarlyStopping



best_params = yaml.load(open("Experiments/Transformer_Hyperparameter_Tuning/optuna/version_4/hparams.yaml"))
print(best_params)

dm = FlyDataModule(cell_line="S2",
                  data_win_radius=5,
                  batch_size=1,
                  label_type='gamma')
dm.setup()

for num_layers in reversed(range(1,10)):
    early_stop_callback = EarlyStopping(
    monitor="val weighted mse loss",
    min_delta=0.00,
    patience=3,
    verbose=False,
    mode='min')
    
    root_dir = "Experiments/Transformer_Num_Layers"
    if not os.path.isdir(root_dir):
        os.mkdir(root_dir)
    
    logger = DictLogger(num_layers,
                       root_dir)
    
    print("Training:"+str(num_layers))
    model = TransformerModule(
        ntoken=best_params['ntoken'],
        ninp=best_params['ninp'],
        nhid=best_params['nhid'],
        nhead=best_params['nhead'],
        nlayers=num_layers,
        dropout=best_params['dropout'],
        optimi=best_params['optimi'],
        lr=best_params['lr'])
    
    trainer = pl.Trainer(
        logger=logger,
        gpus=1,
        max_epochs=200,
        default_root_dir=root_dir,
        callbacks=[early_stop_callback])
    
    trainer.fit(model, dm)


In [None]:
dm         = FlyDataModule(cell_line="S2",
                  data_win_radius=5,
                  batch_size=1,
                  label_type="gamma",
                  label_val=10)
dm.setup()
best_weights = "Experiments/Transformer_Hyperparameter_Tuning/optuna/version_4/checkpoints/epoch=199-step=195599.ckpt"
best_model   = TransformerModule.load_from_checkpoint(best_weights).to("cuda:0")
print(best_model)
print(getModelMetrics(best_model, dm, 'test'))
createPlot(best_model, dm, "test", "test.png")

In [None]:
hparm_vals = np.zeros((11,8), dtype='U7')
metrics    = ['mse','mae','r2', 'pearson','spearman']

labels=list(yaml.load(open("Experiments/Transformer_Hyperparameter_Tuning/optuna/version_1/hparams.yaml")).keys())
hparm_vals[0,0]=labels[0]
hparm_vals[0,1]=labels[1]
hparm_vals[0,2]=labels[7]
for j, metric in enumerate(metrics):
    hparm_vals[0,3+j]=metric
for i in range(0,10):
    hpar = list(yaml.load(open("Experiments/Transformer_Hyperparameter_Tuning/optuna/version_"+str(i)+"/hparams.yaml")).values())
    hparm_vals[i+1,0]=hpar[0]
    hparm_vals[i+1,1]=hpar[1]
    hparm_vals[i+1,2]=hpar[7]
    layer_weights = glob.glob("Experiments/Transformer_Hyperparameter_Tuning/optuna/version_"+str(i)+"/checkpoints/*")[0]
    layer_model   = TransformerModule.load_from_checkpoint(layer_weights).to("cuda:0")
    results= getModelMetrics(layer_model,
                   dm,
                   'val')
    for j, metric in enumerate(metrics):
        hparm_vals[i+1, j+3]= results[metric]
print(hparm_vals)

fig, ax = plt.subplots(1)
ax.axis('off')
#for loc in ['top','bottom','left','right']:
#    ax.spines[loc].set_visible(False)
table = ax.table(hparm_vals)
table.auto_set_font_size(False)
table.set_fontsize(12)
table.scale(1.5, 1.5)
plt.show()
    #layer_weights = glob.glob("Experiments/Transformer_Hyperparameter_Tuning/optuna/version_"+str(i)+"checkpoints/*")[0]
    #layer_model   = TransformerModule.load_from_checkpoint(layer_weights).to("cuda:0")
    

In [None]:
# view 
import glob
dm         = FlyDataModule(cell_line="S2",
                  data_win_radius=5,
                  batch_size=1,
                  label_type="gamma",
                  label_val=10)
dm.setup()

train_metrics = {}
val_metrics   = {}

for full_metrics, tvt in zip([train_metrics, val_metrics],['train', 'val']):
    for i in range(1,10):
        layer_weights = glob.glob("Experiments/Transformer_Num_Layers/optuna/version_"+str(i)+"/checkpoints/*")[0]
        layer_model   = TransformerModule.load_from_checkpoint(layer_weights).to("cuda:0")
        metrics = getModelMetrics(layer_model, dm, 'val')
        for key in metrics.keys():
            if key not in full_metrics:
                full_metrics[key] = []
            full_metrics[key].append(metrics[key])

for key in metrics.keys():
    fig, ax = plt.subplots(1)
    ax.plot(list(range(1, len(train_metrics[key])+1)), train_metrics[key], label="train", color="cornflowerblue")
    ax.plot(list(range(1, len(train_metrics[key])+1)), val_metrics[key], label="val", color="indigo")
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_ylabel(key)
    ax.set_xlabel("num layers")
    plt.legend()
    plt.show()
    

In [None]:
# Experiment 3
# Vary window size

import yaml
import os
from pytorch_lightning.callbacks import EarlyStopping



best_params = yaml.load(open("Experiments/Transformer_Num_Layers/optuna/version_5/hparams.yaml"))
print(best_params)

for data_win_radius in reversed(range(1,6)):
    early_stop_callback = EarlyStopping(
    monitor="val weighted mse loss",
    min_delta=0.00,
    patience=3,
    verbose=False,
    mode='min')
    
    
    dm = FlyDataModule(cell_line="S2",
                  data_win_radius=data_win_radius,
                  batch_size=1,
                  label_type='gamma')
    dm.setup()
    
    
    
    
    root_dir = "Experiments/Transformer_Window_Radius"
    if not os.path.isdir(root_dir):
        os.mkdir(root_dir)
    
    logger = DictLogger(data_win_radius,
                       root_dir)
    
    print("Training:"+str(data_win_radius))
    model = TransformerModule(
        ntoken=best_params['ntoken'],
        ninp=best_params['ninp'],
        nhid=best_params['nhid'],
        nhead=best_params['nhead'],
        nlayers=best_params['nlayers'],
        dropout=best_params['dropout'],
        optimi=best_params['optimi'],
        lr=best_params['lr'])
    
    trainer = pl.Trainer(
        logger=logger,
        gpus=1,
        max_epochs=200,
        default_root_dir=root_dir,
        callbacks=[early_stop_callback])
    trainer.fit(model, dm)

In [None]:
# view 
import glob
dm         = FlyDataModule(cell_line="S2",
                  data_win_radius=5,
                  batch_size=1,
                  label_type="gamma",
                  label_val=10)
dm.setup()

train_metrics = {}
val_metrics   = {}

for full_metrics, tvt in zip([train_metrics, val_metrics],['train', 'val']):
    for i in range(1,6):
        layer_weights = glob.glob("Experiments/Transformer_Window_Radius/optuna/version_"+str(i)+"/checkpoints/*")[0]
        layer_model   = TransformerModule.load_from_checkpoint(layer_weights).to("cuda:0")
        metrics = getModelMetrics(layer_model, dm, 'val')
        for key in metrics.keys():
            if key not in full_metrics:
                full_metrics[key] = []
            full_metrics[key].append(metrics[key])

for key in metrics.keys():
    fig, ax = plt.subplots(1)
    ax.plot(list(range(1, len(train_metrics[key])+1)), train_metrics[key], label="train", color="cornflowerblue")
    ax.plot(list(range(1, len(train_metrics[key])+1)), val_metrics[key], label="val", color="indigo")
    ax.set_ylabel(key)
    ax.set_xticks([1,2,3,4,5])
    ax.set_xlabel("Window Radius")
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    plt.legend()
    plt.show()

In [None]:
# Prediction using Transformer across different cell lines
cell_lines = ['S2', 'KC', 'BG']
for trial, train_line in enumerate(cell_lines):
    dm = FlyDataModule(cell_line=train_line,
                  data_win_radius=4,
                  batch_size=1,
                  label_type='gamma')
    dm.setup()
    
    early_stop_callback = EarlyStopping(
    monitor="val weighted mse loss",
    min_delta=0.00,
    patience=3,
    verbose=False,
    mode='min')
    
    rootdir = "Experiments/Transformer_Cross_Cell_Line_Comparison"
    if not os.path.isdir(rootdir):
        os.mkdir(rootdir)
    logger = DictLogger(trial,
                       rootdir)
    trainer = pl.Trainer(
        logger=logger,
        gpus=1,
        max_epochs=200,
        default_root_dir=rootdir,
        callbacks=[early_stop_callback]
    )
    model = TransformerModule(
        ntoken=1,
        ninp=29,
        nhid=2048,
        nhead=1,
        nlayers=5,
        dropout=0.2,
        optimi="SGD",
        lr=0.0001)
    trainer.fit(model, dm)

In [None]:
import glob
all_metrics = {}
for j, test_line in enumerate(cell_lines):
    dm = FlyDataModule(cell_line=train_line,
                  data_win_radius=4,
                  batch_size=1,
                  label_type='gamma')
    dm.setup()
    for i, train_line in enumerate(cell_lines):
        layer_weights = glob.glob("Experiments/Transformer_Cross_Cell_Line_Comparison/optuna/version_"+str(i)+"/checkpoints/*")[0]
        layer_model   = TransformerModule.load_from_checkpoint(layer_weights).to("cuda:0")
        metrics = getModelMetrics(layer_model, dm, 'test')
        all_metrics[i,j]=metrics
        
        

In [None]:
getmetric = lambda met, i,j: all_metrics[i,j][met]
print(getmetric('r2',0,0))


for met in ['mse','mae','r2','pearson','spearman']:
    vals = np.zeros((3,3))
    for i in range(0,3):
        for j in range(0,3):
            vals[i,j]="{:.2f}".format(getmetric(met, i,j))
    print(met)
    fig, ax = plt.subplots(1, figsize=(10,10))
    table = ax.table(vals,
             cellLoc="center",
            colLabels=cell_lines,
            rowLabels=cell_lines)
    table.set_fontsize(14)
    table.scale(1.5, 1.5)
    ax.axis('off')
    plt.subplots_adjust(left=0.2, top=0.8)
    plt.show()

# Different TAD characterization