In [None]:
from utils.fetch_dataset import fetch_dataset
from utils.quality_measures import *
from utils.save_plots import output_scatter_plot
from utils.adjust_tensor import adjust_tensor

from dataset import MRIDataset
from unet import UNet3D

In [1]:
fetch_dataset()

In [5]:
BATCH_SIZE = 1
IN_CHANNELS = 1
EPOCHS = 1

SHOW_IMAGES = True

In [6]:
from torch.utils.data import DataLoader

dataset = MRIDataset()

train_dataset, test_dataset = torch.utils.data.random_split(dataset, [0.8, 0.2])

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [None]:
model = UNet3D(IN_CHANNELS, 1)
model

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [9]:
import torch
import torch.nn.functional as F


def adjust_tensor(data, mask):
    _, _, D, H, W = data.shape

    assert H % 2 == 0 and W % 2 == 0, "Wymiary H i W muszą być podzielne przez 2."

    yield data[:, :, :D//2, :H//2, :W//2], mask[:, :, :D//2, :H//2, :W//2]

    yield data[:, :, :D//2, :H//2, W//2:], mask[:, :, :D//2, :H//2, W//2:]

    yield data[:, :, :D//2, H//2:, :W//2], mask[:, :, :D//2, H//2:, :W//2]

    yield data[:, :, :D//2, H//2:, W//2:], mask[:, :, :D//2, H//2:, W//2:]

    yield data[:, :, D//2:, :H//2, :W//2], mask[:, :, D//2:, :H//2, :W//2]

    yield data[:, :, D//2:, :H//2, W//2:], mask[:, :, D//2:, :H//2, W//2:]

    yield data[:, :, D//2:, H//2:, :W//2], mask[:, :, D//2:, H//2:, :W//2]

    yield data[:, :, D//2:, H//2:, W//2:], mask[:, :, D//2:, H//2:, W//2:]

In [None]:
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam

def test_sizing(dshape: tuple, dtype = torch.float32):
    data = torch.randn(dshape, dtype=dtype)
    data = data.float().to(device)
    print(f'Data shape: {data.shape}')
    mask = torch.randn(dshape, dtype=dtype)
    mask = mask.float().to(device)
    print(f'Mask shape: {mask.shape}')

    model = UNet3D(1, 1)
    model = model.to(device)
    model.train()

    criterion = BCEWithLogitsLoss()
    optimizer = Adam(params=model.parameters())

    for chunked_data, chunked_mask in adjust_tensor(data, mask):
        out = model(chunked_data)
        loss = criterion(out, chunked_mask)
        loss.backward()
        optimizer.step()

test_sizing(dshape=(1,1, 240, 448, 448))

In [None]:
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam

criterion = BCEWithLogitsLoss()
optimizer = Adam(params=model.parameters())

model = model.to(device)

for epoch in range(EPOCHS):
    train_loss = 0.0
    model.train()
    i = 0
    for idx, (data, mask) in enumerate(train_dataloader):
      for part_idx, (chunked_data, chunked_mask) in enumerate(adjust_tensor(data, mask.unsqueeze(0))):
        chunked_data = chunked_data.float().to(device)
        chunked_mask = chunked_mask.float().to(device)
        optimizer.zero_grad()
        output = model(chunked_data)
        loss = criterion(output, chunked_mask)
        if i % 3 == 0:
          print(f"Scan loss: {loss}")
          output_scatter_plot((output.squeeze(0).squeeze(0) > 0).float(), chunked_mask.squeeze(0).squeeze(0), epoch, idx, part_idx)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        i += 1
    print(f"Loss: {train_loss / i}")

In [None]:
model.eval()
i = 0
for idx, (data, mask) in enumerate(test_dataloader):
    data = data.float().to(device)
    mask = mask.float().to(device)
    for part_idx, (chunked_data, chunked_mask) in enumerate(adjust_tensor(data, mask.unsqueeze(0))):
        chunked_data = chunked_data.float().to(device)
        chunked_mask = chunked_mask.float().to(device).squeeze(0).squeeze(0)
        output = model(chunked_data).squeeze(0).squeeze(0)

        evaluate(output, chunked_mask)
        if SHOW_IMAGES:
          output_scatter_plot((output > 0).float(), chunked_mask, None, idx, part_idx)