In [1]:
import torch
import torch.nn as nn
from project.dataset import Dataset, VALDODataset
from torch.utils.data import DataLoader
from project.preprocessing import NiftiToTensorTransform, z_score_normalization
from project.utils import collate_fn, plot_mri_slice, plot_all_slices, plot_all_slices_from_array

INFO:albumentations.check_version:A new version of Albumentations is available: 1.4.14 (you have 1.4.7). Upgrade using: pip install --upgrade albumentations


In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [3]:
ds = Dataset()

In [4]:
cases = ds.load_raw_mri()
masks = ds.load_cmb_masks()

transform = NiftiToTensorTransform(target_shape = (50, 50), rpn_mode=True)

In [5]:
dataset = VALDODataset(
    cases=cases,
    masks=masks,
    transform=transform,
    normalization=z_score_normalization,
)

pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1
INFO:nibabel.global:pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


In [6]:
dloader = DataLoader(
    dataset,
    shuffle=True,
    batch_size=1,
    collate_fn=collate_fn,
)

In [7]:
from project import Fitter

class RPNFitter(Fitter):
    def train_one_epoch(self, train_loader):
        self.model.train()
        # for all samples in train_loader
        loss_history = []
        for slices, masks, case, counts in train_loader:
            num_slices = slices.shape[0]
            masks = masks.view(num_slices, 1, -1).float().to(self.device)
            # x = slices.view(num_slices, 1, 1, -1).float().to(self.device)
            x = slices.view(num_slices, 1, -1).float().to(self.device)
            y = []
            # feed each slice to rpn
            y = self.model(x)
            # for slc in x:
                # out = self.model(slc)
                # y.append(out)
                
            # y = torch.stack(y)
            # calculate loss
            losses = self.loss(y, masks)
            loss_history.append(losses)
            self.optimizer.zero_grad()
            losses.backward()
            self.optimizer.step()
        
        return loss_history
        # requery rpn with 
        
    def validation(self, val_loader):
        self.model.eval()
        with torch.inference_mode():
            loss_history = []
            # feed all samples
            for slices, masks, case, counts in val_loader:
                num_slices = slices.shape[0]
                masks = masks.float().to(self.device)
                x = slices.view(num_slices, 1, 1, -1).float().to(self.device)
                y = []
                for slc in x:
                    out = self.model(slc)
                    y.append(out)
                y = torch.stack(y)
                # calculate loss
                losses = self.loss(y, masks)
                loss_history.append(losses)
            
            return loss_history
            # get prediction per slice

In [8]:
from project.model import RPN

config = {
    'model': RPN(50**2, 4, 5, 200).to(device),
    'optimizer': torch.optim.SGD,
    'device': device,
    'epochs': 1,
    'loss': nn.SmoothL1Loss(),
    # 'loss': nn.MSELoss(),
    'lr': 0.0000001
}



In [9]:
fitter = RPNFitter(config)

In [10]:
hist = fitter.fit(dloader, dloader)

In [11]:
torch.stack(hist[0])

tensor([ 0.9175,  4.5481,  9.7643,  2.2781,  0.3871,  0.4180,  1.2394,  0.9201,
         8.6340,  0.4156,  1.5528,  1.9274,  2.4495,  1.3684,  1.5424,  0.3756,
         0.4152,  0.8729,  1.4309,  0.3974,  0.4088,  1.2001,  0.4424,  0.3977,
         4.0276,  1.4983,  1.0815,  0.8704,  0.8437,  0.3533,  1.1841,  0.9467,
         0.6137,  0.7853,  0.8948,  1.7257,  0.3669,  6.7948,  2.0156, 10.1898,
         0.3784,  0.9816,  0.4172,  0.4304,  5.1425,  1.0109,  0.3982,  1.5937,
         1.4813,  9.4026,  0.4620,  0.9299,  0.3549,  0.3677,  1.3552,  4.8386,
         0.8475,  0.4291,  1.8925,  3.9732,  1.5179,  0.8681,  0.9597,  0.8694,
         0.9458,  0.4686,  0.7921, 14.0145,  0.4824,  0.3809,  0.8293,  4.0983],
       device='cuda:0', grad_fn=<StackBackward0>)

In [12]:
model = fitter.model

In [13]:
slices, masks, case, counts = next(enumerate(dloader))[1]

In [14]:
num_slices = slices.shape[0]

In [18]:
x = slices.view(num_slices, 1, -1).float().to(device)
y = model(x)