In [None]:
import os

import numpy as np 
import pandas as pd

from time import time

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import FloatTensor, LongTensor
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau

from sklearn.model_selection import KFold, StratifiedKFold
from sklearn.metrics import accuracy_score,  roc_auc_score, confusion_matrix

# For RSNA dataset
from albumentations import (ToFloat, Normalize, VerticalFlip, HorizontalFlip, Compose, Resize,
                            RandomBrightnessContrast, HueSaturationValue, Blur, GaussNoise,
                            Rotate, RandomResizedCrop, Cutout, ShiftScaleRotate, ToGray)
from albumentations.pytorch import ToTensorV2
import pydicom # for DICOM images

import gc 

# For config file
import argparse
import yaml
import sys

# ---------- #

# Setting Training Configuration

sys.argv = ['-f']

def set_params(config_path):
    with open(config_path) as f:
        config = yaml.safe_load(f)

    parser = argparse.ArgumentParser()
    parser.add_argument('--folds', dest='FOLDS', default=config.get('FOLDS'))
    parser.add_argument('--epochs', dest='EPOCHS', default=config.get('EPOCHS'))
    parser.add_argument('--patience', dest='PATIENCE', default=config.get('PATIENCE'))
    parser.add_argument('--workers', dest='WORKERS', default=config.get('WORKERS'))
    parser.add_argument('--lr', dest='LR', default=config.get('LR'))
    parser.add_argument('--wd', dest='WD', default=config.get('WD'))
    parser.add_argument('--lr_patience', dest='LR_PATIENCE', default=config.get('LR_PATIENCE'))
    parser.add_argument('--lr_factor', dest='LR_FACTOR', default=config.get('LR_FACTOR'))
    parser.add_argument('--batch_size1', dest='BATCH_SIZE1', default=config.get('BATCH_SIZE1'))
    parser.add_argument('--batch_size2', dest='BATCH_SIZE2', default=config.get('BATCH_SIZE2'))
    parser.add_argument('--version', dest='VERSION', default=config.get('VERSION'))
    parser.add_argument('--model', dest='MODEL', default=config.get('MODEL'))

    args = parser.parse_args()    
    return args

# ---------- #

# Dataclass for loading into model
csv_columns = ['laterality', 'age', 'implant', 'site_id', 'machine_id']

class RSNADataset(Dataset):
    
    def __init__(self, dataframe, 
                 is_train=True):
        self.dataframe, self.is_train = dataframe, is_train
        
        
        # Data Augmentation (custom for each dataset type)
        if is_train:
            self.transform = Compose([RandomResizedCrop(height=224, width=224),
                                       ShiftScaleRotate(rotate_limit=90, scale_limit = [0.8, 1.2]),
                                       HorizontalFlip(p = 0.5),
                                       VerticalFlip(p = 0.5),
                                       ToTensorV2()])
        else:
            self.transform = Compose([ToTensorV2()])
            
    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, index):
        '''Take each row in batch at a time.'''
        
        # Select path and read image
        image_path = self.dataframe['path_CC'][index]
        image = pydicom.dcmread(image_path).pixel_array.astype(np.float32)
        
        # For this image also import .csv information
        csv_data = np.array(self.dataframe.iloc[index][csv_columns].values, dtype=np.float32)

        # Apply transforms
        image_ar = pydicom.dcmread(self.dataframe['path_CC'].iloc[1]).pixel_array.astype(np.float32)
        transf_image = self.transform(image=image_ar)['image']

        # Change image from 1 channel (B&W) to 3 channels
        transf_image = np.concatenate([transf_image, transf_image, transf_image], axis=0)
        
        # Return info
        if self.is_train:
            return  transf_image, csv_data, self.dataframe['cancer'][index]
        
        else:
            return {"image": transf_image, 
                    "csv": csv_data,
                    "target": self.dataframe['cancer'][index]}


# ---------- #
        
# Function for training model

