In [2]:
import os
import tqdm
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
from tqdm import tqdm
import torch.nn.functional as F
from fvcore.nn import FlopCountAnalysis
from einops.layers.torch import Rearrange
import segmentation_models_pytorch as smp
from torch.utils.data import DataLoader, Dataset

In [3]:
class ViTBlock(nn.Module):
    def __init__(self, in_channels, out_channels, patch_size, num_patches, hidden_dim, num_heads, mlp_dim, dropout=0.1):
        super(ViTBlock, self).__init__()
        self.patch_embedding = nn.Conv2d(in_channels, out_channels, kernel_size=patch_size, stride=patch_size)
        self.position_embedding = nn.Parameter(torch.randn(1, num_patches, out_channels))
        self.cls_token = nn.Parameter(torch.randn(1, 1, out_channels))
        self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(out_channels, num_heads, dim_feedforward=mlp_dim, dropout=dropout, batch_first=True), num_layers=1)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # print("New Call")
        
        # print("Encoder 1", x.shape)
        B, C, H, W = x.shape
        x = self.patch_embedding(x)  # B, out_channels, H', W'
        # print("Encoder 2", x.shape)
        
        x = x.flatten(2).transpose(1, 2)  # B, (H' * W'), out_channels
        # print("Encoder 3", x.shape)
        
        # cls_tokens = self.cls_token.expand(B, -1, -1)
        # print("Encoder 4", cls_tokens.shape, x.shape)
        
        # x = torch.cat((cls_tokens, x), dim=1)
        # print("Encoder 5", x.shape)

        x += self.position_embedding
        # print("Encoder 6", x.shape)

        x = self.dropout(x)
        # print("Encoder 7", x.shape)
        
        x = self.transformer(x)
        # print("Encoder 8", x.shape)
        
        return x

class ViTUnet(nn.Module):
    def __init__(self, num_classes=2, in_channels=3, patch_size=16, vit_hidden_dim=256, vit_num_heads=32, vit_mlp_dim=512, dropout=0.1):
        super(ViTUnet, self).__init__()

        num_patches = (256 // patch_size) ** 2
        self.vit_block1 = ViTBlock(in_channels, vit_hidden_dim, patch_size, num_patches, vit_hidden_dim, vit_num_heads, vit_mlp_dim, dropout)
        self.vit_block2 = ViTBlock(vit_hidden_dim, vit_hidden_dim, patch_size // 2, num_patches // 4, vit_hidden_dim, vit_num_heads, vit_mlp_dim, dropout)
        self.vit_block3 = ViTBlock(vit_hidden_dim, vit_hidden_dim, patch_size // 2, num_patches // 16, vit_hidden_dim, vit_num_heads, vit_mlp_dim, dropout)

        self.decoder = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, num_classes, kernel_size=1)
        )

    def forward(self, x):
        # print("New Iteration")
        # print("Decoder 1", x.shape)
        x1 = self.vit_block1(x)
        x1 = x1.unsqueeze(1)
        # x2 = self.vit_block2(x1.unsqueeze(1)).squeeze(1)
        # print("Decoder 3", x2.shape, x2.unsqueeze(1).shape )
        # x3 = self.vit_block3(x2.unsqueeze(1)).squeeze(1)
        # print("Decoder 4", x3.shape)
        # x = F.interpolate(x1.unsqueeze(1), scale_factor=2, mode='bilinear', align_corners=True)
        # print("Decoder 5", x.shape)
        # x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
        # print("Decoder 6", x.shape)
        x = self.decoder(x1)
        # print("Decoder 7", x.shape)
        return x

In [4]:
# models = ['efficientnet-b4', 'efficientnet-b3']

In [5]:
# benchmark = pd.DataFrame(columns=['model_name', 'epochs', 'gflops', 'dice_score', 'iou_score'])

In [6]:
EPOCH = 100

In [7]:
print(torch.cuda.is_available())
device = "cuda:0"

True


In [8]:
class CustomDataset(Dataset):
    def __init__(self, data_path, label_path):
        self.data_path = data_path
        self.label_path = label_path
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def __len__(self):
        return len(os.listdir(self.data_path))

    def __getitem__(self, idx):
        ct = np.load(os.path.join(self.data_path, f"{idx}.npy"))
        mask = np.load(os.path.join(self.label_path, f"{idx}.npy"))

        ct = torch.Tensor(ct).to(self.device)
        mask = torch.Tensor(mask).to(self.device)

        return ct, mask

class DataGenerator:
    def __init__(self, batch_size):
        self.batch_size = batch_size
        self.ROOT_DATA_PATH = 'Task06_Lung/Preprocessed/all/data/'
        self.ROOT_LABEL_PATH = 'Task06_Lung/Preprocessed/all/label/'
        self.TEST_DATA_PATH = 'Task06_Lung/Preprocessed/test/data/'
        self.TEST_LABEL_PATH = 'Task06_Lung/Preprocessed/test/label/'

    def train_loader(self):
        dataset = CustomDataset(self.ROOT_DATA_PATH, self.ROOT_LABEL_PATH)
        return DataLoader(dataset, batch_size=self.batch_size, shuffle=True)

    def test_loader(self):
        dataset = CustomDataset(self.TEST_DATA_PATH, self.TEST_LABEL_PATH)
        return DataLoader(dataset, batch_size=1, shuffle=False)

In [9]:
batch_size = 8

data_generator = DataGenerator(batch_size)

train_loader = data_generator.train_loader()

test_loader = data_generator.test_loader()

In [10]:
class DiceScore(torch.nn.Module):
    """
    class to compute the Dice Loss
    """
    def __init__(self):
        super().__init__()

    def forward(self, pred, mask):

        #flatten label and prediction tensors
        pred = torch.flatten(pred)
        mask = torch.flatten(mask)

        counter = (pred * mask).sum()  # Counter
        denum = pred.sum() + mask.sum()
        dice = (2*counter)/denum

        return dice


In [11]:
def iou(groundtruth_mask, pred_mask):
    intersect = np.sum(pred_mask*groundtruth_mask)
    union = np.sum(pred_mask) + np.sum(groundtruth_mask) - intersect
    iou = np.mean(intersect/union)
    return round(iou, 3)

In [12]:
# for name in tqdm(models):

# #     model =  smp.Unet(
# #     encoder_name=name,        
# #     encoder_weights="imagenet",     
# #     in_channels=1,                  
# #     classes=1,                    
# # )
    
model = ViTUnet(num_classes=1, in_channels=1)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = torch.nn.BCEWithLogitsLoss()
torch.device(device=device)
model = model.to(device)

for epoch in range(EPOCH):
    model.train()
    total_loss = 0.0

    for batch_idx, (ct, mask) in enumerate(train_loader):
        ct, mask = ct.to(device), mask.to(device)

        outputs = model(ct)

        loss = criterion(outputs, mask)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    average_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch + 1}/{EPOCH}, Average Loss: {average_loss:.4f}")

