In [2]:
import numpy as np
import nibabel as nib
from pathlib import Path
import json 
from PIL import Image
from random import randint
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import math
import torch.nn.functional as F
import copy
from monai.losses import DiceCELoss
from monai.networks.nets import UNETR
from monai.transforms import (
    AsDiscrete,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandFlipd,
    RandCropByPosNegLabeld,
    RandShiftIntensityd,
    ScaleIntensityRanged,
    Spacingd,
    RandRotate90d,
)
from monai.data import (
    DataLoader,
    CacheDataset,
    load_decathlon_datalist,
    decollate_batch,
)
from tqdm import tqdm
import os
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
import pickle

  from torch.distributed.optim import ZeroRedundancyOptimizer


In [4]:
def get_default_device():
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')


def to_device(data, device):
    if isinstance(data, (list, tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)


class DeviceDataloader():
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
    def __iter__(self):
        for b in self.dl:
            yield to_device(b, self.device)
    def __len__(self):
        return len(self.dl)


In [5]:

                                                                ####################################################
                                                                ######## Reading data from my local PC #############





masks_path = Path(r'C:\working_space\pilsen_pigs_2023_cvat_backup\masks_for_all_organs')
data_path = Path(r'C:\working_space\pilsen_pigs_2023_cvat_backup\workspase')

sorted_pictures = sorted(
        [d.name for d in data_path.iterdir() if d.is_dir() and d.name.isdigit()],
        key=lambda x: int(x)
    )

sorted_masks = sorted(
        [d.name for d in masks_path.iterdir() if d.is_dir() and d.name.isdigit()],
        key=lambda x: int(x)
    )



images_3d_full = []
for image_3d in sorted_pictures:
    images_path = Path(fr'C:\working_space\pilsen_pigs_2023_cvat_backup\workspase\{image_3d}\data')
    images = []
    for image in images_path.iterdir():
        img = Image.open(image)
        img = img.convert('L')
        img = np.array(img)
        images.append(img)
    images_3d_full.append(np.array(images))
print("Shape of the first 3D image: ", images_3d_full[0].shape)

masks_3d_full = []
for mask_3d in sorted_masks:
    masks_path = Path(fr'C:\working_space\pilsen_pigs_2023_cvat_backup\masks_for_all_organs\{mask_3d}')
    masks = []
    for mask in masks_path.iterdir():
        img = Image.open(mask)
        img = np.array(img)
        masks.append(img)
    masks_3d_full.append(np.array(masks))
print("Shape of the first 3D mask", masks_3d_full[0].shape)
 


Shape of the first 3D image:  (882, 512, 512)
Shape of the first 3D mask (882, 512, 512)


In [11]:
                                                                ########################################################
                                                                #Processing data from ny local PC and using dataloaders#


class DataTrain(Dataset):

    def __init__(self, data, annotation):
        self.traininputtensor = torch.tensor(data, dtype=torch.float)
        self.output = torch.tensor(annotation, dtype=torch.float)
    
    def __getitem__(self, index):
        input_image = self.traininputtensor[index].unsqueeze(0)  
        output_label = self.output[index].unsqueeze(0)  
        return input_image, output_label

    def __len__(self):
        return self.traininputtensor.size(dim=0)


class DataTest(Dataset):

    def __init__(self, data, annotation):
        self.testinputtensor = torch.tensor(data, dtype=torch.float)
        self.output = torch.tensor(annotation, dtype=torch.float)

    def __getitem__(self, index):
        input_image = self.testinputtensor[index].unsqueeze(0) 
        output_label = self.output[index].unsqueeze(0) 
        return input_image, output_label

    def __len__(self):
        return self.testinputtensor.size(dim=0)


def reshape_mask(mask, num_clases=8, depth=100):
    """We wanted to convert our 3d masks that have shape (1,D,H,W) to the shape (num_classes,D,H,W) """
    new_mask = np.zeros((num_clases, depth, 512, 512), dtype=np.uint8)

    for z in range(depth):  #iteration over 'z' axis 
        for y in range(512):    #iteration over 'y' axis
            for x in range(512):    #iteration over 'x' axis
                if mask[z,x,y] != 0:

                    value = mask[x,y]
                    trida = value/25
                    new_mask[trida, z, x, y] = 1

    return new_mask



def cut_data(data, z_shape=100):
    """Metacentrum haven't got enough memory to process one full image, that's why it will be cutted """
    cutted_data = []

    for i in range(0, data.shape[0], z_shape):
        new_data = data[i:i+z_shape, :, :]
        cutted_data.append(new_data)

    #for neural network we must have same dimension
    if len(cutted_data[-1]) < z_shape:
        cutted_data.pop(-1)

    return np.array(cutted_data)


#experement with one 3d picture
one_3d_picture = images_3d_full[0]
one_3d_mask = masks_3d_full[0]

cutted_picture = cut_data(one_3d_picture)
cutted_mask = cut_data(one_3d_mask)

reshaped_mask = reshape_mask(cutted_mask[0])
print("Shape for changed mask: ", reshaped_mask.shape)



Shape for changed mask:  (8, 100, 512, 512)


In [10]:


device = get_default_device()
""" model = UNETR(
    in_channels=1,
    out_channels=8,
    img_size=(100, 512, 512),
    feature_size=16,
    hidden_size=768,
    mlp_dim=3072,
    num_heads=12,
    pos_embed="perceptron",
    norm_name="instance",
    res_block=True,
    dropout_rate=0.0,
).to(device) """



modelwise = UNETR(   #just_to_copy_weights
    in_channels=1,
    out_channels=14,
    img_size=(96, 96, 96),
    feature_size=16,
    hidden_size=768,
    mlp_dim=3072,
    num_heads=12,
    pos_embed="perceptron",
    norm_name="instance",
    res_block=True,
    dropout_rate=0.0,
).to(device)
""" modelwise.load_state_dict(torch.load(os.path.join(r"/storage/brno2/home/yauheni", "best_metric_model.pth")))
model_state_dict1 = modelwise.state_dict()
for name_dst, param_dst in model.named_parameters():
    if name_dst in modelwise.state_dict():
        param_src = model.state_dict()[name_dst]
        if param_src.size() == param_dst.size():
            param_dst.data.copy_(param_src.data)
        else:
            print(f"Skipping layer {name_dst} due to size mismatch") """


' modelwise.load_state_dict(torch.load(os.path.join(r"/storage/brno2/home/yauheni", "best_metric_model.pth")))\nmodel_state_dict1 = modelwise.state_dict()\nfor name_dst, param_dst in model.named_parameters():\n    if name_dst in modelwise.state_dict():\n        param_src = model.state_dict()[name_dst]\n        if param_src.size() == param_dst.size():\n            param_dst.data.copy_(param_src.data)\n        else:\n            print(f"Skipping layer {name_dst} due to size mismatch") '

In [12]:
tensor = torch.randn(1, 1, 96, 96, 96) #now we know shape of tensor we want 
out = modelwise(tensor)
out.shape

torch.Size([1, 14, 96, 96, 96])