<a href="https://colab.research.google.com/github/TienTranTrung/PlantMedResearch/blob/master/MedImgClassification__EViT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Install library and import dataset

In [None]:
# Import and Install Dependencies
!pip install -q numpy timm pretrainedmodels gdown==5.1.0
!gdown --id 1k4I5_GUmOcPuxdz_L027xhFZCwibAgiY
!pip install -q pytorch-lightning torchvision albumentations grad-cam
# !git clone https://github.com/TienTranTrung/ASL_Rerverse_translate.git
# !pip install -q pretrainedmodels

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.8/58.8 kB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for pretrainedmodels (setup.py) ... [?25l[?25hdone
Downloading...
From (original): https://drive.google.com/uc?id=1k4I5_GUmOcPuxdz_L027xhFZCwibAgiY
From (redirected): https://drive.google.com/uc?id=1k4I5_GUmOcPuxdz_L027xhFZCwibAgiY&confirm=t&uuid=d4979e55-f8de-4783-b6b6-790d6e4aab13
To: /content/traditional_medicine_dataset.zip
100% 216M/216M [00:02<00:00, 98.4MB/s]
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m802.2/802.2 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m47.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requi

In [None]:
!mkdir data
!unzip -q 'traditional_medicine_dataset' -d data

In [None]:
import torch
import torch.nn as nn
from torchvision.transforms import v2 as T
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, random_split
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
import matplotlib.pyplot as plt
import seaborn as sns
from torch.optim import Adam
import os
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay
from matplotlib.colors import LogNorm

In [None]:
# Define transformations
train_transform = T.Compose([
    T.RandomResizedCrop(224),
    T.RandomHorizontalFlip(),
    T.RandomRotation(90),
    T.ToTensor(),
    T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    # Comment out CutMix and other complex transforms for now
    # T.RandomApply([T.RandomErasing(p=1, scale=(0.02, 0.33), ratio=(0.3, 3.3), value='random')], p=0.5),
    # T.RandomApply([T.CutMix(alpha=1.0, num_classes=200)], p=0.5)
])

val_transform = T.Compose([
    T.Resize(256),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])



In [None]:
# Custom dataset class to apply Albumentations transformations
class VMedDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        image = np.array(image)
        if self.transform:
            image = self.transform(image)
        # Debug: check the type and shape of the image
        if not isinstance(image, torch.Tensor):
            raise TypeError(f"Transformed image is not a tensor, got {type(image)}")
        return image, label

data_dir = 'data/Dataset'

# Create datasets
image_datasets = {x: ImageFolder(os.path.join(data_dir, x)) for x in ['train', 'test']}
train_size = int(0.8 * len(image_datasets['train']))
val_size = len(image_datasets['train']) - train_size
train_dataset, val_dataset = random_split(image_datasets['train'], [train_size, val_size])
train_dataset = VMedDataset(train_dataset, transform=val_transform)
val_dataset = VMedDataset(val_dataset, transform=val_transform)
test_dataset = VMedDataset(image_datasets['test'], transform=val_transform)

dataloaders = {
    'train': DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4),
    'val': DataLoader(val_dataset, batch_size=32, shuffle=True, num_workers=4),
    'test': DataLoader(test_dataset, batch_size=32, shuffle=True, num_workers=4)
}
dataset_sizes = {x: len(dataloaders[x].dataset) for x in ['train', 'val', 'test']}
for images, labels in dataloaders['train']:
    print(f'Image batch shape: {images.size()}')
    print(f'Label batch shape: {labels.size()}')
    break
class_names = image_datasets['train'].classes

  self.pid = os.fork()
  self.pid = os.fork()


Image batch shape: torch.Size([32, 3, 128, 128])
Label batch shape: torch.Size([32])


## Custom model

In [None]:
class SimpleCNN(pl.LightningModule):
    def __init__(self, num_classes):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(128 * 56 * 56, 512)  # Adjust this according to the output size after conv2
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

    def training_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self(inputs)
        loss = F.cross_entropy(outputs, labels)
        return loss

    def validation_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self(inputs)
        val_loss = F.cross_entropy(outputs, labels)
        self.log('val_loss', val_loss)
        return val_loss

    def test_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self(inputs)
        test_loss = F.cross_entropy(outputs, labels)
        self.log('test_loss', test_loss)
        return test_loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return optimizer

num_classes = len(class_names)
model = SimpleCNN(num_classes)

## Mobilenet

In [None]:
class MobileNetV2(pl.LightningModule):
    def __init__(self, **kwargs):
        super(MobileNetV2, self).__init__()
        self.model = models.mobilenet_v2(pretrained=True, **kwargs)
        self.model.classifier = nn.Sequential()

    def forward(self, x):
        x = self.model.features(x)
        return x

class LocationWiseSoftAttentionModule(pl.LightningModule):
    def __init__(self, in_channels):
        super(LocationWiseSoftAttentionModule, self).__init__()
        self.query_conv = nn.Conv2d(in_channels, in_channels // 8, 1)
        self.key_conv = nn.Conv2d(in_channels, in_channels // 8, 1)
        self.value_conv = nn.Conv2d(in_channels, in_channels, 1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        proj_query = self.query_conv(x).view(x.shape[0], -1, x.shape[2]*x.shape[3]).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(x.shape[0], -1, x.shape[2]*x.shape[3])
        attention_scores = torch.bmm(proj_query, proj_key)
        attention_map = torch.softmax(attention_scores, dim=-1)
        proj_value = self.value_conv(x).view(x.shape[0], x.shape[2]*x.shape[3], -1)
        weighted = torch.bmm(attention_map, proj_value).view(x.shape)
        return x + self.gamma * weighted

class DepthwiseSeparableConv(pl.LightningModule):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1):
        super(DepthwiseSeparableConv, self).__init__()
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, dilation, groups=in_channels)
        self.pointwise = nn.Conv2d(in_channels, out_channels, 1)

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x

class MedPlantClassifier(pl.LightningModule):
    def __init__(self, num_classes):
        super(MedPlantClassifier, self).__init__()
        self.mobilenetv2 = MobileNetV2()

        self.lsam = LocationWiseSoftAttentionModule(in_channels=1280)
        self.conv1 = DepthwiseSeparableConv(1280, 256, 3, padding=1)  # Ensure padding is sufficient
        self.conv2 = DepthwiseSeparableConv(256, 128, 3, padding=1)
        self.classifier = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.mobilenetv2(x)
        x = self.lsam(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.mean([2, 3])
        x = self.classifier(x)
        return x

    def training_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self(inputs)
        loss = F.cross_entropy(outputs, labels)
        return loss

    def validation_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self(inputs)
        val_loss = F.cross_entropy(outputs, labels)
        self.log('val_loss', val_loss)
        return val_loss

    def test_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self(inputs)
        test_loss = F.cross_entropy(outputs, labels)
        self.log('test_loss', test_loss)
        return test_loss

    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=0.001)
        return optimizer
num_classes = len(class_names)
model = MedPlantClassifier(num_classes)

In [None]:
# Callbacks for checkpointing and early stopping
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath='checkpoints',
    filename='best-checkpoint',
    save_top_k=1,
    mode='min'
)

early_stopping_callback = EarlyStopping(
    monitor='val_loss',
    patience=5,
    verbose=True,
    mode='min'
)

# Training the model
trainer = pl.Trainer(
    max_epochs=20,
    # gpus=1 if torch.cuda.is_available() else 0,
    callbacks=[checkpoint_callback, early_stopping_callback]
)

trainer.fit(model, dataloaders['train'], dataloaders['val'])

In [None]:
# Load the best model for testing
best_model_path = checkpoint_callback.best_model_path
trained_model = MedPlantClassifier.load_from_checkpoint(best_model_path, num_classes=num_classes)

# Test the model
trainer.test(trained_model, dataloaders['test'])

# Visualize the results
def plot_results(model, dataloader, class_names, top_n_classes=20):
    model.eval()
    all_preds = []
    all_labels = []
    for inputs, labels in dataloader:
        inputs, labels = inputs.to(model.device), labels.to(model.device)
        with torch.no_grad():
            outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        all_preds.append(preds.cpu().numpy())
        all_labels.append(labels.cpu().numpy())

    all_preds = np.concatenate(all_preds)
    all_labels = np.concatenate(all_labels)

    # Plot confusion matrix
    cm = confusion_matrix(all_labels, all_preds)

    # Display the full confusion matrix
    fig, ax = plt.subplots(figsize=(20, 20))  # Adjusted size
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
    disp.plot(ax=ax, xticks_rotation='vertical', cmap='viridis')
    plt.title('Confusion Matrix')
    plt.tight_layout()
    plt.show()

    # Optionally, display top N misclassifications
    if top_n_classes < len(class_names):
        misclassifications = np.argsort(-cm.sum(axis=1))[:top_n_classes]
        cm_top_n = cm[misclassifications][:, misclassifications]
        class_names_top_n = [class_names[i] for i in misclassifications]

        fig, ax = plt.subplots(figsize=(10, 10))  # Smaller size for top misclassifications
        disp_top_n = ConfusionMatrixDisplay(confusion_matrix=cm_top_n, display_labels=class_names_top_n)
        disp_top_n.plot(ax=ax, xticks_rotation='vertical', cmap='viridis', norm=LogNorm())
        plt.title(f'Top {top_n_classes} Misclassifications')
        plt.tight_layout()
        plt.show()

    # Print classification report
    print(classification_report(all_labels, all_preds, target_names=class_names))

plot_results(trained_model, dataloaders['test'], class_names)

In [None]:
# Explainability and Interpretability using Grad-CAM
def apply_gradcam(model, dataloader, target_layer, class_names):
    model.eval()
    # Ensure the model's device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    # Initialize GradCAM without 'use_cuda'
    grad_cam = GradCAM(model=model, target_layers=[target_layer])

    for inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)
        with torch.no_grad():
            outputs = model(inputs)
        _, preds = torch.max(outputs, 1)

        for i in range(len(inputs)):
            input_img = inputs[i].cpu().numpy().transpose((1, 2, 0))
            input_img = (input_img - input_img.min()) / (input_img.max() - input_img.min())

            # Get the GradCAM result
            grayscale_cam = grad_cam(input_tensor=inputs[i].unsqueeze(0))
            visualization = show_cam_on_image(input_img, grayscale_cam[0], use_rgb=True)

            plt.imshow(visualization)
            plt.title(f"Predicted: {class_names[preds[i]]}, Actual: {class_names[labels[i]]}")
            plt.show()

# Ensure that 'model.classifier' points to the correct layer
target_layer = trained_model.classifier
apply_gradcam(trained_model, dataloaders['test'], target_layer, class_names)