In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import os

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

In [None]:
import torch

In [None]:
for i in range(torch.cuda.device_count()):
   print(f"{i}:", torch.cuda.get_device_properties(i).name)

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
if device.startswith("cuda"):
    print(torch.cuda.get_device_name(device))

In [None]:
from pathlib import Path
from tqdm.auto import tqdm
from objprint import objstr
from datetime import datetime

In [None]:
from src import utils
from src.utils import Logger, same_seeds, load_config
from src.loader import get_dataloader
from src.SlimUNETR.SlimUNETR import SlimUNETR

from accelerate import Accelerator

In [None]:
config, data_flag, is_HepaticVessel = load_config()
config.trainer.batch_size = 3
data_flag

In [None]:
same_seeds(config.trainer.seed)
logging_dir = Path(os.getcwd()) / "logs" / str(datetime.now()).replace(":","_")
accelerator = Accelerator(
    cpu=False, log_with=["tensorboard"], project_dir=str(logging_dir)
)
Logger(logging_dir if accelerator.is_local_main_process else None)
accelerator.init_trackers('main')
accelerator.print(objstr(config))

accelerator.print("Load Model...")
model = SlimUNETR(**config.slim_unetr)
model.to(device)
image_size = config.trainer.image_size

accelerator.print("Load Dataloader...")
train_loader, val_loader, unlab_loader = get_dataloader(config, data_flag, needs_unlab=True)

In [None]:
base_exp_path_save = Path('D:\\Study\\Аспирантура\\experiments\\Slim-unetr\\tbad_dataset_unlab_stages_with_tflab\\seed25\\epoch800\\use_tfTrue\\ims_128_rot_prob0.8_lrelu_split_new_class_GDFL_g2.0_fr08_fw080915_unlab_ratio0.5_unlab_weight0.3_start_unlab_epoch200')


In [None]:
model, starting_epoch, step, val_step = utils.resume_train_state(
    model, base_exp_path_save, train_loader, accelerator, epoch=-1
)
print("Resuming training from epoch {}".format(starting_epoch))

In [None]:
starting_epoch

In [None]:
for i, image_batch in enumerate(tqdm(val_loader)):
    break

In [None]:
image = image_batch["image"]
label = image_batch["label"]
pred = model(image.to(device)).to('cpu').sigmoid()

In [None]:
image.shape

In [None]:
label.shape

In [None]:
import numpy as np
from matplotlib import pyplot as plt

from lab_unlab_trainer import Transforms

In [None]:
transform = Transforms(flip_prob=1, rot_prob=1, rot_range_z=0.4*np.pi)

In [None]:
image_tf = transform(image)
label_tf = transform(label, randomize=False)

In [None]:
image_tf_inv = transform.inverse(image_tf)

In [None]:
ind = 1
slice_num = 2
fig, ax = plt.subplots(slice_num, 4, figsize=(40, 20))
for i in list(range(slice_num)):
    for j in range(4):
        ax[i, j].imshow(
            image_tf[ind][0, :, :, j + i*8] + label_tf[ind][0, :, :, j + i*8] / 2, 
            cmap="gray"
        )
        ax[i, j].axis('off')

In [None]:
ind = 1
slice_num = 2
fig, ax = plt.subplots(slice_num, 4, figsize=(40, 20))
for i in list(range(slice_num)):
    for j in range(4):
        ax[i, j].imshow(
            image_tf_inv[ind][0, :, :, j + i*8] + label[ind][0, :, :, j + i*8] / 2, 
            cmap="gray"
        )
        ax[i, j].axis('off')