In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import cv2
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
from torchvision.models import vgg19, vgg19_bn
from torch.utils.data import Dataset, DataLoader
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut
from glob import glob
from tqdm.auto import tqdm
from copy import deepcopy

In [None]:
train_path = '../input/rsna-miccai-brain-tumor-radiogenomic-classification/train/'
test_path = '../input/rsna-miccai-brain-tumor-radiogenomic-classification/test/'
df = pd.read_csv('../input/rsna-miccai-brain-tumor-radiogenomic-classification/train_labels.csv', dtype = str)
df.iloc[:,1] = list(map(lambda x:int(x), df.iloc[:,1].values))

# at first we define the functions to read dicom files and preprocess the images. then we sort the images based on their corresponding numbers and concatenate them to pass to the model.

In [None]:
def read_dicom(path):
    img = pydicom.read_file(path)
    img = apply_voi_lut(img.pixel_array, img)
    img = (img - img.min()) / (img.max() - img.min())
    return cv2.resize(img, (256, 256))[..., None]

def read_dicom1(path):
    img = pydicom.read_file(path)
    img = apply_voi_lut(img.pixel_array, img)
    img = (img - img.min()) / (img.max() - img.min()) if img.max() != img.min() else img
    img = img[int(0.15 * img.shape[0]):int(0.67 * img.shape[0]), int(0.23 * img.shape[1]):int(0.78 * img.shape[1])]
    img = cv2.resize(img, (128, 128))
    return img[None, ...]

