In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import numpy as np
from sklearn.preprocessing import LabelEncoder
from sklearn.neighbors import kneighbors_graph
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, BatchNorm
import torch.nn.functional as F
import os

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Auto-download Caltech101 dataset
data_dir = './caltech101'
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    # Convert grayscale images to 3 channels before normalization
    transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.shape[0] == 1 else x),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load dataset
dataset = datasets.Caltech101(root=data_dir, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)

print(f"Dataset loaded: {len(dataset)} images.")

# Fine-tune ResNet101
resnet = models.resnet101(pretrained=True)
resnet.fc = nn.Linear(resnet.fc.in_features, len(dataset.categories))
resnet.to(device)
optimizer = optim.AdamW(resnet.parameters(), lr=0.0003, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()

# Train ResNet101
def train_resnet(model, dataloader, optimizer, criterion, epochs=15):
    model.train()
    scaler = torch.cuda.amp.GradScaler()
    for epoch in range(epochs):
        total_loss, correct = 0, 0
        for img, label in dataloader:
            img, label = img.to(device), label.to(device)
            optimizer.zero_grad()
            with torch.cuda.amp.autocast():
                output = model(img)
                loss = criterion(output, label)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            total_loss += loss.item()
            correct += (output.argmax(dim=1) == label).sum().item()
        accuracy = correct / len(dataset)
        print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataloader):.4f}, Accuracy: {accuracy:.4f}")
    print(f"Final ResNet Accuracy: {accuracy:.4f}")

train_resnet(resnet, dataloader, optimizer, criterion, epochs=15)

# Feature Extraction
resnet.fc = nn.Identity()
def extract_features(dataset, model):
    model.eval()
    features, labels = [], []
    with torch.no_grad():
        for img, label in dataset:
            img = img.unsqueeze(0).to(device)
            feat = model(img).cpu().numpy().flatten()
            features.append(feat)
            labels.append(label)
    return np.array(features), np.array(labels)

features, labels = extract_features(dataset, resnet)
labels = LabelEncoder().fit_transform(labels)

# Build Optimized KNN Graph
def build_knn_graph(features, k=10):
    adj_matrix = kneighbors_graph(features, k, mode='connectivity', include_self=True).toarray()
    edge_index = np.array(np.nonzero(adj_matrix))
    return torch.tensor(edge_index, dtype=torch.long)

edge_index = build_knn_graph(features)

graph_data = Data(
    x=torch.tensor(features, dtype=torch.float),
    edge_index=edge_index,
    y=torch.tensor(labels, dtype=torch.long)
)

# Define Improved GCN Model
class GCN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.norm1 = BatchNorm(hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.norm2 = BatchNorm(hidden_channels)
        self.conv3 = GCNConv(hidden_channels, out_channels)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.norm1(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index).relu()
        x = self.norm2(x)
        x = self.dropout(x)
        x = self.conv3(x, edge_index)
        return F.log_softmax(x, dim=1)

# Initialize & Train GCN
gcn = GCN(in_channels=features.shape[1], hidden_channels=512, out_channels=len(set(labels))).to(device)
optimizer = optim.AdamW(gcn.parameters(), lr=0.001, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)
loss_fn = nn.CrossEntropyLoss()

def train_gcn(model, data, optimizer, loss_fn, scheduler, epochs=100):
    model.train()
    best_acc = 0.0
    for epoch in range(epochs):
        optimizer.zero_grad()
        out = model(data.x.to(device), data.edge_index.to(device))
        loss = loss_fn(out, data.y.to(device))
        loss.backward()
        optimizer.step()
        scheduler.step()
        acc = (out.argmax(dim=1) == data.y.to(device)).float().mean().item()
        best_acc = max(best_acc, acc)
        print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}, Accuracy: {acc:.4f}")
    print(f"Final GCN Accuracy: {best_acc:.4f}")

train_gcn(gcn, graph_data, optimizer, loss_fn, scheduler, epochs=100)

print("Training complete. Best accuracy achieved during training is displayed above.")

Using device: cuda
Files already downloaded and verified
Dataset loaded: 8677 images.


  scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():


Epoch 1, Loss: 1.3057, Accuracy: 0.6929
Epoch 2, Loss: 0.5473, Accuracy: 0.8514
Epoch 3, Loss: 0.3748, Accuracy: 0.8948
Epoch 4, Loss: 0.3161, Accuracy: 0.9106
Epoch 5, Loss: 0.2800, Accuracy: 0.9167
Epoch 6, Loss: 0.2018, Accuracy: 0.9405
Epoch 7, Loss: 0.1950, Accuracy: 0.9425
Epoch 8, Loss: 0.1569, Accuracy: 0.9523
Epoch 9, Loss: 0.1657, Accuracy: 0.9476
Epoch 10, Loss: 0.1705, Accuracy: 0.9495
Epoch 11, Loss: 0.1215, Accuracy: 0.9636
Epoch 12, Loss: 0.1079, Accuracy: 0.9676
Epoch 13, Loss: 0.1287, Accuracy: 0.9617
Epoch 14, Loss: 0.1061, Accuracy: 0.9695
Epoch 15, Loss: 0.1075, Accuracy: 0.9702
Final ResNet Accuracy: 0.9702
Epoch 1, Loss: 5.5822, Accuracy: 0.0089
Epoch 2, Loss: 2.2046, Accuracy: 0.5432
Epoch 3, Loss: 1.1634, Accuracy: 0.8123
Epoch 4, Loss: 0.7474, Accuracy: 0.9017
Epoch 5, Loss: 0.5225, Accuracy: 0.9464
Epoch 6, Loss: 0.3890, Accuracy: 0.9644
Epoch 7, Loss: 0.3137, Accuracy: 0.9689
Epoch 8, Loss: 0.2621, Accuracy: 0.9703
Epoch 9, Loss: 0.2222, Accuracy: 0.9756
Epoc