# **Deep Learning For Lung Cancer Prediction Based on Transcriptomic Data : a Basic MLP with Transfer Learning**
> Author : **Aymen MERROUCHE**. <br>
> In this notebook, we implement a basic MLP for our binary classification task. First we pre train our MLP on the non Lung cancer dataset. Then, in a transfer learning fashion, we fine tune it on the lung cancer dataset (we don't keep the final classification layer meaning that we only keep the dense layers) :

In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
import datetime

from utils import *
from train import *
from data_utils import *
from modules.MLP import *
from modules.focal_loss import *
%load_ext autoreload
%autoreload 2

In [2]:
# device to use, if cuda available then use cuda else use cpu
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Working on : ", device)

Working on :  cuda


In [3]:
# load hyperparametrs
# data paths args
with open('./configs/data_paths.yaml', 'r') as stream:
    data_paths_args  = yaml.load(stream,Loader=yaml.Loader)
    
    
# basic cnn args
with open('./configs/mlp.yaml', 'r') as stream:
    mlp_args  = yaml.load(stream,Loader=yaml.Loader)

## **1 - Pre-Training on the Non Lung Dataset :**

### **1 - 1 - Get the Data :**

In [22]:
%%time
# Getting the data
# dataset
non_lung_dataset = TranscriptomicVectorsDatasetNonLung(data_paths_args["path_to_pan_cancer_hdf5_files"])
non_lung_dataloader_train, non_lung_dataloader_validation = get_data_loaders(non_lung_dataset, batch_size_train = mlp_args["batch_size_pt"],\
                                                                             batch_size_validation = mlp_args["batch_size_pt"])

CPU times: user 12.4 s, sys: 191 ms, total: 12.6 s
Wall time: 12 s


### **1 - 2 - Network, Criterion and Training :**

In [23]:
# network
net = MLP(len(non_lung_dataset[0][0])).to(device).double()

# loss and optimizer  
criterion = FocalLoss().to(device)
optimizer = optim.Adam(net.parameters(), lr=mlp_args['lr_pt'])

# Logging + Experiment

ignore_keys = {'no_tensorboard'}
# get hyperparameters with values in a dict
hparams = {**mlp_args}
# generate a name for the experiment
expe_name = '_'.join([f"{key}={val}" for key, val in hparams.items()])
print("Experimenting with : \n \t"+expe_name)
# path where to save the model
savepath = Path('/tempory/transcriptomic_data/pre_trained_mlp_checkpt.pt')
# Tensorboard summary writer
if mlp_args['no_tensorboard']:
    writer = None
else:
    writer = SummaryWriter("runs/runs"+"_"+datetime.datetime.now().strftime("%Y%m%d-%H%M%S")+expe_name)
    
# start the experiment
checkpoint = CheckpointState(net, optimizer, savepath=savepath)
fit(checkpoint, criterion, non_lung_dataloader_train, non_lung_dataloader_validation, mlp_args['epochs'], writer=writer)
if not mlp_args['no_tensorboard']:
    writer.close()

Epoch 1/25:   5%|▍         | 2/41 [00:00<00:02, 14.17it/s, loss=1.7530e-01]

Experimenting with : 
 	epochs=25_batch_size_pt=128_lr_pt=0.001_batch_size_ft=128_lr_ft=0.001_no_tensorboard=True
Training on GPU 



Epoch 1/25: 100%|██████████| 41/41 [00:02<00:00, 16.43it/s, loss=1.7146e-01]


Epoch 1/25, Train Loss: 9.8874e-02, Test Loss: 0.1370
Epoch 1/25, Train Accuracy: 71.36%, Test Accuracy: 64.58%
Epoch 1/25, Train AUC: 77.40%, Test AUC: 75.01%
Classification Report on Val Set : 
              precision    recall  f1-score   support

   No Cancer       0.95      0.63      0.76      2258
      Cancer       0.20      0.75      0.32       285

    accuracy                           0.65      2543
   macro avg       0.58      0.69      0.54      2543
weighted avg       0.87      0.65      0.71      2543



Epoch 2/25: 100%|██████████| 41/41 [00:02<00:00, 16.53it/s, loss=7.3316e-02]


Epoch 2/25, Train Loss: 8.6726e-02, Test Loss: 0.1128
Epoch 2/25, Train Accuracy: 74.35%, Test Accuracy: 68.75%
Epoch 2/25, Train AUC: 80.45%, Test AUC: 75.41%
Classification Report on Val Set : 
              precision    recall  f1-score   support

   No Cancer       0.95      0.68      0.79      2258
      Cancer       0.22      0.73      0.34       285

    accuracy                           0.69      2543
   macro avg       0.59      0.70      0.57      2543
weighted avg       0.87      0.69      0.74      2543



Epoch 3/25: 100%|██████████| 41/41 [00:02<00:00, 16.54it/s, loss=8.8954e-02]


Epoch 3/25, Train Loss: 8.3718e-02, Test Loss: 0.1135
Epoch 3/25, Train Accuracy: 74.97%, Test Accuracy: 69.69%
Epoch 3/25, Train AUC: 80.66%, Test AUC: 76.55%
Classification Report on Val Set : 
              precision    recall  f1-score   support

   No Cancer       0.95      0.69      0.80      2258
      Cancer       0.23      0.74      0.35       285

    accuracy                           0.70      2543
   macro avg       0.59      0.71      0.58      2543
weighted avg       0.87      0.70      0.75      2543



Epoch 4/25: 100%|██████████| 41/41 [00:02<00:00, 16.58it/s, loss=2.8556e-02]


Epoch 4/25, Train Loss: 6.8399e-02, Test Loss: 0.0910
Epoch 4/25, Train Accuracy: 77.10%, Test Accuracy: 72.67%
Epoch 4/25, Train AUC: 84.19%, Test AUC: 77.70%
Classification Report on Val Set : 
              precision    recall  f1-score   support

   No Cancer       0.95      0.73      0.83      2258
      Cancer       0.25      0.69      0.36       285

    accuracy                           0.73      2543
   macro avg       0.60      0.71      0.59      2543
weighted avg       0.87      0.73      0.77      2543



Epoch 5/25: 100%|██████████| 41/41 [00:02<00:00, 16.51it/s, loss=1.5504e-01]


Epoch 5/25, Train Loss: 4.8252e-02, Test Loss: 0.0673
Epoch 5/25, Train Accuracy: 82.43%, Test Accuracy: 77.49%
Epoch 5/25, Train AUC: 87.69%, Test AUC: 75.94%
Classification Report on Val Set : 
              precision    recall  f1-score   support

   No Cancer       0.94      0.80      0.86      2258
      Cancer       0.27      0.61      0.38       285

    accuracy                           0.77      2543
   macro avg       0.61      0.70      0.62      2543
weighted avg       0.87      0.77      0.81      2543



Epoch 6/25:  34%|███▍      | 14/41 [00:00<00:01, 14.94it/s, loss=7.0421e-02]


KeyboardInterrupt: 

## **2 - Fine-Tuning on the Lung dataset :**

### **2 - 1 - Load Pre-Trained Model :**

In [24]:
# Load the pretrained Model
net =  MLP(len(non_lung_dataset[0][0])).to(device).double()
optimizer = optim.Adam(net.parameters(), lr=mlp_args['lr_ft'])
# path where the pre-trained model is saved : defined above+"_best"
savepath = Path('/tempory/transcriptomic_data/pre_trained_mlp_checkpt_best.pt')
checkpoint = CheckpointState(net, optimizer, savepath=savepath)
checkpoint.load()
pretrained = checkpoint.model

### **2 - 2 - Get the Data :**

In [26]:
%%time
# Getting the data
# dataset
lung_dataset = TranscriptomicVectorsDatasetLung(data_paths_args["path_to_pan_cancer_hdf5_files"])
lung_dataloader_train, lung_dataloader_validation = get_data_loaders(lung_dataset, batch_size_train = mlp_args["batch_size_ft"],\
                                                                             batch_size_validation = mlp_args["batch_size_ft"])

CPU times: user 11.1 s, sys: 106 ms, total: 11.2 s
Wall time: 11.1 s


### **2 - 3 - Fine Tuning Procedure :**

In [27]:
# Beginnig Of Transfer Learnig Procedure
net = fine_tune_mlp(pretrained)
net = net.to(device).double()
criterion = FocalLoss().to(device)
optimizer = optim.Adam(net.parameters(), lr=mlp_args['lr_ft'])
savepath = Path('/tempory/transcriptomic_data/fine_tuned_mlp_checkpt.pt')
checkpoint = CheckpointState(net, optimizer, savepath=savepath)
fit(checkpoint, criterion, lung_dataloader_train, lung_dataloader_validation, mlp_args['epochs'])

Epoch 1/25:  20%|██        | 1/5 [00:00<00:00,  9.74it/s, loss=1.3240e-01]

Training on GPU 



Epoch 1/25: 100%|██████████| 5/5 [00:00<00:00, 14.84it/s, loss=1.5134e-01]


Epoch 1/25, Train Loss: 8.5251e-02, Test Loss: 0.1014
Epoch 1/25, Train Accuracy: 72.72%, Test Accuracy: 74.06%
Epoch 1/25, Train AUC: 84.56%, Test AUC: 64.02%
Classification Report on Val Set : 
              precision    recall  f1-score   support

   No Cancer       0.93      0.74      0.83       257
      Cancer       0.14      0.44      0.22        25

    accuracy                           0.72       282
   macro avg       0.54      0.59      0.52       282
weighted avg       0.86      0.72      0.77       282



Epoch 2/25: 100%|██████████| 5/5 [00:00<00:00, 16.52it/s, loss=5.6080e-02]


Epoch 2/25, Train Loss: 3.3142e-02, Test Loss: 0.0752
Epoch 2/25, Train Accuracy: 86.39%, Test Accuracy: 77.20%
Epoch 2/25, Train AUC: 92.36%, Test AUC: 70.41%
Classification Report on Val Set : 
              precision    recall  f1-score   support

   No Cancer       0.94      0.80      0.87       257
      Cancer       0.19      0.48      0.27        25

    accuracy                           0.77       282
   macro avg       0.57      0.64      0.57       282
weighted avg       0.87      0.77      0.81       282



Epoch 3/25: 100%|██████████| 5/5 [00:00<00:00, 16.77it/s, loss=2.5031e-02]


Epoch 3/25, Train Loss: 2.9853e-02, Test Loss: 0.0925
Epoch 3/25, Train Accuracy: 84.85%, Test Accuracy: 74.62%
Epoch 3/25, Train AUC: 92.50%, Test AUC: 70.04%
Classification Report on Val Set : 
              precision    recall  f1-score   support

   No Cancer       0.94      0.78      0.85       257
      Cancer       0.17      0.48      0.26        25

    accuracy                           0.75       282
   macro avg       0.56      0.63      0.55       282
weighted avg       0.87      0.75      0.80       282



Epoch 4/25: 100%|██████████| 5/5 [00:00<00:00, 16.39it/s, loss=1.6200e-02]


Epoch 4/25, Train Loss: 1.1688e-02, Test Loss: 0.0856
Epoch 4/25, Train Accuracy: 91.50%, Test Accuracy: 76.68%
Epoch 4/25, Train AUC: 97.63%, Test AUC: 69.42%
Classification Report on Val Set : 
              precision    recall  f1-score   support

   No Cancer       0.94      0.79      0.86       257
      Cancer       0.18      0.48      0.27        25

    accuracy                           0.77       282
   macro avg       0.56      0.64      0.56       282
weighted avg       0.87      0.77      0.81       282



Epoch 5/25: 100%|██████████| 5/5 [00:00<00:00, 16.00it/s, loss=9.5307e-03]


Epoch 5/25, Train Loss: 1.2698e-02, Test Loss: 0.0764
Epoch 5/25, Train Accuracy: 91.91%, Test Accuracy: 77.20%
Epoch 5/25, Train AUC: 97.38%, Test AUC: 68.02%
Classification Report on Val Set : 
              precision    recall  f1-score   support

   No Cancer       0.94      0.80      0.87       257
      Cancer       0.19      0.48      0.27        25

    accuracy                           0.77       282
   macro avg       0.57      0.64      0.57       282
weighted avg       0.87      0.77      0.81       282



Epoch 6/25: 100%|██████████| 5/5 [00:00<00:00, 15.80it/s, loss=6.0172e-03]


Epoch 6/25, Train Loss: 6.0508e-03, Test Loss: 0.0666
Epoch 6/25, Train Accuracy: 93.41%, Test Accuracy: 79.79%
Epoch 6/25, Train AUC: 98.97%, Test AUC: 68.45%
Classification Report on Val Set : 
              precision    recall  f1-score   support

   No Cancer       0.94      0.82      0.88       257
      Cancer       0.21      0.48      0.29        25

    accuracy                           0.79       282
   macro avg       0.58      0.65      0.59       282
weighted avg       0.88      0.79      0.83       282



KeyboardInterrupt: 