# **Deep Learning For Lung Cancer Prediction Based on Transcriptomic Data : a Basic MLP with Transfer Learning**
> Author : **Aymen MERROUCHE**. <br>
> In this notebook, we implement a basic MLP for our binary classification task. First we pre train our MLP on the non Lung cancer dataset. Then, we fine tune it on the lung cancer dataset (we don't keep the final classification layer) :

In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
import datetime

from utils import *
from train import *
from data_utils import *
from modules.MLP import *
from modules.focal_loss import *
%load_ext autoreload
%autoreload 2

In [2]:
# device to use, if cuda available then use cuda else use cpu
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Working on : ", device)

Working on :  cuda


In [3]:
# load hyperparametrs
# data paths args
with open('./configs/data_paths.yaml', 'r') as stream:
    data_paths_args  = yaml.load(stream,Loader=yaml.Loader)
    
    
# basic cnn args
with open('./configs/mlp.yaml', 'r') as stream:
    mlp_args  = yaml.load(stream,Loader=yaml.Loader)

## **1 - Pre-Training on the Non Lung Dataset :**

### **1 - 1 - Get the Data :**

In [4]:
%%time
# Getting the data
# dataset
non_lung_dataset = TranscriptomicVectorsDatasetNonLung(data_paths_args["path_to_pan_cancer_hdf5_files"])
non_lung_dataloader_train, non_lung_dataloader_validation = get_data_loaders(non_lung_dataset, batch_size_train = mlp_args["batch_size_pt"],\
                                                                             batch_size_validation = mlp_args["batch_size_pt"])

CPU times: user 12.1 s, sys: 418 ms, total: 12.5 s
Wall time: 12 s


### **1 - 2 - Network, Criterion and Training :**

In [5]:
# network
net = MLP(len(non_lung_dataset[0][0])).to(device).double()

# loss and optimizer  
criterion = FocalLoss().to(device)
optimizer = optim.Adam(net.parameters(), lr=mlp_args['lr_pt'])

# Logging + Experiment

ignore_keys = {'no_tensorboard'}
# get hyperparameters with values in a dict
hparams = {**mlp_args}
# generate a name for the experiment
expe_name = '_'.join([f"{key}={val}" for key, val in hparams.items()])
print("Experimenting with : \n \t"+expe_name)
# path where to save the model
savepath = Path('/tempory/transcriptomic_data/pre_trained_mlp_checkpt.pt')
# Tensorboard summary writer
if mlp_args['no_tensorboard']:
    writer = None
else:
    writer = SummaryWriter("runs/runs"+"_"+datetime.datetime.now().strftime("%Y%m%d-%H%M%S")+expe_name)
    
# start the experiment
checkpoint = CheckpointState(net, optimizer, savepath=savepath)
fit(checkpoint, criterion, non_lung_dataloader_train, non_lung_dataloader_validation, mlp_args['epochs'], writer=writer)
if not mlp_args['no_tensorboard']:
    writer.close()

Epoch 1/25:   5%|▍         | 2/41 [00:00<00:02, 14.35it/s, loss=1.9257e-01]

Experimenting with : 
 	epochs=25_batch_size_pt=128_lr_pt=0.001_batch_size_ft=128_lr_ft=0.001_no_tensorboard=True
Training on GPU 



Epoch 1/25: 100%|██████████| 41/41 [00:02<00:00, 16.63it/s, loss=1.9883e-01]


Epoch 1/25, Train Loss: 9.9510e-02, Test Loss: 0.1415
Epoch 1/25, Train Accuracy: 72.11%, Test Accuracy: 64.69%
Epoch 1/25, Train AUC: 77.48%, Test AUC: 72.08%
Classification Report on Val Set : 
              precision    recall  f1-score   support

   No Cancer       0.95      0.63      0.76      2248
      Cancer       0.21      0.75      0.33       295

    accuracy                           0.65      2543
   macro avg       0.58      0.69      0.54      2543
weighted avg       0.86      0.65      0.71      2543



Epoch 2/25: 100%|██████████| 41/41 [00:02<00:00, 16.45it/s, loss=1.1185e-01]


Epoch 2/25, Train Loss: 9.1687e-02, Test Loss: 0.0983
Epoch 2/25, Train Accuracy: 72.39%, Test Accuracy: 69.75%
Epoch 2/25, Train AUC: 79.06%, Test AUC: 73.71%
Classification Report on Val Set : 
              precision    recall  f1-score   support

   No Cancer       0.94      0.70      0.80      2248
      Cancer       0.23      0.66      0.34       295

    accuracy                           0.70      2543
   macro avg       0.58      0.68      0.57      2543
weighted avg       0.86      0.70      0.75      2543



Epoch 3/25: 100%|██████████| 41/41 [00:02<00:00, 16.48it/s, loss=1.0500e-01]


Epoch 3/25, Train Loss: 8.3851e-02, Test Loss: 0.1247
Epoch 3/25, Train Accuracy: 74.48%, Test Accuracy: 66.82%
Epoch 3/25, Train AUC: 77.66%, Test AUC: 71.60%
Classification Report on Val Set : 
              precision    recall  f1-score   support

   No Cancer       0.95      0.66      0.78      2248
      Cancer       0.22      0.72      0.33       295

    accuracy                           0.67      2543
   macro avg       0.58      0.69      0.56      2543
weighted avg       0.86      0.67      0.73      2543



Epoch 4/25: 100%|██████████| 41/41 [00:02<00:00, 16.50it/s, loss=1.0893e-01]


Epoch 4/25, Train Loss: 7.6306e-02, Test Loss: 0.1045
Epoch 4/25, Train Accuracy: 75.15%, Test Accuracy: 69.28%
Epoch 4/25, Train AUC: 80.82%, Test AUC: 73.64%
Classification Report on Val Set : 
              precision    recall  f1-score   support

   No Cancer       0.94      0.69      0.80      2248
      Cancer       0.23      0.68      0.34       295

    accuracy                           0.69      2543
   macro avg       0.58      0.69      0.57      2543
weighted avg       0.86      0.69      0.75      2543



Epoch 5/25: 100%|██████████| 41/41 [00:02<00:00, 16.57it/s, loss=4.0901e-02]


Epoch 5/25, Train Loss: 5.2452e-02, Test Loss: 0.0887
Epoch 5/25, Train Accuracy: 81.11%, Test Accuracy: 73.47%
Epoch 5/25, Train AUC: 87.02%, Test AUC: 74.85%
Classification Report on Val Set : 
              precision    recall  f1-score   support

   No Cancer       0.94      0.75      0.83      2248
      Cancer       0.25      0.63      0.36       295

    accuracy                           0.73      2543
   macro avg       0.59      0.69      0.59      2543
weighted avg       0.86      0.73      0.78      2543



Epoch 6/25: 100%|██████████| 41/41 [00:02<00:00, 16.55it/s, loss=5.6630e-02]


Epoch 6/25, Train Loss: 3.6650e-02, Test Loss: 0.0826
Epoch 6/25, Train Accuracy: 84.26%, Test Accuracy: 76.33%
Epoch 6/25, Train AUC: 91.81%, Test AUC: 75.49%
Classification Report on Val Set : 
              precision    recall  f1-score   support

   No Cancer       0.93      0.79      0.85      2248
      Cancer       0.26      0.58      0.36       295

    accuracy                           0.76      2543
   macro avg       0.60      0.68      0.61      2543
weighted avg       0.86      0.76      0.80      2543



Epoch 7/25:  56%|█████▌    | 23/41 [00:01<00:01, 15.52it/s, loss=3.0793e-02]


KeyboardInterrupt: 

## **2 - Fine-Tuning on the Lung dataset :**

### **2 - 1 - Load Pre-Trained Model :**

In [6]:
# Load the pretrained Model
net =  MLP(len(non_lung_dataset[0][0])).to(device).double()
optimizer = optim.Adam(net.parameters(), lr=mlp_args['lr_ft'])
# path where the pre-trained model is saved : defined above+"_best"
savepath = Path('/tempory/transcriptomic_data/pre_trained_mlp_checkpt_best.pt')
checkpoint = CheckpointState(net, optimizer, savepath=savepath)
checkpoint.load()
pretrained = checkpoint.model

RuntimeError: CUDA out of memory. Tried to allocate 12.00 MiB (GPU 0; 7.79 GiB total capacity; 1.60 GiB already allocated; 12.62 MiB free; 1.63 GiB reserved in total by PyTorch)

### **2 - 2 - Get the Data :**

In [None]:
%%time
# Getting the data
# dataset
lung_dataset = TranscriptomicVectorsDatasetLung(data_paths_args["path_to_pan_cancer_hdf5_files"])
lung_dataloader_train, lung_dataloader_validation = get_data_loaders(lung_dataset, batch_size_train = mlp_args["batch_size_ft"],\
                                                                             batch_size_validation = mlp_args["batch_size_ft"])

### **2 - 3 - Fine Tuning Procedure :**

In [None]:
# Beginnig Of Transfer Learnig Procedure
net = fine_tune_mlp(pretrained)
net = net.to(device).double()
criterion = FocalLoss().to(device)
optimizer = optim.Adam(net.parameters(), lr=mlp_args['lr_ft'])
savepath = Path('/tempory/transcriptomic_data/fine_tuned_mlp_checkpt.pt')
checkpoint = CheckpointState(net, optimizer, savepath=savepath)
fit(checkpoint, criterion, lung_dataloader_train, lung_dataloader_validation, mlp_args['epochs'])