# Mounting Google Drive and getting missing libraries

In [2]:
import os
try:
  from google.colab import drive
  drive.mount('/content/drive')
except:
  pass
try:
  os.chdir("drive/MyDrive/ViT_Lung_Cancer-main")
except:
  pass

In [None]:
# %pip install ml_collections
# %pip install einops
# %pip install monai
import monai
# monai.config.print_config()

# Parts of Transformer

In [2]:
# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import copy
import logging
import math

#from os.path import join as pjoin

import torch
import torchvision
import torch.nn as nn
import numpy as np

from torchvision.transforms import Resize
from torch.nn import CrossEntropyLoss, MSELoss, Dropout, Softmax, Linear, Conv2d, LayerNorm, Conv3d, AdaptiveAvgPool3d, MultiLabelSoftMarginLoss, BCELoss, Sigmoid, Conv1d, BCEWithLogitsLoss
from torch.nn.modules.utils import _pair
from scipy import ndimage

import models.configs as configs

# from models.modeling_resnet import ResNetV2
# from models.coatnet import CoAtNet


logger = logging.getLogger(__name__)


ATTENTION_Q = "MultiHeadDotProductAttention_1/query"
ATTENTION_K = "MultiHeadDotProductAttention_1/key"
ATTENTION_V = "MultiHeadDotProductAttention_1/value"
ATTENTION_OUT = "MultiHeadDotProductAttention_1/out"
FC_0 = "MlpBlock_3/Dense_0"
FC_1 = "MlpBlock_3/Dense_1"
ATTENTION_NORM = "LayerNorm_0"
MLP_NORM = "LayerNorm_2"

def np2th(weights, conv=False):
    """Possibly convert HWIO to OIHW."""
    if conv:
        weights = weights.transpose([3, 2, 0, 1])
    return torch.from_numpy(weights)


def swish(x):
    return x * torch.sigmoid(x)

ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}

In [3]:
class Attention(nn.Module):
    def __init__(self, config, vis):
        super(Attention, self).__init__()
        self.vis = vis
        self.num_attention_heads = config.transformer["num_heads"]
        self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = Linear(config.hidden_size, self.all_head_size)
        self.key = Linear(config.hidden_size, self.all_head_size)
        self.value = Linear(config.hidden_size, self.all_head_size)

        self.out = Linear(config.hidden_size, config.hidden_size)
        self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
        self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])

        self.softmax = Softmax(dim=-1)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states):
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        attention_probs = self.softmax(attention_scores)
        weights = attention_probs if self.vis else None
        attention_probs = self.attn_dropout(attention_probs)

        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        attention_output = self.out(context_layer)
        attention_output = self.proj_dropout(attention_output)
        return attention_output, weights


class Mlp(nn.Module):
    def __init__(self, config):
        super(Mlp, self).__init__()
        self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"])
        self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size)
        self.act_fn = ACT2FN["gelu"]
        self.dropout = Dropout(config.transformer["dropout_rate"])

        self._init_weights()

    def _init_weights(self):
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.normal_(self.fc1.bias, std=1e-6)
        nn.init.normal_(self.fc2.bias, std=1e-6)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act_fn(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x


class Embeddings(nn.Module):
    """Construct the embeddings from patch, position embeddings.
    """
    def __init__(self, config, img_size, in_channels=3):
        super(Embeddings, self).__init__()
        self.hybrid = None
        img_size = _pair(img_size)

        if config.patches.get("grid") is not None:
            grid_size = config.patches["grid"]
            patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])
            n_patches = (img_size[0] // 16) * (img_size[1] // 16)
            self.hybrid = True
        else:
            patch_size = _pair(config.patches["size"])
            n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
            self.hybrid = False

        if self.hybrid:
            self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers,
                                         width_factor=config.resnet.width_factor, in_channels = in_channels)
            in_channels = self.hybrid_model.width * 16
        self.patch_embeddings = Conv2d(in_channels=in_channels,
                                       out_channels=config.hidden_size,
                                       kernel_size=patch_size,
                                       stride=patch_size)
        self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches+1, config.hidden_size))
        self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))

        self.dropout = Dropout(config.transformer["dropout_rate"])

    def forward(self, x):
        B = x.shape[0]
        cls_tokens = self.cls_token.expand(B, -1, -1)

        if self.hybrid:
            x = self.hybrid_model(x)
        x = self.patch_embeddings(x)
        x = x.flatten(2)
        x = x.transpose(-1, -2)
        x = torch.cat((cls_tokens, x), dim=1)
        embeddings = x + self.position_embeddings
        embeddings = self.dropout(embeddings)
        return embeddings
    

class Embeddings3D(nn.Module):
    """Construct the embeddings from patch, position embeddings.
    """
    def __init__(self, config, img_size, in_channels=3, depth_kernel_size = 5, out_depth = 3):
        super(Embeddings3D, self).__init__()
        self.hybrid = None
        img_size = _pair(img_size)
        self.depth_kernel_size = depth_kernel_size

        if config.patches.get("grid") is not None:
            grid_size = config.patches["grid"]
            patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])
            n_patches = (img_size[0] // 16) * (img_size[1] // 16)
            self.hybrid = True
        else:
            patch_size = _pair(config.patches["size"])
            n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])* out_depth
            self.hybrid = False

        self.patch_embeddings = Conv3d(in_channels=in_channels,
                                       out_channels=config.hidden_size,
                                       kernel_size=(depth_kernel_size,patch_size[0], patch_size[1]),
                                       stride=(depth_kernel_size,patch_size[0], patch_size[1]))
        self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches+1, config.hidden_size))
        self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))

        self.dropout = Dropout(config.transformer["dropout_rate"])
        
        self.adapt_pool = AdaptiveAvgPool3d((out_depth, None, None))

    def forward(self, x):
        B = x.shape[0]
        cls_tokens = self.cls_token.expand(B, -1, -1)
        #x  is [B, C, H, W, D]
        x = x.permute(0,1,4,2,3) #x = [B, C, D, H, W] for Conv3D
 
        #replicate the last slice to obtain a number of slice divisible by self.depth_kernel_size 
        num_rep = int(math.ceil(x.shape[2]/self.depth_kernel_size)*self.depth_kernel_size)
        for r in range(x.shape[2], num_rep):
            x = torch.cat((x, torch.zeros([x.shape[0], x.shape[1], 1, x.shape[3], x.shape[4]]).to(x.device)), 2)
            x = x.float()
        x = self.patch_embeddings(x)
        #apply adaptive average pooling to have the same depth for each CT
        x = self.adapt_pool(x) # [B, C, D, H, W]
        #permute to [B, C, H, W, D]
        x = x.permute(0,1,3,4,2) #x = [B, C, H, W, D]
        x = x.flatten(2)
        x = x.transpose(-1, -2)
        x = torch.cat((cls_tokens, x), dim=1)
        embeddings = x + self.position_embeddings
        embeddings = self.dropout(embeddings)
        return embeddings


