In [None]:
pip install grad-cam

In [None]:
import os

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
from PIL import Image, ImageDraw

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.metrics import classification_report, confusion_matrix

import cv2
import albumentations as A

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, Dataset
from torchvision.models import resnet18, ResNet18_Weights

from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

## Data preparation and analysis

**Load data**

In [None]:
train_path = '/kaggle/input/brain-ct-medical-imaging-colorized-dataset/Computed Tomography (CT) of the Brain/dataset/train'
test_path = '/kaggle/input/brain-ct-medical-imaging-colorized-dataset/Computed Tomography (CT) of the Brain/dataset/test'

**Function сreates a labeled DataFrame that maps image file paths to their corresponding class labels.**

In [None]:
def get_df(path):
    
    classes = os.listdir(path)
    data = []

    for cls in classes:
        class_folder = os.path.join(path, cls)

        if not os.path.isdir(class_folder):
            continue

        for fname in os.listdir(class_folder):
            if fname.lower().endswith('.jpg'):
                img_path = os.path.join(class_folder, fname)
                data.append({
                    'img_path': img_path,
                    'target': cls                    
                })

    return pd.DataFrame(data)

In [None]:
train, test = get_df(train_path), get_df(test_path)

In [None]:
class_names = ['aneurysm', 'cancer', 'tumor']
class_to_idx = {cls_name: idx for idx, cls_name in enumerate(class_names)}
#{'aneurysm': 0, 'cancer': 1, 'tumor': 2}

train['target_encoded'] = train['target'].map(class_to_idx)
test['target_encoded'] = test['target'].map(class_to_idx)

**Class distributions**

In [None]:
print('Class distributions:')
print(train['target'].value_counts())

The data is evenly distributed. It's good!.

**Data visualizations**

For example:
First image in train

In [None]:
def show_image(image_input):
    """
    Universal image display function.
    Accepts either a path to an image (str) or a torch.Tensor.
    """
    if isinstance(image_input, str):
        image = cv2.imread(image_input)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    elif isinstance(image_input, torch.Tensor):          
        image = image_input.detach().cpu().numpy()  # (3, H, W)
        image = np.transpose(image, (1, 2, 0)) 
    else:
        raise TypeError('Input must be a file path (str) or a torch.Tensor.')

    plt.imshow(image)
    plt.axis('off')
    plt.show()

In [None]:
show_image(train.img_path.iloc[0])

**Image grid by class**

In [None]:
def plot_grid_by_class(df, num_images_per_class=3, random_state=777):
    """
    Displays a grid of sample images from each class.
    Each column corresponds to a class; each row is an image from that class.
    """
    classes = sorted(df['target'].unique())
    num_classes = len(classes)
    
    samples = [
        df[df['target'] == cls].sample(num_images_per_class, random_state=random_state)
        for cls in classes
    ]
    sample_df = pd.concat(samples, keys=classes).reset_index(drop=True)

    plt.figure(figsize=(3 * num_classes, 3 * num_images_per_class))

    for col_idx, cls in enumerate(classes):
        class_samples = sample_df[sample_df['target'] == cls].reset_index(drop=True)
        for row_idx in range(num_images_per_class):
            i = row_idx * num_classes + col_idx 
            img_path = class_samples.loc[row_idx, 'img_path']
            img = Image.open(img_path)

            plt.subplot(num_images_per_class, num_classes, i + 1)
            plt.imshow(img)
            plt.title(f'Class: {cls}')
            plt.axis('off')
            
    plt.suptitle('Image grid by class', fontsize=16)
    plt.tight_layout()
    plt.show()


In [None]:
plot_grid_by_class(train)

In many examples, the images show similar artifacts (as illustrated below). Additionally, the scans are not uniform — they were taken using different machines and with varying CT parameters.

In [None]:
img = Image.open(train.img_path.iloc[0])

draw1 = ImageDraw.Draw(img)
draw2 = ImageDraw.Draw(img)

left, top = 390, 240
right, bottom =img.width - 30, img.height - 20
draw1.rectangle([left, top, right, bottom], outline='red', width=3)

left, top = 10, 240
right, bottom =img.width - 460, img.height - 120
draw1.rectangle([left, top, right, bottom], outline='red', width=3)


plt.title('Example of artifacts ')
plt.imshow(img)
plt.axis('off')
plt.show()

**Data preprocessing**

Train-val split

In [None]:
train_df, val_df = train_test_split(
    train, 
    test_size=0.15, 
    stratify=train['target'],  
    random_state=777
)

In [None]:
print(f"Train_df shape: {train_df.shape[0]}")
print(f"Val_df shape: {val_df.shape[0]}")
print(f"test shape: {test.shape[0]}")

At this stage we have 3 data split:
1. train_df 85%-train
2. val_df 15%-train
3. test 144 images

