In [None]:
# Task_2 - Q_2 - a)

In [None]:
import os
import torch
import torch.optim as optim
import torch.nn as nn
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, TensorDataset
from networks import ConvNet
from utils import get_dataset

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# EL2N Score Function to calculate sample difficulty
def calculate_el2n_score(model, data_loader):
    scores = []
    model.eval()
    with torch.no_grad():
        for inputs, targets in data_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            prob_outputs = torch.softmax(outputs, dim=1)
            el2n_score = ((prob_outputs - torch.nn.functional.one_hot(targets, num_classes=prob_outputs.size(1)).float())**2).sum(dim=1)
            scores.extend(el2n_score.cpu().tolist())
    return scores

# Scheduler for gradual data inclusion
def data_scheduler(train_dataset, scores, initial_ratio=0.5, max_ratio=1.0, num_epochs=20):
    sorted_indices = sorted(range(len(scores)), key=lambda i: scores[i])
    total_samples = int(len(scores) * max_ratio)
    increment = (total_samples - int(len(scores) * initial_ratio)) // num_epochs
    for epoch in range(num_epochs):
        end_index = int(len(scores) * initial_ratio) + epoch * increment
        yield torch.utils.data.Subset(train_dataset, sorted_indices[:end_index])

# Define the training function for distillation
def train_distilled_model(model, synthetic_data_loader, epochs=20, initial_lr=0.01):
    optimizer = optim.SGD(model.parameters(), lr=initial_lr, momentum=0.9, weight_decay=1e-4)
    scheduler = CosineAnnealingLR(optimizer, T_max=epochs)
    criterion = nn.CrossEntropyLoss()
    model.train()
    
    for epoch in range(epochs):
        total_loss, correct = 0, 0
        for inputs, targets in synthetic_data_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            correct += predicted.eq(targets).sum().item()

        accuracy = 100. * correct / len(synthetic_data_loader.dataset)
        scheduler.step()
        print(f"Distillation Epoch [{epoch + 1}/{epochs}], Loss: {total_loss/100:.4f}, Accuracy: {accuracy:.2f}%")
        
    return model

# Apply Dataset Distillation with PAD and create a synthetic dataset
def distill_dataset_with_PAD(model, train_loader, epochs=10):
    el2n_scores = calculate_el2n_score(model, train_loader)
    scheduler = data_scheduler(train_loader.dataset, el2n_scores)
    synthetic_data = []
    synthetic_labels = []
    
    for epoch, subset in enumerate(scheduler):
        print(f"Distillation using subset size {len(subset)} for epoch {epoch + 1}")
        train_subset_loader = DataLoader(subset, batch_size=256, shuffle=True)
        train_distilled_model(model, train_subset_loader, epochs=1)
        
        # Save synthetic images (condensed dataset)
        for inputs, targets in train_subset_loader:
            synthetic_data.append(inputs.cpu())
            synthetic_labels.append(targets.cpu())
    
    synthetic_data = torch.cat(synthetic_data)
    synthetic_labels = torch.cat(synthetic_labels)
    return synthetic_data, synthetic_labels

# Function to train a new model from scratch on synthetic dataset (Stage 2)
def train_model_on_synthetic_data(synthetic_data, synthetic_labels, epochs=20):
    # Convert synthetic data to TensorDataset format
    synthetic_loader = DataLoader(TensorDataset(synthetic_data, synthetic_labels), batch_size=256, shuffle=True)
    model = ConvNet(channel=1, num_classes=10, net_width=128, net_depth=3, net_act='relu', net_norm='batchnorm', net_pooling='avgpooling', im_size=(28, 28)).to(device)
    return train_distilled_model(model, synthetic_loader, epochs)

# Evaluate the model on real test data
def evaluate_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    test_accuracy = 100. * correct / total
    print(f"Test Accuracy on Real Test Data: {test_accuracy:.2f}%")
    return test_accuracy

# Load data and initialize the model
channel, im_size, num_classes, class_names, mean, std, train_dataset, test_dataset, test_loader = get_dataset('MNIST', './mnist_dataset')
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
model_mnist = ConvNet(channel=channel, num_classes=num_classes, net_width=128, net_depth=3, net_act='relu', net_norm='batchnorm', net_pooling='avgpooling', im_size=im_size).to(device)

