In [1]:
from model import network_mnist,naive_train,test_taskwise,test,benchmark,train_stream,test_stream,compute_fisher_information,apply_importance_mask,create_masked_weight_dict,load_non_zero_weights
from torch.utils.data import DataLoader, Subset
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from plot import plot_parameter_importance
import os
import copy 

  from .autonotebook import tqdm as notebook_tqdm


In [10]:
model=network_mnist(256,128)
model_2=network_mnist(256,128)
#print(model)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model_2.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
epochs=5

In [None]:
def method_1_train(model,task_number, epochs,criterion,optimizer,device,weight_dicts):
    experience = train_stream[task_number]
    train_loader = DataLoader(experience.dataset, batch_size=64, shuffle=True)

    model.train()
    for epoch in range(epochs):
        total_loss = 0
        
        for images, labels, *_ in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
       

        #print(f"Task {task_number}, Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}")

In [17]:
os.makedirs('figures', exist_ok=True)
all_tasks_data = {}
weight_dicts=[]
for task in range(2):
    print(f"\n{'='*70}")
    print(f"Training on Task {task}")
    print(f"{'='*70}")

    method_1_train(model, task, epochs, criterion=criterion, optimizer=optimizer, device=device, weight_dicts=weight_dicts)
    acc = test_taskwise(model, task, device)
    print(f"Post-training accuracy on Task {task}: {acc:.2f}%")

    fisher_dict = compute_fisher_information(model, task_number=task, num_samples=500, 
                                             criterion=criterion, device=device)
    percent_list = list(range(50,51))
    accuracy_vs_percent = []
    

    original_weights = {name: param.clone() for name, param in model.state_dict().items()}
    for p in percent_list:
        model, mask_dict = apply_importance_mask(model, fisher_dict, importance_percent=p)
        weight_dicts.append(create_masked_weight_dict(model, mask_dict))
        acc_p = test_taskwise(model, task, device)
        accuracy_vs_percent.append(acc_p)
        model.load_state_dict(original_weights, strict=False)
       
    all_tasks_data[task] = (percent_list, accuracy_vs_percent)
    test(model, device)
    #plot_parameter_importance(percent_list, accuracy_vs_percent, task, save_path=f'figures/task_{task}_importance.png')



Training on Task 0
Accuracy on task 0: 99.95%
Post-training accuracy on Task 0: 99.95%
Accuracy on task 0: 99.95%
Accuracy on task 0: 99.95%
Accuracy on task 1: 0.05%
Accuracy on task 2: 0.00%
Accuracy on task 3: 0.00%
Accuracy on task 4: 0.00%
Average Accuracy: 20.00%

Training on Task 1
Accuracy on task 0: 99.95%
Accuracy on task 1: 0.05%
Accuracy on task 2: 0.00%
Accuracy on task 3: 0.00%
Accuracy on task 4: 0.00%
Average Accuracy: 20.00%
Accuracy on task 0: 99.95%
Accuracy on task 1: 0.05%
Accuracy on task 2: 0.00%
Accuracy on task 3: 0.00%
Accuracy on task 4: 0.00%
Average Accuracy: 20.00%
Accuracy on task 0: 99.95%
Accuracy on task 1: 0.05%
Accuracy on task 2: 0.00%
Accuracy on task 3: 0.00%
Accuracy on task 4: 0.00%
Average Accuracy: 20.00%
Accuracy on task 0: 99.95%
Accuracy on task 1: 0.05%
Accuracy on task 2: 0.00%
Accuracy on task 3: 0.00%
Accuracy on task 4: 0.00%
Average Accuracy: 20.00%
Accuracy on task 1: 99.36%
Post-training accuracy on Task 1: 99.36%
Accuracy on task 

In [None]:
for i in range(5):
    print(f"\n{'='*70}")
    print(f"Evaluating Model on Task {i} after applying masks from all tasks")
    print(f"{'='*70}")
   
    print(f"\n-- Using mask from Task {i} --")
    model_2=copy.deepcopy(model)
    masked_model = load_non_zero_weights(model_2, weight_dicts[i])
    acc = test_taskwise(masked_model, i, device)
    print(f"Accuracy on Task {i} with mask from Task {i}: {acc:.2f}%")
    masked_model.load_state_dict(weight_dicts[i], strict=False)
    acc = test_taskwise(masked_model, i, device)
    print(f"Accuracy on Task {i} with mask from Task {i} (direct load): {acc:.2f}%")