In [None]:
#libraries
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
from sklearn.model_selection import train_test_split
import numpy as np
from transformers import ViTForImageClassification, ViTConfig
from torchmetrics import CohenKappa
from collections import Counter
import random
import pandas as pd
from tqdm import tqdm
import torch.optim as optim
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report
import matplotlib.pyplot as plt
from transformers import ViTForImageClassification, ViTImageProcessor

In [None]:
# classes
classes = ['glioma tumor', 'meningioma tumor', 'pituitary tumor', 'no tumor']
num_classes = len(classes)

In [None]:
# black and white image dataset
class BrainTumorDataset(Dataset):
    def __init__(self, file_paths, labels, transform=None):
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        image_path = self.file_paths[idx]
        label = self.labels[idx]
        image = Image.open(image_path).convert("L")
        if self.transform:
            image = self.transform(image)
        return image, label


def pad_data(file_paths, labels, batch_size):
    remainder = len(file_paths) % batch_size
    if remainder == 0:
        return file_paths, labels
    num_to_add = batch_size - remainder
    padded_paths = file_paths + random.choices(file_paths, k=num_to_add)
    padded_labels = labels + random.choices(labels, k=num_to_add)
    return padded_paths, padded_labels


# calculate the mean and standard deviation
def calculate_mean_std_grayscale(dataset_paths):
    pixel_sum = 0.
    pixel_squared_sum = 0.
    total_pixels = 0

    for img_path in dataset_paths:
        img = Image.open(img_path).convert("L")
        img_array = np.array(img) / 255.0
        pixels = img_array.flatten()

        pixel_sum += pixels.sum()
        pixel_squared_sum += np.square(pixels).sum()
        total_pixels += pixels.size

    mean = pixel_sum / total_pixels
    std = np.sqrt(pixel_squared_sum / total_pixels - mean ** 2)

    return [mean], [std]

In [None]:
# dataset path
data_dir = "./Training"
file_paths, labels = [], []
for label, class_name in enumerate(classes):
    class_dir = os.path.join(data_dir, class_name)
    if os.path.exists(class_dir):
        for file_name in os.listdir(class_dir):
            file_paths.append(os.path.join(class_dir, file_name))
            labels.append(label)

test_dir = "./Testing"
test_file_paths, test_labels = [], []
for label, class_name in enumerate(classes):
    class_dir = os.path.join(test_dir, class_name)
    for file_name in os.listdir(class_dir):
        test_file_paths.append(os.path.join(class_dir, file_name))
        test_labels.append(label)

# data splitting
train_paths, val_paths, train_labels, val_labels = train_test_split(
    file_paths, labels, test_size=0.2, random_state=42)

# mean and standard deviation for this dataset
mean, std = calculate_mean_std_grayscale(train_paths)
print("Mean:", mean, "Std:", std)

In [None]:
# augmentation
transform_grayscale = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
])
augmentation_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(15),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
])

In [None]:
# Class distribution before data augmentation
class_counts_before = Counter(train_labels)
for class_label, count in class_counts_before.items():
    print(f"class {class_label}: {count}")
# Performing data augmentation for classes with fewer samples
max_count = max(class_counts_before.values())

augmented_paths = []
augmented_labels = []

for class_label, count in class_counts_before.items():
    if count < max_count:
        class_indices = [i for i, label in enumerate(train_labels) if label == class_label]
        class_paths = [train_paths[i] for i in class_indices]

        num_to_generate = max_count - count
        for _ in range(num_to_generate):
            original_path = random.choice(class_paths)
            image = Image.open(original_path).convert('L')
            augmented_image = augmentation_transform(image).unsqueeze(0)

            augmented_paths.append(original_path)
            augmented_labels.append(class_label)

In [None]:
# concatenating raw and augmented
balanced_train_paths = train_paths + augmented_paths
balanced_train_labels = train_labels + augmented_labels

val_dataset_grayscale = BrainTumorDataset(val_paths, val_labels, transform=transform_grayscale)
val_loader_grayscale = DataLoader(val_dataset_grayscale, batch_size=32, shuffle=False)
test_dataset_grayscale = BrainTumorDataset(test_file_paths, test_labels, transform=transform_grayscale)
test_loader_grayscale = DataLoader(test_dataset_grayscale, batch_size=32, shuffle=False)

padded_train_paths, padded_train_labels = pad_data(balanced_train_paths, balanced_train_labels, batch_size=32)
balanced_train_dataset_grayscale = BrainTumorDataset(padded_train_paths, padded_train_labels,
                                                              transform=augmentation_transform)  # استفاده از داده‌افزایی در آموزش
balanced_train_loader_grayscale = DataLoader(balanced_train_dataset_grayscale, batch_size=32, shuffle=True)

In [None]:
train_losses = []
val_losses = []
test_losses = []
train_accuracies = []
val_accuracies = []
test_accuracies = []

In [None]:
# load the model
model_name = "google/vit-base-patch16-224"
model = ViTForImageClassification.from_pretrained(model_name, num_labels=num_classes, ignore_mismatched_sizes=True)
processor = ViTImageProcessor.from_pretrained(model_name)

# modify the model match the number of classes
original_conv = model.vit.embeddings.patch_embeddings.projection
new_conv = nn.Conv2d(1, original_conv.out_channels, kernel_size=original_conv.kernel_size,
                     stride=original_conv.stride, padding=original_conv.padding, bias=False)

with torch.no_grad():
    weight_mean = original_conv.weight.mean(dim=1, keepdim=True)
    new_conv.weight.copy_(weight_mean)

