Author: Anirudh Iyer
Contact Information: aniiyer@iu.edu
Date: 9/19/2024

Topic: Continual learning through incremental data pipeline on CIFAR-10 using pretrained inceptionv3
    
Description of the Work:
 
    Data Preparation:

    The dataset used is CIFAR-10, and it is divided into smaller subsets with 2 classes each.
    train_subsets and test_subsets are created for continual learning, splitting the data in increments.
    
    Training Function (train_increment):

    This function trains the model in increments using the provided data loader.
    It calculates the loss using a custom Magnet Loss function, optimizes the model, and prints batch losses.

    Visualization and Reporting:

    After each training increment, plots of the progress and embeddings are created.
    A classification report and confusion matrix are printed for the performance analysis.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import models 
import torchvision.transforms as transforms
from torch.utils.data import Subset, ConcatDataset
import numpy as np
from sklearn.cluster import KMeans
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
import warnings
warnings.filterwarnings('ignore')

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class AdaptiveMagnetLoss(nn.Module):
    def __init__(self, alpha=1.0):
        super(AdaptiveMagnetLoss, self).__init__()
        self.alpha = alpha

    def forward(self, embeddings, labels, m, d):
        """
        Compute the Magnet Loss.
        
        Args:
        embeddings: Tensor of shape (batch_size, embedding_dim)
        labels: Tensor of shape (batch_size,)
        m: Number of clusters per class (will be adjusted based on batch size)
        d: Minimum examples per cluster (will be adjusted based on batch size)
        
        Returns:
        loss: Scalar tensor with the computed loss
        """
        device = embeddings.device
        batch_size, embedding_dim = embeddings.size()
        
        # Adjust m and d based on the actual batch size
        adjusted_m = min(m, batch_size // 2)  # Ensure at least 2 samples per cluster
        adjusted_d = max(2, batch_size // adjusted_m)  # Ensure at least 2 samples per cluster
        
        # Create cluster assignments
        unique_labels = torch.unique(labels)
        num_classes = len(unique_labels)
        
        cluster_means = []
        cluster_labels = []
        
        for label in unique_labels:
            label_mask = labels == label
            label_embeddings = embeddings[label_mask]
            
            # Adjust number of clusters for this class
            class_size = label_embeddings.size(0)
            class_m = min(adjusted_m, class_size)
            
            # Simple clustering: divide samples evenly among clusters
            class_clusters = torch.arange(class_size) % class_m
            
            for i in range(class_m):
                cluster_mask = class_clusters == i
                if torch.sum(cluster_mask) > 0:
                    cluster_means.append(torch.mean(label_embeddings[cluster_mask], dim=0))
                    cluster_labels.append(label)
        
        cluster_means = torch.stack(cluster_means)
        cluster_labels = torch.tensor(cluster_labels, device=device)
        
        # Compute variance (use a small epsilon to avoid division by zero)
        variance = torch.sum((embeddings - embeddings.mean(dim=0))**2) / (batch_size - 1)
        var_normalizer = -1 / (2 * variance.clamp(min=1e-10))
        
        # Compute distances
        distances = torch.cdist(embeddings, cluster_means)**2
        
        # Compute intra-cluster and inter-cluster costs
        same_label_mask = labels.unsqueeze(1) == cluster_labels.unsqueeze(0)
        intra_cluster_costs = torch.where(same_label_mask, distances, torch.tensor(float('inf'), device=device))
        inter_cluster_costs = torch.where(same_label_mask, torch.tensor(float('inf'), device=device), distances)
        
        # Compute loss
        intra_cluster_costs, _ = intra_cluster_costs.min(dim=1)
        inter_cluster_costs, _ = inter_cluster_costs.min(dim=1)
        
        loss = F.relu(intra_cluster_costs - inter_cluster_costs + self.alpha)
        return loss.mean()

In [3]:
class MagnetLoss(nn.Module): # Not used
    def __init__(self, alpha=1.0):
        super(MagnetLoss, self).__init__()
        self.alpha = alpha

    def forward(self, r, classes, m, d):
        device = r.device
        classes = classes.to(device)
        
        # Adjust m and d based on the actual batch size
        batch_size = r.size(0)
        adjusted_m = max(2, min(m, batch_size // d))
        adjusted_d = batch_size // adjusted_m

        clusters = torch.sort(torch.arange(0, float(adjusted_m)).repeat(adjusted_d))[0].to(device)
        cluster_classes = classes[0:adjusted_m*adjusted_d:adjusted_d]

        # Compute variance
        variance = torch.sum((r - torch.mean(r, dim=0))**2) / (r.size(0) - 1)
        var_normalizer = -1 / (2 * variance**2)

        # Compute cluster means
        cluster_means = torch.stack([torch.mean(r[clusters == i], dim=0) for i in range(adjusted_m)])

        # Compute distances
        sample_costs = torch.cdist(cluster_means, r.unsqueeze(1)).squeeze()

        # Compute intra-cluster costs
        intra_cluster_mask = (clusters.unsqueeze(1) == torch.arange(adjusted_m, device=device).unsqueeze(0)).float()
        intra_cluster_costs = torch.sum(intra_cluster_mask * sample_costs, dim=1)

        # Compute numerator and denominator
        numerator = torch.exp(var_normalizer * intra_cluster_costs - self.alpha)
        denominator_mask = (classes.unsqueeze(1) != cluster_classes.unsqueeze(0)).float()
        denominator = torch.sum(denominator_mask * torch.exp(var_normalizer * sample_costs), dim=1)

        # Compute loss
        epsilon = 1e-8
        losses = torch.relu(-torch.log(numerator / (denominator + epsilon) + epsilon))
        total_loss = torch.mean(losses)

        return total_loss, losses.detach()

class ClusterBatchBuilder: 
    def __init__(self, labels, k, m, d):
        self.num_classes = len(np.unique(labels))
        self.labels = labels
        self.k = k
        self.m = min(m, self.num_classes)
        self.d = d
        self.centroids = None
        self.assignments = np.zeros_like(labels, int)
        self.cluster_assignments = {}
        self.cluster_classes = np.repeat(range(self.num_classes), k)
        self.example_losses = None
        self.cluster_losses = None

    def update_clusters(self, rep_data, max_iter=20):
        if self.centroids is None:
            self.centroids = np.zeros([self.num_classes * self.k, rep_data.shape[1]])

        for c in range(self.num_classes):
            class_mask = self.labels == c
            class_examples = rep_data[class_mask]
            
            if len(class_examples) == 0:
                continue
            
            n_clusters = min(self.k, len(class_examples))
            
            if n_clusters == 1:
                self.centroids[self.k * c] = class_examples[0]
                self.assignments[class_mask] = self.k * c
            else:
                kmeans = KMeans(n_clusters=n_clusters, init='k-means++', n_init=1, max_iter=max_iter)
                kmeans.fit(class_examples)

                start = self.k * c
                stop = start + n_clusters
                self.centroids[start:stop] = kmeans.cluster_centers_
                self.assignments[class_mask] = start + kmeans.labels_

        self.cluster_assignments = {
            cluster: np.where(self.assignments == cluster)[0]
            for cluster in range(self.k * self.num_classes)
            if np.any(self.assignments == cluster)
        }

    def update_losses(self, indexes, losses):
        if self.example_losses is None:
            self.example_losses = np.zeros_like(self.labels, float)
            self.cluster_losses = np.zeros([self.k * self.num_classes], float)

        self.example_losses[indexes] = losses
        clusters = np.unique(self.assignments[indexes])
        for cluster in clusters:
            cluster_inds = self.assignments == cluster
            self.cluster_losses[cluster] = np.mean(self.example_losses[cluster_inds])

    def gen_batch(self):
        if not self.cluster_assignments:
            raise ValueError("No clusters available. Make sure update_clusters has been called with non-empty data.")

        available_clusters = list(self.cluster_assignments.keys())

        clusters = []
        batch_class_inds = []

        while len(clusters) < self.m and len(available_clusters) > 0:
            if self.cluster_losses is not None:
                p = np.array([self.cluster_losses[c] for c in available_clusters])
                if np.all(p == 0) or np.any(np.isnan(p)):
                    next_cluster = np.random.choice(available_clusters)
                else:
                    p = p / np.sum(p)
                    next_cluster = np.random.choice(available_clusters, p=p)
            else:
                next_cluster = np.random.choice(available_clusters)

            if self.cluster_classes[next_cluster] not in batch_class_inds:
                clusters.append(next_cluster)
                batch_class_inds.extend([self.cluster_classes[next_cluster]] * self.d)

            available_clusters.remove(next_cluster)

        batch_indexes = []
        for c in clusters:
            cluster_examples = self.cluster_assignments[c]
            if len(cluster_examples) < self.d:
                x = np.random.choice(cluster_examples, self.d, replace=True)
            else:
                x = np.random.choice(cluster_examples, self.d, replace=False)
            batch_indexes.extend(x)

        return np.array(batch_indexes), np.array(batch_class_inds)
    


In [4]:
def load_data(batch_size=64):
    transform = transforms.Compose([
        transforms.Resize(299),
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(299, padding=4),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
    ])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

    # Divide the dataset into 5 sets, each containing 2 classes
    class_indices = [np.where(np.array(trainset.targets) == i)[0] for i in range(10)]
    test_class_indices = [np.where(np.array(testset.targets) == i)[0] for i in range(10)]

    train_subsets = []
    test_subsets = []
    class_indices_per_increment = []
    for i in range(0, 10, 2):
        train_subset_indices = np.concatenate((class_indices[i], class_indices[i+1]))
        test_subset_indices = np.concatenate((test_class_indices[i], test_class_indices[i+1]))
        train_subsets.append(Subset(trainset, train_subset_indices))
        test_subsets.append(Subset(testset, test_subset_indices))
        class_indices_per_increment.append([i, i+1])

    return train_subsets, test_subsets, class_indices_per_increment

def train_increment(model, trainloader, optimizer, magnet_loss, k, m, d, device):
    model.train()
    running_loss = 0.0
    
    for i, (inputs, labels) in enumerate(tqdm(trainloader, desc="Training")):
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        
        if isinstance(outputs, tuple):
            outputs = outputs[0]
        
        loss = magnet_loss(outputs, labels, m, d)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        if i % 10 == 0:
            print(f"Batch {i}, Loss: {loss.item():.4f}")
    
    return running_loss / len(trainloader)

# def evaluate(model, testloader, magnet_loss, m, d, device):
#     model.eval()
#     total_loss = 0.0
#     correct = 0
#     total = 0
#     all_preds = []
#     all_labels = []

#     with torch.no_grad():
#         for inputs, labels in tqdm(testloader, desc="Evaluating"):
#             inputs, labels = inputs.to(device), labels.to(device)
#             outputs = model(inputs)
            
#             if isinstance(outputs, tuple):
#                 outputs = outputs[0]
            
#             # Adjust m and d based on the actual batch size
#             batch_size = outputs.size(0)
#             adjusted_m = max(2, min(m, batch_size // d))
#             adjusted_d = batch_size // adjusted_m

#             loss, _ = magnet_loss(outputs, labels, adjusted_m, adjusted_d)
#             total_loss += loss.item()
            
#             _, predicted = torch.max(outputs.data, 1)
#             total += labels.size(0)
#             correct += (predicted == labels).sum().item()
            
#             all_preds.extend(predicted.cpu().numpy())
#             all_labels.extend(labels.cpu().numpy())

#     accuracy = 100 * correct / total
#     avg_loss = total_loss / len(testloader)
#     unique_classes = np.unique(np.concatenate((all_preds, all_labels)))
#     return accuracy, avg_loss, all_preds, all_labels, unique_classes

In [5]:
def evaluate(model, testloader, magnet_loss, m, d, device):
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs, labels in tqdm(testloader, desc="Evaluating"):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            
            if isinstance(outputs, tuple):
                outputs = outputs[0]
            
            # Compute loss
            loss = magnet_loss(outputs, labels, m, d)
            total_loss += loss.item()
            
            # Compute accuracy
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    accuracy = 100 * correct / total
    avg_loss = total_loss / len(testloader)
    unique_classes = np.unique(np.concatenate((all_preds, all_labels)))
    return accuracy, avg_loss, all_preds, all_labels, unique_classes

In [6]:
#Visuals

def plot_training_progress(train_losses, val_accuracies, save_path='training_progress.png'):
    fig, ax1 = plt.subplots(figsize=(10, 6))
    
    color = 'tab:red'
    ax1.set_xlabel('Increments')
    ax1.set_ylabel('Training Loss', color=color)
    ax1.plot(train_losses, color=color)
    ax1.tick_params(axis='y', labelcolor=color)
    
    ax2 = ax1.twinx()
    color = 'tab:blue'
    ax2.set_ylabel('Validation Accuracy', color=color)
    ax2.plot(val_accuracies, color=color)
    ax2.tick_params(axis='y', labelcolor=color)
    
    plt.title('Training Progress')
    fig.tight_layout()
    plt.savefig(save_path)
    plt.close()

def visualize_embeddings(model, dataloader, num_classes, save_path='embeddings_visualization.png'):
    model.eval()
    embeddings = []
    labels = []
    
    with torch.no_grad():
        for inputs, targets in tqdm(dataloader, desc="Computing embeddings"):
            inputs = inputs.to(next(model.parameters()).device)
            outputs = model(inputs)
            if isinstance(outputs, tuple):
                outputs = outputs[0]
            embeddings.append(outputs.cpu().numpy())
            labels.append(targets.numpy())
    
    embeddings = np.vstack(embeddings)
    labels = np.concatenate(labels)
    
    print("Performing t-SNE...")
    tsne = TSNE(n_components=2, random_state=42)
    embeddings_2d = tsne.fit_transform(embeddings)
    
    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], c=labels, cmap='tab10')
    plt.colorbar(scatter)
    plt.title('t-SNE visualization of learned embeddings')
    plt.savefig(save_path)
    plt.close()

def plot_confusion_matrix(all_preds, all_labels, class_names, save_path='confusion_matrix.png'):
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.savefig(save_path)
    plt.close()

In [7]:
def load_model(num_classes=10):
    """Load pre-trained Inception-v3 model."""
    model = models.inception_v3(pretrained=False)
    model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
    checkpoint = torch.load('pretrained_inception_v3_e2b100.pth')
    model.load_state_dict(checkpoint)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    return model.to(device)

Lower epochs due to GPU and time constraints thus increased the learning rate.

1. Due to time constraints only ran till Increment 2 to see if the code is working. 
    Took ~30mins to get accuracy of 80% in the first increment.
2. All the developed images are named with their respective increment.

In [9]:
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Hyperparameters
    batch_size = 256
    k = 10  # clusters per class
    m = 4  # clusters per batch
    d = 8  # examples per cluster
    alpha = 1.0
    learning_rate = 0.001
    num_epochs = 2

    # Load data
    train_subsets, test_subsets, class_indices_per_increment = load_data(batch_size)

    # Initialize model
    model = load_model()
    print("Data and model loaded successfully.")
    
    # Initialize MagnetLoss and optimizer
    magnet_loss = AdaptiveMagnetLoss(alpha).to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Metrics storage
    train_losses = []
    val_accuracies = []

    # Continual learning loop
    for increment in range(5):
        print(f"Increment {increment + 1}/5")

        # Prepare data for current increment
        trainloader = torch.utils.data.DataLoader(train_subsets[increment], batch_size=batch_size, shuffle=True)
        testloader = torch.utils.data.DataLoader(ConcatDataset(test_subsets[:increment+1]), batch_size=batch_size, shuffle=False)

        # Train for multiple epochs
        for epoch in range(num_epochs):
            print(f"Epoch {epoch + 1}/{num_epochs}")
            loss = train_increment(model, trainloader, optimizer, magnet_loss, k, m, d, device)
            
            # Add gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            print(f"Training Loss: {loss:.4f}")
        
        train_losses.append(loss)

        # Evaluate on all seen classes
        accuracy, eval_loss, all_preds, all_labels, unique_classes = evaluate(model, testloader, magnet_loss, m, d, device)
        val_accuracies.append(accuracy)
        print(f"Accuracy after increment {increment + 1}: {accuracy:.2f}%")

        # Save model after each increment
        torch.save(model.state_dict(), f'model_increment_{increment+1}.pth')

        # Visualizations
        plot_training_progress(train_losses, val_accuracies, f'training_progress_increment_{increment+1}.png')
        visualize_embeddings(model, testloader, len(unique_classes), f'embeddings_increment_{increment+1}.png')
        
        # Use the unique classes to determine the class names
        class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
        seen_class_names = [class_names[i] for i in range((increment+1)*2)]
        plot_confusion_matrix(all_preds, all_labels, seen_class_names, f'confusion_matrix_increment_{increment+1}.png')

        # Classification report
        print("\nClassification Report:")
        print(classification_report(all_labels, all_preds, target_names=seen_class_names, labels=range(len(seen_class_names))))

    print("Continual Learning completed.")

if __name__ == "__main__":
    main()

Using device: cuda
Files already downloaded and verified
Files already downloaded and verified
Data and model loaded successfully.
Increment 1/5
Epoch 1/2


Training:   0%|                                                                                 | 0/40 [00:07<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 346.00 MiB (GPU 0; 6.00 GiB total capacity; 14.81 GiB already allocated; 0 bytes free; 15.41 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

As stated in the problem statement above is the model's accuracy on the test sets used till the current increment.

## IGNORE EXTRA CLIPS

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Hyperparameters
    batch_size = 64
    k = 4  # clusters per class
    m = 4  # clusters per batch
    d = 8  # examples per cluster
    alpha = 1.0
    learning_rate = 0.001
    num_epochs = 3  
    # Load data
    train_subsets, test_subsets, class_indices_per_increment = load_data(batch_size)

    # Initialize model
    model = torchvision.models.inception_v3(pretrained=False)
    model.fc = nn.Linear(model.fc.in_features, 10)  # Keep 10 output classes
    model = model.to(device)

    # Initialize MagnetLoss and optimizer
    magnet_loss = MagnetLoss(alpha).to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Metrics storage
    train_losses = []
    val_accuracies = []
    class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

    # Continual learning loop
    for increment in range(5):
        print(f"Increment {increment + 1}/5")

        # Get current class indices
        current_class_indices = class_indices_per_increment[increment]
        current_class_names = [class_names[i] for i in current_class_indices]

        # Prepare data for current increment
        trainloader = torch.utils.data.DataLoader(train_subsets[increment], batch_size=batch_size, shuffle=True)
        testloader = torch.utils.data.DataLoader(ConcatDataset(test_subsets[:increment+1]), batch_size=batch_size, shuffle=False)

        # Compute initial representations for current increment
        model.eval()
        initial_representations = []
        initial_labels = []
        with torch.no_grad():
            for inputs, labels in tqdm(trainloader, desc="Computing initial representations"):
                inputs = inputs.to(device)
                outputs = model(inputs)
                if isinstance(outputs, tuple):
                    outputs = outputs[0]
                initial_representations.append(outputs.cpu().numpy())
                initial_labels.extend(labels.numpy())
        initial_representations = np.vstack(initial_representations)
        initial_labels = np.array(initial_labels)

        # Initialize ClusterBatchBuilder for current increment
        current_labels = np.array([trainset.targets[i] for i in train_subsets[increment].indices])
        batch_builder = ClusterBatchBuilder(current_labels, k, m, d)
        batch_builder.update_clusters(initial_representations)
        
        # Train for multiple epochs
        for epoch in range(num_epochs):
            print(f"Epoch {epoch + 1}/{num_epochs}")
            loss = train_increment(model, trainloader, optimizer, magnet_loss, batch_builder, k, m, d, alpha, device)
            print(f"Training Loss: {loss:.4f}")
        
        train_losses.append(loss)

        # Evaluate on all seen classes
        accuracy, all_preds, all_labels, unique_classes = evaluate(model, testloader, device)
        val_accuracies.append(accuracy)
        print(f"Accuracy after increment {increment + 1}: {accuracy:.2f}%")

        # Save model after each increment
        torch.save(model.state_dict(), f'model_increment_{increment+1}.pth')

        # Visualizations
        plot_training_progress(train_losses, val_accuracies, f'training_progress_increment_{increment+1}.png')
        visualize_embeddings(model, testloader, len(unique_classes), f'embeddings_increment_{increment+1}.png')
        
        # Use the unique classes to determine the class names
        seen_class_names = [class_names[i] for i in range((increment+1)*2)]
        plot_confusion_matrix(all_preds, all_labels, seen_class_names, f'confusion_matrix_increment_{increment+1}.png')

        # Classification report
        print("\nClassification Report:")
        print(classification_report(all_labels, all_preds, target_names=seen_class_names, labels=range(len(seen_class_names))))

    print("Continual Learning completed.")

if __name__ == "__main__":
    main()