class Block(nn.Module):
    def __init__(self, config, vis):
        super(Block, self).__init__()
        self.hidden_size = config.hidden_size
        self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
        self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
        self.attn = Attention(config, vis)
        self.ffn = Mlp(config)

    def forward(self, x):
        h = x
        x = self.attention_norm(x)
        x, weights = self.attn(x)
        x = x + h

        h = x
        x = self.ffn_norm(x)
        x = self.ffn(x)
        x = x + h
        return x, weights

    def load_from(self, weights, n_block):
        ROOT = f"Transformer/encoderblock_{n_block}"
        with torch.no_grad():
            #query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t()
            query_weight = np2th(weights[ROOT +'/'+ ATTENTION_Q + '/' +"kernel"]).view(self.hidden_size, self.hidden_size).t()
            #key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t()
            key_weight = np2th(weights[ROOT + '/' + ATTENTION_K+ '/' +"kernel"]).view(self.hidden_size, self.hidden_size).t()
            #value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t()
            value_weight = np2th(weights[ROOT + '/' + ATTENTION_V + '/' +"kernel"]).view(self.hidden_size, self.hidden_size).t()
            #out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t()
            out_weight = np2th(weights[ROOT+ '/' + ATTENTION_OUT+ '/' +"kernel"]).view(self.hidden_size, self.hidden_size).t()

            #query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1)
            query_bias = np2th(weights[ROOT + '/' + ATTENTION_Q + '/' + "bias"]).view(-1)
            #key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1)
            key_bias = np2th(weights[ROOT+'/'+ ATTENTION_K +'/'+ "bias"]).view(-1)
            #value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1)
            value_bias = np2th(weights[ROOT +'/'+ ATTENTION_V +'/'+ "bias"]).view(-1)
            #out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1)
            out_bias = np2th(weights[ROOT +'/'+ ATTENTION_OUT +'/'+ "bias"]).view(-1)

            self.attn.query.weight.copy_(query_weight)
            self.attn.key.weight.copy_(key_weight)
            self.attn.value.weight.copy_(value_weight)
            self.attn.out.weight.copy_(out_weight)
            self.attn.query.bias.copy_(query_bias)
            self.attn.key.bias.copy_(key_bias)
            self.attn.value.bias.copy_(value_bias)
            self.attn.out.bias.copy_(out_bias)

            #mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t()
            mlp_weight_0 = np2th(weights[ROOT+'/'+ FC_0 +'/'+ "kernel"]).t()
            #mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t()
            mlp_weight_1 = np2th(weights[ROOT+'/'+ FC_1+'/' + "kernel"]).t()
            #mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t()
            mlp_bias_0 = np2th(weights[ROOT +'/'+ FC_0 +'/'+ "bias"]).t()
            #mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t()
            mlp_bias_1 = np2th(weights[ROOT +'/'+ FC_1 +'/'+ "bias"]).t()

            self.ffn.fc1.weight.copy_(mlp_weight_0)
            self.ffn.fc2.weight.copy_(mlp_weight_1)
            self.ffn.fc1.bias.copy_(mlp_bias_0)
            self.ffn.fc2.bias.copy_(mlp_bias_1)

            #self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")]))
            self.attention_norm.weight.copy_(np2th(weights[ROOT +'/'+ ATTENTION_NORM +'/'+ "scale"]))
            #self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")]))
            self.attention_norm.bias.copy_(np2th(weights[ROOT+'/'+ ATTENTION_NORM+'/'+ "bias"]))
            #self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")]))
            self.ffn_norm.weight.copy_(np2th(weights[ROOT+'/'+ MLP_NORM+'/'+ "scale"]))
            #self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")]))
            self.ffn_norm.bias.copy_(np2th(weights[ROOT+'/'+ MLP_NORM+'/'+ "bias"]))


class Encoder(nn.Module):
    def __init__(self, config, vis):
        super(Encoder, self).__init__()
        self.vis = vis
        self.layer = nn.ModuleList()
        self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)
        for _ in range(config.transformer["num_layers"]):
            layer = Block(config, vis)
            self.layer.append(copy.deepcopy(layer))

    def forward(self, hidden_states):
        attn_weights = []
        for layer_block in self.layer:
            hidden_states, weights = layer_block(hidden_states)
            if self.vis:
                attn_weights.append(weights)
        encoded = self.encoder_norm(hidden_states)
        return encoded, attn_weights


class Transformer(nn.Module):                                                                                       # Embedding + encoder
    def __init__(self, config, img_size, in_channels, vis, embeddings_type, depth_kernel_size, out_depth):
        super(Transformer, self).__init__()
        if embeddings_type == '2D':
            self.embeddings = Embeddings(config, img_size=img_size, in_channels=in_channels)
        elif embeddings_type == '3D':
            self.embeddings = Embeddings3D(config, img_size=img_size, in_channels=in_channels, depth_kernel_size = depth_kernel_size,  out_depth = out_depth)
        self.encoder = Encoder(config, vis)

    def forward(self, input_ids):
        embedding_output = self.embeddings(input_ids)
        encoded, attn_weights = self.encoder(embedding_output)
        return encoded, attn_weights


class VisionTransformer(nn.Module): # Transformer + loss function
    def __init__(self, config, img_size=224, in_channels=3, num_classes=21843, loss_weights=None, zero_head=False, vis=False, embeddings_type = '2D', depth_kernel_size = 5, out_depth = 3, loss_type = 'CrossEntropy'):
        super(VisionTransformer, self).__init__()
        self.num_classes = num_classes
        self.zero_head = zero_head
        self.loss_type = loss_type
        self.classifier = config.classifier
        self.in_channels = in_channels
        self.transformer = Transformer(config, img_size, in_channels, vis, embeddings_type, depth_kernel_size, out_depth)
        
        self.sigmoid = Sigmoid()
        
        self.head = Linear(config.hidden_size, self.num_classes)                                                       # Classification layer
        
           
        if 'grid' not in config.patches.keys():
            self.patch_size = config.patches.size
        else:
            self.patch_size = None
        self.loss_weights = loss_weights # useful when you have an unbalanced training set

    def forward(self, x, labels=None):
        x, attn_weights = self.transformer(x)
        logits = self.head(x[:, 0])
        
        if labels is not None:
            if self.loss_type == 'CrossEntropy':
                loss_fct = CrossEntropyLoss(weight = self.loss_weights.to(x.device) if torch.is_tensor(self.loss_weights) else None)
            elif self.loss_type== 'MSE':
                loss_fct = MSELoss()
            elif self.loss_type == 'MultiLabelSoftMarginLoss':
                preds = self.sigmoid(logits)
                loss_fct = MultiLabelSoftMarginLoss()
            elif self.loss_type == 'BCELoss':
                preds = self.sigmoid(logits)
                loss_fct = BCELoss()
            elif self.loss_type == 'BCEWithLogits':
                loss_fct = BCEWithLogitsLoss()
            else:
                raise NameError('ATTENTION! Loss type not managed')
                
            if self.loss_type in ['MultiLabelSoftMarginLoss']:
                loss = loss_fct(preds, labels)
            elif self.loss_type in ['MSE', 'BCEWithLogits']:
                loss = loss_fct(logits.float(), labels.float())
            elif self.loss_type == 'BCELoss':
                loss = loss_fct(preds, labels.float())
            else:
                loss = loss_fct(logits.view(-1, self.num_classes), labels)
            return logits, loss
        
        return logits, attn_weights

    def load_from(self, weights):
        with torch.no_grad():
            if self.zero_head:
                nn.init.zeros_(self.head.bias)
                #else:
                    #nn.init.zeros_(self.head[0].weight)
                    #nn.init.zeros_(self.head[0].bias)
                    #nn.init.zeros_(self.head[2].bias)
                    #nn.init.zeros_(self.head[2].weight)
                  
            else:
                self.head.weight.copy_(np2th(weights["head/kernel"]).t())
                self.head.bias.copy_(np2th(weights["head/bias"]).t())
                
            new_weights = np2th(weights["embedding/kernel"], conv=True)
            if self.in_channels == 1 :
                new_weights = new_weights.mean(dim = 1, keepdim = True)
            elif self.in_channels == 2:
                new_weights = new_weights.mean(dim = 1, keepdim = True)
                new_weights = torch.cat([new_weights, new_weights], dim = 1)
                
            if self.patch_size is not None and self.patch_size != (16,16):
                scale_factor = self.patch_size[0] / 16
                new_weights = torch.nn.functional.interpolate(new_weights,scale_factor = scale_factor, mode = 'bilinear')

            self.transformer.embeddings.patch_embeddings.weight.copy_(new_weights)
            
            self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"]))
            self.transformer.embeddings.cls_token.copy_(np2th(weights["cls"]))
            self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"]))
            self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"]))

            posemb = np2th(weights["Transformer/posembed_input/pos_embedding"])
            posemb_new = self.transformer.embeddings.position_embeddings
            if posemb.size() == posemb_new.size():
                self.transformer.embeddings.position_embeddings.copy_(posemb)
            else:
                logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size()))
                ntok_new = posemb_new.size(1)

                if self.classifier == "token":
                    posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:]
                    ntok_new -= 1
                else:
                    posemb_tok, posemb_grid = posemb[:, :0], posemb[0]

                gs_old = int(np.sqrt(len(posemb_grid)))
                gs_new = int(np.sqrt(ntok_new))
                print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new))
                posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)

                zoom = (gs_new / gs_old, gs_new / gs_old, 1)
                posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1)
                posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1)
                posemb = np.concatenate([posemb_tok, posemb_grid], axis=1)
                self.transformer.embeddings.position_embeddings.copy_(np2th(posemb))

            for bname, block in self.transformer.encoder.named_children():
                for uname, unit in block.named_children():
                    unit.load_from(weights, n_block=uname)

            if self.transformer.embeddings.hybrid:
                self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(weights["conv_root/kernel"], conv=True))
                gn_weight = np2th(weights["gn_root/scale"]).view(-1)
                gn_bias = np2th(weights["gn_root/bias"]).view(-1)
                self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight)
                self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias)

                for bname, block in self.transformer.embeddings.hybrid_model.body.named_children():
                    for uname, unit in block.named_children():
                        unit.load_from(weights, n_block=bname, n_unit=uname)
                        

