In [None]:
import torch
from torch import nn, optim
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
import json
import os
import math
import matplotlib.pyplot as plt
import numpy as np
from torch.nn import functional as F
from datasets import load_dataset, DatasetDict, Dataset
from torch.utils.data import DataLoader, SubsetRandomSampler
from torchvision.transforms import Lambda
import torch.multiprocessing as mp
from torch.optim.lr_scheduler import StepLR

In [None]:
class SwishActivation(nn.Module):
    def forward(self, input):
        return input * torch.sigmoid(input)

class PatchEmbeddings(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.image_size = config["image_size"]
        self.patch_size = config["patch_size"]
        self.num_channels = config["num_channels"]
        self.hidden_size = config["hidden_size"]
        self.num_patches = (self.image_size // self.patch_size) ** 2
        self.projection = nn.Conv2d(self.num_channels, self.hidden_size, kernel_size=self.patch_size, stride=self.patch_size)

    def forward(self, x):
        x = self.projection(x)
        x = x.flatten(2).transpose(1, 2)
        return x

class EmbeddingLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.patch_embeddings = PatchEmbeddings(config)
        self.cls_token = nn.Parameter(torch.randn(1, 1, config["hidden_size"]))
        self.position_embeddings = nn.Parameter(torch.randn(1, self.patch_embeddings.num_patches, config["hidden_size"]))
        self.dropout = nn.Dropout(config["hidden_dropout_prob"])

    def forward(self, x):
        x = self.patch_embeddings(x)
        batch_size, seq_len, _ = x.size()
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        position_embeddings = self.position_embeddings[:, :seq_len, :]  # Adjust size

        # Adding position_embeddings without using expand
        x = torch.cat((cls_tokens, x + position_embeddings), dim=1)
        x = self.dropout(x)
        return x

class AttentionHead(nn.Module):
    def __init__(self, hidden_size, attention_head_size, dropout, bias=True):
        super().__init__()
        self.hidden_size = hidden_size
        self.attention_head_size = attention_head_size
        self.query = nn.Linear(hidden_size, attention_head_size, bias=bias)
        self.key = nn.Linear(hidden_size, attention_head_size, bias=bias)
        self.value = nn.Linear(hidden_size, attention_head_size, bias=bias)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        query = self.query(x)
        key = self.key(x)
        value = self.value(x)
        attention_scores = torch.matmul(query, key.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        attention_probs = F.softmax(attention_scores, dim=-1)
        attention_probs = self.dropout(attention_probs)
        attention_output = torch.matmul(attention_probs, value)
        return attention_output, attention_probs

class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.hidden_size = config["hidden_size"]
        self.num_attention_heads = config["num_attention_heads"]
        self.attention_head_size = self.hidden_size // self.num_attention_heads
        self.all_head_size = self.num_attention_heads * self.attention_head_size
        self.qkv_bias = config["qkv_bias"]
        self.heads = nn.ModuleList([AttentionHead(
            self.hidden_size,
            self.attention_head_size,
            config["attention_probs_dropout_prob"],
            self.qkv_bias
        ) for _ in range(self.num_attention_heads)])
        self.output_projection = nn.Linear(self.all_head_size, self.hidden_size)
        self.output_dropout = nn.Dropout(config["hidden_dropout_prob"])

    def forward(self, x, output_attentions=False):
        attention_outputs = [head(x) for head in self.heads]
        attention_output = torch.cat([output for output, _ in attention_outputs], dim=-1)
        attention_output = self.output_projection(attention_output)
        attention_output = self.output_dropout(attention_output)
        if not output_attentions:
            return attention_output, None
        else:
            attention_probs = torch.stack([attention_probs for _, attention_probs in attention_outputs], dim=1)
            return attention_output, attention_probs

class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense_1 = nn.Linear(config["hidden_size"], config["intermediate_size"])
        self.activation = SwishActivation()
        self.dense_2 = nn.Linear(config["intermediate_size"], config["hidden_size"])
        self.dropout = nn.Dropout(config["hidden_dropout_prob"])

    def forward(self, x):
        x = self.dense_1(x)
        x = self.activation(x)
        x = self.dense_2(x)
        x = self.dropout(x)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.use_faster_attention = config.get("use_faster_attention", False)
        if self.use_faster_attention:
            self.attention = FasterMultiHeadAttention(config)
        else:
            self.attention = MultiHeadAttention(config)
        self.layer_norm_1 = nn.LayerNorm(config["hidden_size"])
        self.mlp = MLP(config)
        self.layer_norm_2 = nn.LayerNorm(config["hidden_size"])

    def forward(self, x, output_attentions=False):
        attention_output, attention_probs = self.attention(self.layer_norm_1(x), output_attentions=output_attentions)
        x = x + attention_output
        mlp_output = self.mlp(self.layer_norm_2(x))
        x = x + mlp_output
        if not output_attentions:
            return x, None
        else:
            return x, attention_probs

class TransformerEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config["num_hidden_layers"])])

    def forward(self, x, output_attentions=False):
        all_attentions = []
        for block in self.blocks:
            x, attention_probs = block(x, output_attentions=output_attentions)
            if output_attentions:
                all_attentions.append(attention_probs)
        if not output_attentions:
            return x, None
        else:
            return x, all_attentions

class VisionTransformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.image_size = config["image_size"]
        self.hidden_size = config["hidden_size"]
        self.num_classes = config["num_classes"]
        self.embedding = EmbeddingLayer(config)
        self.encoder = TransformerEncoder(config)
        self.classifier = nn.Linear(self.hidden_size, self.num_classes)
        self.apply(self._init_weights)

    def forward(self, x, output_attentions=False):
        embedding_output = self.embedding(x)
        encoder_output, all_attentions = self.encoder(embedding_output, output_attentions=output_attentions)
        logits = self.classifier(encoder_output[:, 0, :])
        if not output_attentions:
            return logits, None
        else:
            return logits, all_attentions

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            nn.init.xavier_uniform_(module.weight.data)
            if module.bias is not None:
                nn.init.zeros_(module.bias.data)

In [None]:
def load_data(batch_size=64):
    dataset_path = "/kaggle/input/brain-tumor/brain"

    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
#         transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    train_dataset = ImageFolder(root=os.path.join(dataset_path, 'Training'), transform=transform)
    test_dataset = ImageFolder(root=os.path.join(dataset_path, 'Testing'), transform=transform)

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

    return train_loader, test_loader

In [None]:
def train(model, train_loader, optimizer, criterion, device):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0

    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs, _ = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()

        # Gradient Clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    return total_loss / len(train_loader), correct / total

def evaluate(model, test_loader, criterion, device):
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)

            outputs, _ = model(inputs)
            loss = criterion(outputs, targets)

            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    return total_loss / len(test_loader), correct / total

def plot_training_progress(train_losses, train_accuracies, test_losses, test_accuracies):
    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label="Train")
    plt.plot(test_losses, label="Test")
    plt.title("Loss")
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(train_accuracies, label="Train")
    plt.plot(test_accuracies, label="Test")
    plt.title("Accuracy")
    plt.legend()

    plt.tight_layout()
    plt.show()

In [None]:
# Visualize some sample images from the dataset
def visualize_samples(dataset_loader, num_samples=5):
    data_iter = iter(dataset_loader)
    images, labels = next(data_iter)
    class_names = dataset_loader.dataset.classes

    print(f"Class Names: {class_names}")
    print(f"Labels: {labels}")

    plt.figure(figsize=(15, 8))
    for i in range(num_samples):
        plt.subplot(1, num_samples, i + 1)
        image = images[i].numpy().transpose((1, 2, 0))
        plt.imshow(image)
        plt.title(f"Class: {class_names[labels[i]]}")
        plt.axis('off')
    plt.show()

In [None]:
# Configuration for the Vision Transformer
config = {
    "image_size": 64,
    "patch_size": 4,
    "num_channels": 3,
    "hidden_size": 128,
    "num_hidden_layers": 16,
    "num_attention_heads": 32,
    "intermediate_size": 64,
    "hidden_dropout_prob": 0.1,
    "attention_probs_dropout_prob": 0.1,
    "qkv_bias": True,
    "num_classes": 10,
    "initializer_range": 0.02,
    "use_faster_attention": False
}

In [None]:
# Training hyperparameters
learning_rate = 0.001
num_epochs = 20
batch_size = 32

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize and move model to device
model = VisionTransformer(config).to(device)

In [None]:
# Print model's architecture
print(model)

In [None]:
# Lists to store training and testing metrics for visualization
train_losses, train_accuracies = [], []
test_losses, test_accuracies = [], []

# Define optimizer and loss function
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()

# scheduler
scheduler = StepLR(optimizer, step_size=3, gamma=0.1)

In [None]:
# Load the dataset
train_loader, test_loader = load_data(batch_size=batch_size)

# Visualize some sample images from the dataset
visualize_samples(train_loader)

In [None]:
for epoch in range(num_epochs):
    # Training phase
    train_loss, train_accuracy = train(model, train_loader, optimizer, criterion, device)

    # Testing phase
    test_loss, test_accuracy = evaluate(model, test_loader, criterion, device)

    # Print details
    print(f"Epoch {epoch + 1}/{num_epochs}, Learning Rate: {scheduler.get_last_lr()[0]:.8f}, "
    f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, "
    f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}"
    )

    # Store metrics for visualization
    train_losses.append(train_loss)
    train_accuracies.append(train_accuracy)
    test_losses.append(test_loss)
    test_accuracies.append(test_accuracy)

    scheduler.step()  # Adjust learning rate after each epoch

In [None]:
# Plot training progress
plot_training_progress(train_losses, train_accuracies, test_losses, test_accuracies)