In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        (os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [3]:
# Install PyTorch Geometric dependencies
!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-2.1.0+cu118.html

# Install main torch_geometric package
!pip install torch-geometric


Looking in links: https://data.pyg.org/whl/torch-2.1.0+cu118.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-2.1.0%2Bcu118/torch_scatter-2.1.2%2Bpt21cu118-cp311-cp311-linux_x86_64.whl (10.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.2/10.2 MB[0m [31m15.8 MB/s[0m eta [36m0:00:00[0m00:01[0m0:01[0m
[?25hCollecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-2.1.0%2Bcu118/torch_sparse-0.6.18%2Bpt21cu118-cp311-cp311-linux_x86_64.whl (4.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.9/4.9 MB[0m [31m8.7 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m0m
[?25hCollecting torch-cluster
  Downloading https://data.pyg.org/whl/torch-2.1.0%2Bcu118/torch_cluster-1.6.3%2Bpt21cu118-cp311-cp311-linux_x86_64.whl (3.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m6.7 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m0m
[?25hCollecting torch-spline-conv
  Downloadi

In [5]:
# ===============================
# Hybrid Model: Attention + ViT + MCNN
# 4-Class Dataset Pipeline
# ===============================

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import torch.optim as optim
import numpy as np

# ------------------------------
# SETTINGS
# ------------------------------
IMG_SIZE = 128
PATCH_SIZE = 16
BATCH_SIZE = 8
EPOCHS = 10
NUM_CLASSES = 4
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ------------------------------
# DATASET
# ------------------------------
class CustomMRI(Dataset):
    def __init__(self, data_dir, transform=None):
        self.transform = transform
        self.image_paths = []
        self.labels = []
        class_map = {cls_name: idx for idx, cls_name in enumerate(os.listdir(data_dir))}
        for cls_name, idx in class_map.items():
            cls_dir = os.path.join(data_dir, cls_name)
            for fname in os.listdir(cls_dir):
                if fname.lower().endswith(('png', 'jpg', 'jpeg')):
                    self.image_paths.append(os.path.join(cls_dir, fname))
                    self.labels.append(idx)
                
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert('RGB')
        if self.transform:
            img = self.transform(img)
        label = self.labels[idx]
        return img, label

transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

dataset_dir = "/kaggle/input/brain-tumor-mri-dataset/Training"
dataset = CustomMRI(dataset_dir, transform=transform)

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

# ------------------------------
# MULTI-SCALE CNN
# ------------------------------
class MultiScaleCNN(nn.Module):
    def __init__(self, out_features=128):
        super().__init__()
        self.conv3 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv5 = nn.Conv2d(3, 32, 5, padding=2)
        self.conv7 = nn.Conv2d(3, 32, 7, padding=3)
        self.pool = nn.MaxPool2d(2)
        self.fc = nn.Linear(32*3*(IMG_SIZE//2)*(IMG_SIZE//2), out_features)
    
    def forward(self, x):
        x3 = F.relu(self.pool(self.conv3(x)))
        x5 = F.relu(self.pool(self.conv5(x)))
        x7 = F.relu(self.pool(self.conv7(x)))
        x_cat = torch.cat([x3, x5, x7], dim=1)
        x_flat = x_cat.view(x_cat.size(0), -1)
        return self.fc(x_flat)

# ------------------------------
# SIMPLE ViT
# ------------------------------
class SimpleViT(nn.Module):
    def __init__(self, img_size=IMG_SIZE, patch_size=PATCH_SIZE, in_ch=3, emb_dim=128, num_heads=4, depth=4):
        super().__init__()
        self.patch_embed = nn.Conv2d(in_ch, emb_dim, kernel_size=patch_size, stride=patch_size)
        num_patches = (img_size // patch_size) ** 2
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches, emb_dim))
        encoder_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=num_heads, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
    
    def forward(self, x):
        B = x.size(0)
        x = self.patch_embed(x).flatten(2).transpose(1,2)
        x = x + self.pos_embed
        x = self.transformer(x)
        x = x.mean(dim=1)
        return x

# ------------------------------
# ATTENTION BRANCH (replaces GAT)
# ------------------------------
class AttentionBranch(nn.Module):
    def __init__(self, in_dim=128, out_dim=128):
        super().__init__()
        self.query = nn.Linear(in_dim, out_dim)
        self.key = nn.Linear(in_dim, out_dim)
        self.value = nn.Linear(in_dim, out_dim)
    
    def forward(self, x):
        # x: [batch_size, features]
        Q = self.query(x).unsqueeze(1)
        K = self.key(x).unsqueeze(1)
        V = self.value(x).unsqueeze(1)
        attn_scores = torch.bmm(Q, K.transpose(1,2)) / np.sqrt(Q.size(-1))
        attn_weights = F.softmax(attn_scores, dim=-1)
        out = torch.bmm(attn_weights, V).squeeze(1)
        return out

# ------------------------------
# HYBRID MODEL
# ------------------------------
class HybridModel(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES):
        super().__init__()
        self.mcnn = MultiScaleCNN()
        self.vit = SimpleViT()
        self.attn = AttentionBranch()
        self.fc = nn.Linear(128+128+128, num_classes)
    
    def forward(self, x):
        x1 = self.mcnn(x)
        x2 = self.vit(x)
        x3 = self.attn(x1)
        x_cat = torch.cat([x1, x2, x3], dim=1)
        return self.fc(x_cat)

# ------------------------------
# TRAINING
# ------------------------------
model = HybridModel().to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

for epoch in range(EPOCHS):
    model.train()
    total_loss, correct = 0, 0
    for imgs, labels in train_loader:
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        correct += (outputs.argmax(1) == labels).sum().item()
    print(f"Epoch {epoch+1}, Loss={total_loss:.4f}, Train Acc={correct/len(train_loader.dataset):.4f}")

# ------------------------------
# VALIDATION
# ------------------------------
model.eval()
correct = 0
with torch.no_grad():
    for imgs, labels in val_loader:
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
        outputs = model(imgs)
        correct += (outputs.argmax(1) == labels).sum().item()
val_acc = correct / len(val_loader.dataset)
print(f"Validation Accuracy: {val_acc:.4f}")


Epoch 1, Loss=444.4494, Train Acc=0.7597
Epoch 2, Loss=178.0936, Train Acc=0.8833
Epoch 3, Loss=108.3268, Train Acc=0.9354
Epoch 4, Loss=76.4142, Train Acc=0.9545
Epoch 5, Loss=55.6940, Train Acc=0.9659
Epoch 6, Loss=44.0605, Train Acc=0.9753
Epoch 7, Loss=42.4205, Train Acc=0.9801
Epoch 8, Loss=20.9212, Train Acc=0.9902
Epoch 9, Loss=39.9999, Train Acc=0.9821
Epoch 10, Loss=21.1738, Train Acc=0.9873
Validation Accuracy: 0.9073


In [4]:
# ===============================
# Hybrid Model: GAT + ViT + MCNN
# 4-Class Dataset A-to-Z Pipeline
# ===============================

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import torch.optim as optim
from torch_geometric.nn import GATConv
from torch_geometric.data import Data, Batch

import numpy as np

# ------------------------------
# SETTINGS
# ------------------------------
IMG_SIZE = 128
PATCH_SIZE = 16
BATCH_SIZE = 8
EPOCHS = 10
NUM_CLASSES = 4
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ------------------------------
# DATASET
# ------------------------------
class CustomMRI(Dataset):
    def __init__(self, data_dir, transform=None):
        self.transform = transform
        self.image_paths = []
        self.labels = []
        class_map = {cls_name: idx for idx, cls_name in enumerate(os.listdir(data_dir))}
        for cls_name, idx in class_map.items():
            cls_dir = os.path.join(data_dir, cls_name)
            for fname in os.listdir(cls_dir):
                if fname.lower().endswith(('png', 'jpg', 'jpeg')):
                    self.image_paths.append(os.path.join(cls_dir, fname))
                    self.labels.append(idx)
                
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert('RGB')
        if self.transform:
            img = self.transform(img)
        label = self.labels[idx]
        return img, label

# Transforms
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

# Dataset path
dataset_dir = "/kaggle/input/brain-tumor-mri-dataset/Training"  # change to your dataset
dataset = CustomMRI(dataset_dir, transform=transform)

# Split
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

# ------------------------------
# MULTI-SCALE CNN BRANCH
# ------------------------------
class MultiScaleCNN(nn.Module):
    def __init__(self, out_features=128):
        super().__init__()
        self.conv3 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(3, 32, kernel_size=5, padding=2)
        self.conv7 = nn.Conv2d(3, 32, kernel_size=7, padding=3)
        self.pool = nn.MaxPool2d(2)
        self.fc = nn.Linear(32*3*(IMG_SIZE//2)*(IMG_SIZE//2), out_features)
    
    def forward(self, x):
        x3 = F.relu(self.pool(self.conv3(x)))
        x5 = F.relu(self.pool(self.conv5(x)))
        x7 = F.relu(self.pool(self.conv7(x)))
        x_cat = torch.cat([x3, x5, x7], dim=1)
        x_flat = x_cat.view(x_cat.size(0), -1)
        return self.fc(x_flat)

# ------------------------------
# SIMPLE ViT BRANCH
# ------------------------------
class SimpleViT(nn.Module):
    def __init__(self, img_size=IMG_SIZE, patch_size=PATCH_SIZE, in_ch=3, emb_dim=128, num_heads=4, depth=4):
        super().__init__()
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.emb_dim = emb_dim
        
        self.patch_embed = nn.Conv2d(in_ch, emb_dim, kernel_size=patch_size, stride=patch_size)
        self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches, emb_dim))
        encoder_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=num_heads)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
    
    def forward(self, x):
        B = x.size(0)
        x = self.patch_embed(x)           # B x emb_dim x H' x W'
        x = x.flatten(2).transpose(1,2)  # B x num_patches x emb_dim
        x = x + self.pos_embed
        x = self.transformer(x)           # B x num_patches x emb_dim
        x = x.mean(dim=1)                 # global pooling
        return x

# ------------------------------
# GAT BRANCH
# ------------------------------
class SimpleGAT(nn.Module):
    def __init__(self, in_feats=128, hidden=64, out_feats=128):
        super().__init__()
        self.gat1 = GATConv(in_feats, hidden, heads=4, concat=True)
        self.gat2 = GATConv(hidden*4, out_feats, heads=1, concat=True)
    
    def forward(self, x, edge_index):
        x = F.elu(self.gat1(x, edge_index))
        x = self.gat2(x, edge_index)
        return x.mean(dim=0)

# ------------------------------
# HYBRID MODEL
# ------------------------------
class HybridModel(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES):
        super().__init__()
        self.mcnn = MultiScaleCNN()
        self.vit = SimpleViT()
        self.gat = SimpleGAT()
        self.fc = nn.Linear(128+128+128, num_classes)
    
    def forward(self, img, node_features, edge_index):
        x1 = self.mcnn(img)
        x2 = self.vit(img)
        x3 = self.gat(node_features, edge_index)
        x = torch.cat([x1, x2, x3], dim=1)
        return self.fc(x)

# ------------------------------
# INITIALIZE
# ------------------------------
model = HybridModel().to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

# ------------------------------
# TRAINING LOOP
# ------------------------------
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    correct = 0
    for imgs, labels in train_loader:
        imgs = imgs.to(DEVICE)
        labels = labels.to(DEVICE)
        
        # Generate graph from image patches
        B = imgs.size(0)
        node_features = imgs.unfold(2, PATCH_SIZE, PATCH_SIZE).unfold(3, PATCH_SIZE, PATCH_SIZE)
        node_features = node_features.contiguous().view(B, -1, 3*PATCH_SIZE*PATCH_SIZE)
        node_features = nn.Linear(3*PATCH_SIZE*PATCH_SIZE, 128).to(DEVICE)(node_features)
        
        # Fully connect graph for simplicity
        num_nodes = node_features.size(1)
        row = []
        col = []
        for i in range(num_nodes):
            for j in range(num_nodes):
                row.append(i)
                col.append(j)
        edge_index = torch.tensor([row, col], dtype=torch.long).to(DEVICE)
        
        optimizer.zero_grad()
        outputs = model(imgs, node_features[0], edge_index)  # node_features[0] for batch simplification
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        correct += (outputs.argmax(1) == labels).sum().item()
    
    acc = correct / len(train_loader.dataset)
    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {total_loss:.4f}, Train Acc: {acc:.4f}")

# ------------------------------
# VALIDATION
# ------------------------------
model.eval()
correct = 0
with torch.no_grad():
    for imgs, labels in val_loader:
        imgs = imgs.to(DEVICE)
        labels = labels.to(DEVICE)
        B = imgs.size(0)
        node_features = imgs.unfold(2, PATCH_SIZE, PATCH_SIZE).unfold(3, PATCH_SIZE, PATCH_SIZE)
        node_features = node_features.contiguous().view(B, -1, 3*PATCH_SIZE*PATCH_SIZE)
        node_features = nn.Linear(3*PATCH_SIZE*PATCH_SIZE, 128).to(DEVICE)(node_features)
        
        num_nodes = node_features.size(1)
        row = []
        col = []
        for i in range(num_nodes):
            for j in range(num_nodes):
                row.append(i)
                col.append(j)
        edge_index = torch.tensor([row, col], dtype=torch.long).to(DEVICE)
        
        outputs = model(imgs, node_features[0], edge_index)
        correct += (outputs.argmax(1) == labels).sum().item()
val_acc = correct / len(val_loader.dataset)
print(f"Validation Accuracy: {val_acc:.4f}")


  import torch_geometric.typing
  import torch_geometric.typing
  import torch_geometric.typing
  import torch_geometric.typing


RuntimeError: Tensors must have same number of dimensions: got 2 and 1

In [None]:
# ===========================
# A-to-Z Hybrid: GAT + ViT + MCNN
# ===========================

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets, models
import torch.optim as optim
from torch_geometric.nn import GATConv
from torch_geometric.data import Data, Batch

import numpy as np
from PIL import Image
import os

# --------------------------
# SETTINGS
# --------------------------
IMG_SIZE = 128
BATCH_SIZE = 8
EPOCHS = 30
NUM_CLASSES = 4
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# --------------------------
# DATASET
# --------------------------
class CustomMRI(Dataset):
    def __init__(self, data_dir, transform=None):
        self.transform = transform
        self.image_paths = []
        self.labels = []
        class_map = {cls_name: idx for idx, cls_name in enumerate(os.listdir(data_dir))}
        for cls_name, idx in class_map.items():
            cls_dir = os.path.join(data_dir, cls_name)
            for fname in os.listdir(cls_dir):
                self.image_paths.append(os.path.join(cls_dir, fname))
                self.labels.append(idx)
                
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert('RGB')
        if self.transform:
            img = self.transform(img)
        label = self.labels[idx]
        return img, label

transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

dataset = CustomMRI("/kaggle/input/brain-tumor-mri-dataset/Training", transform=transform)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

# --------------------------
# MULTI-SCALE CNN BRANCH
# --------------------------
class MultiScaleCNN(nn.Module):
    def __init__(self, out_features=128):
        super().__init__()
        self.conv3 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(3, 32, kernel_size=5, padding=2)
        self.conv7 = nn.Conv2d(3, 32, kernel_size=7, padding=3)
        self.pool = nn.MaxPool2d(2)
        self.fc = nn.Linear(32*3*(IMG_SIZE//2)*(IMG_SIZE//2), out_features)
    
    def forward(self, x):
        x3 = F.relu(self.pool(self.conv3(x)))
        x5 = F.relu(self.pool(self.conv5(x)))
        x7 = F.relu(self.pool(self.conv7(x)))
        x_cat = torch.cat([x3, x5, x7], dim=1)
        x_flat = x_cat.view(x_cat.size(0), -1)
        return self.fc(x_flat)

# --------------------------
# VIT BRANCH
# --------------------------
class SimpleViT(nn.Module):
    def __init__(self, img_size=IMG_SIZE, patch_size=16, in_ch=3, emb_dim=128, num_heads=4, depth=4):
        super().__init__()
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.emb_dim = emb_dim
        
        self.patch_embed = nn.Conv2d(in_ch, emb_dim, kernel_size=patch_size, stride=patch_size)
        self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches, emb_dim))
        encoder_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=num_heads)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
    
    def forward(self, x):
        B = x.size(0)
        x = self.patch_embed(x) # B x emb_dim x H' x W'
        x = x.flatten(2).transpose(1,2) # B x num_patches x emb_dim
        x = x + self.pos_embed
        x = self.transformer(x) # B x num_patches x emb_dim
        x = x.mean(dim=1) # global average pooling
        return x

# --------------------------
# GAT BRANCH
# --------------------------
class SimpleGAT(nn.Module):
    def __init__(self, in_feats=128, hidden=64, out_feats=128):
        super().__init__()
        self.gat1 = GATConv(in_feats, hidden, heads=4, concat=True)
        self.gat2 = GATConv(hidden*4, out_feats, heads=1, concat=True)
    
    def forward(self, x, edge_index):
        x = F.elu(self.gat1(x, edge_index))
        x = self.gat2(x, edge_index)
        return x.mean(dim=0) # global node pooling

# --------------------------
# HYBRID MODEL
# --------------------------
class HybridModel(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES):
        super().__init__()
        self.mcnn = MultiScaleCNN()
        self.vit = SimpleViT()
        self.gat = SimpleGAT()
        self.fc = nn.Linear(128+128+128, num_classes)
    
    def forward(self, img, node_features, edge_index):
        x1 = self.mcnn(img)
        x2 = self.vit(img)
        x3 = self.gat(node_features, edge_index)
        x = torch.cat([x1, x2, x3], dim=1)
        return self.fc(x)

# --------------------------
# INITIALIZE MODEL
# --------------------------
model = HybridModel().to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

# --------------------------
# TRAINING LOOP
# --------------------------
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    correct = 0
    for imgs, labels in train_loader:
        imgs = imgs.to(DEVICE)
        labels = labels.to(DEVICE)
        
        # For demonstration: generate dummy node_features & edge_index
        # In practice, construct graph from ROI or patches
        B = imgs.size(0)
        node_features = torch.randn(B, 10, 128).to(DEVICE)  # 10 nodes per image
        edge_index = torch.tensor([[i,j] for i in range(10) for j in range(10)], dtype=torch.long).t().contiguous().to(DEVICE)
        
        optimizer.zero_grad()
        outputs = model(imgs, node_features, edge_index)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        correct += (outputs.argmax(1) == labels).sum().item()
    
    acc = correct / len(train_loader.dataset)
    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {total_loss:.4f}, Train Acc: {acc:.4f}")

# --------------------------
# VALIDATION
# --------------------------
model.eval()
correct = 0
with torch.no_grad():
    for imgs, labels in val_loader:
        imgs = imgs.to(DEVICE)
        labels = labels.to(DEVICE)
        B = imgs.size(0)
        node_features = torch.randn(B, 10, 128).to(DEVICE)
        edge_index = torch.tensor([[i,j] for i in range(10) for j in range(10)], dtype=torch.long).t().contiguous().to(DEVICE)
        outputs = model(imgs, node_features, edge_index)
        correct += (outputs.argmax(1) == labels).sum().item()
val_acc = correct / len(val_loader.dataset)
print(f"Validation Accuracy: {val_acc:.4f}")