class ParallelVisionTransformer(nn.Module): # Parallel Transformer + loss function
    def __init__(self, config, img_size=224, in_channels=3, num_classes=21843, loss_weights=None, zero_head=False, vis=False, embeddings_type = '2D', depth_kernel_size = 5, out_depth = 3, loss_type = 'CrossEntropy'):
        super(ParallelVisionTransformer, self).__init__()
        self.num_classes = num_classes
        self.zero_head = zero_head
        self.loss_type = loss_type
        self.classifier = config.classifier
        self.in_channels = in_channels
        
        assert img_size%16==0, 'Image size must be divisible by 16'
        patch_size = int(img_size/2)
        config1 = copy.deepcopy(config)
        config1.patches.size =(patch_size, patch_size)
        patch_size = int(patch_size/2)
        config2 = copy.deepcopy(config)
        config2.patches.size =(patch_size, patch_size)
        patch_size = int(patch_size/2)
        config3 = copy.deepcopy(config)
        config3.patches.size =(patch_size, patch_size)
        patch_size = int(patch_size/2)
        config4 = copy.deepcopy(config)
        config4.patches.size =(patch_size, patch_size)
        
        self.transformer1 = Transformer(config1, img_size, in_channels, vis, embeddings_type, depth_kernel_size, out_depth)
        self.transformer2 = Transformer(config2, img_size, in_channels, vis, embeddings_type, depth_kernel_size, out_depth)
        self.transformer3 = Transformer(config3, img_size, in_channels, vis, embeddings_type, depth_kernel_size, out_depth)
        self.transformer4 = Transformer(config4, img_size, in_channels, vis, embeddings_type, depth_kernel_size, out_depth)
        
        self.sigmoid = Sigmoid()
        
        self.last_conv = Conv1d(4, 1, kernel_size = 1)
        
        self.head = Linear(config.hidden_size, self.num_classes)
        
           
        if 'grid' not in config.patches.keys():
            self.patch_size = config.patches.size
        else:
            self.patch_size = None
        self.loss_weights = loss_weights

    def forward(self, x, labels=None):
        x1, attn_weights1 = self.transformer1(x)
        x2, attn_weights2 = self.transformer2(x)
        x3, attn_weights3 = self.transformer3(x)
        x4, attn_weights4 = self.transformer4(x)
        
        #print('x1 ', x1.shape)
        #print('x2 ', x2.shape)
        #print('x3 ', x3.shape)
        #print('x4 ', x4.shape)
        x = torch.stack([x1[:,0,:],x2[:,0,:],x3[:,0,:],x4[:,0,:]], dim = 1)
        #print('stack ', x.shape)
        x = self.last_conv(x)
        #print('conv ', x.shape)
        
        logits = self.head(x[:, 0])
        #print('logits ', logits.shape)
        
        if labels is not None:
            if self.loss_type == 'CrossEntropy':
                loss_fct = CrossEntropyLoss(weight = self.loss_weights.to(x.device) if torch.is_tensor(self.loss_weights) else None)
            elif self.loss_type== 'MSE':
                loss_fct = MSELoss()
            elif self.loss_type == 'MultiLabelSoftMarginLoss':
                preds = self.sigmoid(logits)
                loss_fct = MultiLabelSoftMarginLoss()
            elif self.loss_type == 'BCELoss':
                preds = self.sigmoid(logits)
                loss_fct = BCELoss()
            else:
                raise NameError('ATTENTION! Loss type not managed')
                
            if self.loss_type in ['MultiLabelSoftMarginLoss']:
                loss = loss_fct(preds, labels)
            elif self.loss_type in ['MSE']:
                loss = loss_fct(logits.float(), labels.float())
            elif self.loss_type == 'BCELoss':
                loss = loss_fct(preds, labels.float())
            else:
                loss = loss_fct(logits.view(-1, self.num_classes), labels)
            return logits, loss
        
        return logits, {'attn_weights1': attn_weights1, 'attn_weights2': attn_weights2, 'attn_weights3': attn_weights3, 'attn_weights4': attn_weights4}
    
    
    def load_from(self, weights):
        with torch.no_grad():
            if self.zero_head:
                nn.init.zeros_(self.head.bias)
                #else:
                    #nn.init.zeros_(self.head[0].weight)
                    #nn.init.zeros_(self.head[0].bias)
                    #nn.init.zeros_(self.head[2].bias)
                    #nn.init.zeros_(self.head[2].weight)
                  
            else:
                self.head.weight.copy_(np2th(weights["head/kernel"]).t())
                self.head.bias.copy_(np2th(weights["head/bias"]).t())
                
            new_weights = np2th(weights["embedding/kernel"], conv=True)
            if self.in_channels == 1 :
                new_weights = new_weights.mean(dim = 1, keepdim = True)
            elif self.in_channels == 2:
                new_weights = new_weights.mean(dim = 1, keepdim = True)
                new_weights = torch.cat([new_weights, new_weights], dim = 1)
                
            if self.patch_size is not None and self.patch_size != (16,16):
                scale_factor = self.patch_size[0] / 16
                new_weights = torch.nn.functional.interpolate(new_weights,scale_factor = scale_factor, mode = 'bilinear')

            self.transformer.embeddings.patch_embeddings.weight.copy_(new_weights)
            
            self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"]))
            self.transformer.embeddings.cls_token.copy_(np2th(weights["cls"]))
            self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"]))
            self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"]))

            posemb = np2th(weights["Transformer/posembed_input/pos_embedding"])
            posemb_new = self.transformer.embeddings.position_embeddings
            if posemb.size() == posemb_new.size():
                self.transformer.embeddings.position_embeddings.copy_(posemb)
            else:
                logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size()))
                ntok_new = posemb_new.size(1)

                if self.classifier == "token":
                    posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:]
                    ntok_new -= 1
                else:
                    posemb_tok, posemb_grid = posemb[:, :0], posemb[0]

                gs_old = int(np.sqrt(len(posemb_grid)))
                gs_new = int(np.sqrt(ntok_new))
                print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new))
                posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)

                zoom = (gs_new / gs_old, gs_new / gs_old, 1)
                posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1)
                posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1)
                posemb = np.concatenate([posemb_tok, posemb_grid], axis=1)
                self.transformer.embeddings.position_embeddings.copy_(np2th(posemb))

            for bname, block in self.transformer.encoder.named_children():
                for uname, unit in block.named_children():
                    unit.load_from(weights, n_block=uname)

            if self.transformer.embeddings.hybrid:
                self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(weights["conv_root/kernel"], conv=True))
                gn_weight = np2th(weights["gn_root/scale"]).view(-1)
                gn_bias = np2th(weights["gn_root/bias"]).view(-1)
                self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight)
                self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias)

                for bname, block in self.transformer.embeddings.hybrid_model.body.named_children():
                    for uname, unit in block.named_children():
                        unit.load_from(weights, n_block=bname, n_unit=uname)

