In [None]:
# Import
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader, Subset
from timm import create_model
from sklearn.cluster import KMeans
import numpy as np
import matplotlib.pyplot as plt
import warnings

# dataset path
data_path = '/Volumes/WD Harddisk/Desktop/Masterss/3rd Sem (WS 24-25)/Project/Dataset/tiny-imagenet-200'

In [None]:
# Data transformations
transform = transforms.Compose([
    # Resize to 64x64 pixels
    transforms.Resize((64, 64)), 
    # Convert images to PyTorch tensors
    transforms.ToTensor(),  
    # Normalize for ImageNet
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  
])

# Load the dataset
train_dir = f"{data_path}/train"
val_dir = f"{data_path}/val"

train_data = datasets.ImageFolder(train_dir, transform=transform)
val_data = datasets.ImageFolder(val_dir, transform=transform)

# Reducing dataset to 50 classes 
num_classes = 50
selected_classes = sorted(train_data.classes)[:num_classes]  # Select the first 50 classes

def filter_dataset(dataset, allowed_classes):
    indices = [i for i, (_, label) in enumerate(dataset) if dataset.classes[label] in allowed_classes]
    return Subset(dataset, indices)

# Filtering datasets to include only selected classes
train_data = filter_dataset(train_data, selected_classes)
val_data = filter_dataset(val_data, selected_classes)

print(f"Dataset filtered to {num_classes} classes.")
print(f"Training data size: {len(train_data)}, Validation data size: {len(val_data)}")


In [None]:
def get_loader(classes_subset, dataset, batch_size=32):
    #     Generate DataLoader for specific class subsets.
    indices = [i for i, (_, label) in enumerate(dataset) if label in classes_subset]
    subset = Subset(dataset, indices)
    return DataLoader(subset, batch_size=batch_size, shuffle=True)

val_loader = DataLoader(val_data, batch_size=32, shuffle=False)


In [None]:
class NFNetWrapper(nn.Module):
    def __init__(self, num_classes):
        super(NFNetWrapper, self).__init__()
        # Loading NFNet with specified number of output classes
        self.model = create_model('nfnet_f0', pretrained=False, num_classes=num_classes)
        
    def forward(self, x):
        return self.model(x)



In [None]:
# Alphabetical and random orders
alphabetical_order = list(range(num_classes))
random_order = np.random.permutation(num_classes).tolist()

# Generate dissimilar order using clustering
def get_dissimilar_order(dataset):
    from torchvision.models import resnet18, ResNet18_Weights

    # Feature extraction
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        resnet = resnet18(weights=ResNet18_Weights.DEFAULT)
    resnet = nn.Sequential(*list(resnet.children())[:-1])  # Remove the classification layer
    resnet.eval()
    resnet.to(device)

    # Extract features for each class
    features = []
    for class_idx in range(num_classes):
        loader = get_loader([class_idx], dataset, batch_size=16)
        class_features = []
        for images, _ in loader:
            images = images.to(device)
            with torch.no_grad():
                output = resnet(images).view(images.size(0), -1)
                class_features.append(output.mean(dim=0).cpu().numpy())
        features.append(np.mean(class_features, axis=0))

    # Perform clustering
    kmeans = KMeans(n_clusters=10, random_state=42).fit(features)
    return sorted(range(num_classes), key=lambda i: kmeans.labels_[i])

device = torch.device('cpu')
print("Generating dissimilar order...")
dissimilar_order = get_dissimilar_order(train_data)

print("Class orders created: Alphabetical, Random, and Dissimilar.")


In [None]:
def train_incrementally(net, train_data, val_loader, order, num_epochs=3):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.0005)
    all_acc = []

    for step in range(2, num_classes + 2, 10):  # Add classes in steps of 10
        current_classes = order[:step]
        current_class_names = [selected_classes[i] for i in current_classes]
        print(f"Training with classes: {current_class_names}")

        train_loader = get_loader(current_classes, train_data)

        # Training loop
        net.train()
        for epoch in range(num_epochs):
            for images, labels in train_loader:
                images, labels = images.to(device), labels.to(device)
                optimizer.zero_grad()
                outputs = net(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

        # Validation after each step
        acc = validate(net, val_loader)
        all_acc.append(acc)
        print(f"Accuracy after training on {len(current_classes)} classes: {acc:.2f}%")
    return all_acc

def validate(net, loader):
    net.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = outputs.max(1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
    return 100 * correct / total


In [None]:
print("\nTraining with Alphabetical Order")
net = NFNetWrapper(num_classes=num_classes).to(device)
acc_alphabetical = train_incrementally(net, train_data, val_loader, alphabetical_order)

print("\nTraining with Random Order")
net = NFNetWrapper(num_classes=num_classes).to(device)
acc_random = train_incrementally(net, train_data, val_loader, random_order)

print("\nTraining with Dissimilar Order")
net = NFNetWrapper(num_classes=num_classes).to(device)
acc_dissimilar = train_incrementally(net, train_data, val_loader, dissimilar_order)


In [None]:
plt.plot(range(2, num_classes + 2, 10), acc_alphabetical, label='Alphabetical Order', marker='o')
plt.plot(range(2, num_classes + 2, 10), acc_random, label='Random Order', marker='o')
plt.plot(range(2, num_classes + 2, 10), acc_dissimilar, label='Dissimilar Order', marker='o')

plt.xlabel('Number of Classes')
plt.ylabel('Validation Accuracy (%)')
plt.title('Validation Accuracy vs Number of Classes')
plt.legend()
plt.show()