<a href="https://www.kaggle.com/code/sarthaksshukla/medical-mnist-using-se-resnet-99-8-accuracy?scriptVersionId=181265892" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [1]:
!pip install torch-summary

Collecting torch-summary
  Downloading torch_summary-1.4.5-py3-none-any.whl.metadata (18 kB)
Downloading torch_summary-1.4.5-py3-none-any.whl (16 kB)
Installing collected packages: torch-summary
Successfully installed torch-summary-1.4.5


In [2]:
import torch
import os
import shutil
import random
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from fastprogress import master_bar,progress_bar
from tqdm import tqdm
from torchsummary import summary

## Building the model

In [3]:
class Conv2d(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size = 3):
        super().__init__()
        self.model = nn.Sequential(*[
            nn.Conv2d(in_channels,out_channels,kernel_size = kernel_size,stride = 1,padding = (kernel_size - 1) // 2),
            nn.BatchNorm2d(num_features = out_channels),
            nn.ReLU()
        ])
    
    def forward(self,x):
        return self.model(x)

In [4]:
class Dense(nn.Module):
    def __init__(self,in_features,out_features,activation = nn.ReLU()):
        super().__init__()
        self.model = nn.Sequential(*[
            nn.Linear(in_features,out_features),
            activation
        ])
    
    def forward(self,x):
        return self.model(x)

In [51]:
class ScoreBlock(nn.Module):
    def __init__(self,channels):
        super().__init__()
        self.model = nn.Sequential(*[
            nn.AdaptiveAvgPool3d((1,1,channels)),
            nn.Flatten(),
            Dense(channels,channels),
            nn.Dropout(0.25),
            Dense(channels,channels),
            nn.Dropout(0.25),
            nn.Sigmoid()
        ])
    
    def forward(self,x):
        score = self.model(x)
        return torch.reshape(score,(*score.shape,1,1)) * x

In [52]:
class ConvStack(nn.Module):
    def __init__(self,in_channels,num_conv = 3):
        super().__init__()
        self.model = nn.Sequential(*[
            Conv2d(in_channels,in_channels // 2,kernel_size = 1),
            *[Conv2d(in_channels // 2,in_channels // 2) for _ in range(num_conv - 1)],
            Conv2d(in_channels // 2,2 * in_channels),
            nn.Dropout(0.25)
        ])
        self.out_channels = in_channels * 2
        self.add_helper = Conv2d(in_channels,self.out_channels,kernel_size = 1)
        self.score_block = ScoreBlock(self.out_channels)
    
    def forward(self,x):
        return F.relu(self.score_block(self.model(x) + self.add_helper(x)))

In [21]:
class Pool(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(*[
            nn.AvgPool2d(kernel_size = 2,stride = 2),
            nn.ReLU(),
        ])
    
    def forward(self,x):
        return self.model(x)

In [56]:
class FCN(nn.Module):
    def __init__(self,out_channels,num_classes,units = 4096):
        super().__init__()
        self.model = nn.Sequential(*[
            nn.Flatten(),
            Dense(out_channels,units),
            Dense(units,units),
            Dense(units,units),
            Dense(units,num_classes,activation = nn.Softmax(dim = -1))
        ])
    
    def forward(self,x):
        return self.model(x)

In [23]:
class Resnet(nn.Module):
    def __init__(self,image_size,num_classes,units = 4096,num_conv = 5):
        super().__init__()
        height,width = image_size
        channels = 3
        self.model = []
        while height > 1:
            self.model.extend([
                ConvStack(channels,num_conv),
                Pool()
            ])
            channels = self.model[-2].out_channels
            height //= 2
        self.model.append(FCN(channels,num_classes,units))
        self.model = nn.Sequential(*self.model)
    
    def forward(self,x):
        return self.model(x)

# Handling the data

In [24]:
image_data_path = "/kaggle/input/medical-mnist"

In [25]:
classes = os.listdir(image_data_path)
print(classes)

['AbdomenCT', 'BreastMRI', 'Hand', 'CXR', 'HeadCT', 'ChestCT']


In [26]:
model_data_parent_dir = "/kaggle/data"
train_path = os.path.join(model_data_parent_dir,"train")
valid_path = os.path.join(model_data_parent_dir,"valid")
test_path = os.path.join(model_data_parent_dir,"test")

In [27]:
def create_paths(parent_dir,classes):
    types = ['train','valid','test']
    if os.path.exists(parent_dir) is False:
        os.mkdir(parent_dir)
        
        for t in types:
            type_path = os.path.join(parent_dir,t)
            os.mkdir(type_path)
            print(f"{type_path} has been constructed")
            for target in classes:
                target_path = os.path.join(type_path,target)
                os.mkdir(target_path)
                print(f"{target_path} has been constructed")

In [28]:
create_paths(model_data_parent_dir,classes)

/kaggle/data/train has been constructed
/kaggle/data/train/AbdomenCT has been constructed
/kaggle/data/train/BreastMRI has been constructed
/kaggle/data/train/Hand has been constructed
/kaggle/data/train/CXR has been constructed
/kaggle/data/train/HeadCT has been constructed
/kaggle/data/train/ChestCT has been constructed
/kaggle/data/valid has been constructed
/kaggle/data/valid/AbdomenCT has been constructed
/kaggle/data/valid/BreastMRI has been constructed
/kaggle/data/valid/Hand has been constructed
/kaggle/data/valid/CXR has been constructed
/kaggle/data/valid/HeadCT has been constructed
/kaggle/data/valid/ChestCT has been constructed
/kaggle/data/test has been constructed
/kaggle/data/test/AbdomenCT has been constructed
/kaggle/data/test/BreastMRI has been constructed
/kaggle/data/test/Hand has been constructed
/kaggle/data/test/CXR has been constructed
/kaggle/data/test/HeadCT has been constructed
/kaggle/data/test/ChestCT has been constructed


In [29]:
def move_data(src_dir,tgt_dir,mode,size):
    for target in os.listdir(src_dir):
        target_dir = os.path.join(src_dir,target)
        target_files = random.sample(os.listdir(target_dir),size)
        dest_dir = os.path.join(tgt_dir,target)
        print(f"Moving {mode} images for class: {target}")
        
        for target_file in tqdm(target_files):
            target_file_path = os.path.join(target_dir,target_file)
            shutil.copy(target_file_path,dest_dir)

In [30]:
move_data(image_data_path,train_path,"train",6000)

Moving train images for class: AbdomenCT


100%|██████████| 6000/6000 [00:27<00:00, 222.15it/s]


Moving train images for class: BreastMRI


100%|██████████| 6000/6000 [00:25<00:00, 236.61it/s]


Moving train images for class: Hand


100%|██████████| 6000/6000 [00:26<00:00, 226.44it/s]


Moving train images for class: CXR


100%|██████████| 6000/6000 [00:27<00:00, 220.73it/s]


Moving train images for class: HeadCT


100%|██████████| 6000/6000 [00:26<00:00, 226.62it/s]


Moving train images for class: ChestCT


100%|██████████| 6000/6000 [00:25<00:00, 231.06it/s]


In [31]:
move_data(image_data_path,valid_path,"valid",2000)

Moving valid images for class: AbdomenCT


100%|██████████| 2000/2000 [00:04<00:00, 434.01it/s]


Moving valid images for class: BreastMRI


100%|██████████| 2000/2000 [00:04<00:00, 499.53it/s]


Moving valid images for class: Hand


100%|██████████| 2000/2000 [00:04<00:00, 427.07it/s]


Moving valid images for class: CXR


100%|██████████| 2000/2000 [00:04<00:00, 427.42it/s]


Moving valid images for class: HeadCT


100%|██████████| 2000/2000 [00:04<00:00, 427.86it/s]


Moving valid images for class: ChestCT


100%|██████████| 2000/2000 [00:04<00:00, 429.76it/s]


In [32]:
image_size = (96,96)
batch_size = 32
epochs = 10
steps_per_epoch = 1125
valid_epochs = 15
valid_steps_per_epoch = 200
device = "cuda" if torch.cuda.is_available() is True else "cpu"
print(device)

cuda


In [33]:
train_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(size = image_size,antialias = True),
    transforms.Normalize(mean = [0.5],std = [0.5]),
    transforms.RandomHorizontalFlip()
])

valid_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean = [0.5],std = [0.5]),
    transforms.Resize(size = image_size,antialias = True),
])

In [34]:
train_dataset = ImageFolder(root = train_path,transform = train_transform)
valid_dataset = ImageFolder(root = valid_path,transform = valid_transform)

In [35]:
train_dataloader = DataLoader(dataset = train_dataset,batch_size = batch_size,shuffle = True,
                             drop_last = True)

valid_dataloader = DataLoader(dataset = valid_dataset,batch_size = batch_size,drop_last = True)

# Trainer

In [36]:
class Trainer:
    def __init__(self,model,device,train_dataloader,valid_dataloader):
        self.model = model.to(device)
        self.device = device
        self.train_dataloader = train_dataloader
        self.valid_dataloader = valid_dataloader
    
    def train(self,epochs,loss_fn = None,optimizer = None):
        loss_fn = nn.CrossEntropyLoss() if loss_fn is None else loss_fn
        optimizer = torch.optim.Adam(self.model.parameters(),lr = 4.236429595039226e-05) if optimizer is None else optimizer
        
        parent_bar = master_bar(range(epochs))
        for epoch in parent_bar:
            train_batch_losses,train_batch_acc = [],[]
            train_samples = 0
            child_bar = progress_bar(self.train_dataloader,parent = parent_bar)
            self.model.train()
            
            for train_images,train_labels in child_bar:
                train_images,train_labels = train_images.to(self.device),train_labels.to(self.device)
                optimizer.zero_grad(set_to_none = True)
                probs = self.model(train_images)
                loss = loss_fn(probs,train_labels)
                train_batch_losses.append(round(loss.data.item() * train_images.shape[0],5))
                _,labels = torch.max(probs,dim = 1)
                train_samples += train_images.shape[0]
                match_tensor = (labels == train_labels).float()
                train_batch_acc.append(round(float(match_tensor.sum().item()),5))
                loss.backward()
                optimizer.step()
                parent_bar.child.comment = f"Train loss: {train_batch_losses[-1] / train_images.shape[0]}, train acc: {train_batch_acc[-1] / train_labels.shape[0]}"
            
            train_batch_losses = round(sum(train_batch_losses) / train_samples,5)
            train_batch_acc = round(sum(train_batch_acc) / train_samples,5)
            
            valid_batch_losses,valid_batch_acc = [],[]
            valid_samples = 0
            child_bar = progress_bar(self.valid_dataloader,parent = parent_bar)
            self.model.eval()
            
            with torch.no_grad():
                for valid_images,valid_labels in child_bar:
                    valid_images,valid_labels = valid_images.to(self.device),valid_labels.to(self.device)
                    probs = self.model(valid_images)
                    loss = loss_fn(probs,valid_labels)
                    valid_batch_losses.append(round(loss.item() * valid_images.shape[0],5))
                    _,labels = torch.max(probs,dim = 1)
                    valid_samples += valid_images.shape[0]
                    match_tensor = (labels == valid_labels).float()
                    valid_batch_acc.append(round(float(match_tensor.sum().item()),5))
                    parent_bar.child.comment = f"Valid loss: {valid_batch_losses[-1] / valid_images.shape[0]}, valid acc: {valid_batch_acc[-1] / valid_images.shape[0]}"
            
            valid_batch_losses = round(sum(valid_batch_losses) / valid_samples,5)
            valid_batch_acc = round(sum(valid_batch_acc) / valid_samples,5)
            
            parent_bar.write(f"Epoch: {epoch + 1} / {epochs} -> Train loss: {train_batch_losses}, train acc: {train_batch_acc}, valid loss: {valid_batch_losses}, valid acc: {valid_batch_acc}")

In [57]:
model = Resnet(image_size = image_size,num_classes = len(classes),num_conv = 3)

In [58]:
trainer = Trainer(model,device,train_dataloader,valid_dataloader)

In [59]:
trainer.train(epochs = epochs)