In [None]:
import numpy as np
import nibabel as nib
from pathlib import Path
import json 
from PIL import Image
from random import randint
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import math
import torch.nn.functional as F
import copy
from monai.losses import DiceCELoss
from monai.networks.nets import UNETR
from monai.transforms import (
    AsDiscrete,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandFlipd,
    RandCropByPosNegLabeld,
    RandShiftIntensityd,
    ScaleIntensityRanged,
    Spacingd,
    RandRotate90d,
)
from monai.data import (
    DataLoader,
    CacheDataset,
    load_decathlon_datalist,
    decollate_batch,
)
from tqdm import tqdm
import os
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
import pickle
import config

In [None]:
def get_default_device():
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')


def to_device(data, device):
    if isinstance(data, (list, tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)


class DeviceDataloader():
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
    def __iter__(self):
        for b in self.dl:
            yield to_device(b, self.device)
    def __len__(self):
        return len(self.dl)


In [None]:
class DataTrain(Dataset):

    def __init__(self, data, annotation):
        self.traininputtensor = torch.tensor(data, dtype=torch.float)
        self.output = torch.tensor(annotation, dtype=torch.float)
    
    def __getitem__(self, index):
        input_image = self.traininputtensor[index].unsqueeze(0)  
        output_label = self.output[index].unsqueeze(0)  
        return input_image, output_label

    def __len__(self):
        return self.traininputtensor.size(dim=0)


class DataTest(Dataset):

    def __init__(self, data, annotation):
        self.testinputtensor = torch.tensor(data, dtype=torch.float)
        self.output = torch.tensor(annotation, dtype=torch.float)

    def __getitem__(self, index):
        input_image = self.testinputtensor[index].unsqueeze(0) 
        output_label = self.output[index].unsqueeze(0) 
        return input_image, output_label

    def __len__(self):
        return self.testinputtensor.size(dim=0)


def reshape_mask(mask,  gray_step=25, max_gray=255):




def cut_data(data, annotation, cut_parameter):
    final_set_data = []
    final_set_annotation = []
    for i in range(len(data), step=cut_parameter):
        final_set_data.append(data[i:i+cut_parameter])
        final_set_annotation.append(annotation[i:i+cut_parameter])
    return [final_set_data, final_set_annotation]

In [None]:


device = get_default_device()
model = UNETR(
    in_channels=1,
    out_channels=8,
    img_size=(100, 512, 512),
    feature_size=16,
    hidden_size=768,
    mlp_dim=3072,
    num_heads=12,
    pos_embed="perceptron",
    norm_name="instance",
    res_block=True,
    dropout_rate=0.0,
).to(device)



modelwise = UNETR(   #just_to_copy_weights
    in_channels=1,
    out_channels=14,
    img_size=(96, 96, 96),
    feature_size=16,
    hidden_size=768,
    mlp_dim=3072,
    num_heads=12,
    pos_embed="perceptron",
    norm_name="instance",
    res_block=True,
    dropout_rate=0.0,
).to(device)
modelwise.load_state_dict(torch.load(os.path.join(r"/storage/brno2/home/yauheni", "best_metric_model.pth")))
model_state_dict1 = modelwise.state_dict()
for name_dst, param_dst in model.named_parameters():
    if name_dst in modelwise.state_dict():
        param_src = model.state_dict()[name_dst]
        if param_src.size() == param_dst.size():
            param_dst.data.copy_(param_src.data)
        else:
            print(f"Skipping layer {name_dst} due to size mismatch")
