# Training Notebook

This notebook is used for training on data prepared using the data preparation notebook

In [1]:
%load_ext autoreload
%autoreload 2
%reload_ext tensorboard
#%matplotlib qt

In [2]:

# This is a blanket include statement cell, some of these may not be used in this notebook specifically
import os
from datetime import datetime
import pandas as pd
import numpy as np
import json
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from collections import OrderedDict
import SimpleITK as sitk
#import logging
#logging.getLogger("tensorflow").setLevel(logging.ERROR)
from collections import OrderedDict
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets

import torchmetrics

#import initial_ml as iml
from gbm_project import data_prep as dp
from gbm_project.pytorch.run_model_torch import RunModel

from MedicalNet.models import resnet

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.cuda.manual_seed_all(42)
print(f"using {device} device")

using cuda device


In [4]:
##### Uncomment lines corresponding to desired modalities to run with
'''
csv_dir: directory to the csv files of the UPENN-GBM dataset
image_dir: the location of the npy files for a specified modality
modality: the modality group to be selected
'''
csv_dir = '../../data/upenn_GBM/csvs/radiomic_features_CaPTk/'
mod_image_dir = '../../data/upenn_GBM'

#####
#modality = 'DSC'
#modality = 'DTI'
modality = 'struct'
#####
classifier = 'MGMT'

image_dir = os.path.join(mod_image_dir, f"numpy_conversion_{modality}_augmented_channels")
# This function gives the list of patients that will be used in the training
patients = dp.retrieve_patients(csv_dir, image_dir, modality='npy', classifier=classifier)

In [5]:
# loading in hyper-parameter and model configurations, editing them to fit desired parameter set

from gbm_project.pytorch.gen_params_torch_cfg import gen_params, model_config
# when using transfer learning only (no spottune) this specifies which layers to fine-tune
model_config['no_freeze'] = ['conv_seg', 'layer4']
model_config['n_epochs'] = 90
gen_params['data_dir'] = image_dir

# the number of channels to pull from the npy file, set at 1 channel
gen_params['n_channels'] = 1
# the index of the numpy array that contains the image data, corresponding to a specific modality/derivative
gen_params['channel_idx'] = 0
# a flag to use clinical data: currently using sex and age
gen_params['use_clinical'] = False
gen_params['n_classes'] = 1

# Whether to use a learning rate scheduler
model_config['lr_sched'] = True

# Whether to use spottune as the network
model_config['spottune'] = True
model_config['lr_patience'] = 20
model_config['learning_rate'] = 1e-5
model_config['agent_learning_rate'] = 1e-4

# a list of epochs at which to change the GS temperature, and the values to change them to
model_config['temp_steps'] = [0]
model_config['temp_vals'] = [1e2]

# dropout rates
model_config['dropout'] = 0.0
model_config['agent_dropout'] = 0.0

In [6]:
# Create the training/test splits for a specified modality
X_test, y_test, kfold, X_kfold, y_kfold = dp.split_image_v2(csv_dir, mod_image_dir, n_cat=1, n_splits=5, modality=modality, seed=model_config['split_seed'])

## Training cells
The following cells start training for two different situtations

In [None]:
# To train on the derivative specified in 'gen_params['channel_idx']
# performs 5-fold training on a single derivative

for i, (train_index, test_index) in enumerate(kfold.split(X_kfold, y_kfold)):
    model = RunModel(model_config, gen_params)
    print(f"---------Fold {i}--------------")
    torch.manual_seed(42)

    # The network to train with
    ##############################################################
    # ResNet50 network using MedicalNet weights
    #model.set_model(model_name='MedResNet50', transfer=True)
    
    # ResNet50 network using ImageNet weights
    #model.set_model(model_name='ResNet50_torch', transfer=True)

    # ResNet50 with randomly initialized weights
    #model.set_model(model_name='ResNet50', transfer=False)

    # SpotTune Network with MedicalNet Weights
    model.set_model(model_name='spottune', transfer=True)

    # SpotTune Network with ImageNet Weights
    #model.set_model(model_name='spottune_imagenet', transfer=True)
    ##############################################################

    # set_agent should only be called when using the SpotTune network, it should be commented out otherwise
    # creates the agent network to be used in the SpotTune network
    model.set_agent()

    # Sets the traing, validation, and testing data
    model.set_train_data(X_kfold[train_index], y_kfold.iloc[train_index])
    model.set_val_data(X_kfold[test_index], y_kfold.iloc[test_index])
    model.set_test_data(X_test, y_test)
    model.run()
    print(f"--------- End Fold {i}--------------")
    del model
    

In [None]:
# Trains on all available derivatives in the modality chosen with the 'modality' variable
# i.e. struct corresponds to T2, FLAIR, T1, T1GD. Each modality is consistently stored at a specific index in the npy file

# Use for struct and DTI, since there are four derivatives
for k in range(4):
#Uncomment if using DSC derivatives (since there are only three)
#for k in range(3):
    #for j in [1, 100, 1e4, 1e6, 1e8]:
        gen_params['channel_idx'] = k   
        #model_config['temp_vals'] = [j]
        for i, (train_index, test_index) in enumerate(kfold.split(X_kfold, y_kfold)):
            model = RunModel(model_config, gen_params)
            print(f"---------Fold {i}--------------")
            torch.manual_seed(42)
            # The network to train with
            ##############################################################
            # ResNet50 network using MedicalNet weights
            #model.set_model(model_name='MedResNet50', transfer=True)
            
            # ResNet50 network using ImageNet weights
            #model.set_model(model_name='ResNet50_torch', transfer=True)
        
            # ResNet50 with randomly initialized weights
            #model.set_model(model_name='ResNet50', transfer=False)
        
            # SpotTune Network with MedicalNet Weights
            model.set_model(model_name='spottune', transfer=True)
        
            # SpotTune Network with ImageNet Weights
            #model.set_model(model_name='spottune_imagenet', transfer=True)
            ##############################################################
        
            # set_agent should only be called when using the SpotTune network, it should be commented out otherwise
            # creates the agent network to be used in the SpotTune network
            model.set_agent()
            model.set_train_data(X_kfold[train_index], y_kfold.iloc[train_index])
            model.set_val_data(X_kfold[test_index], y_kfold.iloc[test_index])
            model.set_test_data(X_test, y_test)
            model.run()
            print(f"--------- End Fold {i}--------------")
            del model
    