In [None]:
import torch
import torch.nn as nn
import pandas as pd
from project.dataset import Dataset, VALDODataset
from torch.utils.data import DataLoader
from project.preprocessing import NiftiToTensorTransform, z_score_normalization
from project.utils import collate_fn, plot_mri_slice, plot_all_slices, plot_all_slices_from_array
import winsound

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

import logging
logger = logging.getLogger('andy')
fh = logging.FileHandler('andy.log')
formatter = logging.Formatter(
    '%(asctime)s - %(levelname)s - %(message)s'
)

logger.setLevel(logging.DEBUG)
fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter)

logger.addHandler(fh)

In [None]:
ds = Dataset()

data = pd.read_csv('targets.csv')

In [None]:
transform = NiftiToTensorTransform(target_shape = (50, 50), rpn_mode=True)

cases = data.mri
masks = data.masks
target = data.target

In [None]:
dataset = VALDODataset(
    cases=cases,
    masks=masks,
    target=target,
    transform=transform,
    normalization=z_score_normalization,
)

In [None]:
dloader = DataLoader(
    dataset,
    shuffle=True,
    batch_size=1,
    collate_fn=collate_fn,
)

### Config for fitter

In [None]:
from project.model import RPN

config = {
    'model': RPN(50**2, 4, 5, 2500).to(device),
    'optimizer': torch.optim.Adam,
    'device': device,
    'epochs': 1,
    'loss': nn.SmoothL1Loss(),
    # 'loss': nn.MSELoss(),
    'lr': 0.0000001
}

### Sample trial

In [None]:
sample = next(enumerate(dloader))

In [None]:
model = config['model']

In [None]:
slices, masks, target, case = sample[1]

In [None]:
num_slices = slices.shape[0]
x = slices.view(num_slices, 1, -1).float().to(device)
masks = masks.view(num_slices, 1, -1).float().to(device)

In [None]:
model(x, target[0])

### Fitter

In [None]:
from project import Fitter

class RPNFitter(Fitter):
    def train_one_epoch(self, train_loader):
        self.model.train()
        loss_history = []
        counter = 0
        for slices, masks, targets, cases in train_loader:
            target = targets[0]
            num_slices = slices.shape[0]
            x = slices.view(num_slices, 1, -1).float().to(self.device)
            masks = masks.view(num_slices, 1, -1).float().to(self.device)
            y = self.model(x, target)
            losses = self.loss(y, masks[target])
            loss_history.append(losses)
            self.optimizer.zero_grad()
            losses.backward()
            self.optimizer.step()
            counter += 1
            if counter % 50 == 0:
                logger.info(f'Progress:\t{counter}/{len(dataset)}')
                logger.info(f'Current error:\t{losses}')
            
        return loss_history

In [None]:
fitter = RPNFitter(config)

### Training

In [None]:
hist = fitter.fit(dloader, dloader)

In [None]:
winsound.Beep(500, 500)
winsound.Beep(500, 500)
winsound.Beep(500, 500)

In [None]:
import seaborn as sns

sns.lineplot(torch.tensor(hist[0]))

# Summary

target slice is now included in the dataset. It takes too much time however as there are now more than 7000 records to handle and since the dataloader needs to load the same sample for how many slices it has, it takes a while compared to loading a sample once and iterating through each slice.

**Next goal**: implement batches, fix `collate_fn`