In [1]:
import sys
from pathlib import Path

sys.path.append(str(Path.cwd().parent))

from torchsummary import summary

import matplotlib.pyplot as plt
from hydra import compose, initialize
from omegaconf import OmegaConf

from tqdm import tqdm
import random
import numpy as np

import matplotlib.patches as mpatches
from typing import List, Optional, Tuple

import torch
import hydra
import pyrootutils
import pytorch_lightning as pl
from omegaconf import DictConfig
from pytorch_lightning import (Callback, LightningDataModule, LightningModule,
                               Trainer)
from pytorch_lightning.loggers import TensorBoardLogger

# pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)

import src.utils.default as utils

log = utils.get_pylogger(__name__)

torch.set_float32_matmul_precision('medium')
# torch.autograd.set_detect_anomaly(True)

from src.models.simclr import SimCLR

from src.models.unet3d.model_encoders import UNet3D as UNet3D_Encoder


import SimpleITK as sitk

sitk.ProcessObject_SetGlobalWarningDisplay(False)
test_path = '/mrhome/vladyslavz/Pictures/test_imgs'

In [2]:
segm_cfg = OmegaConf.load('/mrhome/vladyslavz/git/central-sulcus-analysis/sulci_segm_logs/CS1x_tversky_BVISA_SST_monai_PRETRAINED/runs/2023-04-10_16-04-53/.hydra/config.yaml')
CHKP = '/mrhome/vladyslavz/git/central-sulcus-analysis/sulci_segm_logs/CS1x_tversky_BVISA_SST_monai_PRETRAINED/runs/2023-04-10_16-04-53/checkpoints/epoch-146-Esubj-0.4199.ckpt'


segm_datamodule: LightningDataModule = hydra.utils.instantiate(segm_cfg.data)

segm_model: LightningModule = hydra.utils.instantiate(segm_cfg.model,
                                                      freeze_encoder=False,
                                                      monai=True)

print(segm_cfg.data)

2023-04-11 11:23:59,786 - Len of train examples 38 len of validation examples 12
Loading encoder weights from checkpoint...
/mrhome/vladyslavz/git/central-sulcus-analysis/sulci_segm_logs/sst-bvisa-1x-monai-BasicUnet/runs/2023-04-10_11-38-09/checkpoints/epoch-108_val_loss-0.078.ckpt
U-Net Embedding dimension: 458752


