In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        (os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
!pip install timm torch-geometric


In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

import timm
from tqdm import tqdm
from torch_geometric.nn import GATConv


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


In [None]:
DATASET_PATH = "/kaggle/input/brain-tumor-mri-dataset/Training"


In [None]:
IMG_SIZE = 224
BATCH_SIZE = 16
EPOCHS = 30
LR = 1e-4
NUM_CLASSES = 4


In [None]:
train_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

test_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])


In [None]:
full_dataset = datasets.ImageFolder(DATASET_PATH, transform=train_transforms)

print("Classes:", full_dataset.classes)
print("Total images:", len(full_dataset))


In [None]:
train_size = int(0.8 * len(full_dataset))
val_size   = int(0.1 * len(full_dataset))
test_size  = len(full_dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    full_dataset, [train_size, val_size, test_size]
)


In [None]:
val_dataset.dataset.transform = test_transforms
test_dataset.dataset.transform = test_transforms


In [None]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader  = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)


In [None]:
class GATBlock(nn.Module):
    def __init__(self, in_channels, out_channels, heads=4):
        super().__init__()
        self.gat = GATConv(in_channels, out_channels, heads=heads, concat=False)

    def forward(self, x):
        num_nodes = x.size(0)

        edge_index = torch.combinations(
            torch.arange(num_nodes, device=x.device), r=2
        ).t()
        edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)

        return self.gat(x, edge_index)


In [None]:
class Swin_GRU_GAT(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        self.swin = timm.create_model(
            "swin_tiny_patch4_window7_224",
            pretrained=True,
            num_classes=0
        )

        swin_dim = self.swin.num_features

        self.gru = nn.GRU(
            input_size=swin_dim,
            hidden_size=256,
            batch_first=True,
            bidirectional=True
        )

        self.gat = GATBlock(512, 256)

        self.classifier = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.0005),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.swin(x)
        x = x.unsqueeze(1)
        x, _ = self.gru(x)
        x = x.squeeze(1)
        x = self.gat(x)
        return self.classifier(x)


In [None]:
model = Swin_GRU_GAT(NUM_CLASSES).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LR)


In [None]:
def train_one_epoch(model, loader):
    model.train()
    loss_sum, correct, total = 0, 0, 0

    for images, labels in tqdm(loader):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        loss_sum += loss.item()
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    return loss_sum / len(loader), correct / total


In [None]:
def evaluate(model, loader):
    model.eval()
    loss_sum, correct, total = 0, 0, 0

    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            loss_sum += loss.item()
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    return loss_sum / len(loader), correct / total


In [None]:
for epoch in range(EPOCHS):
    train_loss, train_acc = train_one_epoch(model, train_loader)
    val_loss, val_acc = evaluate(model, val_loader)

    print(f"\nEpoch [{epoch+1}/{EPOCHS}]")
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    print(f"Val   Loss: {val_loss:.4f} | Val   Acc: {val_acc:.4f}")


In [None]:
test_loss, test_acc = evaluate(model, test_loader)
print("Test Accuracy:", test_acc)