def compute_att_map(attn_weights, device, imgs_size):
    att_mat = torch.stack(attn_weights, dim = 1) # att_mat --> [B, num_layers, num_heads, num_patches+1, num_patches+1]
    # Averages the attention weights across all heads
    att_mat = torch.mean(att_mat, dim = 2) # att_mat --> [B, num_layers, num_patches+1, num_patches+1]
    # To account for residual connections, we add an identity matrix to the
    # attention matrix and re-normalize the weights.
    residual_att = torch.eye(att_mat.size(2)).to(device) # Creating the identity matrix
    aug_att_mat = att_mat + residual_att # sum
    aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1) # [B, num_layers, num_patches+1, num_patches+1] # normalization
    
    # Recursively multiply the weight matrices to compute the attention rollout
    joint_attentions = torch.zeros(aug_att_mat.size()).to(device)     # [B, num_layers, num_patches+1, num_patches+1] # empty matrix
    for i in range(aug_att_mat.size(0)):
        joint_attentions[i][0] = aug_att_mat[i][0]
        
        for n in range(1, aug_att_mat.size(1)):
            joint_attentions[i][n] = torch.matmul(aug_att_mat[i][n], joint_attentions[i][n-1])
            
    # Attention of the last transformer
    v = joint_attentions[:,-1] # [B, num_patches+1, num_patches+1]
    grid_size = int(np.sqrt(aug_att_mat.size(-1)))
    masks = v[:,0, 1:].reshape(-1, grid_size, grid_size).detach()
    masks = (masks/masks.amax(dim = (1,2), keepdim = True))
    
    masks = Resize((imgs_size(1), imgs_size(2)))(masks)
    
    return masks


class CNNClassifier(nn.Module):
    def __init__(self, img_size=224, in_channels=3, num_classes=21843, loss_weights=None, loss_type = 'CrossEntropy', model_type = 'ResNet18', pretrained = False, use_clinical_data = False, fusion_strategy = 'learned_features', combination_type = 'concat', clin_feats = '10'):
        super(CNNClassifier, self).__init__()
        self.num_classes = num_classes
        self.loss_type = loss_type
        self.loss_weights = loss_weights
        self.in_channels = in_channels
        self.use_clinical_data = use_clinical_data
        self.fusion_strategy = fusion_strategy
        self.combination_type = combination_type
        if self.loss_type == 'CrossEntropy':
            self.loss_fct = CrossEntropyLoss(weight = self.loss_weights if torch.is_tensor(self.loss_weights) else None)
            
        
        if model_type == 'ResNet18':
            self.classifier = torchvision.models.resnet18(pretrained)
        elif model_type == 'ResNet50':
            self.classifier = torchvision.models.resnet50(pretrained)
        elif model_type == 'ResNet101':
            self.classifier = torchvision.models.resnet101(pretrained)
        elif model_type == 'AlexNet':
            self.classifier = torchvision.models.alexnet(pretrained)
        elif model_type == 'DenseNet121':
            self.classifier = torchvision.models.densenet121(pretrained)
        elif model_type == 'Vgg16':
            self.classifier = torchvision.models.vgg16(pretrained)
        elif model_type == 'MobileNet_v2':
            self.classifier = torchvision.models.mobilenet_v2(pretrained)
        elif model_type == 'EfficientNet_b6':
            self.classifier = torchvision.models.efficientnet_b6(pretrained)
        elif model_type == 'EfficientNet_b5':
            self.classifier = torchvision.models.efficientnet_b5(pretrained)
        elif 'CoAtNet' in model_type:
            if '0' in model_type:
                num_blocks = [2, 2, 3, 5, 2]            # L
                channels = [64, 96, 192, 384, 768]      # D
                block_types = ['C', 'C', 'T', 'T']
            elif '1' in model_type:
                num_blocks = [2, 2, 6, 14, 2]
                channels = [64, 96, 192, 384, 768]
                block_types = ['C', 'C', 'T', 'T']
            elif '2' in model_type:
                num_blocks = [2, 2, 6, 14, 2]
                channels = [128, 128, 256, 512, 1026]
                block_types = ['C', 'C', 'T', 'T']
            elif '3' in model_type:
                num_blocks = [2, 2, 6, 14, 2]
                channels = [192, 192, 384, 768, 1536]
                block_types = ['C', 'C', 'T', 'T']
            elif '4' in model_type:
                num_blocks = [2, 2, 12, 28, 2]
                channels = [192, 192, 384, 768, 1536]
                block_types = ['C', 'C', 'T', 'T']
            elif '5' in model_type: #changing DIM head from 32 to 64
                num_blocks = [2, 2, 12, 28, 2]
                channels = [192, 256, 512, 1280, 2048]
                block_types = ['C', 'C', 'T', 'T']
            self.classifier = CoAtNet((img_size, img_size), in_channels, num_blocks, channels, num_classes, block_types = block_types)
            
        #change the first conv layer according to the model_type
        if in_channels != 3:
            if 'ResNet' in model_type:
                self.classifier.conv1 = nn.Conv2d(in_channels = in_channels, out_channels = self.classifier.conv1.out_channels, kernel_size = self.classifier.conv1.kernel_size, stride = self.classifier.conv1.stride, padding = self.classifier.conv1.padding, bias = True if not self.classifier.conv1.bias is None else False)
            elif model_type in ['AlexNet', 'Vgg16']:
                self.classifier.features[0]= nn.Conv2d(in_channels = in_channels, out_channels = self.classifier.features[0].out_channels, kernel_size = self.classifier.features[0].kernel_size, stride = self.classifier.features[0].stride, padding = self.classifier.features[0].padding, bias = True if not self.classifier.features[0].bias is None else False)
            elif 'DenseNet' in model_type:
                self.classifier.features.conv0= nn.Conv2d(in_channels = in_channels, out_channels = self.classifier.features.conv0.out_channels, kernel_size = self.classifier.features.conv0.kernel_size, stride = self.classifier.features.conv0.stride, padding = self.classifier.features.conv0.padding, bias = True if not self.classifier.features.conv0.bias is None else False)
            elif 'MobileNet' in model_type or 'EfficientNet' in model_type:
                self.classifier.features[0][0]= nn.Conv2d(in_channels = in_channels, out_channels = self.classifier.features[0][0].out_channels, kernel_size = self.classifier.features[0][0].kernel_size, stride = self.classifier.features[0][0].stride, padding = self.classifier.features[0][0].padding, bias = True if not self.classifier.features[0][0].bias is None else False)
            elif 'CoAtNet' in model_type:
                self.classifier.s0[0][0] = nn.Conv2d(in_channels = in_channels, out_channels = self.classifier.s0[0][0].out_channels, kernel_size = self.classifier.s0[0][0].kernel_size, stride = self.classifier.s0[0][0].stride, padding = self.classifier.s0[0][0].padding, bias = True if not self.classifier.s0[0][0].bias is None else False)
                
        
        fc = None
        #define the last fully connected layer and define new layers if clinical data are used  
        if self.use_clinical_data:
            fc = torch.nn.Sequential(
                torch.nn.Linear(2048, 512),
                torch.nn.ReLU(inplace=True)
            )                
            
            if self.fusion_strategy == 'learned_features':
                self.clinical_feats =  torch.nn.Sequential(
                    nn.Linear(clin_feats, 1024),
                    nn.ReLU(inplace=True),
                    nn.Linear(1024, 512),
                    nn.ReLU(inplace=True),
                )
            
            self.final_classifier = torch.nn.Sequential(
                torch.nn.Dropout(p=0.25),
                torch.nn.Linear(512*2 if self.mode == "concat" else 512, 512),
                torch.nn.ReLU(inplace=True),
                torch.nn.Dropout(p=0.25),
                torch.nn.Linear(512, 1)
            )
        #change the last selected fc layer according to the model_type
        if 'ResNet' in model_type:
            if fc is None:
                fc = nn.Linear(in_features = self.classifier.fc.in_features, out_features=num_classes)
            if self.classifier.fc.bias is None:
                fc.bias = None
            self.classifier.fc = fc
        elif model_type in ['AlexNet', 'Vgg16']:
            if fc is None:
                fc = nn.Linear(in_features = self.classifier.classifier[6].in_features, out_features=num_classes)
            if self.classifier.classifier[6].bias is None:
                fc.bias = None
            self.classifier.classifier[6]= fc
        elif 'DenseNet' in model_type:
            if fc is None:
                fc = nn.Linear(in_features = self.classifier.classifier.in_features, out_features=num_classes)
            if self.classifier.classifier.bias is None:
                fc.bias = None
            self.classifier.classifier= fc
        elif 'MobileNet' in model_type or 'EfficientNet' in model_type:
            if fc is None:
                fc = nn.Linear(in_features = self.classifier.classifier[1].in_features, out_features=num_classes)
            if self.classifier.classifier[1].bias is None:
                fc.bias = None
            self.classifier.classifier[1] = fc
        elif 'CoAtNet' in model_type:
            fc = nn.Linear(in_features = self.classifier.fc.in_features, out_features=num_classes)
            if self.classifier.fc.bias is None:
                fc.bias = None
            self.classifier.fc = fc
        
    def forward(self, x, labels = None, clinical_feats = None):
        logits = self.classifier(x)
        
        if self.use_clinical_data:
            if self.fusion_strategy == 'learned_features':
                cl = self.clinical_feats(clinical_feats)
            elif self.fusion_strategy == 'features':
                cl = clinical_feats
            
            if self.combination_type == 'concat':
                x = torch.cat((logits,cl), dim = -1)
            elif self.combination_type == 'sum':
                x = logits + cl
            elif self.combination_type == 'mul':
                x = logits * cl
            
            logits = self.final_classifier(x)
        
        if labels is not None:
            '''
            if self.loss_type == 'CrossEntropy':
                self.loss_fct = CrossEntropyLoss(weight = self.loss_weights.to(x.device) if torch.is_tensor(self.loss_weights) else None)
                '''                
            loss = self.loss_fct(logits.view(-1, self.num_classes), labels)
            return logits, loss
        
        return logits, None
        

