In [18]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import torch
import torchvision
import torchvision.transforms as transforms

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import nn
from torch.nn import functional as F
from torch.utils import data as torch_data
from torch.optim import lr_scheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchinfo import summary
import time

import sklearn
from sklearn import model_selection as sk_model_selection
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import f1_score
from sklearn.metrics import roc_auc_score
from skimage.transform import resize

import nibabel as nib
import matplotlib.pyplot as plt

from unet_down import UNet

import argparse

In [3]:
from unet_down import UNet

input_shape = (90,90,90)

model = UNet(in_channels=1,
             out_channels=2,
             n_blocks=4,
             input_shape = input_shape,
             start_filters=32,
             activation='relu',
             normalization='batch',
             conv_mode='same',
             dim=3,
             hidden_channels=2048)

In [28]:
WB_PATH = '/mnt/24CC5B14CC5ADF9A/Brain_Tumor_Classification/Datasets/Data_WB_90_90_90'

def construct_target_volume(scan_id,mri_type,scale_size=260):
    voxel_WB = nib.load(f'{WB_PATH}/BraTS2021_{scan_id}/BraTS2021_{scan_id}_{mri_type}.nii.gz').get_fdata().astype('float')
    return voxel_WB

class Dataset(torch_data.Dataset):
    def __init__(self, ids, targets, mri_type, if_pred = False):
        self.ids = ids
        self.targets = targets
        self.mri_type = mri_type
        self.if_pred = if_pred
        
    def __len__(self):
        return len(self.ids)
    
    def __getitem__(self, index):
        scan_id = self.ids[index]
        data = construct_target_volume(scan_id,self.mri_type,scale_size=90)

        if self.if_pred:
            return {"X": torch.tensor(data).float().unsqueeze(0), "id":scan_id}
        else:
            y = torch.tensor(self.targets[index], dtype = torch.long)
            return {"X": torch.tensor(data).float().unsqueeze(0), "y": y}

        
def get_train_valid_split(label_path):
    train_df = pd.read_csv(label_path,dtype = {'BraTS21ID':'str','MGMT_value':'int'})
    index_name = train_df[(train_df['BraTS21ID'] == '00109') | (train_df['BraTS21ID'] == '00123') | (train_df['BraTS21ID'] == '00709')].index
    train_df = train_df.drop(index_name).reset_index(drop=True)

    X = train_df['BraTS21ID'].values
    y = train_df['MGMT_value'].values
    
    kfold =  StratifiedKFold(n_splits=5,shuffle = True,random_state = SEED)
    return X,y,list(kfold.split(X,y))

LABEL_PATH = './train_labels.csv'
SEED = 42

X,y,SPLIT = get_train_valid_split(LABEL_PATH)
        
train_idx,valid_idx = SPLIT[0]
X_train,X_valid = X[train_idx],X[valid_idx]
y_train,y_valid = y[train_idx],y[valid_idx]

train_data_retriever = Dataset(
        X_train, 
        y_train,
        'flair'
)


train_loader = torch_data.DataLoader(
        train_data_retriever,
        batch_size=1,
        shuffle=True,
        num_workers=8,
        pin_memory = True
)


dataiter = iter(train_loader)
images = dataiter.next()['X']

In [40]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('runs/half_Unet')

In [56]:
writer.add_graph(model, images)
writer.close()

In [57]:
!tensorboard --logdir=runs

TensorFlow installation not found - running with reduced feature set.

NOTE: Using experimental fast data loading logic. To disable, pass
    "--load_fast=false" and report issues on GitHub. More details:
    https://github.com/tensorflow/tensorboard/issues/4784

Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.8.0 at http://localhost:6006/ (Press CTRL+C to quit)
^C
