In [1]:
# This mounts your Google Drive to the Colab VM.
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

### Change to current directory ###
path ='/content/drive/MyDrive/Deep_learning/project/TransUnet_copy/'
transunet_model_path = '/content/drive/MyDrive/Deep_learning/project/trained_models/TransUNET_model - 20 epochs'
dino_model_path = '/content/drive/MyDrive/Deep_learning/project/trained_models/DINO_TransUNET_model - 20 epochs'


Mounted at /content/drive


In [2]:
%cd drive/MyDrive/Deep_learning/project/TransUnet_copy/

# Installs
!pip install ml_collections medpy

# Imports
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.modules.loss import CrossEntropyLoss
from torch.utils.data import DataLoader
import torch.backends.cudnn as cudnn
from tqdm import tqdm
from utils.utils import DiceLoss
from torchvision import transforms
from networks.vit_seg_modeling import VisionTransformer as ViT_seg
from networks.vit_seg_modeling import CONFIGS as CONFIGS_ViT_seg
from datasets.dataset_sartorius import Sartorius_dataset, RandomGenerator
from google.colab.patches import cv2_imshow

/content/drive/MyDrive/Deep_learning/project/TransUnet_copy
Collecting ml_collections
  Downloading ml_collections-0.1.1.tar.gz (77 kB)
[K     |████████████████████████████████| 77 kB 3.2 MB/s 
[?25hCollecting medpy
  Downloading MedPy-0.4.0.tar.gz (151 kB)
[K     |████████████████████████████████| 151 kB 21.1 MB/s 
Collecting SimpleITK>=1.1.0
  Downloading SimpleITK-2.1.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (48.4 MB)
[K     |████████████████████████████████| 48.4 MB 1.9 MB/s 
[?25hBuilding wheels for collected packages: ml-collections, medpy
  Building wheel for ml-collections (setup.py) ... [?25l[?25hdone
  Created wheel for ml-collections: filename=ml_collections-0.1.1-py3-none-any.whl size=94524 sha256=2512e6eaf2bb1803d31e1996616b0dfadd7d468629509efea672897fb7d23f86
  Stored in directory: /root/.cache/pip/wheels/b7/da/64/33c926a1b10ff19791081b705879561b715a8341a856a3bbd2
  Building wheel for medpy (setup.py) ... [?25l[?25hdone
  Created wheel for medp

In [3]:
# Dataset Definitions
train_base_dir = '/content/drive/MyDrive/Deep_learning/project/data/train_npz'
test_base_dir = '/content/drive/MyDrive/Deep_learning/project/data/test_npz'
list_dir = path + 'lists/lists_Sartorius/'
img_size = 224    # original image size [520, 704]
rand_seed = 1234
num_classes = 2
n_skip = 3
vit_name = 'R50-ViT-B_16'
vit_patches_size = 16
z_spacing = 1
deterministic = True



db_train = Sartorius_dataset(base_dir=train_base_dir, list_dir=list_dir, split="train",transform=transforms.Compose([RandomGenerator(output_size=[img_size, img_size])]))

db_test = Sartorius_dataset(base_dir=test_base_dir, split="test", list_dir=list_dir,transform=transforms.Compose([RandomGenerator(output_size=[img_size, img_size])]))

def worker_init_fn(worker_id):
    random.seed(rand_seed + worker_id)


In [4]:
# Set random seeds for reproducibility
random.seed(rand_seed)
np.random.seed(rand_seed)
torch.manual_seed(rand_seed)
torch.cuda.manual_seed(rand_seed)

# Set model Configuration
config_vit = CONFIGS_ViT_seg[vit_name]
config_vit.n_classes = num_classes
config_vit.n_skip = n_skip
if vit_name.find('R50') != -1:
    config_vit.patches.grid = (int(img_size / vit_patches_size), int(img_size / vit_patches_size))

# Load models
dino_model = torch.load(dino_model_path)
#transunet_model = torch.load(transunet_model_path) 

dino_model.train()
#transunet_model.eval()

VisionTransformer(
  (transformer): Transformer(
    (embeddings): Embeddings(
      (hybrid_model): ResNetV2(
        (root): Sequential(
          (conv): StdConv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
          (gn): GroupNorm(32, 64, eps=1e-06, affine=True)
          (relu): ReLU(inplace=True)
        )
        (body): Sequential(
          (block1): Sequential(
            (unit1): PreActBottleneck(
              (gn1): GroupNorm(32, 64, eps=1e-06, affine=True)
              (conv1): StdConv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (gn2): GroupNorm(32, 64, eps=1e-06, affine=True)
              (conv2): StdConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (gn3): GroupNorm(32, 256, eps=1e-06, affine=True)
              (conv3): StdConv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (relu): ReLU(inplace=True)
              (downsample): StdConv2d(64, 256, 

In [5]:
checkloader = DataLoader(db_train, batch_size=25, shuffle=True, num_workers=1, pin_memory=True,
                        worker_init_fn=worker_init_fn)


In [6]:
def test_sartorius():

    dino_model.train()
    ce_loss = CrossEntropyLoss()
    dice_loss = DiceLoss(num_classes)
    
    iter_num = 0

    loss = 0


    for i_batch, sampled_batch in enumerate(checkloader):

        image_batch, label_batch = sampled_batch['image'], sampled_batch['label']
        image_batch, label_batch = image_batch.cuda(), label_batch.cuda()

        outputs = dino_model(image_batch)

        loss_ce = ce_loss(outputs, label_batch[:].long())
        loss_dice = dice_loss(outputs, label_batch, softmax=True)
        curr_loss = 0.5 * loss_ce + 0.5 * loss_dice
        
        loss += curr_loss

        iter_num = iter_num + 1

        print(f"iteration {iter_num} : loss : {curr_loss.item()}, loss_ce: {loss_ce.item()}")

    print(f"\nmean loss after {iter_num} iterations is - {loss.item()}")

In [7]:
test_sartorius()

iteration 1 : loss : 0.2710210084915161, loss_ce: 0.2610531151294708


RuntimeError: ignored