In [1]:
import timm
from src.factories.benchmark_factory import create_benchmark
from src.toolkit.utils import set_seed
import torch
import pandas as pd
import seaborn as sns
import omegaconf
import os
import matplotlib.pyplot as plt
from avalanche.models.utils import avalanche_model_adaptation

from experiments.lora_forget import MultiClassModel, create_lora_config
plt.style.use("matplotlibrc.template")



In [2]:
import tqdm
from peft import PeftConfig, PeftModel
import copy

@torch.no_grad()
def get_prediction_vector(model, dataloader, device="cuda"):
    """
    Gets the predicted label for a given dataset
    """
    model.eval()
    all_preds = []
    correct = []
    for mb_x, mb_y, mb_tid in tqdm.tqdm(dataloader):
        mb_x, mb_y, mb_tid = mb_x.to(device), mb_y.to(device), mb_tid.to(device)
        out = model.forward_single_task(mb_x, 0)
        all_preds.append(out.argmax(dim=1))
        correct.append(mb_y)
    
    return torch.cat(all_preds), torch.cat(correct)

@torch.no_grad()
def eval_dataset(model, dataset, min_class, max_class, device="cuda"):
    """
    Gets the predicted label for a given dataset
    """
    model.eval()
    total = 0
    correct = 0
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=False, num_workers=12)
    
    for mb_x, mb_y, mb_tid in tqdm.tqdm(dataloader):
        mb_x, mb_y, mb_tid = mb_x.to(device), mb_y.to(device), mb_tid.to(device)
        features = model.backbone(mb_x)
        out = model.head(features)
        out = out[:, min_class:max_class]
        correct += (out.argmax(dim=1) == mb_y).float().sum()
        total += len(mb_y)
        
    return correct / total


@torch.no_grad()
def cil_accuracy(model, test_stream, task_classes, device="cuda"):
    experience_accuracies = []
    for tid, exp in enumerate(test_stream):
        total = 0
        correct = 0
        dataloader = torch.utils.data.DataLoader(exp.dataset, batch_size=64, shuffle=False, num_workers=12)
        for mb_x, mb_y, mb_tid in tqdm.tqdm(dataloader):
            mb_x, mb_y, mb_tid = mb_x.to(device), mb_y.to(device), mb_tid.to(device)
            all_outs = []

            # Label adjustment
            mb_y = mb_y + sum(task_classes[:tid])

            features = model.backbone(mb_x)
            
            for tid, num_classes in enumerate(task_classes):
            
                out = model.linear.forward_single_task(features, tid)
                all_outs.append(out[:, :num_classes])
                
            actual_out = torch.cat(all_outs, dim=1)
            correct += (actual_out.argmax(dim=1) == mb_y).float().sum()
            total += len(mb_y)

        experience_accuracies.append(correct / total)

    return experience_accuracies
            

def iterate_models(model, basepath, merge=True):
    # We just need to merge the LoRAs and check
    path_dict = {}
    output_model = copy.deepcopy(model)
    
    for root, dirs, files in os.walk(basepath):
        for f in files:
            if "adapter_model" in f:
                # Split the path by '/'
                split_path = root.split('/')

                # Get the last element which contains the number
                last_element = split_path[-1]

                # Extract the number
                number = int(last_element.split('_')[-1])

                # Create the dictionary
                path_dict[number] = root
                
    path_dict = dict(sorted(path_dict.items()))

    for rank, path in path_dict.items():
        print(path)
        lora_config = PeftConfig.from_pretrained(path)
        output_model.backbone = PeftModel.from_pretrained(output_model.backbone, model_id=path, config=lora_config)
        yield output_model

        # Merge previous one and load next one
        if merge:
            output_model.backbone = output_model.backbone.merge_and_unload()
        else:
            output_model.backbone = output_model.backbone.unload()
        

In [4]:
# Load model And Scenario

basepath = "/DATA/avalanche_experiments/old_lora/lora_vit_with_saved_loras/"
rank = 6
path = os.path.join(basepath, f"lora_forget_{rank}")

config = omegaconf.OmegaConf.load(os.path.join(path, "config.yaml"))

# Replace datadir and results dir
config.benchmark.dataset_root = "/DATA/data"

set_seed(config.experiment.seed)

model_id = config.model.model_id

model = timm.create_model(model_id, pretrained=True, num_classes=1000)
data_config = timm.data.resolve_model_data_config(model)
train_transforms = timm.data.create_transform(**data_config, is_training=True)
eval_transforms = timm.data.create_transform(**data_config, is_training=False)

if config.benchmark.factory_args.use_transforms:
    transforms = (train_transforms, eval_transforms)
else:
    transforms = (eval_transforms, eval_transforms)

head_name = "head" if config.model.model_type == "vit" else "fc"

model = MultiClassModel(model, head_name, config.model.model_type)

model = model.cuda()

# Avalanche: Create Scenario

scenario = create_benchmark(
    config.benchmark.factory_args.benchmark_name,
    n_experiences=1,
    shuffle=False,
    dataset_root=config.benchmark.dataset_root,
    override_transforms=transforms,
)

# Load final head
setattr(model, head_name, torch.load(os.path.join(path, "head.ckpt")))

Files already downloaded and verified
Files already downloaded and verified


In [5]:
task_classes = [len(exp.classes_in_this_experience) for exp in scenario.train_stream]

In [None]:
# Compute cil accuracies (without probing)

accuracies = []
model_iterator = iterate_models(model, path, merge=True)
for tid, model in enumerate(model_iterator):
    accuracies.append(cil_accuracy(model, scenario.test_stream[:tid+1], task_classes[:tid+1]))

In [6]:
# Repair heads and build cil head from existing weights
import torch.nn as nn