CONFIGS = {
    'ViT-B_16': configs.get_b16_config(),
    'ViT-B_32': configs.get_b32_config(),
    'ViT-L_16': configs.get_l16_config(),
    'ViT-L_32': configs.get_l32_config(),
    'ViT-H_14': configs.get_h14_config(),
    'R50-ViT-B_16': configs.get_r50_b16_config(),
    'R50-ViT-MRI' : configs.get_r50_MRI_config(),
    'ViT-MRI' : configs.get_MRI_config(),
    'ViT-1Lay' : configs.get_MinConfig(),
    'testing': configs.get_testing(),
    'ViT-half':configs.get_halfHidden(),
    'ViT-half2':configs.get_halfHiddenHeads(),
    'ViT-h12l2':configs.get_h12l2Config(),
    'ViT-h8l2': configs.get_h8l2Config(),
    'ViT-h8l2hid384' : configs.get_h8l2hid384Config(),
    'ViT-h4l2hid384' : configs.get_h4l2hid384Config(),
    'ViT-parallel_h12l2': configs.get_ParallelConfig_h12l2(),
    'ViT-parallel': configs.get_ParallelViT_config(),}


In [None]:
import sys
import os
import random
import json
import math
import platform
from monai.data import CacheDataset

class NGSLungDataset(CacheDataset):
    def __init__(self, root_dir, split_path, section, num_fold, transforms, seed = 100, cache_num = sys.maxsize, cache_rate=1.0, num_workers=0, execute_test = True):    
        #if execute test is False, training and test split are used both for traning. 
  
        if not os.path.isdir(root_dir):
            raise ValueError("Root directory root_dir must be a directory.")
        self.section = section
        self.text_labels = ['negative', 'positive']
        #self.transforms = transforms
        self.num_fold = num_fold
        self.seed = seed
        self.execute_test = execute_test
        
        data = self._generate_data_list(split_path)
        super().__init__(data, transforms, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers)
        
     
    #split data in train, val and test sets in a reproducible way
    def _generate_data_list(self, split_path):
        with open(split_path) as fp:
           path=json.load(fp)
        data = list()

        
        if self.section == 'test':
            data = path[f'fold{self.num_fold}']['test'] if self.execute_test else []
        elif self.section == 'training':
            data = path[f'fold{self.num_fold}']['train']
            if not self.execute_test:
                data = data + path[f'fold{self.num_fold}']['test']
        elif self.section == 'validation':
            data = path[f'fold{self.num_fold}']['val']
        else: 
            raise ValueError(
                    f"Unsupported section: {self.section}, "
                    "available options are ['training', 'validation', 'test']."
                )
        
        if platform.system() != 'Windows':
            for sample in data:
                for key in sample.keys():
                    if isinstance(sample[key], str):
                        sample[key] = sample[key].replace('\\', '/')
        return data     
    '''
    def get_label_proportions(self):
        c = [None]*2
        label_props = [None]*2
        for i in range(2):
            c[i] = len([el['label'] for el in self.data if el['label'] == i])
        for i in range(len(c)):
            label_props[i] = max(c)/c[i]
        return label_props
    '''

# Trying to solve the ITK reader problem

In [None]:
try:
  !git clone https://github.com/Project-MONAI/MONAI.git
  %cd MONAI
  !pip install -e '.[itk]'
except:
  pass
!pip install itk
#!pip install itk==5.3rc4 # trying a different version
!pip install 'monai[itk]'

os.chdir("/content/drive/MyDrive/ViT_Lung_Cancer-main")

!python -c "import monai" || pip install -q "monai-weekly[itk, pillow]"
!pip install -q "SimpleITK"
!pip install 'monai[itk]'

import os
import shutil
import numpy as np
from PIL import Image
import tempfile
from monai.data import ITKReader, PILReader
from monai.transforms import (
    LoadImage, LoadImaged, EnsureChannelFirstd,
    Resized, EnsureTyped, Compose)
from monai.config import print_config
import SimpleITK as sitk
#!pwd

fatal: destination path 'MONAI' already exists and is not an empty directory.
/content/drive/MyDrive/ViT_Lung_Cancer-main/MONAI
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Obtaining file:///content/drive/MyDrive/ViT_Lung_Cancer-main/MONAI
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Collecting itk>=5.2
  Downloading itk-5.2.1.post1-cp37-cp37m-manylinux2014_x86_64.whl (8.3 kB)
Collecting itk-registration==5.2.1.post1
  Downloading itk_registration-5.2.1.post1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (20.3 MB)