**Preprocess function**
* Load img
* Convert BGR to RGB
* Resize
* Augmentation (rotate, flip, contrast change)
* Normalize pixels

Since the images are diverse in nature, augmentation will provide the model with greater variability during training. This will help during testing if the images are slightly rotated, have different contrast or brightness, or are flipped.

In [None]:
def preprocess_image(img_path, size=224, augment=False):
    image = cv2.imread(img_path)
    
    if image is None:
        raise ValueError(f'Image not found at path: {img_path}')
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Convert BGR to RGB
    image = cv2.resize(image, (size, size)) #Resize to (size x size)
    # Apply augmentation if requested
    if augment:
        transform = A.Compose([
            A.Rotate(limit=15, p=1.0),
            A.HorizontalFlip(p=0.5),
            A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.5),
        ])
        augmented = transform(image=image)
        image = augmented['image']

    image = image.astype(np.float32) / 255.0 

    # ImageNet mean/std
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    image = (image - mean) / std

    # transpose in (C,H,W)
    image = np.transpose(image, (2, 0, 1))

    return torch.tensor(image, dtype=torch.float32)


Examples: 
1. Image before processing
2. Image after processing without augmentation
3. Image after processing with augmentation
 

In [None]:
plt.title('Image before processing')
show_image(train_df.img_path.iloc[0])

In [None]:
plt.title('Image after processing without augmentation')
show_image(preprocess_image(train_df.img_path.iloc[0]))

In [None]:
plt.title('Image after processing with augmentation')
show_image(preprocess_image(train_df.img_path.iloc[0],augment=True))

## Model building and training

In [None]:
class CustomImageDataset(Dataset):
    def __init__(self, df, augment=False):
        """
        Args:
            df (pd.DataFrame): DataFrame with 'img_path' and 'target' columns
            augment (bool): Whether to apply augmentation
        """
        self.img_paths = df['img_path'].values
        self.labels = df['target_encoded'].values
        self.augment = augment

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        label = self.labels[idx]
        image = preprocess_image(img_path, augment=self.augment)
        return image, label

In [None]:
# Split into train/val beforehand → train_df, val_df

train_dataset = CustomImageDataset(train_df, augment=True)
val_dataset   = CustomImageDataset(val_df, augment=False)
test_dataset = CustomImageDataset(test, augment = False)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=2, pin_memory=True)
test_loader   = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2, pin_memory=True)

In [None]:
# Bilding model. Add Dropout 50% + output 3 layers

class ResNet18Custom(nn.Module):
    def __init__(self, dropout_p=0.5, weights=ResNet18_Weights.DEFAULT, num_classes=3):
        super(ResNet18Custom, self).__init__()
        self.base_model = resnet18(weights=weights)
        in_features = self.base_model.fc.in_features
        self.base_model.fc = nn.Sequential(
            nn.Dropout(dropout_p),
            nn.Linear(in_features, num_classes)
        )

    def forward(self, x):
        return self.base_model(x)


In [None]:
# EarlyStopping Class

class EarlyStopping:
    def __init__(self, patience=5, verbose=True, delta=0):
        self.patience = patience
        self.verbose = verbose
        self.delta = delta
        self.best_loss = None
        self.counter = 0
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
            if self.verbose:
                print(f'Initial val loss: {val_loss:.4f}')
        elif val_loss > self.best_loss - self.delta:
            self.counter += 1
            if self.verbose:
                print(f'No improvement in val loss for {self.counter} epoch(s)')
            if self.counter >= self.patience:
                if self.verbose:
                    print(f'Early stopping triggered after {self.patience} epochs without improvement.')
                self.early_stop = True
        else:
            if self.verbose:
                print(f'Validation loss improved from {self.best_loss:.4f} to {val_loss:.4f}')
            self.best_loss = val_loss
            self.counter = 0


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = ResNet18Custom(dropout_p=0.5, weights=ResNet18_Weights.DEFAULT).to(device)
model = torch.nn.DataParallel(model)  #Parallel train on 2 device T4

optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

early_stopping = EarlyStopping(patience=5, verbose=True) #EarlyStopping 5

In [None]:
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []

num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    all_train_preds = []
    all_train_targets = []
    total = 0

    # Training
    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * inputs.size(0)
        total += inputs.size(0)

        preds = torch.argmax(outputs, dim=1)
        all_train_preds.extend(preds.cpu().numpy())
        all_train_targets.extend(targets.cpu().numpy())

    train_loss /= total
    train_acc = accuracy_score(all_train_targets, all_train_preds)

    # Validation
    model.eval()
    val_loss = 0
    all_val_preds = []
    all_val_targets = []
    total_val = 0

    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            val_loss += loss.item() * inputs.size(0)
            total_val += inputs.size(0)

            preds = torch.argmax(outputs, dim=1)
            all_val_preds.extend(preds.cpu().numpy())
            all_val_targets.extend(targets.cpu().numpy())

    val_loss /= total_val
    val_acc = accuracy_score(all_val_targets, all_val_preds)

    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accuracies.append(train_acc)
    val_accuracies.append(val_acc)

    print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')

    early_stopping(val_loss)
    if early_stopping.early_stop:
        print('Stopping training early')
        break