preds = []
labels = []

for batch_idx, (ct, mask) in enumerate(test_loader):

    ct, mask = ct.to(device), mask.to(device)
        
    with torch.no_grad():
        pred = model(ct)
        
    pred = pred.cpu().numpy()
    mask = mask.cpu().numpy()
    pred = np.where(pred > 0.5, 1, 0)

    preds.append(pred)
    labels.append(mask)
    
preds = np.array(preds)
labels = np.array(labels)

dice_score = DiceScore()(torch.from_numpy(preds), torch.from_numpy(labels))
iou_score = iou(labels, preds)

data = torch.rand(1, 1, 256, 256).to(device)
flops = FlopCountAnalysis(model, data)
print(dice_score, iou_score, flops.total() // 1e-9)

    # new_row = {'model_name': name, 'epochs':EPOCH, 'gflops': flops, 'dice_score': dice_score, 'iou_score': iou_score}
    # benchmark = pd.concat([benchmark, pd.DataFrame([new_row])], ignore_index=True)

  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)


Epoch 1/100, Average Loss: 0.1296
Epoch 2/100, Average Loss: 0.0251
Epoch 3/100, Average Loss: 0.0200
Epoch 4/100, Average Loss: 0.0175
Epoch 5/100, Average Loss: 0.0161
Epoch 6/100, Average Loss: 0.0153
Epoch 7/100, Average Loss: 0.0146
Epoch 8/100, Average Loss: 0.0139
Epoch 9/100, Average Loss: 0.0132
Epoch 10/100, Average Loss: 0.0125
Epoch 11/100, Average Loss: 0.0119
Epoch 12/100, Average Loss: 0.0113
Epoch 13/100, Average Loss: 0.0106
Epoch 14/100, Average Loss: 0.0101
Epoch 15/100, Average Loss: 0.0096
Epoch 16/100, Average Loss: 0.0090
Epoch 17/100, Average Loss: 0.0086
Epoch 18/100, Average Loss: 0.0083
Epoch 19/100, Average Loss: 0.0079
Epoch 20/100, Average Loss: 0.0075
Epoch 21/100, Average Loss: 0.0073
Epoch 22/100, Average Loss: 0.0070
Epoch 23/100, Average Loss: 0.0068
Epoch 24/100, Average Loss: 0.0066
Epoch 25/100, Average Loss: 0.0064
Epoch 26/100, Average Loss: 0.0061
Epoch 27/100, Average Loss: 0.0060
Epoch 28/100, Average Loss: 0.0058
Epoch 29/100, Average Loss: 0

Unsupported operator aten::add_ encountered 1 time(s)
Unsupported operator aten::div encountered 1 time(s)
Unsupported operator aten::unflatten encountered 1 time(s)
Unsupported operator aten::mul encountered 4 time(s)
Unsupported operator aten::scaled_dot_product_attention encountered 1 time(s)
Unsupported operator aten::add encountered 2 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
vit_block1.transformer.layers.0.self_attn.out_proj, vit_block2, vit_block2.dropout, vit_block2.patch_embedding, vit_block2.transformer, vit_block2.transformer.layers.0, vit_block2.transformer.layers.0.dropout, vit_block2.transformer.layers.0.dropout1, vit_block2.transformer.layers.0.dropout2, vit_block2.transformer.layers.0.li

tensor(0.7726) 0.629 2.609512448e+18


In [13]:
# benchmark

In [14]:
# file_name = 'benchmark.xlsx'
# benchmark.to_excel(file_name)

In [15]:
encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
src = torch.rand(10, 32, 512)
out = encoder_layer(src)

In [16]:
print(src.shape, out.shape)

torch.Size([10, 32, 512]) torch.Size([10, 32, 512])


In [17]:
encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=1)
src = torch.rand(10, 32, 512)
out = transformer_encoder(src)

print(src.shape, out.shape)

torch.Size([10, 32, 512]) torch.Size([10, 32, 512])


