# Implementing a basic CNN-RNN classifier

Special thanks to the authors of [this](https://github.com/pytorch/ignite/blob/master/examples/notebooks/EfficientNet_Cifar100_finetuning.ipynb) notebook, and [this one](https://github.com/PacktPublishing/PyTorch-Computer-Vision-Cookbook/blob/master/Chapter10/DeployingModel.ipynb), and also [this one](https://www.kaggle.com/protan/ignite-example).

In [None]:
import os
import pandas as pd
import torch
import numpy as np
import random
np.random.seed(2021)
random.seed(2021)
torch.manual_seed(2021)

## Dataset

In [None]:
# Returns list containing all ids (XXXXX) in the specified folder
def get_ids(d="../input/rsna-miccai-png/train"):
    files = [f for f in sorted(os.listdir(d))]
    return files

# Reads the labels in the csv file
def get_labels(f="../input/rsna-miccai-brain-tumor-radiogenomic-classification/train_labels.csv"):
    df = pd.read_csv(f).sort_values(by="BraTS21ID")
    return df["MGMT_value"].tolist()

In [None]:
from torch.utils.data import Dataset, DataLoader, Subset
import glob
from PIL import Image
import cv2


# Make sure vids_path does not contain a trailing /
# vid_type is one in ["FLAIR", "T1w", "T1wCE", "T2w"]
class DatasetRSNE(Dataset):
    def __init__(self, ids, labels=None, transform=None, vids_path="../input/rsna-miccai-png/train",
                 vid_type=["FLAIR", "T1w"], d=(288, 288, 96)):      
        self.transform = transform
        self.ids = ids
        self.labels = labels
        self.vids_path = vids_path
        self.vid_type = vid_type
        self.d = d
        
    def __len__(self):
        return len(self.ids)
    
    def __getitem__(self, idx):
        all_frames = []
        for vid in self.vid_type:
            path2imgs = glob.glob(self.vids_path + "/" + self.ids[idx] + "/" + vid + "/*.png")
            frames = []
            for p2i in path2imgs:
                frame = Image.open(p2i)
                frames.append(frame)
            frames_tr = []
            if len(frames)==0:
                # if no frames are available for the video, return a tensor of zeros
                frames_tr = torch.zeros(self.d[2],1,self.d[0],self.d[1])
            else:
                frames = np.array([cv2.resize(np.array(frames[i]), dsize=(self.d[1],self.d[0]), interpolation=cv2.INTER_LINEAR) for i in range(len(frames))])
                frames = np.array([cv2.resize(frames.transpose(1,2,0)[i], dsize=(self.d[2],self.d[1]), interpolation=cv2.INTER_LINEAR) for i in range(self.d[0])])
                frames = frames.transpose(2,0,1)

                seed = np.random.randint(1e9)
                for frame in frames:
                    random.seed(seed)
                    np.random.seed(seed)
                    frame = self.transform(frame)
                    frames_tr.append(frame)
                frames_tr = torch.stack(frames_tr)
            all_frames.append(frames_tr)
        all_frames = torch.squeeze(torch.stack(all_frames, dim=1))
        if self.labels is not None:
            label = self.labels[idx]
            return all_frames, torch.tensor(label, dtype=torch.float32)
        return all_frames

In [None]:
import torchvision.transforms as transforms
from sklearn.preprocessing import StandardScaler
import SimpleITK as sitk

# Many other transforms are possible
#            transforms.RandomHorizontalFlip(p=0.5),  
#            transforms.RandomAffine(degrees=0, translate=(0.1,0.1)),    
#            ...

# Applies the N4 Bias Field Correction to remove radiofrequency inhomogeneity
class N4BiasFieldCorrect(object):
    """ Apply SimpleITK.N4BiasFieldCorrectionImageFilter """

    def __init__(self, max_iter=30):
        assert isinstance(max_iter, (int))
        self.max_iter = max_iter

    def __call__(self, sample):
        inputImage = sitk.GetImageFromArray(sample)
        maskImage = sitk.GetImageFromArray((sample > 0.1) * 1)  # idk what this 0.1 represents, perhaps should change it
        inputImage = sitk.Cast(inputImage, sitk.sitkFloat32)
        maskImage = sitk.Cast(maskImage, sitk.sitkUInt8)
        corrector = sitk.N4BiasFieldCorrectionImageFilter()
        numberFittingLevels = 1
        if self.max_iter is not None:
            corrector.SetMaximumNumberOfIterations([self.max_iter] * numberFittingLevels)
        corrected_image = corrector.Execute(inputImage, maskImage)

        # img_[p] = sitk.GetArrayFromImage(corrected_image)
        img = torch.from_numpy(np.array(corrected_image))

        return img


class std_sc(object):
    """ Centers with the mean and scales by the """
    def __init__(self):
        pass

    def __call__(self, sample): 
        sc = StandardScaler()
        img_ = sc.fit_transform(sample)


train_transforms = transforms.Compose([
            #transforms.Resize((h,w)),
            transforms.ToTensor()
            #N4BiasFieldCorrect()  # for the moment this doesn't work and I don't know what it does so I do not include it as a transform
            #std_sc()  # could add a standard scaling here but I don't think it makes much sense
            ])
test_transforms = transforms.Compose([
#             transforms.Resize((h,w)),
            transforms.ToTensor()
            #N4BiasFieldCorrect()  # for the moment this doesn't work and I don't know what it does so I do not include it as a transform
            #std_sc()  # could add a standard scaling here but I don't think it makes much sense
            ])

In [None]:
training_ids, training_labels = get_ids("../input/rsna-miccai-png/train"), get_labels("../input/rsna-miccai-brain-tumor-radiogenomic-classification/train_labels.csv")
n = len(training_ids)
print(len(training_ids), len(training_labels))

train_ids, train_labels = training_ids[:(n // 10) * 7], training_labels[:(n // 10) * 7]
print(len(train_ids), len(train_labels))
val_ids, val_labels = training_ids[(n // 10) * 7:], training_labels[(n // 10) * 7:]
print(len(val_ids), len(val_labels))

test_ids = get_ids("../input/rsna-miccai-png/test")
print(len(test_ids))

In [None]:
train_ds = DatasetRSNE(ids= train_ids, labels= train_labels, transform= train_transforms)
print(len(train_ds))

val_ds = DatasetRSNE(ids= val_ids, labels= val_labels, transform= test_transforms)
print(len(val_ds))


imgs, label = train_ds[100]
imgs.shape, label, torch.min(imgs), torch.max(imgs)

In [None]:
test_transformer = transforms.Compose([
            #transforms.Resize((h,w)),
            transforms.ToTensor(),
            ])
test_ds = DatasetRSNE(ids= test_ids, transform= test_transformer, vids_path= "../input/rsna-miccai-png/test")

In [None]:
imgs = test_ds[5]
imgs.shape, label, torch.min(imgs), torch.max(imgs)

In [None]:
def collate_fn_train_rnn(batch):
    imgs_batch, label_batch = list(zip(*batch))
    imgs_batch = [imgs for imgs in imgs_batch if len(imgs)>0]
    label_batch = [torch.tensor(l) for l, imgs in zip(label_batch, imgs_batch) if len(imgs)>0]
    imgs_tensor = torch.stack(imgs_batch) if len(imgs_batch) > 0 else torch.zeros(batch_size,4,h,w)
                        # ensure that the training does not crash even if the video contains no frames
    labels_tensor = torch.stack(label_batch)
    return imgs_tensor,labels_tensor

def collate_fn_test_rnn(batch):
    imgs_batch = [*batch]
    imgs_batch = [imgs for imgs in imgs_batch if len(imgs)>0]
    imgs_tensor = torch.stack(imgs_batch)
    return imgs_tensor


batch_size = 4
train_dl = DataLoader(train_ds, batch_size= batch_size,
                      shuffle=True, collate_fn= collate_fn_train_rnn)
val_dl = DataLoader(val_ds, batch_size= batch_size,
                      shuffle=True, collate_fn= collate_fn_train_rnn)
test_dl = DataLoader(test_ds, batch_size= batch_size,
                     shuffle=False, collate_fn= collate_fn_test_rnn)

In [None]:
# for xb,yb in train_dl:
#     print(xb.shape, yb.shape)
#     break
    
# for xb,yb in val_dl:
#     print(xb.shape, yb.shape)
#     break

# for xb in test_dl:
#     print(xb.shape)
#     break

## Define model

In [None]:
from torch import nn
from torchvision import models

class Resnet18Rnn(nn.Module):
    def __init__(self, params_model):
        super(Resnet18Rnn, self).__init__()
        out_features = params_model["out_features"]
        dr_rate= params_model["dr_rate"]
        pretrained = params_model["pretrained"]
        self.rnn_hidden_size = params_model["rnn_hidden_size"]
        rnn_num_layers = params_model["rnn_num_layers"]
        self.bidirectional = bidirectional = params_model["bidirectional"]
        self.stack_input = params_model["stack_input"]  # if True, stacks the input image (which has 1 channel) to have 3 channels
                                                        # otherwise adds a convolutional layer to go from 1 to 3 channels
        self.n_types = params_model["n_types"]
        
        self.In1Out3 = nn.Sequential(
                    nn.Conv2d(in_channels= 1, out_channels= 3, kernel_size= 3, padding= 1),
                    nn.BatchNorm2d(3),
                    nn.ReLU())
        self.baseModel1 = models.resnet18(pretrained=False)
        self.baseModel2 = models.resnet18(pretrained=False)
        self.baseModel3 = models.resnet18(pretrained=False)
        self.baseModel4 = models.resnet18(pretrained=False)
        self.rnn1 = nn.LSTM(self.baseModel1.fc.in_features, self.rnn_hidden_size, rnn_num_layers, bidirectional=bidirectional)
        self.rnn2 = nn.LSTM(self.baseModel2.fc.in_features, self.rnn_hidden_size, rnn_num_layers, bidirectional=bidirectional)
        self.rnn3 = nn.LSTM(self.baseModel3.fc.in_features, self.rnn_hidden_size, rnn_num_layers, bidirectional=bidirectional)
        self.rnn4 = nn.LSTM(self.baseModel4.fc.in_features, self.rnn_hidden_size, rnn_num_layers, bidirectional=bidirectional)
        self.mods = [self.baseModel1, self.baseModel2, self.baseModel3, self.baseModel4]
        self.RNNs = [self.rnn1, self.rnn2, self.rnn3, self.rnn4]
        for i in range(self.n_types):
            if pretrained:
                self.mods[i].load_state_dict(torch.load("../input/pretrained-model-weights-pytorch/resnet18-5c106cde.pth"))
            self.mods[i].fc = Identity()        
        self.dropout= nn.Dropout(dr_rate)
        self.fc1 = nn.Linear(self.rnn_hidden_size*self.n_types*(self.bidirectional+1), out_features)  # if two classes, just one output neuron
        self.sigm = nn.Sigmoid()                             # with logistic activation function

    def forward(self, x):
        # dim B, T, n_types, H, W
        b_z, ts, c, h, w = x.shape
        outputs = []
        for v_type in range(self.n_types):  # up to four video types (FLAIR, T1w, T1wCE, T2w)
            xx = x[:,:,v_type,:,:]
            ii = 0
            if self.stack_input:
                xx = torch.stack([xx for _ in range(3)], dim=2)
                xx = torch.squeeze(xx, 3)
                y = self.mods[v_type](xx[:,ii])
            else:
                y = self.mods[v_type](self.In1Out3(xx[:,ii]))
            # dim B, T, 3, H, W
            out, (hn, cn) = self.RNNs[v_type](y.unsqueeze(1))
            for ii in range(1, ts):
                if self.stack_input:
                    y = self.mods[v_type](xx[:,ii])
                else:
                    y = self.mods[v_type](self.In1Out3(xx[:,ii]))
                out, (hn, cn) = self.RNNs[v_type](y.unsqueeze(1), (hn, cn))
            outputs.append(out)

        outputs = torch.reshape(torch.stack(outputs), (-1,self.rnn_hidden_size*self.n_types*(self.bidirectional+1)))
        outputs = self.sigm(self.dropout(self.fc1(outputs)))
        return torch.reshape(outputs, (-1,))

class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
    def forward(self, x):
        return x

In [None]:
# Select device
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
    print("WARNING: Training without GPU can be very slow!")

In [None]:
params_model={
        "out_features": 1,
        "dr_rate": 0.1,
        "pretrained" : True,
        "rnn_num_layers": 1,
        "rnn_hidden_size": 150,
        "bidirectional": True,
        "stack_input": True,
        "n_types": 1}

my_resnet18rnn = Resnet18Rnn(params_model).to(device).train()

# my_resnet18rnn.load_state_dict(torch.load('../input/checkpoint5600/cnnrnn_cnnrnn_0.5600.pt')) # path of your weights
# my_resnet18rnn.to(device).train()

In [None]:
# my_resnet18rnn

In [None]:
from itertools import chain

import torch.optim as optim

criterion = nn.BCELoss()

optimizer = optim.Adam([
                        {"params": my_resnet18rnn.baseModel1.parameters(), 'lr': 5e-4},
                        {"params": my_resnet18rnn.baseModel2.parameters(), 'lr': 5e-4},
                        {"params": my_resnet18rnn.baseModel3.parameters(), 'lr': 5e-4},
                        {"params": my_resnet18rnn.baseModel4.parameters(), 'lr': 5e-4},
                        {"params": my_resnet18rnn.rnn1.parameters()},
                        {"params": my_resnet18rnn.rnn2.parameters()},
                        {"params": my_resnet18rnn.rnn3.parameters()},
                        {"params": my_resnet18rnn.rnn4.parameters()},
                        {"params": my_resnet18rnn.In1Out3.parameters()},
                        {"params": my_resnet18rnn.fc1.parameters()}
                       ], lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)


from torch.optim.lr_scheduler import ReduceLROnPlateau
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, threshold=1e-5,
                              threshold_mode='abs', min_lr=1e-6, eps=1e-08, verbose=False)

In [None]:
def process_function(engine, batch):
    my_resnet18rnn.train()
    optimizer.zero_grad()
    x, y = batch[0].to(device), batch[1].to(device)
    y_pred = my_resnet18rnn(x)
    loss = criterion(y_pred, y)
    loss.backward()
    optimizer.step()
    return loss.item()

In [None]:
def eval_function(engine, batch):
    my_resnet18rnn.eval()
    with torch.no_grad():
        x, y = batch[0].to(device), batch[1].to(device)
        y_pred = my_resnet18rnn(x)
        return y_pred, y

In [None]:
from ignite.engine import Engine
trainer = Engine(process_function)
train_evaluator = Engine(eval_function)
validation_evaluator = Engine(eval_function)

In [None]:
from ignite.metrics import RunningAverage

RunningAverage(output_transform=lambda x: x).attach(trainer, 'ce')

In [None]:
def thresholded_output_transform(output):
    y_pred, y = output
    y_pred = torch.round(y_pred)
    return y_pred, y

In [None]:
from ignite.metrics import Accuracy, Loss
from ignite.contrib.metrics import ROC_AUC

Accuracy(output_transform=thresholded_output_transform).attach(train_evaluator, 'accuracy')
Loss(criterion).attach(train_evaluator, 'ce')

Accuracy(output_transform=thresholded_output_transform).attach(validation_evaluator, 'accuracy')
Loss(criterion).attach(validation_evaluator, 'ce')
ROC_AUC().attach(validation_evaluator, "AUC")

In [None]:
from ignite.contrib.handlers import ProgressBar

pbar = ProgressBar(persist=True, bar_format="")
pbar.attach(trainer, ['ce'])

In [None]:
from ignite.handlers import Checkpoint, DiskSaver, EarlyStopping, TerminateOnNan
from ignite.engine import Events

def score_function(engine):
    val_loss = engine.state.metrics['AUC']
    return val_loss

handler = EarlyStopping(patience=6, score_function=score_function, trainer=trainer)
validation_evaluator.add_event_handler(Events.COMPLETED, handler)

In [None]:
from ignite.handlers import ModelCheckpoint, EarlyStopping

checkpointer = ModelCheckpoint('checkpoint', 'textcnn', save_interval=1, n_saved=2,
                               create_dir=True, save_as_state_dict=True, require_empty=False)
best_model_save = ModelCheckpoint(
    'best_model', 'cnnrnn', n_saved=1,
    create_dir=True, save_as_state_dict=True,
    score_function=score_function, require_empty=False)
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'cnnrnn': my_resnet18rnn})
validation_evaluator.add_event_handler(Events.EPOCH_COMPLETED, best_model_save, {'cnnrnn': my_resnet18rnn})

In [None]:
training_history = {'accuracy': [], 'ce': []}
validation_history = {'accuracy': [], 'ce': [], 'AUC': []}

def print_logs(engine, dataloader, mode, history_dict):
    metrics = -1
    if "AUC" in history_dict.keys():
        validation_evaluator.run(dataloader)
        metrics = validation_evaluator.state.metrics
        for key in validation_evaluator.state.metrics.keys():
            history_dict[key].append(validation_evaluator.state.metrics[key])
    else:
        train_evaluator.run(dataloader, max_epochs=1)
        metrics = train_evaluator.state.metrics
        for key in train_evaluator.state.metrics.keys():
            history_dict[key].append(train_evaluator.state.metrics[key])
        
    avg_acc = metrics['accuracy']
    avg_loss = metrics['ce']
    if "AUC" in history_dict.keys():
        auc = metrics['AUC']
        print(
        mode + " Results - Epoch {} - Avg accuracy: {:.2f} | Avg loss: {:.2f} | AUC: {:.2f}"
        .format(engine.state.epoch, avg_acc, avg_loss, auc))
    else:
        print(
            mode + " Results - Epoch {} - Avg accuracy: {:.2f} | Avg loss: {:.2f}"
            .format(engine.state.epoch, avg_acc, avg_loss, ))


trainer.add_event_handler(Events.EPOCH_COMPLETED, print_logs, train_dl, 'Training', training_history)
trainer.add_event_handler(Events.EPOCH_COMPLETED, print_logs, val_dl, 'Validation', validation_history)

In [None]:
epochs = 4
trainer.run(train_dl, max_epochs=epochs)

Provar a reduir una mica el batch_size a canvi d'imatges més grans.

In [None]:
train_evaluator.state.metrics

In [None]:
print(training_history['ce'])
print(validation_history['ce'])

In [None]:
import matplotlib.pyplot as plt

%matplotlib inline


plt.plot(range(epochs), training_history['ce'], 'dodgerblue', label='training')
plt.plot(range(epochs), validation_history['ce'], 'orange', label='validation')
plt.xlim(0, epochs);
plt.xlabel('Epoch')
plt.ylabel('Binary Cross Entropy Loss')
plt.title('Binary Cross Entropy on Training/Validation Set')
plt.legend();

In [None]:
plt.plot(range(epochs), validation_history['AUC'], 'orange', label='validation')
plt.xlim(0, epochs);
plt.xlabel('Epoch')
plt.ylabel('AUC')
plt.title('AUC on Validation Set')
plt.legend();

## Inference

In [None]:
import pathlib

model_path = next(pathlib.Path('best_model').rglob('*'))
model_path

model_state_dict = torch.load(model_path)
my_resnet18rnn.load_state_dict(model_state_dict)
# change model mode to 'evaluation'
# disable dropout and use learned batch norm statistics
my_resnet18rnn.eval()

predictions = []
labels = []

In [None]:
predictions = []

my_resnet18rnn.eval()

with torch.no_grad():
    for batch in test_dl:
        x = batch.to(device)
        y_pred = my_resnet18rnn(x)
        # move from GPU to CPU and convert to numpy array
        y_pred_numpy = y_pred.cpu().numpy()

        predictions = np.concatenate([predictions, y_pred_numpy])

In [None]:
predictions_str = [p for p in predictions]

# test.csv index in a contiguous integers from 0 to len(test_set)
# to this should work fine
submission = pd.DataFrame({'id': list(range(len(predictions_str))), 'label': predictions_str})

In [None]:
submission.head(10)