# Stage 1: Distill the dataset with PAD and create a synthetic dataset
print("Distilling dataset with PAD:")
synthetic_data, synthetic_labels = distill_dataset_with_PAD(model_mnist, train_loader)

# Stage 2: Train a new model on the synthetic data and evaluate on real test data
print("\nTraining new model on the condensed dataset:")
model_synthetic = train_model_on_synthetic_data(synthetic_data, synthetic_labels)

# Evaluate the model trained on synthetic data
print("\nEvaluating model trained on synthetic data:")
evaluate_model(model_synthetic, test_loader)


Distilling dataset with PAD:
Distillation using subset size 30000 for epoch 1
Distillation Epoch [1/1], Loss: 0.3792, Accuracy: 91.22%
Distillation using subset size 31500 for epoch 2
Distillation Epoch [1/1], Loss: 0.0974, Accuracy: 97.87%
Distillation using subset size 33000 for epoch 3
Distillation Epoch [1/1], Loss: 0.0771, Accuracy: 98.30%
Distillation using subset size 34500 for epoch 4
Distillation Epoch [1/1], Loss: 0.0622, Accuracy: 98.73%
Distillation using subset size 36000 for epoch 5
Distillation Epoch [1/1], Loss: 0.0577, Accuracy: 98.86%
Distillation using subset size 37500 for epoch 6
Distillation Epoch [1/1], Loss: 0.0535, Accuracy: 98.96%
Distillation using subset size 39000 for epoch 7
Distillation Epoch [1/1], Loss: 0.0485, Accuracy: 99.12%
Distillation using subset size 40500 for epoch 8
Distillation Epoch [1/1], Loss: 0.0458, Accuracy: 99.19%
Distillation using subset size 42000 for epoch 9
Distillation Epoch [1/1], Loss: 0.0434, Accuracy: 99.20%
Distillation usin

99.44

In [None]:
# Task_2 - Q_3 

In [None]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from networks import ConvNet  
from utils import get_dataset
from torchinfo import summary  

# Define FLOPs calculation function
def calculate_flops(model, input_size):
    model_summary = summary(model, input_size=input_size, verbose=0)
    flops = model_summary.total_mult_adds  # Total multiply-adds
    return flops

# Initialize synthetic data as a parameter to be optimized
def initialize_synthetic_data(input_shape, num_samples):
    # Ensure input_shape is a tuple of integers
    if not isinstance(input_shape, tuple) or not all(isinstance(dim, int) for dim in input_shape):
        raise ValueError("input_shape must be a tuple of integers")

    # Create synthetic data on the specified device
    synthetic_data = torch.randn(num_samples, *input_shape, device=device, requires_grad=True)
    return synthetic_data

# Define the DATM function
def difficulty_aligned_trajectory_matching(model, synthetic_data, real_data_trajectory, ipc, start_epoch, end_epoch):
    
    # Define bounds based on IPC, start with early trajectories for low IPC
    lower_bound, upper_bound = (0, min(10, end_epoch)) if ipc < 10 else (start_epoch, end_epoch)
    
    # Create optimizer for synthetic data
    optimizer = optim.SGD([synthetic_data], lr=0.01)
    criterion = nn.MSELoss()
    
    for epoch in range(lower_bound, upper_bound):
        optimizer.zero_grad()
        
        # Forward pass on synthetic data
        synthetic_output = model(synthetic_data)
        
        # Ensure real_output has the same batch size as synthetic_output by slicing
        real_output = real_data_trajectory[epoch].to(device)
        if real_output.shape[0] > synthetic_output.shape[0]:
            real_output = real_output[:synthetic_output.shape[0]]
        
        # Matching loss: minimize distance between synthetic and real model parameters
        matching_loss = criterion(synthetic_output, real_output)
        matching_loss.backward()
        optimizer.step()
        
        print(f"DATM Epoch [{epoch + 1}/{upper_bound}], Matching Loss: {matching_loss.item():.4f}")
    return synthetic_data

