In [1]:
import yaml
import numpy as np
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# check to confirm if only 2 gpus are being utilized
print(torch.cuda.device_count())

2


### load a model

In [2]:
import os
os.chdir("/home/ws/er5241/Repos/training-repo/")

from train import get_model, load_weights

model_params = {'network_type': 'Unet', 'encoder_name': 'resnet18', 'classes': 10, 'in_channels': 1}
model = get_model(**model_params).to(device)
model.train(False)
print(f'model in training: {model.training}')

# initial brain model (teacher model)
# evaluated with batch size 256
chkpnt = {'checkpoint_addr': '/home/ws/er5241/logs/medaka-supervised/2021-09-08_23-31-03/logdir/checkpoints/last.pth', 'checkpoint_key': 'model_state_dict'}
load_weights(model, **chkpnt)

# # model trained on enlarged dataset (student model)
# # evaluated with batch size 256
# chkpnt = {'checkpoint_addr': '/home/ws/er5241/logs_old/medaka-supervised/2022-05-03_11-15-57/logdir/checkpoints/last.pth', 'checkpoint_key': 'model_state_dict'}
# load_weights(model, **chkpnt)

>>>>>>>>>>>>>>>> initializing model
trainable parameters in model: 14323242
<<<<<<<<<<<<<<<< done in 0.27 sec.
model in training: False
>>>>>>>>>>>>>>>> loading checkpoint
<<<<<<<<<<<<<<<< done in 0.059 sec.


### create a config to form loaders

In [3]:
cfg = {'gatherers': [{'matcher_kwargs': {'targets': '/mnt/HD-LSDF/Medaka/segmentations/workshop/brain_decropped/{}.tif',
                                        'volumes': '/mnt/HD-LSDF/Medaka/segmentations/workshop/medaka_decropped/{}.tif'},
                                        'seed': 42}],
       
       'datasets': [
           {'dataset_names': 'default',
            'dataset_kwargs': {'mode_3d': False, 'use_ram': True, 'localised_crop': True}},
         
           {'dataset_names': 'train',
              'aug_name': 'medium_aug_rot', 'dataset_kwargs': {'crop_size': 256}, 
              'dataset_rebalance_function_name': 'TVSD_dataset_resample', 'dataloader_kwargs': {'batch_size': 64}},
         
           {'dataset_names': 'valid',
              'dataset_kwargs': {'crop_size': 512}, 'dataloader_kwargs': {'batch_size': 32}}]}

In [4]:
import os
os.chdir('/home/ws/er5241/Repos/training-repo')

import src.loaders
import src.callbacks

In [5]:
loaders = getattr(src.loaders, 'generic_loaders')(**cfg, seed=42)

train_loader = loaders[0]['train']
val_loader = loaders[0]['valid']

seed = 42


getting TVSD datasets:   0%|          | 0/23 [00:00<?, ?it/s]

resampling TVSD datasets:   0%|          | 0/23 [00:00<?, ?it/s]

{'drop_last': True, 'shuffle': True, 'num_workers': 16, 'pin_memory': True, 'batch_size': 64}


getting TVSD datasets:   0%|          | 0/8 [00:00<?, ?it/s]

{'drop_last': True, 'shuffle': True, 'num_workers': 16, 'pin_memory': True, 'batch_size': 32}


In [6]:
ious = []

for i, (img_batch, lbl_batch) in enumerate(tqdm(val_loader)):
  pred_batch = model(img_batch.to(device))
  iou_vals = src.callbacks.get_iou(pred_batch, lbl_batch.to(device))
  
  ious.append(iou_vals)

  0%|          | 0/507 [00:01<?, ?it/s]

In [7]:
np.mean(ious)

83.415422862889