In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import time
import matplotlib.pyplot as plt
import os
import numpy as np


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])


class FixedRotation:
    def __init__(self, angle):
        self.angle = angle

    def __call__(self, img):
        return transforms.functional.rotate(img, self.angle)


transform_rotate = transforms.Compose([
    FixedRotation(15),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# load CIFAR-10
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=16, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=16, shuffle=False, num_workers=2)

trainset_rotate = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_rotate)
trainloader_rotate = torch.utils.data.DataLoader(trainset_rotate, batch_size=16, shuffle=True, num_workers=2)

testset_rotate = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_rotate)
testloader_rotate = torch.utils.data.DataLoader(testset_rotate, batch_size=16, shuffle=False, num_workers=2)

# ResNet-18
def get_resnet18():
    model = torchvision.models.resnet18(pretrained=False)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, 10)
    model = model.to(device)
    return model


def train_model(model, trainloader, testloader, epochs=10):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    for epoch in range(epochs):
        start_time = time.time()
        running_loss = 0.0
        model.train()
        correct = 0
        total = 0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data[0].to(device), data[1].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_acc = 100 * correct / total
        end_time = time.time()
        epoch_time = end_time - start_time

        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data in testloader:
                images, labels = data[0].to(device), data[1].to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        test_acc = 100 * correct / total

        print(f'Epoch {epoch + 1}, Train Acc: {train_acc:.2f}%, Test Acc: {test_acc:.2f}%, Time: {epoch_time:.2f}s')

    return model


if os.path.exists('base_model.pth'):
    base_model = get_resnet18()
    base_model.load_state_dict(torch.load('base_model.pth'))
    print("Base model loaded from file.")
else:
    base_model = get_resnet18()
    base_model = train_model(base_model, trainloader, testloader)
    torch.save(base_model.state_dict(), 'base_model.pth')
    print("Base model trained and saved.")

if os.path.exists('finetuned_model.pth'):
    finetuned_model = get_resnet18()
    finetuned_model.load_state_dict(torch.load('finetuned_model.pth'))
    print("Finetuned model loaded from file.")
else:
    finetuned_model = get_resnet18()
    finetuned_model = train_model(finetuned_model, trainloader_rotate, testloader_rotate)
    torch.save(finetuned_model.state_dict(), 'finetuned_model.pth')
    print("Finetuned model trained and saved.")


def compute_task_vector(pretrained_model, finetuned_model):
    task_vector = []
    param_shapes = []  # Store parameter shapes

    for p_pre, p_fine in zip(pretrained_model.parameters(), finetuned_model.parameters()):
        param_shapes.append(p_pre.shape)  # Save the shape of each parameter
        task_vector.append((p_fine.data - p_pre.data).detach().cpu().numpy())

    flattened_task_vector = np.concatenate([p.flatten() for p in task_vector])
    return flattened_task_vector, param_shapes


def localize_and_stitch(model, pretrained_model, task_vector, param_shapes, sparsity=0.01):
    abs_vector = torch.abs(torch.tensor(task_vector))
    k = int(sparsity * abs_vector.numel())
    topk_indices = abs_vector.topk(k).indices
    mask = torch.zeros_like(abs_vector)
    mask[topk_indices] = 1
    offset = 0
    for param, shape in zip(model.parameters(), param_shapes):
        numel = np.prod(shape)
        task_slice = torch.tensor(task_vector[offset:offset + numel]).view(shape)
        mask_slice = mask[offset:offset + numel].clone().detach().view(shape)
        param.data += (task_slice * mask_slice).to(param.device)
        offset += numel
    return model


# k means top - k %
k_values = [i / 20 for i in range(0, 21)]
original_test_accs = []
rotated_test_accs = []

for k in k_values:
    task_vector, param_shapes = compute_task_vector(base_model, finetuned_model)
    merged_model = get_resnet18()
    merged_model.load_state_dict(base_model.state_dict())
    merged_model = localize_and_stitch(merged_model, base_model, task_vector, param_shapes, sparsity=k)

    # evaluate
    merged_model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = merged_model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    original_test_acc = 100 * correct / total
    original_test_accs.append(original_test_acc)

    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader_rotate:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = merged_model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    rotated_test_acc = 100 * correct / total
    rotated_test_accs.append(rotated_test_acc)

    del merged_model
    torch.cuda.empty_cache()

plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(k_values, original_test_accs, marker='o')
plt.title('Original Test Set Accuracy')
plt.xlabel('k')
plt.ylabel('Accuracy (%)')

plt.subplot(1, 2, 2)
plt.plot(k_values, rotated_test_accs, marker='o')
plt.title('Rotated Test Set Accuracy')
plt.xlabel('k')
plt.ylabel('Accuracy (%)')

plt.tight_layout()
plt.show()
    