In [1]:
import torch
import torch.nn as nn
import pandas as pd
from project.dataset import Dataset, VALDODataset
from torch.utils.data import DataLoader
from project.preprocessing import NiftiToTensorTransform, z_score_normalization
from project.utils import collate_fn, plot_mri_slice, plot_all_slices, plot_all_slices_from_array, collatev2
import winsound

INFO:albumentations.check_version:A new version of Albumentations is available: 1.4.18 (you have 1.4.7). Upgrade using: pip install --upgrade albumentations


In [2]:
import logging
logger = logging.getLogger('andy')
fh = logging.FileHandler('andy.log')
formatter = logging.Formatter(
    '%(asctime)s - %(levelname)s - %(message)s'
)

logger.setLevel(logging.DEBUG)
fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter)

logger.addHandler(fh)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [3]:
ds = Dataset()

data = pd.read_csv('targets.csv')

In [4]:
transform = NiftiToTensorTransform(target_shape = (50, 50), rpn_mode=True)

cases = data.mri
masks = data.masks
target = data.target

In [5]:
dataset = VALDODataset(
    cases=cases,
    masks=masks,
    target=target,
    transform=transform,
    normalization=z_score_normalization,
)

In [6]:
dloader = DataLoader(
    dataset,
    shuffle=True,
    batch_size=2,
    collate_fn=collatev2,
)

### Config for fitter

In [7]:
from project.model import RPN

config = {
    'model': RPN(50**2, 4, 5, 2500).to(device),
    'optimizer': torch.optim.Adam,
    'device': device,
    'epochs': 1,
    'loss': nn.SmoothL1Loss(),
    # 'loss': nn.MSELoss(),
    'lr': 0.0000001
}



### Sample trial

In [40]:
sample = next(enumerate(dloader))

pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1
INFO:nibabel.global:pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1
pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1
INFO:nibabel.global:pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


In [41]:
model = config['model']

In [42]:
loss = config['loss']

In [43]:
slices, masks, target, case = sample[1][0]

In [44]:
num_slices = slices.shape[0]
x = slices.view(num_slices, 1, -1).float().to(device)
masks = masks.view(num_slices, 1, -1).float().to(device)

In [45]:
y1 = model(x, target)

In [46]:
slices, masks, target, case = sample[1][1]

In [47]:
num_slices = slices.shape[0]
x = slices.view(num_slices, 1, -1).float().to(device)
masks = masks.view(num_slices, 1, -1).float().to(device)

In [48]:
y2 = model(x, target)

In [49]:
sample[1][1][1][sample[1][1][2]]

tensor([[[12.4023, 29.6875, 12.7930, 30.0781]]], dtype=torch.float64)

In [50]:
t1 = sample[1][0][1][sample[1][0][2]].float().to(device)
t2 = sample[1][1][1][sample[1][1][2]].float().to(device)

In [51]:
Y = [y1, y2]
T = [t1, t2]

In [67]:
loss(torch.stack(Y), torch.stack(T))

tensor(10.3537, device='cuda:0', grad_fn=<SmoothL1LossBackward0>)

### Fitter

In [None]:
from project import Fitter

class RPNFitter(Fitter):
    def train_one_epoch(self, train_loader):
        self.model.train()
        loss_history = []
        counter = 0
        for batch in train_loader:
            
        for slices, masks, targets, cases in train_loader:
            target = targets[0]
            num_slices = slices.shape[0]
            x = slices.view(num_slices, 1, -1).float().to(self.device)
            masks = masks.view(num_slices, 1, -1).float().to(self.device)
            y = self.model(x, target)
            losses = self.loss(y, masks[target])
            loss_history.append(losses)
            self.optimizer.zero_grad()
            losses.backward()
            self.optimizer.step()
            counter += 1
            if counter % 50 == 0:
                logger.info(f'Progress:\t{counter}/{len(dataset)}')
                logger.info(f'Current error:\t{losses}')
            
        return loss_history

In [None]:
fitter = RPNFitter(config)

### Training

In [None]:
hist = fitter.fit(dloader, dloader)

In [None]:
winsound.Beep(500, 500)
winsound.Beep(500, 500)
winsound.Beep(500, 500)

In [None]:
import seaborn as sns

sns.lineplot(torch.tensor(hist[0]))

# Summary