def read_file_dicom(instance_name, file_name, train_path, num = 128):
    path = f'{train_path}/{instance_name}/{file_name}/'
    files = os.listdir(path)
    mid = len(files) // 2
    files = sorted(files, key = lambda x:int(x[6:-4]))
    files = files[mid - num//2:mid + num//2] if len(files) > num else files
    lis = [read_dicom1(path + i) for i in files]
    image = np.concatenate(lis, axis = 0)
    return image

def read_all_dicom(instance_name, train_path = train_path, files = ['FLAIR', 'T1w', 'T1wCE', 'T2w']):
    lis = [read_file_dicom(instance_name, file_name, train_path) for file_name in files]
    return np.concatenate(lis, axis = -1)

In [None]:
files = df.iloc[:, 0].values
labels = df.iloc[:, 1].values
val_index = np.random.choice(len(files), len(files)//5, replace = False)
train_index = np.delete(np.arange(len(files)), val_index)
train_files = files[train_index]
train_labels = labels[train_index]
val_files = files[val_index]
val_labels = labels[val_index]

# here we define the dataset class and then we make the data loaders.

In [None]:
class RSNADataset(Dataset):
    def __init__(self, path, instances, labels):
        self.path = path
        self.instances = instances
        self.labels = labels
    def __len__(self):
        return len(self.instances)
    def __getitem__(self, idx):
        instance_name = self.instances[idx]
        label = self.labels[idx:idx+1]
        instance = torch.from_numpy(read_all_dicom(instance_name, train_path = self.path, files = ['FLAIR']))
        label = torch.from_numpy(label)
        return instance.float(), label.float()#torch.cat((instance, instance, instance), dim = 0).float()

In [None]:
train_data = RSNADataset(train_path, train_files, train_labels)
val_data = RSNADataset(train_path, val_files, val_labels)
train_data = DataLoader(train_data, batch_size = 1, shuffle = True)
val_data = DataLoader(val_data)

# we initialize the model with VGG19 as the feature detector then we pass the resulting vectors through a transformer encoder and finally we pass the encoded values to fully connected layers to predict the label.

In [None]:
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        vgg = vgg19()
        vgg.load_state_dict(torch.load('../input/vgg19/vgg19.pth'))
        self.feature_detector = vgg.features[:40]
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.encoder = nn.TransformerEncoder(encoder_layer = nn.TransformerEncoderLayer(d_model = 512, nhead = 8), num_layers = 3)
        self.fc = nn.Linear(512, 1)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        x = x[0][:, None].repeat(1, 3, 1, 1)
        x = self.feature_detector(x)
        x = self.avgpool(x)[:, :, 0, 0][None, ...]
        x = self.encoder(x)
        x = self.fc(x).sort(descending = True)[0][:, :15].mean(dim = 1)
        x = self.sigmoid(x)
        return x

In [None]:
model = MyModel()

In [None]:
opt = torch.optim.Adam(model.parameters(), lr = 0.00005)
criterion = nn.BCELoss()

In [None]:
def train_model(model, train_data, optimizer, criteria, epochs, val_data = None, device = 'cpu'):
    model = model.to(device)
    model_state = None
    best_val = 10
    for epoch in range(epochs):
        train_loss = 0
        val_loss = 0
        train_acc = 0
        val_acc = 0
        train_step = 0
        model.train()
        for image, label in tqdm(train_data):
            image = image.to(device)
            label = label.to(device)
            pred = model(image)
            loss = criteria(pred, label)
            train_loss += loss.item()
            train_step += 1
            train_acc += 1 if torch.floor(pred[0,0] + 0.5) == label[0,0] else 0
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if train_step % 80 == 0:
                print(f'{train_loss / train_step}')
        train_loss /= train_data.__len__()
        train_acc /= train_data.__len__()
        torch.cuda.empty_cache()
        if val_data:
            model.eval()
            for image, label in val_data:
                image = image.to(device)
                label = label.to(device)
                pred = model(image)
                loss = criteria(pred , label)
                val_loss += loss.item()
                val_acc += 1 if torch.floor(pred[0,0] + 0.5) == label[0,0] else 0
            val_loss /= val_data.__len__()
            val_acc /= val_data.__len__()
            if val_loss < best_val:
                model_state = deepcopy(model.state_dict())
                best_val = val_loss
            print('epoch:', epoch, 'loss:', train_loss, 'acc:', train_acc, ';;;; val_loss:', val_loss, 'val_acc:', val_acc)
            torch.cuda.empty_cache()
        else:
            print('epoch:', epoch, 'loss: ', train_loss, 'acc:', train_acc)
    if val_data:
        model.load_state_dict(model_state)

In [None]:
train_model(model, train_data, opt, criterion, 15, val_data = val_data, device = 'cuda')

  0%|          | 0/468 [00:00<?, ?it/s]

0.6952187955379486
0.6961167763918639
0.6960212789475918
0.6935452966019511
0.6956108662486077
epoch: 0 loss: 0.6959783824590536 acc: 0.46153846153846156 ;;;; val_loss: 0.6997804672290118 val_acc: 0.4358974358974359


  0%|          | 0/468 [00:00<?, ?it/s]

0.6865115188062191
0.694757865741849
0.6962431378662586
0.6967140460386873
0.6970116446912289
epoch: 1 loss: 0.6953978422615263 acc: 0.48504273504273504 ;;;; val_loss: 0.6990140141584934 val_acc: 0.4358974358974359


  0%|          | 0/468 [00:00<?, ?it/s]

0.6949882946908474
0.6940025471150875
0.6946269219120343
0.695854957960546
0.69545365691185
epoch: 2 loss: 0.6950635585265282 acc: 0.48504273504273504 ;;;; val_loss: 0.6984352179062672 val_acc: 0.4358974358974359


  0%|          | 0/468 [00:00<?, ?it/s]

0.6974190019071103
0.6960452012717724
0.6947157981495062
0.695728431455791
0.6956389400362969
epoch: 3 loss: 0.694813396431442 acc: 0.48504273504273504 ;;;; val_loss: 0.697714049082536 val_acc: 0.4358974358974359


  0%|          | 0/468 [00:00<?, ?it/s]

0.6976111300289631
0.6971059557050466
0.6959434891740481
0.695882385969162
0.6950115510821342
epoch: 4 loss: 0.6945765963476948 acc: 0.48504273504273504 ;;;; val_loss: 0.6972331954882696 val_acc: 0.4358974358974359


  0%|          | 0/468 [00:00<?, ?it/s]

0.6942917615175247
0.6925254691392183
0.6928828363617261
0.6935764906927944
0.6943544971942902
epoch: 5 loss: 0.6943346310375084 acc: 0.48504273504273504 ;;;; val_loss: 0.6966694501730112 val_acc: 0.4358974358974359


  0%|          | 0/468 [00:00<?, ?it/s]

0.6928689979016781
0.6943769153207541
0.6950034809609255
0.6941969899460674
0.6945778372883796
epoch: 6 loss: 0.6941389910176269 acc: 0.48504273504273504 ;;;; val_loss: 0.6960499470050519 val_acc: 0.4358974358974359


  0%|          | 0/468 [00:00<?, ?it/s]

0.6934008166193962
0.6946640551090241
0.6943921372294426
0.6942555813118816
0.6945174716413021
epoch: 7 loss: 0.6939943816162583 acc: 0.48504273504273504 ;;;; val_loss: 0.6956130159206879 val_acc: 0.4358974358974359


  0%|          | 0/468 [00:00<?, ?it/s]

0.6970626465976238
0.6953885097056627
0.6947073464592298
0.6948529014363884
0.694025205373764
epoch: 8 loss: 0.693862246015133 acc: 0.48504273504273504 ;;;; val_loss: 0.6951111280001127 val_acc: 0.4358974358974359


  0%|          | 0/468 [00:00<?, ?it/s]

0.6893782526254654
0.6928161811083555
0.6934713756044706
0.6932540012523531
0.6934722258150577
epoch: 9 loss: 0.6937117803300548 acc: 0.48504273504273504 ;;;; val_loss: 0.6946791181197534 val_acc: 0.4358974358974359


  0%|          | 0/468 [00:00<?, ?it/s]

0.6953121840953826
0.6939303800463676
0.6942089629669984
0.6939143994823098
0.6937375386059285
epoch: 10 loss: 0.6935624741336219 acc: 0.48504273504273504 ;;;; val_loss: 0.6942937190716083 val_acc: 0.4358974358974359


  0%|          | 0/468 [00:00<?, ?it/s]

0.693454758822918
0.6932396683841944
0.6934326328337193
0.6935605136677623
0.6933242927491665
epoch: 11 loss: 0.6934557338046212 acc: 0.48504273504273504 ;;;; val_loss: 0.6939663688341776 val_acc: 0.4358974358974359


  0%|          | 0/468 [00:00<?, ?it/s]

0.6935548484325409
0.6935166105628013
0.6935516575972239
0.6935068162158131
0.6933656127750873
epoch: 12 loss: 0.693372633594733 acc: 0.48504273504273504 ;;;; val_loss: 0.6936046068484967 val_acc: 0.4358974358974359


  0%|          | 0/468 [00:00<?, ?it/s]

0.6936320468783379
0.6934056241065264
0.6932899152239164
0.6932834211736918
0.6933140544593335
epoch: 13 loss: 0.6933083912500968 acc: 0.48504273504273504 ;;;; val_loss: 0.6932770793254559 val_acc: 0.4358974358974359


  0%|          | 0/468 [00:00<?, ?it/s]

0.6931694127619267
0.6932531509548425
0.6932518551747004
0.6932296503335238
0.6931870265305042
epoch: 14 loss: 0.6932380012212656 acc: 0.4829059829059829 ;;;; val_loss: 0.6929894853860904 val_acc: 0.5641025641025641


In [None]:
torch.save(model.state_dict(), './state.pth')

# at last we evaluate the model with the given test data.

In [None]:
model.eval()
test_files = np.array(os.listdir(test_path))
test_files = sorted(test_files, key = lambda x:int(x))
test_labels = np.ones((len(test_files),))
test_data = RSNADataset(test_path, test_files, test_labels)
test_data = DataLoader(test_data)

In [None]:
def get_preds(model, test_data):
    torch.cuda.empty_cache()
    preds = []
    model.to('cuda')
    for image, _ in tqdm(test_data):
        image = image.to('cuda')
        pred = model(image).detach().cpu()
        preds.append(pred)
    preds = torch.cat(preds, dim = 0).numpy()[:, 0]
    #preds[preds < 0.5] = 0.0
    #preds[preds >= 0.5] = 1.0
    torch.cuda.empty_cache()
    return preds

In [None]:
preds = get_preds(model, test_data)

  0%|          | 0/87 [00:00<?, ?it/s]

In [None]:
submission = pd.DataFrame({'BraTS21ID':test_files, 'MGMT_value':preds})
submission.to_csv('./submission.csv', index = False)