def train_folds(args, model, train_original):
    
    # Training params from configuration in args
    
    FOLDS = args.FOLDS
    EPOCHS = args.EPOCHS
    PATIENCE = args.PATIENCE
    WORKERS = args.WORKERS
    LR = args.LR
    WD = args.WD
    LR_PATIENCE = args.LR_PATIENCE            
    LR_FACTOR = args.LR_FACTOR            

    BATCH_SIZE1 = args.BATCH_SIZE1           # for train
    BATCH_SIZE2 = args.BATCH_SIZE2           # for valid

    VERSION = args.VERSION
    MODEL = args.MODEL
    
  
    # Split in folds
    strat_fold = StratifiedKFold(n_splits = FOLDS)

    # Generate indices to split data into training and test set.
    k_folds = strat_fold.split(X = np.zeros(len(train_original)), 
                               y = train_original['cancer'].astype(int))
    
    # For each fold
    for i, (train_index, valid_index) in enumerate(k_folds):
        
        print(f"---------- Fold: {i+1} ----------")
      
        # --- Create Instances ---
        # Best ROC score in this fold
        best_roc = None
        # Reset patience before every fold
        patience_f = PATIENCE

        # Optimizer/ Scheduler/ Criterion
        optimizer = torch.optim.Adam(model.parameters(), lr = LR, 
                                     weight_decay=WD)
        scheduler = ReduceLROnPlateau(optimizer=optimizer, mode='max', 
                                      patience=LR_PATIENCE, verbose=True, factor=LR_FACTOR)
        criterion = nn.BCEWithLogitsLoss()

        
        # --- Read in Data ---
        train_data = train_original.iloc[train_index].reset_index(drop=True)
        valid_data = train_original.iloc[valid_index].reset_index(drop=True)

        # Create Data instances
        train = RSNADataset(train_data, is_train=True)
        valid = RSNADataset(valid_data, is_train=True)

        # Dataloaders
        train_loader = DataLoader(train, batch_size=BATCH_SIZE1, 
                                  shuffle=True)
        valid_loader = DataLoader(valid, batch_size=BATCH_SIZE2, 
                                  shuffle=False)

        
        # === EPOCHS ===
        for epoch in range(EPOCHS):
            start_time = time()
            correct = 0
            train_losses = 0
            
            
            # === TRAIN ===
            # Sets the module in training mode.
            model.train()

            # For each batch
            for k, (image, meta, targets) in enumerate(train_loader):                 
                
                # Clear gradients first; very important
                # usually done BEFORE prediction
                optimizer.zero_grad()

                # Log Probabilities & Backpropagation
                out = model(image, meta)

                loss = criterion(out, targets.unsqueeze(1).float())
                loss.backward()
                optimizer.step()

                # --- Save information after this batch ---
                # Save loss
                train_losses += loss.item()
                # From log probabilities to actual probabilities
                train_preds = torch.round(torch.sigmoid(out)) # 0 and 1
                # Number of correct predictions
                correct += (train_preds.cpu() == targets.cpu().unsqueeze(1)).sum().item()

            # Compute Train Accuracy
            train_acc = correct / len(train_index)

            # === EVAL ===
            # Sets the model in evaluation mode.
            model.eval()

            # Create matrix to store evaluation predictions (for accuracy)
            
            valid_preds = torch.zeros(size = (len(valid_index), 1), device = torch.device('cpu'), dtype=torch.float32)

            # Disables gradients (we need to be sure no optimization happens)
            with torch.no_grad():
                for k, (image, meta, targets) in enumerate(valid_loader):
                    out = model(image, meta)
                    pred = torch.sigmoid(out)
                    valid_preds[k*image.shape[0] : k*image.shape[0] + image.shape[0]] = pred


                # Calculate accuracy
                valid_acc = accuracy_score(valid_data['cancer'].values, 
                                           np.where(valid_preds > 0.5, 1, 0))
               
                # Calculate ROC
                valid_roc = roc_auc_score(valid_data['cancer'].values, 
                                          valid_preds)
               
                # Calculate time on Train + Eval
                #duration = str(dtime.timedelta(seconds=time() - start_time))[:7]


                # PRINT INFO
                final_logs = '{} | Epoch: {}/{} | Loss: {:.4} | Acc_tr: {:.3} | Acc_vd: {:.3} | ROC: {:.3}'.\
                                format('0', epoch+1, EPOCHS, 
                                       train_losses, train_acc, valid_acc, valid_roc)
                print(final_logs)


                # === SAVE MODEL ===

                # Update scheduler (for learning_rate)
                scheduler.step(valid_roc)
                # Name the model
                model_name = f"Fold{i+1}_Epoch{epoch+1}_ValidAcc{valid_acc:.3f}_ROC{valid_roc:.3f}.pth"

                # Update best_roc
                if not best_roc: # If best_roc = None
                    best_roc = valid_roc
                    torch.save(model.state_dict(), model_name)
                    continue

                if valid_roc > best_roc:
                    best_roc = valid_roc
                    # Reset patience (because we have improvement)
                    patience_f = PATIENCE
                    torch.save(model.state_dict(), model_name)
                else:
                    # Decrease patience (no improvement in ROC)
                    patience_f = patience_f - 1
                    if patience_f == 0:
                        stop_logs = 'Early stopping (no improvement since 3 models) | Best ROC: {}'.\
                                    format(best_roc)
                        add_in_file(stop_logs, f)
                        print(stop_logs)
                        break


        # === CLEANING ===
        # Clear memory
        del train, valid, train_loader, valid_loader, image, targets
        gc.collect()
