In [2]:
import json, os, pathlib as p
import nibabel as nib
import numpy as np
import random
from monai.networks.nets import UNet
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import pickle
from model import UnetGenerator3D
from dataset import TrainDataset
from preprocessing import create_LR_img, scale_to_reference_img, pad_to_shape, extract_3D_patches, reconstruct_from_patches, split_dataset, create_and_save_LR_imgs, get_patches


In [3]:
DATA_DIR = p.Path.home()/"data"/"bobsrepository"
t1_files = sorted(DATA_DIR.rglob("*T1w.nii.gz"))
t2_files = sorted(DATA_DIR.rglob("*T2w.nii.gz"))
t2_LR_files = sorted(DATA_DIR.rglob("*T2w_LR.nii.gz"))

#t2_LR_files = create_and_save_LR_imgs(t2_files, scale_factor=2, output_dir=DATA_DIR/"LR")



In [4]:

files = list(zip(t1_files, t2_files, t2_LR_files))

#SPLIT DATASET
train, val, test = split_dataset(files)


71


In [5]:


#EXTRACT PATCHES

patch_size = (64, 64, 64)
stride = (32, 32, 32)
ref_img = nib.load(str(t1_files[0]))
target_shape = (192, 224, 192) 

train_t1, train_t2, train_t2_LR = get_patches(train, patch_size, stride, target_shape, ref_img)
val_t1, val_t2, val_t2_LR = get_patches(val, patch_size, stride, target_shape, ref_img)
test_t1, test_t2, test_t2_LR = get_patches(test, patch_size, stride, target_shape, ref_img)


In [None]:

t1_input = []
t2_output = []
t2_LR_input = []

for t1_file, t2_file, t2_LR_file in train:
    #scaling to reference image
    t1_img = scale_to_reference_img(nib.load(t1_file), ref_img)
    t2_img = scale_to_reference_img(nib.load(t2_file), ref_img)
    t2_LR_img = scale_to_reference_img(nib.load(t2_LR_file), ref_img)
    #padding to be divisible by patch size
    t1_img = pad_to_shape(t1_img, target_shape)
    t2_img = pad_to_shape(t2_img, target_shape)
    t2_LR_img = pad_to_shape(t2_LR_img, target_shape)
    #extracting patches
    t1_patches = extract_3D_patches(t1_img.get_fdata(), patch_size, stride)
    t2_patches = extract_3D_patches(t2_img.get_fdata(), patch_size, stride)
    t2_LR_patches = extract_3D_patches(t2_LR_img.get_fdata(), patch_size, stride)
    #saving patches
    t1_input.append(t1_patches)
    t2_output.append(t2_patches)
    t2_LR_input.append(t2_LR_patches)


In [7]:
print(len(train_t1))  # Number of training samples
print(len(train_t1[0]))  # Number of patches in the first training sample
print(train_t1[0][0].shape)  # Shape of the first patch, should be (64, 64, 64)

49
150
(64, 64, 64)


In [None]:
#NETWORK TRAINING

batch_size = 4
shuffle = True
# Flatten train data into a single list of patches
input_1 = [patch for img_patches in train_t1 for patch in img_patches]
input_2 = [patch for img_patches in train_t2_LR for patch in img_patches]
output = [patch for img_patches in train_t2 for patch in img_patches]

train_dataset = TrainDataset(input_1, input_2, output)
train_loader = DataLoader(train_dataset, batch_size, shuffle)

net = 
