In [1]:
import torch
import torch.nn as nn
from project.dataset import Dataset, VALDODataset
from torch.utils.data import DataLoader
from project.preprocessing import NiftiToTensorTransform, z_score_normalization
from project.utils import collate_fn, plot_mri_slice, plot_all_slices, plot_all_slices_from_array

INFO:albumentations.check_version:A new version of Albumentations is available: 1.4.15 (you have 1.4.7). Upgrade using: pip install --upgrade albumentations


In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [3]:
ds = Dataset()

In [4]:
cases = ds.load_raw_mri()
masks = ds.load_cmb_masks()

transform = NiftiToTensorTransform(target_shape = (50, 50), rpn_mode=True)

In [5]:
dataset = VALDODataset(
    cases=cases,
    masks=masks,
    transform=transform,
    normalization=z_score_normalization,
)

pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1
INFO:nibabel.global:pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


In [6]:
dloader = DataLoader(
    dataset,
    shuffle=True,
    batch_size=1,
    collate_fn=collate_fn,
)

In [7]:
from project.model import RPN

config = {
    'model': RPN(50**2, 4, 5, 200).to(device),
    'optimizer': torch.optim.Adam,
    'device': device,
    'epochs': 10,
    'loss': nn.SmoothL1Loss(),
    # 'loss': nn.MSELoss(),
    'lr': 0.0000001
}



In [8]:
from project import Fitter

class RPNFitter(Fitter):
    def train_one_epoch(self, train_loader):
        self.model.train()
        # for all samples in train_loader
        loss_history = []
        for slices, masks, case, counts in train_loader:
            num_slices = slices.shape[0]
            masks = masks.view(num_slices, 1, -1).float().to(self.device)
            # x = slices.view(num_slices, 1, 1, -1).float().to(self.device)
            x = slices.view(num_slices, 1, -1).float().to(self.device)
            y = []
            # feed each slice to rpn
            y = self.model(x)
            # for slc in x:
                # out = self.model(slc)
                # y.append(out)
                
            # y = torch.stack(y)
            # calculate loss
            losses = self.loss(y, masks)
            loss_history.append(losses)
            self.optimizer.zero_grad()
            losses.backward()
            self.optimizer.step()
        
        return loss_history
        # requery rpn with 
        
    def validation(self, val_loader):
        self.model.eval()
        with torch.inference_mode():
            loss_history = []
            # feed all samples
            for slices, masks, case, counts in val_loader:
                num_slices = slices.shape[0]
                masks = masks.float().to(self.device)
                x = slices.view(num_slices, 1, 1, -1).float().to(self.device)
                y = []
                for slc in x:
                    out = self.model(slc)
                    y.append(out)
                y = torch.stack(y)
                # calculate loss
                losses = self.loss(y, masks)
                loss_history.append(losses)
            
            return loss_history
            # get prediction per slice

In [9]:
fitter = RPNFitter(config)

In [None]:
hist = fitter.fit(dloader, dloader)

In [18]:
torch.stack(hist[9])

tensor([ 1.7102,  0.7020,  0.8433,  0.1469,  0.1413,  3.7790,  0.8573,  0.6262,
         1.3749,  1.1366,  6.6733,  3.8952,  3.9173,  1.3003,  0.5793,  0.1515,
         0.2007,  0.1513,  9.8213,  0.5605,  0.9524,  0.1280,  0.8081,  0.7316,
         1.2720,  0.1308,  0.5574,  0.5824,  0.1254,  0.5852,  0.1766,  4.7182,
         0.1523,  0.7586,  0.1557,  0.7385,  8.5513,  0.5908,  1.7726,  5.0154,
         0.1531,  0.9307,  1.2182,  1.4003,  0.6001,  1.1548,  0.1523,  0.1645,
         4.3647,  0.1196,  0.1606,  0.9408,  0.1858,  0.1150,  2.2279,  9.3290,
         0.1500,  0.6154,  1.4492,  1.4121,  0.3936,  0.1495,  0.6419,  1.6959,
         0.1626,  0.1442, 14.3137,  1.2627, 10.1454,  0.7176,  2.0415,  0.7344],
       device='cuda:0', grad_fn=<StackBackward0>)