Attribute 'encoder' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['encoder'])`.
Attribute 'net' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['net'])`.
Attribute 'loss_function' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['loss_function'])`.


{'_target_': 'src.data.bvisa_augm_dm.CS_DataModule', 'dataset_cfg': {'dataset': 'bvisa', 'target': 'central_sulcus', 'dataset_path': '/mrhome/vladyslavz/git/central-sulcus-analysis/data/brainvisa_augm/nobackup/generated', 'use_half_brain': False, 'resample': None, 'crop2content': False, 'padd2same_size': '256-124-256'}, 'train_batch_size': 1, 'validation_batch_size': 1, 'num_workers': 1, 'double_validation': True}


In [3]:
segm_model = segm_model.load_from_checkpoint(CHKP,
                                             freeze_encoder=False,
                                             monai=True)

Loading encoder weights from checkpoint...
/mrhome/vladyslavz/git/central-sulcus-analysis/sulci_segm_logs/sst-bvisa-1x-monai-BasicUnet/runs/2023-04-10_11-38-09/checkpoints/epoch-108_val_loss-0.078.ckpt
U-Net Embedding dimension: 458752


In [4]:
idx = 0

val_sample = segm_datamodule.val_dataset[idx]

img = val_sample['image']
target = val_sample['target']
caseid = val_sample['caseid']
sitk_img = sitk.ReadImage(segm_datamodule.val_dataset.img_paths[idx][0])



In [5]:
with torch.no_grad():
    segm_pred = segm_model(img.unsqueeze(0))

In [11]:
sitk_img.GetSize()

(256, 256, 124)

In [12]:
segm_pred_sitk.GetSize()

(256, 124, 256)

In [13]:
segm_pred_bin = torch.softmax(segm_pred, dim=1)[:,:1,:,:,:].squeeze(0).squeeze(0).numpy()

segm_pred_sitk = sitk.GetImageFromArray(segm_pred_bin)
target_img = sitk.GetImageFromArray(target.numpy().astype(np.uint8))

# segm_pred_sitk.CopyInformation(sitk_img)
# target_img.CopyInformation(sitk_img)

In [19]:
sitk_img = sitk.GetImageFromArray(img[0].numpy().astype(np.float32))

In [20]:
sitk.WriteImage(segm_pred_sitk, f'{test_path}/segm_pred_sitk.nii.gz')
sitk.WriteImage(sitk_img, f'{test_path}/img.nii.gz')
sitk.WriteImage(target_img, f'{test_path}/target.nii.gz')

In [30]:
imgs = [x for x in Path('/mrhome/vladyslavz/git/SynthSeg/data/training_label_maps').glob('**/training_seg*.nii.gz')]

In [31]:
for i in tqdm(imgs):
    img = sitk.ReadImage(str(i))
    img_array = sitk.GetArrayFromImage(img)

    img_array = (img_array ==24).astype(np.int16)
    img_array_csf = sitk.GetImageFromArray(img_array)
    img_array_csf.CopyInformation(img)
    
    out = str(i).replace('training_seg', 'csf_mask')
    sitk.WriteImage(img_array_csf, out)

100%|██████████| 20/20 [00:03<00:00,  5.54it/s]


In [28]:
out

'/mrhome/vladyslavz/git/central-sulcus-analysis/data/synthseg_corrected/nobackup/training_seg_20/csf_mask_99.nii.gz'

In [29]:
img_array.sum()

0

In [35]:
from src.models.UNet3D import  BasicUNet3D

In [32]:
simclr = SimCLR.load_from_checkpoint('/mrhome/vladyslavz/git/central-sulcus-analysis/sulci_segm_logs/sst-bvisa-1x-monai-BasicUnet/runs/2023-04-10_11-38-09/checkpoints/epoch-108_val_loss-0.078.ckpt')

Attribute 'encoder' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['encoder'])`.


U-Net Embedding dimension: 458752


In [36]:
unet = BasicUNet3D.load_from_checkpoint('/mrhome/vladyslavz/git/central-sulcus-analysis/sulci_segm_logs/CS1x_tversky_BVISA_SST_monai_PRETRAINED_frozenENCODER_orientCOrrect/runs/2023-04-11_10-38-38/checkpoints/epoch-280-Esubj-0.4681.ckpt')

Loading encoder weights from checkpoint...
/mrhome/vladyslavz/git/central-sulcus-analysis/sulci_segm_logs/sst-bvisa-1x-monai-BasicUnet/runs/2023-04-10_11-38-09/checkpoints/epoch-108_val_loss-0.078.ckpt


Attribute 'encoder' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['encoder'])`.


U-Net Embedding dimension: 458752
Freezing encoder...


Attribute 'net' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['net'])`.
Attribute 'loss_function' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['loss_function'])`.


In [37]:
simclr.encoder.conv_0 = unet.net.conv_0 
simclr.encoder.down_1 = unet.net.down_1 
simclr.encoder.down_2 = unet.net.down_2 
simclr.encoder.down_3 = unet.net.down_3 
simclr.encoder.down_4 = unet.net.down_4 

In [41]:

trainer = Trainer()
trainer.fit(simclr)

The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /mrhome/vladyslavz/anaconda3/envs/css/lib/python3.10 ...
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
GPU available but not used. Set `accelerator` and `devices` using `Trainer(accelerator='gpu', devices=1)`.


ValueError: An invalid dataloader was passed to `Trainer.fit(train_dataloaders=...)`. Either pass the dataloader to the `.fit()` method OR implement `def train_dataloader(self):` in your LightningModule/LightningDataModule.

In [None]:
simclr.save()