### *DS 261: AI for Medical Image Analysis : Assignment 02*
*Submitted By: Aman Pawar, Mtech (1st Year), SR NO: 22761, Department of Bioengineering* 

*Note: Please install the following libraries if not already installed*

*Necessary Installs*<br/>
```!pip install numpy scipy matplotlib sklearn tqdm torchmetrics``` <br/>
```!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118```<br/>

*Installs for better visualizations*<br/>
```!pip install torchsummary graphviz torchview``` *Note on Windows you must install executable of graphviz*<br/> 

### *Task-3: Utilize the predicted masks from Task-2 to classify CT scans into three distinct groups Normal, Mild, and Severe*

In [1]:
# Doing Necessary imports
import gc
import scipy.io
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
plt.rcParams["figure.dpi"]=1200
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
import torchmetrics as tm
from torchsummary import summary
from torchview import draw_graph
from torch.utils.data import Dataset, DataLoader, random_split

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Working on {device}")

Working on cuda


In [2]:
def read_data(path):
    """Function to read data"""
    mat = scipy.io.loadmat(path)
    image_data = mat[list(mat.keys())[-1]]
    print("Reading Data...")
    print(f"Resolution of CT image : {image_data.shape[:2]} ")
    print(f"The number of CT images : {image_data.shape[2]}\n")
    return image_data

def infection_ratio_cluster(data_ct, data_mask):
    """Function to compute the infection Ratio and cluster in 3 groups"""
    Normal_ct, Mild_ct, Severe_ct = [], [], []
    Normal_mask, Mild_mask, Severe_mask = [], [], []

    for i in range(data_mask.shape[2]):
        slice_data_mask = data_mask[:, :, i]
        slice_data_ct = data_ct[:,:,i]
        total_pixels = (slice_data_mask > 0).sum()
        infected_pixels = ((slice_data_mask > 0) & (slice_data_mask < 2)).sum()
        infection_ratio = infected_pixels / total_pixels * 100

        # Categorize the slice based on the infection ratio
        if infection_ratio == 0:
            Normal_mask.append(slice_data_mask)
            Normal_ct.append(slice_data_ct)
        elif 0 < infection_ratio <= 40:
            Mild_mask.append(slice_data_mask)
            Mild_ct.append(slice_data_ct)
        else:
            Severe_mask.append(slice_data_mask)
            Severe_ct.append(slice_data_ct)

    return np.array(Normal_mask), np.array(Normal_ct), np.array(Mild_mask), np.array(Mild_ct), np.array(Severe_mask), np.array(Severe_ct)

class CustomImageDataset_seg(Dataset):
    def __init__(self, data, masks, transform=None):
        self.data = data
        self.masks = masks
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image = self.data[index]
        masks = self.masks[index]

        if self.transform:
            image = self.transform(image)

        return image, masks

