In [19]:
import torch
import torch.nn as nn
import pandas as pd
from project.dataset import Dataset, VALDODataset
from torch.utils.data import DataLoader
from project.preprocessing import NiftiToTensorTransform, z_score_normalization
from project.utils import collate_fn, plot_mri_slice, plot_all_slices, plot_all_slices_from_array, collatev2
import winsound
from torchvision.models import resnet18, ResNet18_Weights
from project.utils import memcheck, compute_statistics

In [20]:
import logging
logger = logging.getLogger('andy')
fh = logging.FileHandler('andy.log')
formatter = logging.Formatter(
    '%(asctime)s - %(levelname)s - %(message)s'
)

logger.setLevel(logging.DEBUG)
fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter)

logger.addHandler(fh)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

### Load dataset

In [21]:
ds = Dataset()

data = pd.read_csv('targets.csv')
data.shape

(7986, 7)

In [15]:
data = data.query('has_microbleed_slice == 1').reset_index(drop=True)
data

Unnamed: 0,mri,masks,target,has_microbleed_case,has_microbleed_slice,cohort,max_value
0,C:\Users\araza\Documents\1\gits\thesis_project...,C:\Users\araza\Documents\1\gits\thesis_project...,8,1,1,1,928.405273
1,C:\Users\araza\Documents\1\gits\thesis_project...,C:\Users\araza\Documents\1\gits\thesis_project...,9,1,1,1,928.405273
2,C:\Users\araza\Documents\1\gits\thesis_project...,C:\Users\araza\Documents\1\gits\thesis_project...,11,1,1,1,928.405273
3,C:\Users\araza\Documents\1\gits\thesis_project...,C:\Users\araza\Documents\1\gits\thesis_project...,12,1,1,1,928.405273
4,C:\Users\araza\Documents\1\gits\thesis_project...,C:\Users\araza\Documents\1\gits\thesis_project...,15,1,1,1,928.405273
...,...,...,...,...,...,...,...
359,C:\Users\araza\Documents\1\gits\thesis_project...,C:\Users\araza\Documents\1\gits\thesis_project...,25,1,1,3,241.000000
360,C:\Users\araza\Documents\1\gits\thesis_project...,C:\Users\araza\Documents\1\gits\thesis_project...,26,1,1,3,241.000000
361,C:\Users\araza\Documents\1\gits\thesis_project...,C:\Users\araza\Documents\1\gits\thesis_project...,24,1,1,3,448.000000
362,C:\Users\araza\Documents\1\gits\thesis_project...,C:\Users\araza\Documents\1\gits\thesis_project...,25,1,1,3,448.000000


#### Select a cohort

In [16]:
ch1 = ds.load_raw_mri(1)
data = data[data.mri.isin(ch1)]
data.shape

(45, 7)

### Train and Test Split

In [None]:
from sklearn.model_selection import train_test_split

data_train, data_test = train_test_split(
    data,
    test_size=0.2,
    random_state=12
)

data_train = data_train.reset_index(drop=True)
data_test = data_test.reset_index(drop=True)

In [None]:
data_train.shape, data_test.shape

### Normalization

In [None]:
global_min, global_max = compute_statistics(data_train.mri.tolist())

In [None]:
global_min, global_max

In [None]:
transform = NiftiToTensorTransform(
    target_shape = (300, 300),
    rpn_mode=True,
    normalization=(global_min, global_max)
) # Hanggang dito lang kaya ng GPU mem ko

### Dataloaders

In [None]:
train_set = VALDODataset(
    cases=data_train.mri.tolist(),
    masks=data_train.masks.tolist(),
    target=data_train.target.tolist(),
    transform=transform,
)

val_set = VALDODataset(
    cases=data_test.mri,
    masks=data_test.masks,
    target=data_test.target,
    transform=transform,
)

batch_size = 5

train_loader = DataLoader(
    train_set,
    shuffle=True,
    batch_size=batch_size,
    collate_fn=collatev2,
)

val_loader = DataLoader(
    val_set,
    shuffle=True,
    batch_size=batch_size,
    collate_fn=collatev2,
)

