In [1]:
from utils.fetch_dataset import fetch_dataset
from utils.quality_measures import *
from utils.save_plots import output_scatter_plot
from utils.adjust_tensor import adjust_tensor

from dataset import MRIDataset
from unet import UNet3D

In [2]:
fetch_dataset()

Downloading...
From (original): https://drive.google.com/uc?id=1aIzlyrPTgrzLKwGS1_3alvNJKgLvbGQ3
From (redirected): https://drive.google.com/uc?id=1aIzlyrPTgrzLKwGS1_3alvNJKgLvbGQ3&confirm=t&uuid=fc10ab53-6a30-4ac2-96ef-6a96bf8b8930
To: /content/classes.zip
100%|██████████| 6.94M/6.94M [00:00<00:00, 96.4MB/s]
Downloading...
From (original): https://drive.google.com/uc?id=1ki-aYI07KEbi7mWsPWxeAvmMfVsWrBCq
From (redirected): https://drive.google.com/uc?id=1ki-aYI07KEbi7mWsPWxeAvmMfVsWrBCq&confirm=t&uuid=c764e275-c299-4d98-ab01-6813980b5cef
To: /content/patients.zip
100%|██████████| 2.27G/2.27G [00:18<00:00, 124MB/s]


In [3]:
BATCH_SIZE = 1
IN_CHANNELS = 1
EPOCHS = 1

SHOW_IMAGES = True

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [5]:
import torch
import torch.nn.functional as F


def adjust_tensor(data, mask):
    _, _, D, H, W = data.shape

    assert H % 2 == 0 and W % 2 == 0, "Wymiary H i W muszą być podzielne przez 2."

    yield data[:, :, :D//2, :H//2, :W//2], mask[:, :, :D//2, :H//2, :W//2]

    yield data[:, :, :D//2, :H//2, W//2:], mask[:, :, :D//2, :H//2, W//2:]

    yield data[:, :, :D//2, H//2:, :W//2], mask[:, :, :D//2, H//2:, :W//2]

    yield data[:, :, :D//2, H//2:, W//2:], mask[:, :, :D//2, H//2:, W//2:]

    yield data[:, :, D//2:, :H//2, :W//2], mask[:, :, D//2:, :H//2, :W//2]

    yield data[:, :, D//2:, :H//2, W//2:], mask[:, :, D//2:, :H//2, W//2:]

    yield data[:, :, D//2:, H//2:, :W//2], mask[:, :, D//2:, H//2:, :W//2]

    yield data[:, :, D//2:, H//2:, W//2:], mask[:, :, D//2:, H//2:, W//2:]

In [33]:
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam

def train_model(train_dataloader, model):

  criterion = BCEWithLogitsLoss()
  optimizer = Adam(params=model.parameters())

  model = model.to(device)

  for epoch in range(EPOCHS):
      train_loss = 0.0
      model.train()
      i = 0
      for idx, (data, mask) in enumerate(train_dataloader):
        for part_idx, (chunked_data, chunked_mask) in enumerate(adjust_tensor(data, mask.unsqueeze(0))):
          chunked_data = chunked_data.float().to(device)
          chunked_mask = chunked_mask.float().to(device)
          optimizer.zero_grad()
          output = model(chunked_data)
          loss = criterion(output, chunked_mask)
          if i % 3 == 0:
            print(f"Scan loss: {loss}")
            output_scatter_plot((output.squeeze(0).squeeze(0) > 0).float(), chunked_mask.squeeze(0).squeeze(0), epoch, idx, part_idx)
          print("A")
          loss.backward()
          print("B")
          optimizer.step()

          train_loss += loss.item()
          i += 1

      print(f"Loss: {train_loss / i}")
  return model

In [34]:
def eval_model(test_dataloader, model):
  model.eval()
  i = 0
  for idx, (data, mask) in enumerate(test_dataloader):
      data = data.float().to(device)
      mask = mask.float().to(device)
      for part_idx, (chunked_data, chunked_mask) in enumerate(adjust_tensor(data, mask.unsqueeze(0))):
          chunked_data = chunked_data.float().to(device)
          chunked_mask = chunked_mask.float().to(device).squeeze(0).squeeze(0)
          output = model(chunked_data).squeeze(0).squeeze(0)

          evaluate(output, chunked_mask)

          if SHOW_IMAGES:
            output_scatter_plot((output > 0).float(), chunked_mask, None, idx, part_idx)

  return model

In [None]:
from torch.utils.data import DataLoader

channel_labels = ['ch1', 'ch2', 'ch3']

for channel_label in channel_labels:
  print(f"Channel: {channel_label}")
  dataset = MRIDataset(patients_dir = f"./patients/{channel_label}")
  train_dataset, test_dataset = torch.utils.data.random_split(dataset, [0.8, 0.2])
  train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
  test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)
  model = UNet3D(IN_CHANNELS, 1)
  model = train_model(train_dataloader, model)
  model = eval_model(test_dataloader, model)