In [3]:
# Defining the the U-Net Architecture
class UNet(nn.Module):

    def __init__(self, in_channels=1, out_channels=1, init_features=32):
        super(UNet, self).__init__()

        features = init_features
        self.encoder1 = UNet._block(in_channels, features, name="enc1")
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder2 = UNet._block(features, features * 2, name="enc2")
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder3 = UNet._block(features * 2, features * 4, name="enc3")
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder4 = UNet._block(features * 4, features * 8, name="enc4")
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck")

        self.upconv4 = nn.ConvTranspose2d(
            features * 16, features * 8, kernel_size=2, stride=2
        )
        self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4")
        self.upconv3 = nn.ConvTranspose2d(
            features * 8, features * 4, kernel_size=2, stride=2
        )
        self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3")
        self.upconv2 = nn.ConvTranspose2d(
            features * 4, features * 2, kernel_size=2, stride=2
        )
        self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2")
        self.upconv1 = nn.ConvTranspose2d(
            features * 2, features, kernel_size=2, stride=2
        )
        self.decoder1 = UNet._block(features * 2, features, name="dec1")

        self.conv = nn.Conv2d(
            in_channels=features, out_channels=out_channels, kernel_size=1
        )

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        enc3 = self.encoder3(self.pool2(enc2))
        enc4 = self.encoder4(self.pool3(enc3))

        bottleneck = self.bottleneck(self.pool4(enc4))

        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decoder4(dec4)
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)
        return torch.sigmoid(self.conv(dec1))

    @staticmethod
    def _block(in_channels, features, name):
        return nn.Sequential(
            OrderedDict(
                [
                    (
                        name + "conv1",
                        nn.Conv2d(
                            in_channels=in_channels,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm1", nn.BatchNorm2d(num_features=features)),
                    (name + "relu1", nn.ReLU(inplace=True)),
                    (
                        name + "conv2",
                        nn.Conv2d(
                            in_channels=features,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm2", nn.BatchNorm2d(num_features=features)),
                    (name + "relu2", nn.ReLU(inplace=True)),
                ]
            )
        )

In [4]:
#For parellel training
model = UNet(in_channels=1, out_channels=3)
model= nn.DataParallel(model)
model = model.to(device)

#Define your loss function and optimizer
criterion = nn.CrossEntropyLoss(weight=torch.tensor([0.85,2.25,1.25])).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [5]:
# Setting the path
path_ctscan = '/home/user2/AIMIA_AmanPawar/ctscan_hw1.mat'
path_mask = '/home/user2/AIMIA_AmanPawar/infmsk_hw1.mat'

# Reading the data
ctscan_data = read_data(path_ctscan)
mask_data = read_data(path_mask)

Normal_mask, Normal_ct, Mild_mask, Mild_ct, Severe_mask, Severe_ct = infection_ratio_cluster(ctscan_data, mask_data)
print(f"Number of Normal CT : {len(Normal_mask)}\nNumber of Mild CT : {len(Mild_mask)}\nNumber of Severe CT : {len(Severe_mask)}")

del ctscan_data, mask_data
gc.collect()

# Preparing the dataset for segementation

# Concatenate the CT and mask images
all_images = np.concatenate((Normal_ct, Mild_ct, Severe_ct), axis=0)
all_masks = np.concatenate((Normal_mask, Mild_mask, Severe_mask), axis=0)

# Normalizing all Masks
all_masks = all_masks/2.
print(f"All Images : {all_images.shape}, Masks : {all_masks.shape}")


transform = transforms.Compose([
    transforms.ToTensor(),  # Convert images to PyTorch tensors
    transforms.Normalize((0.0,), (1.0,))  # Rescale pixel values to the range [0, 1]
])

dataset = CustomImageDataset_seg(all_images, all_masks, transform=transform)

# Data Loaders
batch_size = 256
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
print(f"Total Number of Image in the DataLoader : {data_loader.dataset.__len__()}") 
# Free Memory
del all_images, all_masks
gc.collect()

Reading Data...
Resolution of CT image : (512, 512) 
The number of CT images : 3554

Reading Data...
Resolution of CT image : (512, 512) 
The number of CT images : 3554

Number of Normal CT : 1441
Number of Mild CT : 1954
Number of Severe CT : 159
All Images : (3554, 512, 512), Masks : (3554, 512, 512)
Total Number of Image in the DataLoader : 3554


0

In [6]:
def irc(data_mask):
    """Function to compute the infection Ratio and cluster in 3 groups"""
   
    Normal_mask, Mild_mask, Severe_mask = [], [], []

    for i in data_mask:
        total_pixels = (i > 0).sum()
        infected_pixels = ((i > 0) & (i < 2)).sum()
        infection_ratio = infected_pixels / total_pixels * 100

        # Categorize the slice based on the infection ratio
        if infection_ratio == 0:
            Normal_mask.append(i)
            
        elif 0 < infection_ratio <= 40:
            Mild_mask.append(i)
    
        else:
            Severe_mask.append(i)

    return np.array(Normal_mask), np.array(Mild_mask), np.array(Severe_mask),
            

In [8]:
%%time
# Loading the model:
checkpoint_path = "/home/user2/AIMIA_AmanPawar/model_weights/unet_checkpoint.pth"
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
best_valid_accuracy = checkpoint['best_valid_accuracy']

# Initialize variables
num_classes = 3
Normal_mask_count, Mild_mask_count, Severe_mask_count = 0, 0, 0
data_loader_iter = tqdm(data_loader, desc='Testing', position=0)

# Evaluating the model
model.eval()
with torch.no_grad():
    for i, (images, masks) in enumerate(data_loader_iter):
        images = images.to(device)
        masks = masks.to(device)
       
        outputs = model(images)
        predicted = torch.argmax(outputs*2., 1)

        masks = masks.cpu().numpy()
        predicted = predicted.cpu().numpy()

        Normal_mask, Mild_mask, Severe_mask = irc(predicted)
        
        Normal_mask_count += len(Normal_mask)
        Mild_mask_count += len(Mild_mask)
        Severe_mask_count += len(Severe_mask)

    print(f"Number of Normal CT : {Normal_mask_count}\nNumber of Mild CT : {Mild_mask_count}\nNumber of Severe CT : {Severe_mask_count}")

    

Testing: 100%|██████████| 14/14 [00:17<00:00,  1.27s/it]

Number of Normal CT : 1364
Number of Mild CT : 2045
Number of Severe CT : 145
CPU times: user 8min 20s, sys: 16.4 s, total: 8min 36s
Wall time: 17.8 s



