In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import tensorflow_datasets as tfds
import numpy as np
import networkx as nx
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, BatchNorm
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.neighbors import kneighbors_graph

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load TF-Flowers dataset
ds_train = tfds.load('tf_flowers', split='train', as_supervised=True)

# Convert to PyTorch Dataset
class FlowersDataset(Dataset):
    def __init__(self, tf_dataset, transform=None):
        self.data = list(tf_dataset)
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img, label = self.data[idx]
        img = img.numpy()
        img = transforms.ToPILImage()(img)
        if self.transform:
            img = self.transform(img)
        return img, label.numpy()

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

flowers_dataset = FlowersDataset(ds_train, transform=transform)
dataloader = DataLoader(flowers_dataset, batch_size=32, shuffle=True)

# Fine-tune ResNet101
resnet = models.resnet101(pretrained=True)
resnet.fc = nn.Linear(resnet.fc.in_features, 5)
resnet.to(device)
optimizer = optim.Adam(resnet.parameters(), lr=0.0005)
criterion = nn.CrossEntropyLoss()

def train_resnet(model, dataloader, optimizer, criterion, epochs=10):
    model.train()
    for epoch in range(epochs):
        total_loss, correct = 0, 0
        for img, label in dataloader:
            img, label = img.to(device), label.to(device)
            optimizer.zero_grad()
            output = model(img)
            loss = criterion(output, label)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            correct += (output.argmax(dim=1) == label).sum().item()
        accuracy = correct / len(flowers_dataset)
        print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataloader):.4f}, Accuracy: {accuracy:.4f}")
    print(f"Final ResNet Accuracy: {accuracy:.4f}")

train_resnet(resnet, dataloader, optimizer, criterion, epochs=10)

# Feature extraction
resnet.fc = nn.Identity()
def extract_features(dataset, model):
    model.eval()
    features, labels = [], []
    with torch.no_grad():
        for img, label in dataset:
            img = img.unsqueeze(0).to(device)
            feat = model(img).cpu().numpy().flatten()
            features.append(feat)
            labels.append(label)
    return np.array(features), np.array(labels)

features, labels = extract_features(flowers_dataset, resnet)
labels = LabelEncoder().fit_transform(labels)

# Build optimized KNN graph
def build_knn_graph(features, k=15):
    adj_matrix = kneighbors_graph(features, k, mode='connectivity', include_self=True).toarray()
    edge_index = np.array(np.nonzero(adj_matrix))
    return torch.tensor(edge_index, dtype=torch.long)

edge_index = build_knn_graph(features)

graph_data = Data(
    x=torch.tensor(features, dtype=torch.float),
    edge_index=edge_index,
    y=torch.tensor(labels, dtype=torch.long)
)

# Define Enhanced GCN
class GCN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.bn1 = BatchNorm(hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.bn2 = BatchNorm(hidden_channels)
        self.conv3 = GCNConv(hidden_channels, out_channels)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.bn1(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index).relu()
        x = self.bn2(x)
        x = self.dropout(x)
        x = self.conv3(x, edge_index)
        return x

# Initialize & Train GCN
gcn = GCN(in_channels=features.shape[1], hidden_channels=512, out_channels=len(set(labels))).to(device)
optimizer = optim.Adam(gcn.parameters(), lr=0.001, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)
loss_fn = nn.CrossEntropyLoss()

def train_gcn(model, data, optimizer, loss_fn, scheduler, epochs=100):
    model.train()
    best_acc = 0.0
    for epoch in range(epochs):
        optimizer.zero_grad()
        out = model(data.x.to(device), data.edge_index.to(device))
        loss = loss_fn(out, data.y.to(device))
        loss.backward()
        optimizer.step()
        scheduler.step()
        acc = (out.argmax(dim=1) == data.y.to(device)).float().mean().item()
        best_acc = max(best_acc, acc)
        print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}, Accuracy: {acc:.4f}")
    print(f"Final GCN Accuracy: {best_acc:.4f}")

train_gcn(gcn, graph_data, optimizer, loss_fn, scheduler, epochs=100)

print("Training complete. Best accuracy achieved during training is displayed above.")

Using device: cuda




Epoch 1, Loss: 0.6416, Accuracy: 0.7777
Epoch 2, Loss: 0.3801, Accuracy: 0.8673
Epoch 3, Loss: 0.2824, Accuracy: 0.9035
Epoch 4, Loss: 0.2488, Accuracy: 0.9114
Epoch 5, Loss: 0.1623, Accuracy: 0.9422
Epoch 6, Loss: 0.1148, Accuracy: 0.9599
Epoch 7, Loss: 0.1470, Accuracy: 0.9534
Epoch 8, Loss: 0.1209, Accuracy: 0.9583
Epoch 9, Loss: 0.1151, Accuracy: 0.9629
Epoch 10, Loss: 0.1117, Accuracy: 0.9643
Final ResNet Accuracy: 0.9643
Epoch 1, Loss: 1.8806, Accuracy: 0.3245
Epoch 2, Loss: 0.0950, Accuracy: 0.9752
Epoch 3, Loss: 0.0724, Accuracy: 0.9796
Epoch 4, Loss: 0.0686, Accuracy: 0.9809
Epoch 5, Loss: 0.0611, Accuracy: 0.9807
Epoch 6, Loss: 0.0543, Accuracy: 0.9834
Epoch 7, Loss: 0.0508, Accuracy: 0.9831
Epoch 8, Loss: 0.0459, Accuracy: 0.9839
Epoch 9, Loss: 0.0466, Accuracy: 0.9864
Epoch 10, Loss: 0.0455, Accuracy: 0.9845
Epoch 11, Loss: 0.0400, Accuracy: 0.9875
Epoch 12, Loss: 0.0396, Accuracy: 0.9875
Epoch 13, Loss: 0.0410, Accuracy: 0.9886
Epoch 14, Loss: 0.0417, Accuracy: 0.9853
Epoc