[K     |████████████████████████████████| 20.3 MB 1.8 MB/s 
[?25hCollecting itk-numerics==5.2.1.post1
  Downloading itk_numerics-5.2.1.post1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (54.5 MB)
[K     |████████████████████████████████| 54.5 MB 1.2 MB/s 
[?25hCollecting itk-segm

[K     |████████████████████████████████| 48.4 MB 21 kB/s 
[?25hLooking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


# get_SpecificLoader

In [None]:
import logging
import os
import json
import torch
import math

from monai.data.image_reader import ITKReader

# from dataset.NGSLungDatasetCV import NGSLungDataset as DS

from torch.utils.data import DataLoader

from utils.transforms import (CorrectSpacing, 
                              PatchedImage, 
                              ResizeWithRatioD,
                              CenterPatchedImage,
                              ListPatchedImage,
                              ResizeWithRatioDVariableDim, 
                              PatchedImageAllSlice,
                              TensorPad, ConvertListToTensor, 
                              MyCropForegroundd, 
                              ExpandDims, 
                              ConcatChDim, 
                              AsDepthFirstD, 
                              DeleteKeys,
                              PrepareClinicalData, 
                              DeleteNotUsableClinicalData)
from monai.transforms import (DivisiblePadD,  
    LoadImageD, 
    NormalizeIntensityD,
    AddChannelD,
    Compose,
    RandFlipD,
    RandRotate90D,
    ToTensorD,
    ScaleIntensityD,
    SpacingD,
    OrientationD,
    RandCropByPosNegLabelD,
    RandSpatialCropD,
    CropForegroundD, 
    ScaleIntensityRanged, 
    IdentityD)

logger = logging.getLogger(__name__)


def get_loss_weights(split_path, label_key):
    path_data = None
    if '114sample' in split_path:
        path_data = os.path.join('data','ngslung_pathAllGene114.json')
    elif '131samples' in split_path:
        path_data = os.path.join('data','ngslung_pathAllGene131_WithCombo_MultiLabel.json')
    
    if path_data is not None:
        with open(path_data) as fp:
            d = json.load(fp)
            lbl = [el[label_key] for el in d]
            num = len(lbl)
            pos = sum(lbl)
            neg = num - pos
            p_pos = num/pos
            p_neg = num/neg
            weights = [p_neg/max(p_pos, p_neg), p_pos/max(p_pos, p_neg)]
        return torch.tensor(weights)
    return None

def get_SpecificLoader(dataset, label_key, img_size, num_patches, section, train_batch_size, root_dir, split_path, num_fold, execute_test, inner_loop_idx = None, eval_stride=1, window = 'parenchyma', k_divisible=74, padding=True):
    KEYS = ('image','mask', label_key)
    
    patch_per_side = int(math.sqrt(num_patches))
    spatial_size = (-1, int(img_size/patch_per_side), int(img_size/patch_per_side))
    
    if dataset in ['NGSLUNG']:
        train_transforms = Compose([
            LoadImageD(keys = KEYS[:-1], reader = 'ITKReader'),
            #TransposeITKD(keys = KEYS[:-1]),
            AddChannelD(keys = KEYS[:-1]),
            CorrectSpacing(KEYS[:-1]),
            SpacingD(keys = KEYS[:-1], pixdim=(1., 1., 1.), mode = ("bilinear")),
            NormalizeIntensityD(keys = KEYS[:-1]),
            ScaleIntensityD(keys = KEYS[:-1]),
            OrientationD(keys = KEYS[:-1], axcodes = 'RAS'),
            ResizeWithRatioD(keys = KEYS[:-1], image_size=spatial_size[1]),
            DivisiblePadD(KEYS[0], k = spatial_size, mode = 'constant'),
            RandCropByPosNegLabelD(keys = KEYS[:-1], label_key=KEYS[1], spatial_size = (spatial_size[1],spatial_size[2],num_patches), pos = 1, neg = 0),
            RandFlipD(keys = KEYS[:-1], prob = 0.5, spatial_axis=0),
            RandRotate90D(keys = KEYS[:-1], prob=0.5, spatial_axes=(0,1)),    
            ToTensorD(keys = KEYS[:-1]),
            PatchedImage(KEYS[:-1], num_patches = num_patches)
            ])
        
        val_transforms = Compose([
            LoadImageD(keys = KEYS[:-1], reader = 'ITKReader'),
            #TransposeITKD(keys = KEYS[:-1]),
            AddChannelD(keys = KEYS[:-1]),
            CorrectSpacing(KEYS[:-1]),
            SpacingD(keys = KEYS[:-1], pixdim=(1., 1., 1.), mode = ("bilinear")),
            NormalizeIntensityD(keys = KEYS[:-1]),
            ScaleIntensityD(keys = KEYS[:-1]),
            OrientationD(keys = KEYS[:-1], axcodes = 'RAS'),
            ResizeWithRatioD(keys = KEYS[:-1], image_size=spatial_size[1]),
            DivisiblePadD(KEYS[0], k = spatial_size[1], mode = 'constant'),
            RandCropByPosNegLabelD(keys = KEYS[:-1], label_key=KEYS[1], spatial_size = (spatial_size[1],spatial_size[2],num_patches), pos = 1, neg = 0),
            ToTensorD(keys = KEYS[:-1]),
            PatchedImage(KEYS[:-1], num_patches = num_patches)
            ])
    elif dataset in ['NGSLUNG_crop']:
        train_transforms = Compose([
        LoadImageD(keys = KEYS[:-1], reader = 'ITKReader'),
            #TransposeITKD(keys = KEYS[:-1]),
            AddChannelD(keys = KEYS[:-1]),
            CorrectSpacing(KEYS[:-1]),
            SpacingD(keys = KEYS[:-1], pixdim=(1., 1., 1.), mode = ("bilinear")),
            NormalizeIntensityD(keys = KEYS[:-1]),
            ScaleIntensityD(keys = KEYS[:-1]),
            OrientationD(keys = KEYS[:-1], axcodes = 'RAS'),
            CropForegroundD(KEYS[:-1], KEYS[1]),
            ResizeWithRatioD(keys = KEYS[:-1], image_size=spatial_size[1]),
            DivisiblePadD(KEYS[0], k = (spatial_size[1],spatial_size[2],-1), mode = 'constant'),
            RandFlipD(keys = KEYS[:-1], prob = 0.5, spatial_axis=0),
            RandRotate90D(keys = KEYS[:-1], prob=0.5, spatial_axes=(0,1)),    
            ToTensorD(keys = KEYS[:-1]),
            PatchedImage(KEYS[:-1], num_patches = num_patches)
        ])
        
        val_transforms = Compose([
            LoadImageD(keys = KEYS[:-1], reader = 'ITKReader'),
                #TransposeITKD(keys = KEYS[:-1]),
                AddChannelD(keys = KEYS[:-1]),
                CorrectSpacing(KEYS[:-1]),
                SpacingD(keys = KEYS[:-1], pixdim=(1., 1., 1.), mode = ("bilinear")),
                NormalizeIntensityD(keys = KEYS[:-1]),
                ScaleIntensityD(keys = KEYS[:-1]),
                OrientationD(keys = KEYS[:-1], axcodes = 'RAS'),
                CropForegroundD(KEYS[:-1], KEYS[1]),
                ResizeWithRatioD(keys = KEYS[:-1], image_size=spatial_size[1]),
                DivisiblePadD(KEYS[0], k = (spatial_size[1],spatial_size[2],-1), mode = 'constant'),
                ToTensorD(keys = KEYS[:-1]),
                #PatchedImage(KEYS[:-1], num_patches = num_patches)
                CenterPatchedImage(KEYS[:-1], num_patches = num_patches)
            ])
        
    elif dataset in ['NGSLUNG_crop_ValidVoting']:
        train_transforms = Compose([
            LoadImageD(keys = KEYS[:-1], reader = 'ITKReader'),
            #TransposeITKD(keys = KEYS[:-1]),
            AddChannelD(keys = KEYS[:-1]),
            CorrectSpacing(KEYS[:-1]),
            ScaleIntensityD(keys = KEYS[:-1]),
            OrientationD(keys = KEYS[:-1], axcodes = 'RAS'),
            CropForegroundD(KEYS[:-1], KEYS[1]),
            ResizeWithRatioD(keys = KEYS[:-1], image_size=spatial_size[1]),
            DivisiblePadD(KEYS[0], k = (spatial_size[1],spatial_size[2],-1), mode = 'constant'),
            RandFlipD(keys = KEYS[:-1], prob = 0.5, spatial_axis=0),
            RandRotate90D(keys = KEYS[:-1], prob=0.5, spatial_axes=(0,1)),    
            ToTensorD(keys = KEYS[:-1]),
            PatchedImage(KEYS[:-1], num_patches = num_patches)
        ])
        
        val_transforms = Compose([
                LoadImageD(keys = KEYS[:-1], reader = 'ITKReader'),
                #TransposeITKD(keys = KEYS[:-1]),
                AddChannelD(keys = KEYS[:-1]),
                CorrectSpacing(KEYS[:-1]),
                ScaleIntensityD(keys = KEYS[:-1]),
                OrientationD(keys = KEYS[:-1], axcodes = 'RAS'),
                CropForegroundD(KEYS[:-1], KEYS[1]),
                ResizeWithRatioD(keys = KEYS[:-1], image_size=spatial_size[1]),
                DivisiblePadD(KEYS[0], k = (spatial_size[1],spatial_size[2],-1), mode = 'constant'),
                ToTensorD(keys = KEYS[:-1]),
                ListPatchedImage(KEYS[:-1], num_patches = num_patches, stride = eval_stride)
            ])
        
    elif dataset in ['NGSLUNG_crop_ValidVoting_ScaledRange']:
        print('dataset: ', dataset)
        if window == 'parenchyma':
            a_min = -1350
            a_max = 150
        elif window == 'mediastinum':
            a_min = -115
            a_max = 235
        else:
            raise ValueError('Error! Invalid window name.')
        print(spatial_size[1])
        k_divisible = [k_divisible,k_divisible, 1]
          
        
        if 'multi_label' in label_key:
            train_transforms = Compose([
                LoadImageD(keys = KEYS[:-1], reader = 'ITKReader'),
                ConvertListToTensor(keys = KEYS[-1]),
                AddChannelD(keys = KEYS[:-1]),
                CorrectSpacing(KEYS[:-1]),
                ScaleIntensityRanged(keys = KEYS[0], a_min = a_min, a_max = a_max, b_min = 0, b_max=1, clip = True),
                OrientationD(keys = KEYS[:-1], axcodes = 'RAS'),
                MyCropForegroundd(KEYS[:-1], KEYS[1], k_divisible = k_divisible),
                ResizeWithRatioD(keys = KEYS[:-1], image_size=spatial_size[1]),
                DivisiblePadD(KEYS[0], k = (spatial_size[1],spatial_size[2],-1), mode = padding),
                RandFlipD(keys = KEYS[:-1], prob = 0.5, spatial_axis=0),
                RandRotate90D(keys = KEYS[:-1], prob=0.5, spatial_axes=(0,1)),    
                ToTensorD(keys = KEYS[:-1]),
                PatchedImage(KEYS[:-1], num_patches = num_patches)
            ])
            
            val_transforms = Compose([
                LoadImageD(keys = KEYS[:-1], reader = 'ITKReader'),
                ConvertListToTensor(keys = KEYS[-1]),
                AddChannelD(keys = KEYS[:-1]),
                CorrectSpacing(KEYS[:-1]),
                ScaleIntensityRanged(keys = KEYS[0], a_min = a_min, a_max = a_max, b_min = 0, b_max=1, clip = True),
                OrientationD(keys = KEYS[:-1], axcodes = 'RAS'),
                MyCropForegroundd(KEYS[:-1], KEYS[1], k_divisible = k_divisible),
                ResizeWithRatioD(keys = KEYS[:-1], image_size=spatial_size[1]),
                DivisiblePadD(KEYS[:-1], k = (spatial_size[1],spatial_size[2],-1), mode = padding),
                ToTensorD(keys = KEYS[:-1]),
                ListPatchedImage(KEYS[:-1], num_patches = num_patches, stride = eval_stride)
                ])
        else:    
            train_transforms = Compose([
            LoadImageD(keys = KEYS[:-1], reader = 'ITKReader'),
                #TransposeITKD(keys = KEYS[:-1]),
                AddChannelD(keys = KEYS[:-1]),
                CorrectSpacing(KEYS[:-1]),
                ScaleIntensityRanged(keys = KEYS[0], a_min = a_min, a_max = a_max, b_min = 0, b_max=1, clip = True),
                OrientationD(keys = KEYS[:-1], axcodes = 'RAS'),
                MyCropForegroundd(KEYS[:-1], KEYS[1], k_divisible = k_divisible),
                ResizeWithRatioD(keys = KEYS[:-1], image_size=spatial_size[1]),
                DivisiblePadD(KEYS[0], k = (spatial_size[1],spatial_size[2],-1), mode = padding),
                RandFlipD(keys = KEYS[:-1], prob = 0.5, spatial_axis=0),
                RandRotate90D(keys = KEYS[:-1], prob=0.5, spatial_axes=(0,1)),    
                ToTensorD(keys = KEYS[:-1]),
                PatchedImage(KEYS[:-1], num_patches = num_patches)
            ])
            
            val_transforms = Compose([
                LoadImageD(keys = KEYS[:-1], reader = 'ITKReader'),
                #TransposeITKD(keys = KEYS[:-1]),
                AddChannelD(keys = KEYS[:-1]),
                CorrectSpacing(KEYS[:-1]),
                ScaleIntensityRanged(keys = KEYS[0], a_min = a_min, a_max = a_max, b_min = 0, b_max=1, clip = True),
                OrientationD(keys = KEYS[:-1], axcodes = 'RAS'),
                MyCropForegroundd(KEYS[:-1], KEYS[1], k_divisible = k_divisible),
                ResizeWithRatioD(keys = KEYS[:-1], image_size=spatial_size[1]),
                DivisiblePadD(KEYS[:-1], k = (spatial_size[1],spatial_size[2],-1), mode = padding),
                ToTensorD(keys = KEYS[:-1]),
                ListPatchedImage(KEYS[:-1], num_patches = num_patches, stride = eval_stride)
                ])
        
    elif dataset in ['NGSLUNG_crop_ScaledRange_AllVolume']:
        #create a 2D patched image using all patch of the volume
        if window == 'parenchyma':
            a_min = -1350
            a_max = 150
        elif window == 'mediastinum':
            a_min = -115
            a_max = 235
        else:
            raise ValueError('Error! Invalid window name.')
            
        train_transforms = Compose([
            LoadImageD(keys = KEYS[:-1], reader = 'ITKReader'),
            #TransposeITKD(keys = KEYS[:-1]),
            AddChannelD(keys = KEYS[:-1]),
            CorrectSpacing(KEYS[:-1]),
            ScaleIntensityRanged(keys = KEYS[0], a_min = a_min, a_max = a_max, b_min = 0, b_max=1, clip = True),
            OrientationD(keys = KEYS[:-1], axcodes = 'RAS'),
            CropForegroundD(KEYS[:-1], KEYS[1]),
            ResizeWithRatioDVariableDim(keys = KEYS[0], image_size=img_size),
            RandFlipD(keys = KEYS[:-1], prob = 0.5, spatial_axis=0),
            RandRotate90D(keys = KEYS[:-1], prob=0.5, spatial_axes=(0,1)),
            ToTensorD(keys = KEYS[:-1]),
            PatchedImageAllSlice(KEYS[:-1]),
            TensorPad(KEYS[0], image_size = img_size, pad_value = 0),
        ])
        
        val_transforms = Compose([
            LoadImageD(keys = KEYS[:-1], reader = 'ITKReader'),
            #TransposeITKD(keys = KEYS[:-1]),
            AddChannelD(keys = KEYS[:-1]),
            CorrectSpacing(KEYS[:-1]),
            ScaleIntensityRanged(keys = KEYS[0], a_min = a_min, a_max = a_max, b_min = 0, b_max=1, clip = True),
            OrientationD(keys = KEYS[:-1], axcodes = 'RAS'),
            CropForegroundD(KEYS[:-1], KEYS[1]),
            ResizeWithRatioDVariableDim(keys = KEYS[0], image_size=img_size),
            ToTensorD(keys = KEYS[:-1]),
            PatchedImageAllSlice(KEYS[:-1]),
            TensorPad(KEYS[0], image_size = img_size, pad_value = 0),
        ])
        
    elif dataset in ['NGSLUNG_crop_ScaledRange_3D']:
        
        pad_depth = math.ceil(max_depth/depth_kernel_size)*depth_kernel_size
        
        if window == 'parenchyma':
            a_min = -1350
            a_max = 150
        elif window == 'mediastinum':
            a_min = -115
            a_max = 235
        else:
            raise ValueError('Error! Invalid window name.')
        
        train_trans_list = [
            LoadImageD(keys = KEYS[:-1], reader = 'ITKReader'),
            AddChannelD(keys = KEYS[:-1]),
            CorrectSpacing(KEYS[:-1]),
            ScaleIntensityRanged(keys = KEYS[0], a_min = -1350, a_max = 150, b_min = 0, b_max=1, clip = True),
            OrientationD(keys = KEYS[:-1], axcodes = 'RAS'),
            CropForegroundD(KEYS[:-1], KEYS[1]),
            ResizeWithRatioD(keys = KEYS[:-1], image_size=img_size),
            DivisiblePadD(keys = KEYS[:-1], k = (img_size,img_size,-1), method = 'symmetric')]
        
        if train_batch_size != 1 : 
            train_trans_list = train_trans_list + [ DivisiblePadD(keys = KEYS[:-1], k = (-1,-1,pad_depth), method = 'end')
                ]
        
        train_trans_list = train_trans_list + [ RandFlipD(keys = KEYS[:-1], prob = 0.5, spatial_axis=0),
                                                RandRotate90D(keys = KEYS[:-1], prob=0.5, spatial_axes=(0,1)),
                                                ToTensorD(keys = KEYS[:-1])
                                                ]
        
        train_transforms = Compose(train_trans_list)
        
        val_trans_list = [
            LoadImageD(keys = KEYS[:-1], reader = 'ITKReader'),
            AddChannelD(keys = KEYS[:-1]),
            CorrectSpacing(KEYS[:-1]),
            ScaleIntensityRanged(keys = KEYS[0], a_min = -1350, a_max = 150, b_min = 0, b_max=1, clip = True),
            OrientationD(keys = KEYS[:-1], axcodes = 'RAS'),
            CropForegroundD(KEYS[:-1], KEYS[1]),
            ResizeWithRatioD(keys = KEYS[:-1], image_size=img_size),
            DivisiblePadD(keys = KEYS[:-1], k = (img_size,img_size,-1), method = 'symmetric')]
        
        if train_batch_size != 1 : 
            val_trans_list = val_trans_list + [ DivisiblePadD(keys = KEYS[:-1], k = (-1,-1,pad_depth), method = 'end')
                ]
        
        val_trans_list = val_trans_list + [ ToTensorD(keys = KEYS[:-1])]
        
        val_transforms = Compose(val_trans_list)
        
    elif dataset in ['NGSLUNG_attMap']:
        KEYS = ('patched_img','att_maps', label_key)
        train_transforms = Compose([
            LoadImageD(keys = KEYS[:-1]),
            AddChannelD(keys = KEYS[:-1]),
            ScaleIntensityD(keys = KEYS[:-1]),
            ExpandDims(keys = KEYS[:-1]),
            RandSpatialCropD(keys = KEYS[:-1],roi_size=[-1,-1,1], random_size=False),
            ConcatChDim(keys = KEYS[:-1], deleteLastDimIf1 = True),
            RandFlipD(keys = KEYS[0], prob = 0.5, spatial_axis=0),
            RandRotate90D(keys = KEYS[0], prob=0.5, spatial_axes=(0,1)),    
            ToTensorD(keys = KEYS[0])
            ])
        
        val_transforms = Compose([
            LoadImageD(keys = KEYS[:-1]),
            AddChannelD(keys = KEYS[:-1]),
            ScaleIntensityD(keys = KEYS[:-1]),
            ExpandDims(keys = KEYS[:-1]),
            ConcatChDim(keys = KEYS[:-1]),
            AsDepthFirstD(keys = KEYS[0]),
            ToTensorD(keys = KEYS[0])
            ])
        
    elif dataset in ['NGSLUNG_patchedImage']:
        KEYS = ('patched_img', label_key)
        del_keys = ['image','mask','att_maps_meta_dict','patched_img_meta_dict']    
        train_transforms = Compose([
            LoadImageD(keys = KEYS[:-1]),
            AddChannelD(keys = KEYS[:-1]),
            ScaleIntensityD(keys = KEYS[:-1]),
            ExpandDims(keys = KEYS[:-1]),
            RandSpatialCropD(keys = KEYS[:-1],roi_size=[-1,-1,1], random_size=False),
            DeleteKeys(keys = KEYS[:-1], del_keys = del_keys,deleteLastDimIf1 = True),
            RandFlipD(keys = KEYS[0], prob = 0.5, spatial_axis=0),
            RandRotate90D(keys = KEYS[0], prob=0.5, spatial_axes=(0,1)),    
            ToTensorD(keys = KEYS[0])
            ])
        
        val_transforms = Compose([
            LoadImageD(keys = KEYS[:-1]),
            AddChannelD(keys = KEYS[:-1]),
            ScaleIntensityD(keys = KEYS[:-1]),
            ExpandDims(keys = KEYS[:-1]),
            DeleteKeys(keys = KEYS[:-1],del_keys = del_keys,),
            AsDepthFirstD(keys = KEYS[0]),
            ToTensorD(keys = KEYS[0])
            ])
    
    if section == 'training':
        dataset = NGSLungDataset(root_dir = root_dir, split_path = split_path, section = 'training', num_fold = num_fold, transforms = train_transforms, execute_test = execute_test)
        shuffle = True
    elif section in ['validation', 'test']:
        dataset = NGSLungDataset(root_dir = root_dir, split_path = split_path, section = section, num_fold = num_fold, transforms = val_transforms, execute_test = execute_test)
        shuffle = False
    
    print(f'Section: {section}, Shuffle: {shuffle}')
    
    loader = DataLoader(dataset,
                        batch_size = train_batch_size,
                        num_workers=0,
                        shuffle = shuffle 
                        )
    
    return loader

# New model

In [None]:
class ViTResNet(nn.Module):
    def __init__(self, config, img_size=224, in_channels=3, num_classes=21843, loss_weights=None, zero_head=True, vis=True, embeddings_type = '2D', loss_type = 'CrossEntropy', classifier_net = 'ResNet18', vit_pretrained_dir = None):
        super(ViTResNet, self).__init__()
        self.vit = VisionTransformer(config=config, img_size=img_size, in_channels=in_channels, num_classes=num_classes, loss_weights=loss_weights, zero_head=zero_head, vis=vis, embeddings_type = embeddings_type, loss_type = loss_type)
        self.classifier = CNNClassifier(img_size=img_size, in_channels=in_channels, num_classes=num_classes, loss_type=loss_type, model_type = classifier_net)
        if vit_pretrained_dir is not None:
            self.vit.load_state_dict(torch.load(vit_pretrained_dir))

    def forward(self, x, labels = None):
        print(len(x))
        imgs = x.copy()
        x, attn_weights = self.vit(x)
        att_map = compute_att_map(attn_weights, x.device, imgs.shape)
        #print(att_map.shape)
        x, _ = self.classifier(np.concatenate(imgs,x))
        return x, labels

## Training the new model

In [None]:
loader = get_SpecificLoader(dataset="NGSLUNG_crop_ValidVoting_ScaledRange", label_key="EGFR", img_size=224, num_patches=9, section="training", train_batch_size=32, root_dir="data", split_path="data/5BalancedCrossValFold_EGFR_131samples.json", num_fold=4, execute_test=True, inner_loop_idx = None)

In [None]:
model = ViTResNet(CONFIGS['ViT-h12l2'], num_classes=2)
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# optimizer.zero_grad() # io



for epoch in range(2):

    running_loss = 0.0
    for i, data in enumerate(loader, 0):
        # get the inputs; data is (should be) a list of [inputs, labels]
        inputs, labels = loader[i]
        # print(data)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0

print('Finished Training')