Copy from [[PyTorch Train] RSNA Video Classification + W&B 🚀](https://www.kaggle.com/heyytanay/pytorch-train-rsna-video-classification-w-b) and edit for running without Internet.

<div class="alert alert-success">
    <h2 align='center'>📔 Imports and Installation</h2>
</div>

[](https://www.kaggle.com/heyytanay/pytorch-train-rsna-video-classification-w-b)

In [None]:
import sys
sys.path.append('../input/timm-pytorch-image-models/pytorch-image-models-master')
import timm


In [None]:
import os
import sys
import re
import gc
import platform
import random

import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau

#import einops

from sklearn.metrics import roc_auc_score
from sklearn.model_selection import StratifiedKFold, train_test_split
import json

import glob
import cv2

#from rich import print as _pprint
#from rich.progress import track

import albumentations as A
from albumentations.pytorch import ToTensorV2

#import wandb

import warnings
warnings.simplefilter('ignore')

<div class="alert alert-success">
    <h2 align='center'>⛽ Utility Functions </h2>
</div>

In [None]:
if os.path.exists("../input/rsna-miccai-brain-tumor-radiogenomic-classification"):
    data_directory = '../input/rsna-miccai-brain-tumor-radiogenomic-classification'
    pytorch3dpath = "../input/efficientnetpyttorch3d/EfficientNet-PyTorch-3D"
else:
    data_directory = '/media/roland/data/kaggle/rsna-miccai-brain-tumor-radiogenomic-classification'
    pytorch3dpath = "EfficientNet-PyTorch-3D"

In [None]:
mri_types=['FLAIR']

def get_patient_id(patient_id):
    res=str(int(patient_id)).zfill(5)
    return res
    

def get_path(row,is_test=False,mri_type=mri_types[0]):
    patient_id = get_patient_id(row.BraTS21ID)
    if is_test:
        path=f'../input/rsna-miccai-png/test/{patient_id}/{mri_type}/'
    else:
        path=f'../input/rsna-miccai-png/train/{patient_id}/{mri_type}/'
    return path

def wandb_log(**kwargs):
    """
    Logs a key-value pair to W&B
    """
    for k, v in kwargs.items():
        wandb.log({k: v})
        
def cprint(string):
    """
    Utility function for beautiful colored printing.
    """
    print(f"[black]{string}[/black]")

<div class="alert alert-success">
    <h2 align='center'>🚀 Config Dictionary and W&B Integration </h2>
</div>

In [None]:
Config = dict(
    MAX_FRAMES = 12,
    EPOCHS = 15,
    LR = 1.2e-5,
    IMG_SIZE = (224, 224),
    FEATURE_EXTRACTOR = 'resnet34',
    DR_RATE = 0.35,
    NUM_CLASSES = 1,
    RNN_HIDDEN_SIZE = 100,
    RNN_LAYERS = 1,
    TRAIN_BS = 8,
    VALID_BS = 4,
    NUM_WORKERS = 2,
    infra = "Kaggle",
    competition = 'rsna_miccai',
    _wandb_kernel = 'tanaym'
)

In [None]:
class Augments:
    """
    Contains Train, Validation Augments
    """
    train_augments = A.Compose([
        ToTensorV2(p=1.0),
    ],p=1.)
    
    valid_augments = A.Compose([
        ToTensorV2(p=1.0),
    ], p=1.)

<div class="alert alert-success">
    <h2 align='center'>💻 Custom Dataset Class</h2>
</div>

<div class="alert alert-block alert-info" style="font-size:14px; font-family:verdana; line-height: 1.7em;">
    📌 In this custom Dataset, I am essentially reading "MAX_FRAMES" number of images from a patient's FLAIR folder and making list of those frames and converting it to torch tensor.
</div>

In [None]:
class RSNADataset(Dataset):
    def __init__(self, df, augments=None, is_test=False):
        self.df = df
        self.augments = augments
        self.is_test = is_test
        
    def __getitem__(self, idx):
        row = self.df.loc[idx]
        paths = self.getPaths(row)
        frames = []
        for path in paths:
            img = cv2.imread(path)
            img = cv2.resize(img, Config['IMG_SIZE'])
            img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
            frames.append(img)

        frames_tr = np.stack(frames, axis=2)
        #frames_tr = np.asarray(frames_tr,dtype=int)
        if self.augments:
            for frame in frames:
                frame = self.augments(image=frame)['image']
                frames_tr.append(frame)
            
        if self.is_test:
            return frames_tr,idx
        else:
            label = torch.tensor(row['MGMT_value']).float()
            return frames_tr, label
        
    def __len__(self):
        return len(self.df)
    
    def getPaths(self, row):
        paths = glob.glob(row['path'] + '*.png')
        sortedPaths = self.sort(paths)
        maxWindowStart = len(sortedPaths) - Config['MAX_FRAMES']
        start = 0 # np.random.randint(1, maxWindowStart)
        paths = sortedPaths[start:Config['MAX_FRAMES']]
        
        return paths
        
    def sort(self, entry):
        # https://stackoverflow.com/a/2669120/7636462
        convert = lambda text: int(text) if text.isdigit() else text 
        alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
    
        return sorted(entry, key = alphanum_key)

<div class="alert alert-success">
    <h2 align='center'>📈 Model Class with ResNext Backbone</h2>
</div>

In [None]:
with open('../input/timm-pretrained-resnet/index.json','r') as f:
    index=json.load(f)
    
backstone=Config['FEATURE_EXTRACTOR']
where=index['resnet'][backstone]

In [None]:

class ResNextModel(nn.Module):
    def __init__(self):
        super(ResNextModel, self).__init__()
        self.backbone = timm.create_model(backstone, pretrained=False, in_chans=1)
        pretrained=f'../input/timm-pretrained-resnet/resnet/{where}'
        state_dict=torch.load(pretrained)
        conv1_weight = state_dict['conv1.weight']
        state_dict['conv1.weight'] = conv1_weight.sum(dim=1, keepdim=True)
        self.backbone.load_state_dict(state_dict)
    def forward(self, x):
        return self.backbone(x)

class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
    def forward(self, x):
        return x

class RSNAModel(nn.Module):
    def __init__(self, pretrained=True):
        super(RSNAModel, self).__init__()
        self.backbone = ResNextModel()
        num_features = self.backbone.backbone.fc.in_features
        
        self.backbone.backbone.fc = Identity()
        self.dropout= nn.Dropout(Config['DR_RATE'])
        self.rnn = nn.LSTM(num_features, Config['RNN_HIDDEN_SIZE'], Config['RNN_LAYERS'])
        self.fc1 = nn.Linear(Config['RNN_HIDDEN_SIZE'], Config['NUM_CLASSES'])
        
    def forward(self, x):
        b_z, fr, h, w = x.shape
        ii = 0
        in_pass = x[:, ii].unsqueeze_(1)
        y = self.backbone((in_pass))
        output, (hn, cn) = self.rnn(y.unsqueeze(1))
        for ii in range(1, fr):
            y = self.backbone((x[:, ii].unsqueeze_(1)))
            out, (hn, cn) = self.rnn(y.unsqueeze(1), (hn, cn))
        out = self.dropout(out[:, -1])
        out = self.fc1(out)
        out = torch.sigmoid(out)
        return out

<div class="alert alert-success">
    <h2 align='center'>🏴‍☠️ Training and Validation Functions</h2>
</div>

In [None]:
def train_one_epoch(model, train_dataloader, optimizer, loss_fn, epoch, device, log_wandb=True, verbose=False):
    """
    Trains model for one epoch
    """
    model.train()
    running_loss = 0
    prog_bar = tqdm(enumerate(train_dataloader), total=len(train_dataloader))
    for batch, (frames, targets) in prog_bar:
        optimizer.zero_grad()
        
        frames = frames.to(device, torch.float)
        targets = targets.to(device, torch.float)
        #print(frames.shape)
        # Re arrange the frames in the format our model wants to recieve
        #frames = einops.rearrange(frames, 'b h w f -> b f h w')
        frames = frames.permute(0,3,1,2)
        preds = model(frames).view(-1)
        del frames
        gc.collect()
        loss = loss_fn(preds, targets)
        
        loss.backward()
        optimizer.step()
        
        loss_item = loss.item()
        running_loss += loss_item
        
        prog_bar.set_description(f"loss: {loss_item:.4f}")
        
        if log_wandb == True:
            wandb_log(
                batch_train_loss=loss_item
            )
        
        if verbose == True and batch % 20 == 0:
            print(f"Batch: {batch}, Loss: {loss_item}")
    
    avg_loss = running_loss / len(train_dataloader)
    
    return avg_loss

@torch.no_grad()
def valid_one_epoch(model, valid_dataloader, loss_fn, epoch, device, log_wandb=True, verbose=False):
    """
    Validates the model for one epoch
    """
    model.eval()
    running_loss = 0
    prog_bar = tqdm(enumerate(valid_dataloader), total=len(valid_dataloader))
    
    for batch, (frames, targets) in prog_bar:
        frames = frames.to(device, torch.float)
        targets = targets.to(device, torch.float)

        # Re arrange the frames in the format our model wants to recieve
        #frames = einops.rearrange(frames, 'b h w f -> b f h w')
        frames = frames.permute(0,3,1,2)
        preds = model(frames).view(-1)
        del frames
        gc.collect()
        loss = loss_fn(preds, targets)
        loss_item = loss.item()
        running_loss += loss_item

        prog_bar.set_description(f"val_loss: {loss_item:.4f}")

        if log_wandb == True:
            wandb_log(
                batch_val_loss=loss_item
            )

        if verbose == True and batch % 10 == 0:
            print(f"Batch: {batch}, Loss: {loss_item}")
    
    avg_val_loss = running_loss / len(valid_dataloader)
    
    return avg_val_loss

<div class="alert alert-success">
    <h2 align='center'>🏗 Training and Validating the Model</h2>
</div>

In [None]:
log_wandb = False
if torch.cuda.is_available():
    print("Using GPU: {}\n".format(torch.cuda.get_device_name()))
    device = torch.device('cuda')
else:
    print("\nGPU not found. Using CPU: {}\n".format(platform.processor()))
    device = torch.device('cpu')


# Load training csv file
df = pd.read_csv('../input/rsna-miccai-brain-tumor-radiogenomic-classification/train_labels.csv')
df['path'] = df.apply(lambda row: get_path(row), axis=1)
#df=df[:20]  #close it for submitting

# Removing two patient ids from the dataframe since there are not FLAIR directories for these ids. 
df = df.loc[df.BraTS21ID!=109]
df = df.loc[df.BraTS21ID!=709]
df = df.reset_index(drop=True)

train_df, valid_df = train_test_split(df, test_size=0.1, stratify=df.MGMT_value.values)
train_df = train_df.reset_index(drop=True)
valid_df = valid_df.reset_index(drop=True)

print(f'Size of Training Set: {len(train_df)}, Validation Set: {len(valid_df)}')

In [None]:
def prepera_data():
    train_data = RSNADataset(train_df)
    valid_data = RSNADataset(valid_df)

    train_loader = DataLoader(
        train_data,
        batch_size=Config['TRAIN_BS'], 
        shuffle=True,
        num_workers=Config['NUM_WORKERS']
    )

    valid_loader = DataLoader(
        valid_data, 
        batch_size=Config['VALID_BS'], 
        shuffle=False,
        num_workers=Config['NUM_WORKERS']
    )
    return train_loader,valid_loader

train_loader,valid_loader=prepera_data()

In [None]:
def hunt_model():
    model = RSNAModel()
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=Config['LR'])

    train_loss_fn = nn.BCEWithLogitsLoss()
    valid_loss_fn = nn.BCEWithLogitsLoss()

    print(f"\nUsing Backbone: {Config['FEATURE_EXTRACTOR']}")

    current_loss = 1000
    for epoch in range(Config['EPOCHS']):
        print(f"\n{'--'*8} EPOCH: {epoch+1} {'--'*8}\n")

        train_loss = train_one_epoch(model, train_loader, optimizer, train_loss_fn, epoch=epoch, device=device, log_wandb=log_wandb)

        valid_loss = valid_one_epoch(model, valid_loader, valid_loss_fn, epoch=epoch, device=device, log_wandb=log_wandb)

        print(f"val_loss: {valid_loss:.4f}")

        if log_wandb == True:
            wandb_log(
                train_loss=train_loss,
                valid_loss=valid_loss
            )

        if valid_loss < current_loss:
            current_loss = valid_loss
            torch.save(model.state_dict(), f"model_{Config['FEATURE_EXTRACTOR']}.pt")
    return model

model=hunt_model()
modelfiles=[model]

In [None]:
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut
from skimage.transform import resize
def load_dicom_image(path, img_size, voi_lut=True, rotate=0):
    dicom = pydicom.read_file(path)
    data = dicom.pixel_array
    if voi_lut:
        data = apply_voi_lut(dicom.pixel_array, dicom)
    else:
        data = dicom.pixel_array
       
    #data=crop_center(data,100,100)
    
    if np.max(data)==0:
        data=None
    else:
        data = cv2.resize(data, img_size)
    return data


def load_dicom_images_3d(scan_id,img_size, mri_type="FLAIR", rotate=0,split='test'):

    files = sorted(glob.glob(f"{data_directory}/{split}/{scan_id}/{mri_type}/*.dcm"), 
               key=lambda var:[int(x) if x.isdigit() else x for x in re.findall(r'[^0-9]|[0-9]+', var)])
    '''
    middle = len(files)//2
    num_imgs2 = num_imgs//2
    p1 = max(0, middle - num_imgs2)
    p2 = min(len(files), middle + num_imgs2)
    '''
    #img3d = np.stack([load_dicom_image(f, rotate=rotate) for f in files[p1:p2]]).T 
    img3d=[]
    #n=min(len(files),Config['MAX_FRAMES'])
    frame_count=0
    for f in files:
 
        temp= load_dicom_image(f,img_size=Config['IMG_SIZE'])
        if  temp is None:
            #print('remove empty array')
            continue
        else:
            img3d.append(temp)
            frame_count+=1
            if frame_count>=Config['MAX_FRAMES']:
                break
    img3d=np.stack(img3d,axis=2)
    img3d = img3d - np.min(img3d)
    img3d = img3d / np.max(img3d)
    #img3d=np.asarray(img3d,dtype=np.uint8)
    #print(f'before:{img3d.shape}')
    #img3d=min_cube(img3d)
    #print(f'after:{img3d.shape}')
    #img3d=np.stack([load_dicom_image(f) for f in files])
    #if img3d.shape[-1] < num_imgs:
    #    n_zero = np.zeros((img_size, img_size, num_imgs - img3d.shape[-1]))
    #    img3d = np.concatenate((img3d,  n_zero), axis = -1)
        
    #if np.min(img3d) < np.max(img3d):
    
    
    if rotate > 0 :
        #rot_choices = [0, cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE, cv2.ROTATE_180]
        #data = cv2.rotate(data, rot_choices[rotate])
        if rotate==1:
            img3d=img3d+np.random.random(img3d.shape)*0.012
        elif rotate==2:
            img3d=gaussian_filter(img3d,sigma=7)
        elif rotate==3:
            img3d=img3d*1.05;
            img3d[img3d>1]=1.
        elif rotate==4:
            img3d=laplace(img3d)
        elif rotate==5:
            img3d=crop_center(img3d,(150,150,150))
        else: 
            img3d=img3d*0.99
    #img3d.shape
    if frame_count<Config['MAX_FRAMES']:
        img3d=resize(img3d,(SIZE,SIZE,Config['MAX_FRAMES']))
        print(scan_id,frame_count)
    #print(img3d.shape,scan_id,frame_count)
    return img3d
    #return np.expand_dims(img3d,0)


In [None]:
class TestSet(Dataset):
    def __init__(self, df, mri_type):
        #self.paths = paths
        self.df=df
        self.mri_type = mri_type
        self.label_smoothing = 0
        self.augment =False
          
    def __len__(self):
        return len(self.df)
    '''
    def old(self):
        row = self.df.loc[idx]
        paths = self.getPaths(row)
        frames = []
        for path in paths:
            img = cv2.imread(path)
            img = cv2.resize(img, Config['IMG_SIZE'])
            img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
            frames.append(img)

        frames_tr = np.stack(frames, axis=2)
        #frames_tr = np.asarray(frames_tr,dtype=int)
        if self.augments:
            for frame in frames:
                frame = self.augments(image=frame)['image']
                frames_tr.append(frame)
            
        if self.is_test:
            return frames_tr,idx
        else:
            label = torch.tensor(row['MGMT_value']).float()
            return frames_tr, label
    '''  
    
    def __getitem__(self, idx):
        row = self.df.loc[idx]
        scan_id=row['BraTS21ID']
        frames = load_dicom_images_3d(str(scan_id).zfill(5), img_size=Config['IMG_SIZE'],mri_type=self.mri_type)
        #print(scan_id,frames.shape)
        return frames, idx

In [None]:
def min_cube(array_3d):
    brain=np.nonzero(array_3d)
    up=np.max(brain,axis=1)+1
    down=np.min(brain,axis=1) 
    solid=array_3d[down[0]:up[0],down[1]:up[1],down[2]:up[2]]    
    return solid

In [None]:
def predict(model, df, mri_type, split):
    print("Predict:",mri_type, df.shape)
    df.loc[:,"MRI_Type"] = mri_type
    #display(df)
    #test_data = RSNADataset(df,is_test=True)
    test_data=TestSet(
        df, mri_type
    )
    test_loader = DataLoader(
        test_data, 
        batch_size=4,
        shuffle=False,
        num_workers=1
    )
    model.to(device)
    #print(test_loader[0])
    #checkpoint = torch.load(modelfile)
    #model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()
    #model.train()
    
    y_pred = []
    ids = []

    for e, batch in enumerate(test_loader,1):
        print(f"{e}/{len(test_loader)}", end="\r")
        with torch.no_grad():
            frames,ids_=batch
            #print(frames.shape)
            frames = frames.permute(0,3,1,2)
            frames = frames.to(device, torch.float)
            tmp_pred = model(frames).view(-1)
            #print(tmp_pred)
            if tmp_pred.size == 1:
                y_pred.append(tmp_pred.item())
            else:
                y_pred.extend(tmp_pred.tolist())
            ids.extend(ids_.numpy().tolist())
            
    preddf = pd.DataFrame({"BraTS21ID": ids, "MGMT_value": y_pred}) 
    preddf = preddf.set_index("BraTS21ID")
    return preddf

In [None]:
df = pd.read_csv('../input/rsna-miccai-brain-tumor-radiogenomic-classification/sample_submission.csv')
submission=df.copy(deep=True)

In [None]:
submission['path'] = submission.apply(lambda row: get_path(row,is_test=True), axis=1)

submission["MGMT_value"] = 0
for m, mtype in zip(modelfiles, mri_types):
    pred = predict(m, submission, mtype, split="test")
    submission["MGMT_value"] += pred["MGMT_value"]

In [None]:
submission["MGMT_value"] /= len(modelfiles)
#submission.drop(columns=['path','MRI_Type'],inplace=True)

In [None]:
df['MGMT_value']=submission["MGMT_value"]

In [None]:
#submission=submission.set_index('BraTS21ID')
df.to_csv("submission.csv",index=False)
df

In [None]:
submission["MGMT_value"].hist()

In [None]:
submission