# Define training with DATM
def train_with_datm(model, train_loader, real_data_trajectory, ipc, synthetic_data, epochs=20, initial_lr=0.01):
    optimizer = optim.SGD(model.parameters(), lr=initial_lr, momentum=0.9, weight_decay=1e-4)
    scheduler = CosineAnnealingLR(optimizer, T_max=epochs)
    criterion = nn.CrossEntropyLoss()
    model.train()
    
    for epoch in range(epochs):
        total_loss, correct = 0, 0
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            correct += predicted.eq(targets).sum().item()

        accuracy = 100. * correct / len(train_loader.dataset)
        scheduler.step()
        print(f"Epoch [{epoch + 1}/{epochs}], Loss: {total_loss/100:.4f}, Accuracy: {accuracy:.2f}%")
        
        # Update synthetic data with DATM every few epochs
        if epoch % 5 == 0:  # Adjust frequency of DATM update as needed
            synthetic_data = difficulty_aligned_trajectory_matching(
                model, synthetic_data, real_data_trajectory, ipc, 0, epochs
            )

    return model, synthetic_data

# Define testing and FLOPs calculation
def test_model(model, test_loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            correct += predicted.eq(targets).sum().item()

    accuracy = 100. * correct / len(test_loader.dataset)
    print(f"Test Accuracy: {accuracy:.2f}%")

    # Calculate FLOPs for a single forward pass
    input_size = (1, *inputs.shape[1:]) 
    flops = calculate_flops(model, input_size)
    print(f"FLOPs for a single forward pass: {flops:.2e}")
    return accuracy, flops

# Function to save real model trajectory
def save_real_data_trajectory(model, train_loader, epochs):
    trajectory = {}
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
    criterion = nn.CrossEntropyLoss()

    model.train()
    for epoch in range(epochs):
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
        
        # Save the model output at the current epoch to the trajectory
        trajectory[epoch] = outputs.detach().clone()  
    return trajectory

# Load MNIST dataset using get_dataset
channel, im_size, num_classes, class_names, mean, std, train_dataset, test_dataset, test_loader = get_dataset('MNIST', './mnist_dataset')

# Set im_size as a tuple for compatibility with ConvNet
im_size = (28, 28)  # for MNIST, since it's 28x28
input_shape = (channel, im_size[0], im_size[1])

# Create DataLoader with batch size of 256
train_loader_mnist = DataLoader(train_dataset, batch_size=256, shuffle=True)

# Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize ConvNet-3 for MNIST
model_mnist = ConvNet(channel=channel, num_classes=num_classes, net_width=128, net_depth=3, net_act='relu', net_norm='batchnorm', net_pooling='avgpooling', im_size=im_size).to(device)

# Save real data trajectory for comparison in DATM
real_data_trajectory = save_real_data_trajectory(model_mnist, train_loader_mnist, epochs=20)

# Initialize synthetic data for MNIST
num_samples = 10  # Adjust based on IPC and dataset requirements
synthetic_data = initialize_synthetic_data(input_shape, num_samples)

# Train and evaluate ConvNet-3 on MNIST with DATM
print("Training ConvNet-3 on MNIST with DATM:")
model_mnist, synthetic_data = train_with_datm(model_mnist, train_loader_mnist, real_data_trajectory, ipc=10, synthetic_data=synthetic_data, epochs=20)
mnist_accuracy, mnist_flops = test_model(model_mnist, test_loader)


Training ConvNet-3 on MNIST with DATM:
Epoch [1/20], Loss: 0.0174, Accuracy: 99.85%
DATM Epoch [1/20], Matching Loss: 25.1445
DATM Epoch [2/20], Matching Loss: 27.5291
DATM Epoch [3/20], Matching Loss: 28.7153
DATM Epoch [4/20], Matching Loss: 23.6025
DATM Epoch [5/20], Matching Loss: 29.7281
DATM Epoch [6/20], Matching Loss: 37.4621
DATM Epoch [7/20], Matching Loss: 42.2686
DATM Epoch [8/20], Matching Loss: 37.3069
DATM Epoch [9/20], Matching Loss: 40.0144
DATM Epoch [10/20], Matching Loss: 39.8121
DATM Epoch [11/20], Matching Loss: 38.1626
DATM Epoch [12/20], Matching Loss: 37.0455
DATM Epoch [13/20], Matching Loss: 40.8722
DATM Epoch [14/20], Matching Loss: 40.0883
DATM Epoch [15/20], Matching Loss: 51.9066
DATM Epoch [16/20], Matching Loss: 29.1467
DATM Epoch [17/20], Matching Loss: 48.2282
DATM Epoch [18/20], Matching Loss: 29.8836
DATM Epoch [19/20], Matching Loss: 45.7632
DATM Epoch [20/20], Matching Loss: 39.7854
Epoch [2/20], Loss: 0.0164, Accuracy: 99.86%
Epoch [3/20], Loss: 