In [8]:

# install torchvision compatible with torch==2.2.1+cu121
%pip install torchvision==0.17.1+cu121 --extra-index-url https://download.pytorch.org/whl/cu121

# install torchsummary
%pip install torchsummary

#install captum
%pip install captum==0.7.0

Found existing installation: torch 2.2.1+cu121
Uninstalling torch-2.2.1+cu121:
  Successfully uninstalled torch-2.2.1+cu121
Found existing installation: torchvision 0.17.1+cu121
Uninstalling torchvision-0.17.1+cu121:
  Successfully uninstalled torchvision-0.17.1+cu121
Note: you may need to restart the kernel to use updated packages.
Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cu121
Collecting torchvision==0.17.1+cu121
  Using cached https://download.pytorch.org/whl/cu121/torchvision-0.17.1%2Bcu121-cp312-cp312-linux_x86_64.whl (7.0 MB)
Collecting torch==2.2.1 (from torchvision==0.17.1+cu121)
  Using cached https://download.pytorch.org/whl/cu121/torch-2.2.1%2Bcu121-cp312-cp312-linux_x86_64.whl (757.2 MB)
Installing collected packages: torch, torchvision
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2/2[0m [torchvision][0m [torchvision]
[1A[2KSuccessfully installed torch-2.2.1+cu121 torchvision-0.17.1+cu121

[1m[[0m[34;49mnotice[0m

In [9]:
import pip
#python3 -m pip install --upgrade pip
import glob
import numpy as np
import sys
import os
from matplotlib import pyplot as plt
import torch

import torchvision
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF

from torchsummary import summary
import torch.optim as optim
from torchvision.ops import sigmoid_focal_loss

from captum.attr import Occlusion




  from .autonotebook import tqdm as notebook_tqdm


In [4]:
class ChannelAttentionBlock(nn.Module):
    def __init__(self, channel, ratio):
        super(ChannelAttentionBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1).to(device)
        self.max_pool = nn.AdaptiveMaxPool2d(1).to(device)
        self.fc1 = nn.Linear(channel, channel // ratio).to(device)
        self.fc2 = nn.Linear(channel // ratio, channel).to(device)

    def forward(self, x):
        avg_pool = self.avg_pool(x)
        avg_pool = avg_pool.view(avg_pool.size(0), -1)
        avg_pool = self.fc2(F.relu(self.fc1(avg_pool)))

        max_pool = self.max_pool(x)
        max_pool = max_pool.view(max_pool.size(0), -1)
        max_pool = self.fc2(F.relu(self.fc1(max_pool)))

        scale = torch.sigmoid(avg_pool + max_pool)
        scale = scale.view(scale.size(0), scale.size(1), 1, 1)
        
        return x * scale
    

In [None]:
class SpatialAttentionBlock(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttentionBlock, self).__init__()
        self.kernel_size = kernel_size
        self.conv = nn.Conv2d(2, 1, kernel_size=self.kernel_size, stride=1, padding=self.kernel_size // 2, bias=False)
        nn.init.xavier_uniform_(self.conv.weight)
    def forward(self, input_feature):
        r, c = input_feature.size(-2), input_feature.size(-1)
        avg_pool = torch.mean(input_feature, dim=1, keepdim=True).to(device)
        max_pool = torch.max(input_feature, dim=1, keepdim=True)[0].to(device)
        concat = torch.cat([avg_pool, max_pool], dim=1).to(device)
        
        concat = self.conv(concat)
        
        concat = torch.sigmoid(concat)
        concat = concat.view(-1, 1, r, c)
        
        return input_feature * concat
        

In [None]:
def cbam_block(i, ratio=8):
        attention_feature = ChannelAttentionBlock(i, ratio=ratio).to(device)
        attention_feature = SpatialAttentionBlock().to(device)
        return attention_feature



In [7]:
class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding):
        super(DepthwiseSeparableConv, self).__init__()
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, padding=padding, groups=in_channels)
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.pointwise(self.depthwise(x))

In [None]:
class BM(nn.Module) :
    def __init__(self, c) :
        super(BM, self).__init__()
        self.c = c
        self.inter = c // 2  # Intermediate dimension for attention
         
        #channel attention
        self.channel_attention = ChannelAttentionBlock(c, ratio=8)

        # Contextual feature extraction (global context)
        self.g = nn.Conv2d(self.c, self.inter, kernel_size=1, padding="same")
        self.theta = nn.Conv2d(self.c, self.inter, kernel_size=1, padding="same")
        self.phi = nn.Conv2d(self.c, self.inter, kernel_size=1, padding="same")

        # Local context feature extraction 
        self.local_context = nn.Sequential(
            
            DepthwiseSeparableConv(self.c, self.inter, kernel_size=3, padding="same"),
            nn.Dropout(p=0.1)
            
        )

        # Spatial attention mechanism
        self.spatial_attention = SpatialAttentionBlock()
        self.conv = nn.Conv2d(self.inter, 1, kernel_size=1, padding="same")

        # Multi-scale context extraction (global + local)
        self.multi_scale_conv = nn.Sequential(
            DepthwiseSeparableConv(self.inter, self.inter, kernel_size=1, padding="same"),
            nn.Dropout(p=0.3),
            DepthwiseSeparableConv(self.inter, self.inter, kernel_size=3, padding="same"),
            nn.Dropout(p=0.2),
            
        )
        

        # Adaptive weighting for global and local context
        self.global_weight = nn.Parameter(torch.ones(1))  # Learnable scalar weight
        self.local_weight = nn.Parameter(torch.ones(1))   # Learnable scalar weight

        # Final output transformation (to match input dimensions)
        self.W = nn.Sequential(
            nn.Conv2d(self.inter, self.c, kernel_size=1, padding="same"),
            
            nn.BatchNorm2d(self.c),
            nn.ReLU(inplace=True)
        )
       
        nn.init.constant_(self.W[1].weight, 0)
        nn.init.constant_(self.W[1].bias, 0)

    def forward(self, x):
        #the exact logic of BM attention is currently a placeholder only. 
        return 

In [None]:
def convbatch(i, k, d, f):
    conv = nn.Conv2d(i, f, kernel_size=k, dilation=d, padding = "same")
    nn.init.xavier_uniform_(conv.weight)
    batch_norm = nn.BatchNorm2d(f)
    return nn.Sequential(conv, batch_norm, nn.ReLU())



In [None]:
class RSconv(nn.Module):
    def __init__(self, i, k, d, f, angle):
        super(RSconv, self).__init__()
        self.conv = nn.Conv2d(i, f, kernel_size=k, dilation=d, padding= 'same')
        nn.init.xavier_uniform_(self.conv.weight)
        self.conv.weight = nn.Parameter(self.conv.weight, requires_grad=True)
        self.conv.bias = nn.Parameter(self.conv.bias, requires_grad=True)
        self.bn = nn.BatchNorm2d(f, track_running_stats=True)
        self.bn.weight.requires_grad = True
        self.bn.bias.requires_grad = True
        self.angle = nn.Parameter(torch.zeros(1), requires_grad=True)
    def forward(self, x):
        #the exact logic of RSconv is currently a placeholder only. 
        return 

In [None]:
class SVconv(nn.Module) :
    def __init__(self, i, k, d, f):
        super(SVconv, self).__init__()
        #kernel size
        self.k = k
        #dilation rate
        self.d = d
        #receptive field size
        self.rf = (k - 1) * d + 1
        #learnable mask for kernel
        self.mask = nn.Parameter(torch.rand(self.rf, self.rf))
        #kernel W
        self.weight = nn.Parameter(torch.randn(f, i, self.rf, self.rf))
        nn.init.xavier_uniform_(self.weight)
        self.bn = nn.BatchNorm2d(f, track_running_stats=True)
        self.bn.weight.requires_grad = True
        self.bn.bias.requires_grad = True
        
    def forward(self, x) :
        #the exact logic of SVconv is currently a placeholder only. 
        return 
        
        
        
        



         
        

    

In [None]:
class ATKPool(nn.Module):
    #Code for average top-k pooling
    def __init__(self, initial_k=1):
        super(ATKPool, self).__init__()
        self.k = nn.Parameter(torch.tensor(float(initial_k)))
    def forward(self, input):
        batch_size, channels, height, width = input.size()
        k = torch.clamp(self.k, min=1, max=2**2).int()
        unfolded = F.unfold(input, kernel_size = 2, stride = 2)
        unfolded = unfolded.view(batch_size, channels, 2 * 2, -1)
        unfolded, _ = torch.sort(unfolded.clone(), dim=2, descending=True)
        top_k_values = unfolded[:, :, :k].clone()
        avg_top_k = top_k_values.mean(dim=2)
        newheight = (height + 2 - 2) // 2 
        newwidth = (width + 2 - 2) // 2
        atkp  = avg_top_k.view(batch_size, channels, newheight, newwidth)
        return atkp

In [None]:
class RESP(nn.Module):
    def __init__(self, initial_threshold=0.5, initial_k = 1, initial_k1 = 1, initial_weight = 0.5):
         super(RESP, self).__init__()
         self.max = nn.MaxPool2d(kernel_size=2)
         self.threshold = nn.Conv2d(16, 1, kernel_size = 1)
         nn.init.constant_(self.threshold.weight, initial_threshold)
         nn.init.constant_(self.threshold.bias,  initial_threshold)
         self.k = nn.Parameter(torch.tensor(float(initial_k)))
         self.k1 = nn.Parameter(torch.tensor(float(initial_k1)))
         self.weight = nn.Parameter(torch.tensor(initial_weight))
         self.atkp = ATKPool(initial_k)
    def forward(self, x):
          #the exact logic of RESP is currently a placeholder only.
          return   

              
          

         
         

        

In [None]:
class Branch1(nn.Module):
    def __init__(self, block):
        super(Branch1, self).__init__()
        self.block = block
        #self.device = device
        self.initialized = False
        self.convbatch1 = RSconv(i=16, k=3, d=1, f=16, angle=torch.zeros(1).to(device))
        self.convbatch2 = RSconv(i=16, k=3, d=1, f=16, angle=torch.zeros(1).to(device))
        self.convbatch3 = RSconv(i=32, k=3, d=1, f=16, angle=torch.zeros(1).to(device))
        self.convbatch4 = RSconv(i=48, k=3, d=1, f=16, angle=torch.zeros(1).to(device))
        self.convbatch5 = convbatch(i=16, k=3, d=1, f=16)
        self.convbatch6 = convbatch(i=16, k=3, d=1, f=16)
        self.convbatch7 = convbatch(i=32, k=3, d=1, f=16)
        self.convbatch8 = convbatch(i=48, k=3, d=1, f=16)
        self.convbatch = convbatch(i=64, k=1, d=1, f=16)
        self.cbam = cbam_block(i=64, ratio=8)
        
    def forward(self, x):
        
          
        if(self.block==1 or self.block==2): 
           conv1 = self.convbatch1(x)
           conv2 = self.convbatch2(conv1)
           conv2 = torch.cat([conv1, conv2], dim=1)
           conv3 = self.convbatch3(conv2)
           conv3 = torch.cat([conv2, conv3], dim=1)
           conv4 = self.convbatch4(conv3)
           conv4 = torch.cat([conv3, conv4], dim=1)
           conv4 = self.cbam(conv4)
           conv = self.convbatch(conv4)
        elif (self.block==3 or self.block==4):
           conv1 = self.convbatch5(x)
           conv2 = self.convbatch6(conv1)
           conv2 = torch.cat([conv1, conv2], dim=1)
           conv3 = self.convbatch7(conv2)
           conv3 = torch.cat([conv2, conv3], dim=1)
           conv4 = self.convbatch8(conv3)
           conv4 = torch.cat([conv3, conv4], dim=1)
           conv4 = self.cbam(conv4)
           conv = self.convbatch(conv4)
        return conv

    
    


In [None]:
class Branch2(nn.Module):
    def __init__(self, block):
        super(Branch2, self).__init__()
        self.block = block
        self.convbatch1 = convbatch(i=16, k=5, d=1, f=16)
        self.convbatch2 = convbatch(i=16, k=5, d=1, f=16)
        self.convbatch3 = convbatch(i=32, k=5, d=1, f=16)
        self.convbatch4 = convbatch(i=48, k=5, d=1, f=16)
        self.convbatch5 = SVconv(i=16, k=3, d=2, f=16)
        self.convbatch6 = SVconv(i=16, k=3, d=2, f=16)
        self.convbatch7 = SVconv(i=32, k=3, d=2, f=16)
        self.convbatch8 = SVconv(i=48, k=3, d=2, f=16)
        self.convbatch = convbatch(i=64, k=1, d=1, f=16)
        self.cbam = cbam_block(i=64, ratio=8)
    def forward(self, x):
     if(self.block==1 or self.block==2): 
        conv1 = self.convbatch1(x)
        conv2 = self.convbatch2(conv1)
        conv2 = torch.cat([conv1, conv2], dim=1)
        conv3 = self.convbatch3(conv2)
        conv3 = torch.cat([conv2, conv3], dim=1)
        conv4 = self.convbatch4(conv3)
        conv4 = torch.cat([conv3, conv4], dim=1)
        conv4 = self.cbam(conv4)
        conv = self.convbatch(conv4)
     
     elif(self.block==3 or self.block==4):
        conv1 = self.convbatch5(x)
        conv2 = self.convbatch6(conv1)
        conv2 = torch.cat([conv1, conv2], dim=1)
        conv3 = self.convbatch7(conv2)
        conv3 = torch.cat([conv2, conv3], dim=1)
        conv4 = self.convbatch8(conv3)
        conv4 = torch.cat([conv3, conv4], dim=1)
        conv4 = self.cbam(conv4)
        conv = self.convbatch(conv4)
     
     return conv
    
    



In [None]:
class DenseModel(nn.Module):
    #This class defines the (overall) LUNG Net architecture 
    def __init__(self):
        super(DenseModel, self).__init__()
        
        self.angle = nn.Parameter(torch.zeros(1, requires_grad=True))  
        self.conv1 = nn.Conv2d(1, 16, kernel_size=5, padding=2).to(device)
        nn.init.xavier_uniform_(self.conv1.weight)
        self.conv1.weight = nn.Parameter(self.conv1.weight.to(device))
        self.norm1 = nn.BatchNorm2d(16).to(device)
        self.conv2 = nn.Conv2d(16, 16, kernel_size=3, padding=1).to(device)
        nn.init.xavier_uniform_(self.conv2.weight)
        self.conv2.weight = nn.Parameter(self.conv2.weight.to(device))
        self.norm2 = nn.BatchNorm2d(16).to(device)
        self.conv = nn.Conv2d(33, 16, kernel_size=1, padding=0).to(device)
        nn.init.xavier_uniform_(self.conv.weight)
        self.conv.weight = nn.Parameter(self.conv.weight.to(device))
        self.norm = nn.BatchNorm2d(16).to(device)
        self.block1conv1 = nn.Conv2d(16, 1, kernel_size=1).to(device)
        nn.init.xavier_uniform_(self.block1conv1.weight)
        self.block1conv1.weight = nn.Parameter(self.block1conv1.weight.to(device))
        self.norm_block1conv1 = nn.BatchNorm2d(1).to(device)
        self.block1branch2 = Branch2(1).to(device)
        self.block1branch1 = Branch1(1).to(device)
        self.block2conv1 = nn.Conv2d(16, 1, kernel_size=1).to(device)
        nn.init.xavier_uniform_(self.block2conv1.weight)
        self.block2conv1.weight = nn.Parameter(self.block2conv1.weight.to(device))
        self.norm_block2conv1 = nn.BatchNorm2d(1).to(device)
        self.block2branch2 = Branch2(2).to(device)
        self.block2branch1 = Branch1(2).to(device)
        self.block3conv1 = nn.Conv2d(16, 1, kernel_size=1).to(device)
        nn.init.xavier_uniform_(self.block3conv1.weight)
        self.block3conv1.weight = nn.Parameter(self.block3conv1.weight.to(device))
        self.norm_block3conv1 = nn.BatchNorm2d(1).to(device)
        self.block3branch2 = Branch2(3).to(device)
        self.block3branch1 = Branch1(3).to(device)
        self.block4conv1 = nn.Conv2d(16, 1, kernel_size=1).to(device)
        nn.init.xavier_uniform_(self.block4conv1.weight)
        self.block4conv1.weight = nn.Parameter(self.block4conv1.weight.to(device))
        self.norm_block4conv1 = nn.BatchNorm2d(1).to(device)
        self.block4branch2 = Branch2(4).to(device)
        self.block4branch1 = Branch1(4).to(device)
        self.pool = nn.MaxPool2d(kernel_size=2).to(device)
        self.pooling = RESP().to(device)
        
        self.flatten = nn.Flatten()

        #3 dense layers to be defined here
        
        self.cbam = cbam_block(i=33, ratio=8)
        self.BM = BM(c=33)
        self.initialized = False
    def forward(self, x):
          x = self.conv1(x)
          
          x = F.relu(self.norm1(x))
          x = self.conv2(x)
          
          x = F.relu(self.norm2(x))
            
          block1conv1 = self.block1conv1(x)
          
          block1conv1 = F.relu(self.norm_block1conv1(block1conv1))
          block1branch2 = self.block1branch2(x)
          block1branch1 = self.block1branch1(x)
          conc1 = torch.cat([block1branch2, block1branch1], dim=1)
          conc1 = torch.cat([conc1, block1conv1], dim=1)
          

          conc1 = self.cbam(conc1)
          conc1 = self.conv(conc1)
          
          conc1 = F.relu(self.norm(conc1))
          

          x = self.pool(conc1)
          x = self.pool(x)
          x = F.dropout(x, p=0.1, training=self.training)


          block2conv1 = self.block2conv1(x)
          
          block2conv1 = F.relu(self.norm_block2conv1(block2conv1))
          block2branch2 = self.block2branch2(x)
          block2branch1 = self.block2branch1(x)
          conc2 = torch.cat([block2branch2, block2branch1], dim=1)
          conc2 = torch.cat([conc2, block2conv1], dim=1)
          
          conc2 = self.cbam(conc2)
          conc2 = self.conv(conc2)
          
          conc2 = F.relu(self.norm(conc2))
          

          x = self.pool(conc2)
          x = self.pool(x)
          x = F.dropout(x, p=0.15, training=self.training)

          

          block3conv1 = self.block3conv1(x)
          
          block3conv1 = F.relu(self.norm_block3conv1(block3conv1))
          block3branch2 = self.block3branch2(x)
          block3branch1 = self.block3branch1(x)
          conc3 = torch.cat([block3branch2, block3branch1], dim=1)
          conc3 = torch.cat([conc3, block3conv1], dim=1)
          
          conc3 = self.BM(conc3)
          conc3 = self.conv(conc3)
          
          conc3 = F.relu(self.norm(conc3))
          

          x = self.pooling(conc3)
          

          block4conv1 = self.block4conv1(x)
          
          block4conv1 = F.relu(self.norm_block4conv1(block4conv1))
          block4branch2 = self.block4branch2(x)
          block4branch1 = self.block4branch1(x)
          conc4 = torch.cat([block4branch2, block4branch1], dim=1)
          conc4 = torch.cat([conc4, block4conv1], dim=1)
          conc4 = self.BM(conc4)
          conc4 = self.conv(conc4)
          
          conc4 = F.relu(self.norm(conc4))
          
          x = self.pooling(conc4)
          

          x = self.flatten(x)
          # x = dense layers to be computed here
          

          return torch.sigmoid(x)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



model = DenseModel().to(device)
input_size = (1, 512, 512)

dummy_input = torch.randn(1, *input_size).to(device)
output = model(dummy_input)
print("Output shape:", output.shape)




          


In [None]:
summary(model, (1, 512, 512))

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#train_dataset = the train dataset is read here
#val_dataset = the validation dataset is read here

#test_dataset = #the test dataset is read here

#train_loader = The train_dataset is to be loaded here
#val_loader = The val_dataset is to be loaded here
#test_loader = The test_dataset is to be loaded here




In [None]:

#training
checkpoint_path = "define the checkpoint path"
class ModelCheckpoint: #The class for model training
    def __init__(self, filepath, monitor='val_accuracy', verbose=0, save_best_only=True, mode='max'):
        self.filepath = filepath
        self.monitor = monitor
        self.verbose = verbose
        self.save_best_only = save_best_only
        self.mode = mode
        self.best_metric = None


    def __call__(self, model, optimizer, metric):
        if self.best_metric is None or (metric > self.best_metric and self.mode == 'max') or (metric < self.best_metric and self.mode == 'min'):
            self.best_metric = metric
            if self.verbose:
                print(f"                   Validation {self.monitor}: {metric:.4f} (improved)")
            if self.save_best_only:
                torch.save({
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'metric': metric
                }, self.filepath)
                if self.verbose:
                    print(f"               Best model saved at: {self.filepath}")

# Creating an instance of the model checkpoint callback
checkpoint_callback = ModelCheckpoint(filepath=os.path.join(checkpoint_path, 'best_model.pth'),
                                      monitor='val_accuracy',
                                      verbose=1,
                                      save_best_only=True,
                                      mode='max') #checkpoint_path to be defined with actual checkpoint path


def train(model, optimizer, train_loader, val_loader, num_epochs=100):
    #The actual training code to be defined here
    #This is only a placeholder for the actual training code
    return 


train(model, optimizer, train_loader, val_loader)

In [None]:
#testing (evaluation)
from sklearn.metrics import confusion_matrix, f1_score, roc_auc_score

#model = Load the model here


def evaluate(model, data_loader):
    #This is a placeholder for the actual definition of testing or evaluation.
    return






    