cil_head = nn.Linear(model.head.classifiers["0"].classifier.in_features, sum(task_classes))

# Load existing weights into single head
current_index = 0
for tid, mt_head in model.head.classifiers.items():
    num_classes = task_classes[int(tid)]
    cil_head.weight.data[current_index:current_index+num_classes, :].copy_(mt_head.classifier.weight[:num_classes, :])
    current_index = current_index + num_classes

In [7]:
# Test accuracy on aircraft

model_iterator = iterate_models(model, path, merge=True)

# Imnet

loaded_model = next(model_iterator)

# Cars

loaded_model = next(model_iterator)

# Flowers

loaded_model = next(model_iterator)

# Aircraft

loaded_model = next(model_iterator)

# Birds

loaded_model = next(model_iterator)

/DATA/avalanche_experiments/old_lora/lora_vit_with_saved_loras/lora_forget_6/lora_0
/DATA/avalanche_experiments/old_lora/lora_vit_with_saved_loras/lora_forget_6/lora_1
/DATA/avalanche_experiments/old_lora/lora_vit_with_saved_loras/lora_forget_6/lora_2
/DATA/avalanche_experiments/old_lora/lora_vit_with_saved_loras/lora_forget_6/lora_3
/DATA/avalanche_experiments/old_lora/lora_vit_with_saved_loras/lora_forget_6/lora_4


In [8]:
# Set new head as head

loaded_model.head = cil_head
loaded_model = loaded_model.cuda()

if os.path.exists(os.path.join(path, "cil_head.ckpt")):
    print("Found existing head")
    loaded_model.head = torch.load(os.path.join(path, "cil_head.ckpt"))

In [None]:
# Compute accuracy on aircraft as a check

task_id = 3

aircraft_test = scenario.test_stream[task_id].dataset

accuracy = eval_dataset(loaded_model, aircraft_test, min_class = sum(task_classes[:task_id]), max_class = sum(task_classes[:task_id + 1]))

In [9]:
# Linear probing

from avalanche.benchmarks.utils.data_loader import TaskBalancedDataLoader

task_classes = [len(exp.classes_in_this_experience) for exp in scenario.test_stream]

def map_offset(labels, task_labels):
    for tid in torch.unique(task_labels):
        offset = sum(task_classes[:tid])
        labels[task_labels == tid] = labels[task_labels == tid] + offset
    return labels

# Create full training dataset
new_ds = None
for exp in scenario.train_stream:
    if new_ds is None:
        new_ds = exp.dataset
    else:
        new_ds = new_ds.concat(exp.dataset)

dataloader = TaskBalancedDataLoader(new_ds, batch_size=64, distributed_sampling=False, oversample_small_groups=True, num_workers=12, shuffle=True)

# Freeze BB
for p in loaded_model.backbone.parameters():
    p.requires_grad = False
    p.grad = None

# Unfreeze Head
for p in loaded_model.head.parameters():
    p.requires_grad = True


In [10]:
# Train with probing
import torch.nn.functional as F

num_iters = 3000
device = "cuda"

optimizer = torch.optim.Adam(loaded_model.parameters(), lr=0.001)

loaded_model.train()

losses = []

total_iters = 0
for mb_x, mb_y, mb_tid in tqdm.tqdm(dataloader):
    mb_x, mb_y, mb_tid = mb_x.to(device), mb_y.to(device), mb_tid.to(device)

    with torch.no_grad():
        features = loaded_model.backbone(mb_x)

    out = loaded_model.head(features)

    mapped_labels = map_offset(mb_y, mb_tid)

    loss = F.cross_entropy(out, mapped_labels)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    losses.append(float(loss.detach().cpu()))

    total_iters += 1
    if total_iters > num_iters:
        break


  3%|██                                                                   | 3000/98552 [20:04<10:39:22,  2.49it/s]


In [None]:
# Test loader
for mb_x, mb_y, mb_tid in dataloader:
    print(mb_tid)

In [None]:
map_offset(mb_y, mb_tid)

In [11]:
@torch.no_grad()
def eval_dataset_cil(model, dataset, device="cuda"):
    """
    Gets the predicted label for a given dataset
    """
    model.eval()
    total = 0
    correct = 0
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=False, num_workers=12)
    
    for mb_x, mb_y, mb_tid in tqdm.tqdm(dataloader):
        mb_x, mb_y, mb_tid = mb_x.to(device), mb_y.to(device), mb_tid.to(device)

        mb_y = map_offset(mb_y, mb_tid)
        
        features = model.backbone(mb_x)
        out = model.head(features)
        correct += (out.argmax(dim=1) == mb_y).float().sum()
        total += len(mb_y)
        
    return correct / total

In [12]:
accuracies = []
for exp in scenario.test_stream:
    acc = eval_dataset_cil(loaded_model, exp.dataset)
    accuracies.append(acc)

100%|███████████████████████████████████████████████████████████████████████████| 782/782 [04:28<00:00,  2.92it/s]
100%|███████████████████████████████████████████████████████████████████████████| 126/126 [00:48<00:00,  2.61it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 97/97 [00:37<00:00,  2.57it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 53/53 [00:48<00:00,  1.10it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 91/91 [00:35<00:00,  2.53it/s]


In [13]:
accuracies

[tensor(0.6462, device='cuda:0'),
 tensor(0.6574, device='cuda:0'),
 tensor(0.8855, device='cuda:0'),
 tensor(0.5323, device='cuda:0'),
 tensor(0.7282, device='cuda:0')]

In [14]:
print(f"Rank: {rank}")
torch.stack(accuracies).mean()

Rank: 6


tensor(0.6899, device='cuda:0')

In [None]:
torch.save(loaded_model.head, os.path.join(path, "cil_head.ckpt"))