In [None]:
# Test metrics
model.eval()
all_preds = []
all_targets = []

with torch.no_grad():
    for inputs, targets in test_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        preds = torch.argmax(outputs, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_targets.extend(targets.cpu().numpy())

test_acc = accuracy_score(all_targets, all_preds)
print(f'Test Accuracy: {test_acc:.4f}')

In [None]:
print(confusion_matrix(all_targets, all_preds)) # Сonfusion_matrix
print(classification_report(all_targets, all_preds, target_names=class_names))

Wow! Zero errors!

**Loss and accuracy plots**

In [None]:
epochs = range(1, len(train_losses) + 1)

plt.figure(figsize=(12, 5))

# Loss
plt.subplot(1, 2, 1)
plt.plot(epochs, train_losses, label='Train Loss')
plt.plot(epochs, val_losses, label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss over Epochs')
plt.legend()

# Accuracy
plt.subplot(1, 2, 2)
plt.plot(epochs, train_accuracies, label='Train Accuracy')
plt.plot(epochs, val_accuracies, label='Val Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Accuracy over Epochs')
plt.legend()

plt.tight_layout()
plt.show()

the plots show that up to the 2nd epoch there is a sharp improvement in metrics, then the metrics remain around the same level.

## Analysis of results

Heatmap (Grad-CAM) for key cases

In [None]:
img_path = val_df.sample(1).iloc[0]['img_path']
# Prepare input tensor
input_tensor = preprocess_image(img_path).unsqueeze(0).to(device)

# Prepare original image for overlay
img_np = cv2.imread(img_path)
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)
img_np = cv2.resize(img_np, (224, 224))
img_np = img_np.astype(np.float32) / 255.0

# Choose target layer (last conv block of resnet18)
target_layer = model.module.base_model.layer4[-1]

# Create GradCAM object
cam = GradCAM(model=model, target_layers=[target_layer])

model.eval()
with torch.no_grad():
    output = model(input_tensor)
    pred_class = output.argmax(dim=1).item()

# Get heatmap
grayscale_cam = cam(input_tensor=input_tensor, targets=[ClassifierOutputTarget(pred_class)])
grayscale_cam = grayscale_cam[0]

# Overlay heatmap
visualization = show_cam_on_image(img_np, grayscale_cam, use_rgb=True)

# Show result
plt.figure(figsize=(6, 6))
plt.imshow(visualization)
plt.title(f'Predicted class: {pred_class}')
plt.axis('off')
plt.show()


Visualize gradcam grid 3*3 

In [None]:
def visualize_gradcam_grid(model, df, device):
    """
    display a 3x3 grid of Grad-CAM visualizations.

    for each of the three target classes (0, 1, 2), the function randomly selects
    three images from the provided DataFrame and visualizes the Grad-CAM heatmap
    for the specified class.
    
    """
    model.eval()
    fig, axs = plt.subplots(3, 3, figsize=(15, 15))

    for class_idx in range(3):
        samples = df[df['target_encoded'] == class_idx].sample(3).reset_index(drop=True)
        for i in range(3):
            img_path = samples.loc[i, 'img_path']

            # Preprocess
            input_tensor = preprocess_image(img_path).unsqueeze(0).to(device)

            img_np = cv2.imread(img_path)
            img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)
            img_np = cv2.resize(img_np, (224, 224))
            img_np = img_np.astype(np.float32) / 255.0

            target_layer = model.module.base_model.layer4[-1]
            cam = GradCAM(model=model, target_layers=[target_layer])

            grayscale_cam = cam(input_tensor=input_tensor, targets=[ClassifierOutputTarget(class_idx)])
            grayscale_cam = grayscale_cam[0]

            visualization = show_cam_on_image(img_np, grayscale_cam, use_rgb=True)

            ax = axs[class_idx, i]
            ax.imshow(visualization)
            ax.set_title(f'Class {class_idx}')
            ax.axis('off')

    plt.tight_layout()
    plt.show()


In [None]:
visualize_gradcam_grid(model, test, device)

Conclusion

If the model achieves 100% accuracy on the test set, it is worth questioning its practical applicability.

As seen in the heatmap:
For class 0, the model’s attention is focused on specific artifacts typical for this class (see example above).
For classes 1 and 2, the diagnostic relevance is slightly better — the attention shifts toward the posterior part of the image and varies at times.

Suggestions for improvement:

1. Increase the size of the test set

2. Group images by patient

3. Reduce image artifacts

4. Standardize the images