model.vit.embeddings.patch_embeddings.projection = new_conv

# model features
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
learning_rate = 1e-4
num_classes = 4
model.config.image_size = 224
model.config.num_channels = 1
model.vit.embeddings.patch_embeddings.num_channels = 1
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
best_val_loss = float('inf')
best_epoch = 0
num_epochs = 100
patience = 5
patience_counter = 0
model = model.to(device)
print(model.vit.embeddings.patch_embeddings.projection)
print(model.config.num_channels)
print(model.vit.config.num_channels)
print(model.vit.embeddings.patch_embeddings.num_channels)
print(model.vit.embeddings.patch_embeddings.projection)
print(model.config.num_channels)
print(model)

In [None]:
# evaluate the model on the validation data
def eval_model(model, data_loader, criterion, device='cuda'):
    model.eval()
    total, correct = 0., 0.
    val_loss = 0.
    with torch.no_grad():
        for data, targets in data_loader:
            x_b, y_b = data.to(device), targets.to(device)
            logits = model(x_b).logits
            val_loss += criterion(logits, y_b).item()
            probs = torch.softmax(logits, dim=1)
            pred = torch.argmax(probs, dim=1)
            correct += (pred == y_b).sum()
            total += len(y_b)
        val_acc = (100 * correct / total).item()
        average_val_loss = val_loss / len(data_loader)
        return average_val_loss, val_acc

# evaluate the model on the test data
def test_model(model, data_loader, device):
    model.eval()
    correct, total = 0, 0
    y_labels = []
    y_preds = []
    model.load_state_dict(torch.load('./best_epoch.pth', weights_only=True))
    test_p_bar = tqdm(data_loader)
    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            logits = outputs.logits
            _, predicted = torch.max(logits, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            acc_test = 100 * correct / total
            test_p_bar.set_postfix_str(f'accuracy = {acc_test:.4f}%')
            test_p_bar.update()
            y_labels.extend(labels.cpu().numpy())
            y_preds.extend(predicted.cpu().numpy())
        test_p_bar.close()
    return acc_test, y_labels, y_preds

In [None]:
%%time
# train the model
for epoch in range(1, num_epochs + 1):
    total, correct = 0, 0
    p_bar = tqdm(balanced_train_loader_grayscale)
    losses = 0
    acc_train = 0.0
    model.train()
    for image, label in balanced_train_loader_grayscale:
        image, label = image.to(device), label.to(device)
        optimizer.zero_grad()
        logits = model(image).logits
        loss = criterion(logits, label)
        loss.backward()
        optimizer.step()
        losses += loss.item()
        probs = torch.softmax(logits, dim=1)
        pred = torch.argmax(probs, dim=1)
        correct += (pred == label).sum().item()
        total += len(label)
        acc_train = 100 * correct / total
        p_bar.set_postfix_str(f'loss={loss.item():.4f}, acc={acc_train:.4f}%')
        p_bar.update()

    average_loss = losses / len(balanced_train_loader_grayscale)
    train_losses.append(average_loss)
    train_accuracies.append(acc_train)

    val_loss, val_acc = eval_model(model, val_loader_grayscale, criterion, device)
    #scheduler.step(val_loss)
    val_losses.append(val_loss)
    val_accuracies.append(val_acc)

    p_bar.close()

    print(f"Epoch {epoch}")
    print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}%")
    current_lr = optimizer.param_groups[0]['lr']
    print(f'Current learning rate at epoch {epoch}: {current_lr}')

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_epoch = epoch
        torch.save(model.state_dict(), f'best_epoch.pth')
        patience_counter = 0
        print(f"Patience Counter: {patience_counter}")
    else:
        patience_counter += 1
        print(f"Patience Counter: {patience_counter}")
        if patience_counter >= patience:
            print("Early stopping triggered")
            break

    print("..... .... ... .. .. .")

print(f"Best epoch : {best_epoch}")
print("Training complete!")

In [None]:
%%time
# test the model
test_accuracy, test_labels_all, test_preds_all = test_model(model, test_loader_grayscale, device)
print(f"Test Accuracy: {test_accuracy:.4f}")

In [None]:
# confusion matrix
fig, ax = plt.subplots(figsize=(8, 8))
test_conf_matrix = confusion_matrix(test_labels_all, test_preds_all)
disp = ConfusionMatrixDisplay(confusion_matrix=test_conf_matrix
                              ,display_labels=['glioma', 'meningioma', 'pituitary', 'no tumor'])
disp.plot(cmap=plt.cm.Blues, ax=ax, xticks_rotation='horizontal')
ax.set_xticklabels(ax.get_xticklabels(),fontsize = 16)
ax.set_yticklabels(ax.get_yticklabels(),fontsize = 16)
for row in disp.text_:
    for text in row:
        text.set_fontsize(20)
#disp = ConfusionMatrixDisplay(confusion_matrix=cm)
ax.grid(False)
plt.title("Confusion Matrix for the ViT-Base Model")
plt.show()

In [None]:
# evaluation metrics
target_names = ['glioma tumor', 'meningioma tumor','pituitary tumor', 'no tumor']
report = classification_report(test_labels_all, test_preds_all, target_names=[str(i) for i in target_names])
print(report)

In [None]:
# CohenKappa metric
num_classes = 4
kappa = CohenKappa(task='multiclass', num_classes=num_classes)
test_preds_all = torch.tensor(test_preds_all)
test_labels_all = torch.tensor(test_labels_all)
kappa_score = kappa(test_preds_all, test_labels_all)

print(f"Cohen's Kappa: {kappa_score * 100:.4f}%")