### Config for fitter

In [None]:
from project.model import RPN

config = {
    'model': RPN(
        input_dim=512,
        output_dim=4,
        image_size=300,
        nh=4
    ).to(device),
    'optimizer': torch.optim.Adam,
    'device': device,
    'epochs': 20,
    'loss': nn.SmoothL1Loss(),
    # 'loss': nn.MSELoss(),
    'lr': 0.0001
}

#### Load Pretrained Embedder

#### Load RPN Weights

### Sample trial

### Fitter

In [None]:
from project import Fitter

class RPNFitter(Fitter):
    def train_one_epoch(self, train_loader):
        self.model.train()
        loss_history = []
        counter = 0
        for batch in train_loader:
            self.log('----------------- BATCH -----------------')
            Y = []
            T = []
            for slices, masks, target, case in batch:
                x = slices.squeeze(1).float().to(self.device)
                masks = masks.squeeze(1).float().to(self.device)/300
                y = self.model(x, target)
                Y.append(y)
                T.append(masks[target])
            
            losses = self.loss(torch.stack(Y), torch.stack(T))
            self.optimizer.zero_grad()
            losses.backward()
            self.optimizer.step()
            counter += len(batch)
            if counter % 10 == 0:
                # self.log(f'Progress:\t{counter}/{len(dataset)}')
                self.log(f'Current error:\t{losses}')
            
            loss_history.append(losses.detach().cpu().numpy())
            
            # del losses, Y, T
            # torch.cuda.empty_cache()
            # logger.info(f'MEMORY after CLEARING MEMORY\t{memcheck()}')
        
        return loss_history
    def validation(self, val_loader):
        self.model.eval()
        loss_history = []
        with torch.inference_mode():
            for batch in val_loader:
                Y = []
                T = []
                for slices, masks, target, case in batch:
                    x = slices.squeeze(1).float().to(self.device)
                    masks = masks.squeeze(1).float().to(self.device)/300
                    y = self.model(x, target)
                    Y.append(y)
                    T.append(masks[target])
                losses = self.loss(torch.stack(Y), torch.stack(T))
                loss_history.append(losses.cpu().numpy())
        return loss_history
                

In [None]:
fitter = RPNFitter(config, logger=logger)

### Training

In [None]:
thist, vhist = fitter.fit(train_loader, val_loader)

In [None]:
winsound.Beep(500, 500)
winsound.Beep(500, 500)
winsound.Beep(500, 500)

In [None]:
import seaborn as sns
import numpy as np

th = torch.tensor(np.array(thist))
vh = torch.tensor(np.array(vhist))
# print(th.shape)
sns.lineplot(th.mean(1))
sns.lineplot(vh.mean(1))

### Save the weights

# Summary

Added normalization

todo
- train per cohort

# Trial

In [None]:
model = fitter.model

In [None]:
sample = next(enumerate(train_loader))

In [None]:
slices, masks, target, case = sample[1][1]

In [None]:
x = slices.squeeze(1).float().to(device)
T = masks.squeeze(1).float().to(device)

In [None]:
y = model(x, target)

In [None]:
fitter.loss(y, T[target])

In [None]:
masks[target]

In [None]:
y

In [None]:
bbox = masks[target].squeeze().cpu().long()

In [None]:
y = (y*300).squeeze().detach().cpu().long()

In [None]:
bbox, y

In [None]:
import matplotlib.patches as patches
import matplotlib.pyplot as plt

ax = sns.heatmap(x[target].squeeze().cpu())

truth = patches.Rectangle(
    (bbox[0], bbox[1]),
    bbox[2] - bbox[0],
    bbox[3] - bbox[1],
    linewidth=1, edgecolor='g', facecolor='none'
)

pred = patches.Rectangle(
    (y[0], y[1]),
    y[2] - y[0],
    y[3] - y[1],
    linewidth=1, edgecolor='r', facecolor='none'
)

ax.add_patch(truth)
ax.add_patch(pred)
plt.show()