In [1]:
import os
import numpy as np
from nifti_dataset import NiftiDataset, RandomCrop3D, ToTensor
import torch
from torch.utils.data import DataLoader, SubsetRandomSampler
from torchvision import transforms

import matplotlib.pyplot as plt
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
input_dir = "/home/nbaranov/projects/04_cv/MedicalImageAnalysis/data/small_data/small/"
f_size = (14,8)

t1_dir = os.path.join(input_dir, 't1')
t2_dir = os.path.join(input_dir, 't2')

In [20]:
valid_split = 0.1
batch_size = 1
n_jobs = 8
n_epochs = 50

In [21]:
tfms = transforms.Compose([RandomCrop3D((32, 32, 90)), ToTensor()])

# set up training and validation data loader for nifti images
dataset = NiftiDataset(t1_dir, t2_dir, tfms, preload=False)  # set preload=False if you have limited CPU memory
num_train = len(dataset)
indices = list(range(num_train))
val_size = int(valid_split * num_train)

In [22]:
np.random.seed(666)
valid_idx = np.random.choice(indices, size=val_size, replace=False)
train_idx = list(set(indices) - set(valid_idx))
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
train_loader = DataLoader(dataset, sampler=train_sampler, batch_size=batch_size,
                          num_workers=n_jobs, pin_memory=True)
valid_loader = DataLoader(dataset, sampler=valid_sampler, batch_size=batch_size,
                          num_workers=n_jobs, pin_memory=True)

In [23]:
from model import SimpleEncDec

model = SimpleEncDec((batch_size, 1, 32, 32, 90))
model.load_state_dict(torch.load("results/trained.pth"))
model.eval()


SimpleEncDec(
  (encoder): Sequential(
    (0): Sequential(
      (0): Conv3d(1, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1))
      (1): ReLU()
      (2): AdaptiveMaxPool3d(output_size=(16, 16, 45))
    )
    (1): Sequential(
      (0): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1))
      (1): ReLU()
      (2): AdaptiveMaxPool3d(output_size=(8, 8, 23))
    )
  )
  (decoder): Sequential(
    (0): Sequential(
      (0): ConvTranspose3d(32, 16, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), output_padding=(1, 1, 1))
      (1): ReLU()
    )
    (1): Sequential(
      (0): ConvTranspose3d(16, 1, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 2), output_padding=(1, 1, 1))
      (1): ReLU()
    )
  )
)

In [24]:
from sklearn.metrics import mean_absolute_error, mean_squared_error

preds = []
with torch.no_grad():
    for src, tgt in valid_loader:
        src, tgt = src.float(), tgt.float()

        pred = model(src)
        mse = mean_absolute_error(pred.detach().numpy(), tgt.detach().numpy())
        print(mse)
        # preds.append(pred.detach().numpy())

ValueError: Found array with dim 5. Estimator expected <= 2.

In [25]:
src.shape

torch.Size([1, 1, 32, 32, 90])