In [None]:
from torchvision.datasets import FGVCAircraft
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

# remove copyright banner
class RemoveCopyrightBanner(object):
    def __call__(self, img):
        width, height = img.size
        return img.crop((0, 0, width, height - 20))

transform = transforms.Compose([
    RemoveCopyrightBanner(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create the FGVC Aircraft dataset instance
train_dataset = FGVCAircraft(
    root='./data', 
    split='train',              # Options: 'train', 'val', 'trainval', 'test'
    annotation_level='variant',    # Options: 'variant', 'family', 'manufacturer'
    transform=transform, 
    download=True
)

val_dataset = FGVCAircraft(
    root='./data', 
    split='val',           
    annotation_level='variant', 
    transform=transform, 
    download=True
)

test_dataset = FGVCAircraft(
    root='./data', 
    split='test',             
    annotation_level='variant',  
    transform=transform, 
    download=True
)

In [None]:
# function to show images
def show_images(train_dataset, num_images=5):
    fig, axes = plt.subplots(1, num_images, figsize=(15, 5))
    for i in range(num_images):
        image, label = train_dataset[i]
        image = image.permute(1, 2, 0)  # convert from CxHxW to HxWxC
        axes[i].imshow(image)
        axes[i].set_title(f'Label: {label}')
        axes[i].axis('off')
    plt.show()

show_images(train_dataset, num_images=5)

# Create Dataset

In [None]:
from collections import defaultdict
import torch
from tqdm import tqdm

def group_task_indices(dataset):
    """
    Task 0: 0-9, Task 1: 10-19, ..., Task 9: 90-99
    Output a dictionary where keys are task indices and values are lists of image indices.
    For example, task_dict[0] will contain indices of images with labels 0-9.
    """
    task_dict = defaultdict(list)
    for idx, (_, label) in tqdm(enumerate(dataset), total=len(dataset)):
        for i in range((label // 10), 10):
            task_dict[i].append(idx)
    return task_dict
    
train_task_idxs = group_task_indices(train_dataset)
val_task_idxs = group_task_indices(val_dataset)
test__idxs = group_task_indices(test_dataset)

In [None]:
# from torch.utils.data import Subset
# train_subset = Subset(train_dataset, train_task_idxs[0])

# # initialize dataloaders with task 0
# train_loader = torch.utils.data.DataLoader(
#     train_subset, batch_size=32, shuffle=True, num_workers=4
# )
# # initalize val_loader with task 0
# val_subset = Subset(val_dataset, val_task_idxs[0])
# val_loader = torch.utils.data.DataLoader(
#     val_subset, batch_size=32, shuffle=False, num_workers=4
# )
# # initalize test_loader with task 0
# test_subset = Subset(test_dataset, test__idxs[0])
# test_loader = torch.utils.data.DataLoader(
#     test_subset, batch_size=32, shuffle=False, num_workers=4
# )

## Training Code

In [None]:
def val_net(net_to_val, val_loader):
    net_to_val.eval()
    loss = 0

    criterion = torch.nn.CrossEntropyLoss()

    with torch.no_grad():
        for img, label in tqdm.tqdm(val_loader, desc="Validating"):

            # Get the input images and their corresponding labels
            img, label = img.cuda(), label.cuda()

            # Forward pass: Get predictions from the model
            outputs = net_to_val(img)

            # compute SmoothL1Losss
            loss += criterion(outputs, label)

        return loss / len(val_loader)

def train_net(max_epochs, net_to_train, opt, train_loader, val_loader, save_path=None):
    criterion = torch.nn.CrossEntropyLoss()

    train_losses = []
    val_losses = []
    # prepare the net for training
    net_to_train.cuda()

    # loop over the dataset multiple times
    for epoch in range(max_epochs):
        net_to_train.train()

        running_loss = 0.0

        # train on batches of data
        for imgs, labels in tqdm.tqdm(train_loader, unit='batch'):
            
            imgs, labels = imgs.cuda(), labels.cuda()
            
            # zero the parameter gradients
            opt.zero_grad()

            # prediction
            outputs = net_to_train(imgs)

            # compute the loss
            loss = criterion(outputs, labels)
            loss.backward()
            opt.step()

            # print loss statistics
            running_loss += loss.item()

        avg_loss = running_loss / len(train_loader)
        train_losses.append(avg_loss)
        val_losses.append(val_net(net_to_train, val_loader))

        # save checkpoint
        if save_path:
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': net_to_train.state_dict(),
                'optimizer_state_dict': opt.state_dict(),
                'loss': avg_loss,
            }, save_path)
            print(f"Checkpoint saved to {save_path}")
        
        # early stopping
        if len(val_losses) > 1 and val_losses[-1] > val_losses[-2]:
            break
    
    print("finished training")
    return train_losses, val_losses

In [None]:
import torch.nn as nn
def modify_resnet_head(model, num_classes):
    """
    Modify the last fully connected layer of the ResNet model to match the number of classes.
    """
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, num_classes)
    return model

In [None]:
import torch

def get_test_accuracy(model, test_loader, num_classes):
    model.eval()
    correct_preds = 0
    total = 0
    correct_per_class = [0] * num_classes
    total_per_class = [0] * num_classes

    with torch.no_grad():
        for imgs, labels in tqdm(test_loader, desc="Testing", total=len(test_loader)):
            imgs, labels = imgs.cuda(), labels.cuda()
            output = model(imgs)
            preds = output.argmax(dim=1)

            correct_preds += (preds == labels).sum().item()
            total += labels.size(0)

            # Per-class stats
            for c in range(num_classes):
                correct_per_class[c] += ((preds == c) & (labels == c)).sum().item()
                total_per_class[c] += (labels == c).sum().item()

    overall_acc = correct_preds / total
    per_class_acc = [correct_per_class[c] / total_per_class[c] if total_per_class[c] > 0 else 0.0
                     for c in range(num_classes)]
    return overall_acc, per_class_acc


In [None]:
from torch.utils.data import Subset
from torchvision import models
import torch.optim as optim

# Initialize the model
model = models.resnet18(pretrained=True)
optimizer = optim.Adam(model.parameters(), lr=0.001)

for task in range(10):
    print(f"Training on task {task}...")

    model = modify_resnet_head(model, (task+1) * 10)
    model = model.cuda()

    train_loader = torch.utils.data.DataLoader(
        Subset(train_dataset, train_task_idxs[task]), batch_size=32, shuffle=True, num_workers=4
    )
    val_loader = torch.utils.data.DataLoader(
        Subset(val_dataset, val_task_idxs[task]), batch_size=32, shuffle=False, num_workers=4
    )
    test_loader = torch.utils.data.DataLoader(
        Subset(test_dataset, test__idxs[task]), batch_size=32, shuffle=False, num_workers=4
    )
    
    # Train the model on the current task
    train_losses, val_losses = train_net(10, model, optimizer, train_loader, val_loader)

    # Evaluate the model on the test set
    overall_acc, per_class_acc = get_test_accuracy(model, test_loader, (task+1) * 10)
    print(f"Overall accuracy for task {task}: {overall_acc:.4f}")
    print(f"Per-class accuracy for task {task}: {per_class_acc}")

    # save to text file
    with open(f"accuracies.txt", "w") as f:
        f.write(f"Task {task} - overall accuracy: {overall_acc:.4f}\n")
        f.write(f"Task {task} - perclass accuracy: {per_class_acc}\n")

    # Save the model after training on each task
    torch.save(model.state_dict(), f"model_task_{task}.pth")
    print(f"Model for task {task} saved as model_task_{task}.pth")

In [None]:
# class ContinualLearningMetrics:
#     def __init__(self, num_tasks):
#         """
#         Initialize the metrics tracker for continual learning evaluation.
        
#         Args:
#             num_tasks: The total number of sequential tasks to be learned
#         """
#         self.num_tasks = num_tasks
#         # Track the best accuracy observed for each task
#         self.best_accuracies = [0.0] * num_tasks
#         # Current accuracy for each task after most recent training
#         self.current_accuracies = [0.0] * num_tasks
#         # History of accuracies after training on each task
#         # Each entry is a list of accuracies for all tasks after training on a specific task
#         self.accuracy_history = []

#     def update_accuracies(self, accuracies_after_task):
#         """
#         Update the current accuracies after training on a new task.
        
#         Args:
#             accuracies_after_task: List or array of accuracies for all tasks 
#                                    after training on the current task
#         """
#         # Update current accuracies
#         self.current_accuracies = accuracies_after_task
#         # Store the accuracy snapshot in history
#         self.accuracy_history.append(accuracies_after_task.copy())

#         # Update best accuracies for each task if current is better
#         for i in range(self.num_tasks):
#             if accuracies_after_task[i] > self.best_accuracies[i]:
#                 self.best_accuracies[i] = accuracies_after_task[i]

#     def average_accuracy(self):
#         """
#         Compute average accuracy across all tasks at the current time.
        
#         Returns:
#             Float: The average of current accuracies across all tasks
#         """
#         return sum(self.current_accuracies) / self.num_tasks

#     def forgetting(self):
#         """
#         Compute forgetting measure for each task.
        
#         Forgetting for task i = best accuracy on task i - current accuracy on task i
        
#         Returns:
#             Float: Average forgetting across all tasks except the last one 
#                   (since the last task can't be forgotten yet)
#         """
#         forgetting_per_task = []
#         # Calculate forgetting for all tasks except the most recent one
#         for i in range(self.num_tasks - 1):
#             # Forgetting is the drop from best to current accuracy (minimum 0)
#             forgetting_per_task.append(max(0, self.best_accuracies[i] - self.current_accuracies[i]))
        
#         # Handle the case where there's only one task
#         if len(forgetting_per_task) == 0:
#             return 0.0
            
#         # Return average forgetting
#         return sum(forgetting_per_task) / len(forgetting_per_task)

#     def backward_transfer(self):
#         """
#         Compute backward transfer (BWT).
        
#         BWT = average of (accuracy after learning last task - accuracy after learning task i) 
#               for all tasks i < last task
              
#         Positive BWT indicates learning new tasks improves performance on old tasks.
#         Negative BWT indicates forgetting.
        
#         Returns:
#             Float or None: Average backward transfer value, or None if not enough data
#         """
#         # Check if we have enough data to compute BWT
#         if len(self.accuracy_history) < self.num_tasks:
#             return None

#         # Get accuracies after training on the final task
#         last_task_accuracies = self.accuracy_history[-1]
#         bwt_values = []
        
#         # For each previous task, calculate the difference between:
#         # - its accuracy after training on all tasks, and
#         # - its accuracy immediately after it was trained
#         for i in range(self.num_tasks - 1):
#             # Compare current performance to performance right after learning task i
#             bwt_values.append(last_task_accuracies[i] - self.accuracy_history[i][i])

#         # Return average backward transfer
#         return sum(bwt_values) / len(bwt_values)
    
#     def detailed_metrics(self):
#         """
#         Return a dictionary with all metrics for easier reporting.
        
#         Returns:
#             dict: Dictionary containing all computed metrics
#         """
#         return {
#             "average_accuracy": self.average_accuracy(),
#             "forgetting": self.forgetting(),
#             "backward_transfer": self.backward_transfer(),
#             "current_accuracies": self.current_accuracies.copy(),
#             "best_accuracies": self.best_accuracies.copy()
#         }


# # Example usage:
# if __name__ == "__main__":
#     # Initialize with 3 sequential tasks
#     num_tasks = 3
#     metrics = ContinualLearningMetrics(num_tasks)

#     # Simulate training and evaluation after each task
#     # Format: [accuracy_task1, accuracy_task2, accuracy_task3]
    
#     # After training on task 1, only task 1 has been seen
#     metrics.update_accuracies([80.0, 0.0, 0.0])
#     print("After Task 1:")
#     print(f"Average Accuracy: {metrics.average_accuracy():.2f}%")
#     print(f"Forgetting: {metrics.forgetting():.2f}%")
#     print(f"BWT: {metrics.backward_transfer()}")
    
#     # After training on task 2, task 1's accuracy dropped slightly
#     metrics.update_accuracies([75.0, 85.0, 0.0])
#     print("\nAfter Task 2:")
#     print(f"Average Accuracy: {metrics.average_accuracy():.2f}%")
#     print(f"Forgetting: {metrics.forgetting():.2f}%")
#     print(f"BWT: {metrics.backward_transfer()}")
    
#     # After training on task 3, both previous tasks show some forgetting
#     metrics.update_accuracies([70.0, 80.0, 90.0])
#     print("\nAfter Task 3:")
#     print(f"Average Accuracy: {metrics.average_accuracy():.2f}%")
#     print(f"Forgetting: {metrics.forgetting():.2f}%")
#     print(f"BWT: {metrics.backward_transfer():.2f}%")
    
#     # Get all metrics as a dictionary
#     all_metrics = metrics.detailed_metrics()
#     print("\nDetailed Metrics:")
#     for key, value in all_metrics.items():
#         print(f"{key}: {value}")
