In [1]:
# !pip install openpyxl

In [2]:
import os
import tqdm
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
from tqdm import tqdm
import torch.nn.functional as F
from fvcore.nn import FlopCountAnalysis
from einops.layers.torch import Rearrange
import segmentation_models_pytorch as smp
from torch.utils.data import DataLoader, Dataset

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class PositionalEmbedding(nn.Module):
    def __init__(self, num_patches, projection_dim):
        super(PositionalEmbedding, self).__init__()
        self.embedding = nn.Embedding(num_patches, projection_dim)

    def forward(self, x):
        positions = torch.arange(x.shape[1], device=x.device).unsqueeze(0)
        encoded_positions = self.embedding(positions)
        return x + encoded_positions

class MLP(nn.Module):
    def __init__(self, input_dim, hidden_units, dropout_rate):
        super(MLP, self).__init__()
        layers = []
        for units in hidden_units:
            layers.extend([
                nn.Linear(input_dim, units),
                nn.GELU(),
                nn.Dropout(dropout_rate)
            ])
            input_dim = units
        self.mlp = nn.Sequential(*layers)

    def forward(self, x):
        return self.mlp(x)

class TokenLearner(nn.Module):
    def __init__(self, num_tokens):
        super(TokenLearner, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=num_tokens, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=num_tokens, out_channels=num_tokens, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(in_channels=num_tokens, out_channels=num_tokens, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(in_channels=num_tokens, out_channels=1, kernel_size=3, padding=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # print("TokenLearner")
        # print(x.shape)
        x = F.gelu(self.conv1(x))
        # print(x.shape)
        x = F.gelu(self.conv2(x))
        # print(x.shape)
        x = F.gelu(self.conv3(x))
        # print(x.shape)
        x = self.sigmoid(self.conv4(x))
        # print(x.shape)
        # x = x.view(x.size(0), x.size(1), -1)
        # print(x.shape)
        # x = x.permute(0, 2, 1)  # Permute to match TensorFlow's behavior
        # print(x.shape)
        # print("TokenLearner Done")
        return x

class Transformer(nn.Module):
    def __init__(self, input_dim, num_heads, mlp_units, dropout_rate):
        super(Transformer, self).__init__()
        self.multihead_attention = nn.MultiheadAttention(input_dim, num_heads, dropout=dropout_rate)
        self.mlp = MLP(input_dim, mlp_units, dropout_rate)

    def forward(self, x):
        # print("Transformer")
        # print(x.shape)
        attention_output, _ = self.multihead_attention(x, x, x)
        # print(x.shape)
        x = x + attention_output
        # print(x.shape)
        x = x + self.mlp(x)
        # print(x.shape)
        return x

class ViTClassifier(nn.Module):
    def __init__(self, image_size, patch_size, num_patches, projection_dim, num_heads, mlp_units, dropout_rate, num_classes, use_token_learner=True, token_learner_units=4):
        super(ViTClassifier, self).__init__()
        self.patch_embedding = nn.Conv2d(in_channels=1, out_channels=projection_dim, kernel_size=patch_size, stride=patch_size)
        self.positional_embedding = PositionalEmbedding(num_patches, projection_dim)
        self.dropout = nn.Dropout(dropout_rate)
        self.transformer_layers = nn.ModuleList([Transformer(projection_dim, num_heads, mlp_units, dropout_rate) for _ in range(6)])
        self.use_token_learner = use_token_learner
        self.token_learner = TokenLearner(token_learner_units) if use_token_learner else None
        self.layer_norm = nn.LayerNorm(projection_dim)
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Linear(projection_dim, num_classes)
        self.projection_dim = projection_dim

    def forward(self, x):
        # print(x.shape)
        x = self.patch_embedding(x)
        B, _, H, W = x.shape
        # print(x.shape)
        x = x.flatten(2).transpose(1, 2)
        # print(x.shape)
        x = self.positional_embedding(x)
        # print(x.shape)
        x = self.dropout(x)
        # print(x.shape)
        
        for transformer in self.transformer_layers:
            x = transformer(x)
            # print(x.shape)
        # print(x.shape)

        if self.use_token_learner:
            x = x.permute(0, 2, 1).unsqueeze(1)  # Adjust shape for TokenLearner
            # print(x.shape)
            x = self.token_learner(x)
            # print(x.shape)
            x = x.squeeze(1).permute(0, 2, 1)  # Adjust shape back
            # print(x.shape)
        
        # print(x.shape)
        return x
        
        # x = self.layer_norm(x)
        # x = self.global_avg_pool(x).squeeze(2)
        # x = self.classifier(x)
        # return x

# # Create the ViT classifier model
# model = ViTClassifier(
#     image_size=256,
#     patch_size=8,
#     num_patches=196,
#     projection_dim=256,
#     num_heads=8,
#     mlp_units=[256, 128],
#     dropout_rate=0.1,
#     num_classes=4,
#     use_token_learner=True,
#     token_learner_units=4
# )

# # Print model summary
# print(model)

In [4]:
class ViTBlock(nn.Module):
    def __init__(self, in_channels, out_channels, patch_size, num_patches, hidden_dim, num_heads, mlp_dim, dropout=0.1):
        super(ViTBlock, self).__init__()
        self.patch_embedding = nn.Conv2d(in_channels, out_channels, kernel_size=patch_size, stride=patch_size)
        self.position_embedding = nn.Parameter(torch.randn(1, num_patches, out_channels))
        self.cls_token = nn.Parameter(torch.randn(1, 1, out_channels))
        self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(out_channels, num_heads, dim_feedforward=mlp_dim, dropout=dropout, batch_first=True), num_layers=1)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # print("New Call")
        
        # print("Encoder 1", x.shape)
        B, C, H, W = x.shape
        x = self.patch_embedding(x)  # B, out_channels, H', W'
        # print("Encoder 2", x.shape)
        
        x = x.flatten(2).transpose(1, 2)  # B, (H' * W'), out_channels
        # print("Encoder 3", x.shape)
        
        # cls_tokens = self.cls_token.expand(B, -1, -1)
        # print("Encoder 4", cls_tokens.shape, x.shape)
        
        # x = torch.cat((cls_tokens, x), dim=1)
        # print("Encoder 5", x.shape)

        x += self.position_embedding
        # print("Encoder 6", x.shape)

        x = self.dropout(x)
        # print("Encoder 7", x.shape)
        
        x = self.transformer(x)
        # print("Encoder 8", x.shape)
        
        return x

class ViTUnet(nn.Module):
    def __init__(self, num_classes=1, in_channels=3, patch_size=16, vit_hidden_dim=256, vit_num_heads=8, vit_mlp_dim=512, dropout=0.1):
        super(ViTUnet, self).__init__()

        num_patches = (256 // patch_size) ** 2
        self.vit_block1 = ViTClassifier(image_size=256, patch_size=patch_size, num_patches=num_patches, projection_dim=vit_hidden_dim, num_heads=vit_num_heads,
    mlp_units=[256, 256], dropout_rate=dropout, num_classes=num_classes, use_token_learner=True, token_learner_units=4)
        # self.vit_block1 = ViTBlock(in_channels, vit_hidden_dim, patch_size, num_patches, vit_hidden_dim, vit_num_heads, vit_mlp_dim, dropout)
        # self.vit_block2 = ViTBlock(vit_hidden_dim, vit_hidden_dim, patch_size // 2, num_patches // 4, vit_hidden_dim, vit_num_heads, vit_mlp_dim, dropout)
        # self.vit_block3 = ViTBlock(vit_hidden_dim, vit_hidden_dim, patch_size // 2, num_patches // 16, vit_hidden_dim, vit_num_heads, vit_mlp_dim, dropout)

        self.decoder = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, num_classes, kernel_size=1)
        )

    def forward(self, x):
        # print("New Iteration")
        # print("Decoder 1", x.shape)
        x1 = self.vit_block1(x)
        # print(x1.shape)
        x1 = x1.unsqueeze(1)
        # x2 = self.vit_block2(x1.unsqueeze(1)).squeeze(1)
        # print("Decoder 3", x2.shape, x2.unsqueeze(1).shape )
        # x3 = self.vit_block3(x2.unsqueeze(1)).squeeze(1)
        # print("Decoder 4", x3.shape)
        # x = F.interpolate(x1.unsqueeze(1), scale_factor=2, mode='bilinear', align_corners=True)
        # print("Decoder 5", x.shape)
        # x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
        # print("Decoder 6", x1.shape)
        x = self.decoder(x1)
        # print("Decoder 7", x.shape)
        return x

In [5]:
# models = ['efficientnet-b4', 'efficientnet-b3']

In [6]:
# benchmark = pd.DataFrame(columns=['model_name', 'epochs', 'gflops', 'dice_score', 'iou_score'])

In [7]:
EPOCH = 5

In [8]:
print(torch.cuda.is_available())
device = "cuda:0"

True


In [9]:
class CustomDataset(Dataset):
    def __init__(self, data_path, label_path):
        self.data_path = data_path
        self.label_path = label_path
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def __len__(self):
        return len(os.listdir(self.data_path))

    def __getitem__(self, idx):
        ct = np.load(os.path.join(self.data_path, f"{idx}.npy"))
        mask = np.load(os.path.join(self.label_path, f"{idx}.npy"))

        ct = torch.Tensor(ct).to(self.device)
        mask = torch.Tensor(mask).to(self.device)

        return ct, mask

class DataGenerator:
    def __init__(self, batch_size):
        self.batch_size = batch_size
        self.ROOT_DATA_PATH = 'Task06_Lung/Preprocessed/all/data/'
        self.ROOT_LABEL_PATH = 'Task06_Lung/Preprocessed/all/label/'
        self.TEST_DATA_PATH = 'Task06_Lung/Preprocessed/test/data/'
        self.TEST_LABEL_PATH = 'Task06_Lung/Preprocessed/test/label/'

    def train_loader(self):
        dataset = CustomDataset(self.ROOT_DATA_PATH, self.ROOT_LABEL_PATH)
        return DataLoader(dataset, batch_size=self.batch_size, shuffle=True)

    def test_loader(self):
        dataset = CustomDataset(self.TEST_DATA_PATH, self.TEST_LABEL_PATH)
        return DataLoader(dataset, batch_size=1, shuffle=False)

In [10]:
batch_size = 8

data_generator = DataGenerator(batch_size)

train_loader = data_generator.train_loader()

test_loader = data_generator.test_loader()

In [11]:
class DiceScore(torch.nn.Module):
    """
    class to compute the Dice Loss
    """
    def __init__(self):
        super().__init__()

    def forward(self, pred, mask):

        #flatten label and prediction tensors
        pred = torch.flatten(pred)
        mask = torch.flatten(mask)

        counter = (pred * mask).sum()  # Counter
        denum = pred.sum() + mask.sum()
        dice = (2*counter)/denum

        return dice


In [12]:
def iou(groundtruth_mask, pred_mask):
    intersect = np.sum(pred_mask*groundtruth_mask)
    union = np.sum(pred_mask) + np.sum(groundtruth_mask) - intersect
    iou = np.mean(intersect/union)
    return round(iou, 3)

In [13]:
# for name in tqdm(models):

# #     model =  smp.Unet(
# #     encoder_name=name,        
# #     encoder_weights="imagenet",     
# #     in_channels=1,                  
# #     classes=1,                    
# # )
    
model = ViTUnet(num_classes=1, in_channels=1)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = torch.nn.BCEWithLogitsLoss()
torch.device(device=device)
model = model.to(device)

for epoch in range(EPOCH):
    model.train()
    total_loss = 0.0

    for batch_idx, (ct, mask) in enumerate(train_loader):
        ct, mask = ct.to(device), mask.to(device)

        outputs = model(ct)

        loss = criterion(outputs, mask)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    average_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch + 1}/{EPOCH}, Average Loss: {average_loss:.4f}")

preds = []
labels = []

for batch_idx, (ct, mask) in enumerate(test_loader):

    ct, mask = ct.to(device), mask.to(device)
        
    with torch.no_grad():
        pred = model(ct)
        
    pred = pred.cpu().numpy()
    mask = mask.cpu().numpy()
    pred = np.where(pred > 0.5, 1, 0)

    preds.append(pred)
    labels.append(mask)
    
preds = np.array(preds)
labels = np.array(labels)

dice_score = DiceScore()(torch.from_numpy(preds), torch.from_numpy(labels))
iou_score = iou(labels, preds)

data = torch.rand(4, 1, 256, 256).to(device)
flops = FlopCountAnalysis(model, data)
print("Dice:",dice_score)
print("\n")
print("IoU:",iou_score)
print("\n")
print("Flops:",flops.total() // 1e9)

    # new_row = {'model_name': name, 'epochs':EPOCH, 'gflops': flops, 'dice_score': dice_score, 'iou_score': iou_score}
    # benchmark = pd.concat([benchmark, pd.DataFrame([new_row])], ignore_index=True)

Epoch 1/5, Average Loss: 0.0465
Epoch 2/5, Average Loss: 0.0264
Epoch 3/5, Average Loss: 0.0264
Epoch 4/5, Average Loss: 0.0270
Epoch 5/5, Average Loss: 0.0264
Dice: tensor(0.)


IoU: 0.0




Unsupported operator aten::embedding encountered 1 time(s)
Unsupported operator aten::add encountered 13 time(s)
Unsupported operator aten::div encountered 12 time(s)
Unsupported operator aten::unflatten encountered 6 time(s)
Unsupported operator aten::mul encountered 24 time(s)
Unsupported operator aten::softmax encountered 6 time(s)
Unsupported operator aten::gelu encountered 15 time(s)
Unsupported operator aten::sigmoid encountered 1 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
vit_block1.classifier, vit_block1.global_avg_pool, vit_block1.layer_norm, vit_block1.transformer_layers.0.multihead_attention.out_proj, vit_block1.transformer_layers.1.multihead_attention.out_proj, vit_block1.transformer_layers.2

Flops: 12.0


In [14]:
# benchmark

In [15]:
# file_name = 'benchmark.xlsx'
# benchmark.to_excel(file_name)