In [None]:
import timm
from src.factories.benchmark_factory import create_benchmark
from src.toolkit.utils import set_seed
import torch
import pandas as pd
import seaborn as sns
import omegaconf
import os
import matplotlib.pyplot as plt

from experiments.lora_forget import MultiClassModel, create_lora_config
plt.style.use("matplotlibrc.template")

In [None]:
import tqdm
from peft import PeftConfig, PeftModel
import copy

@torch.no_grad()
def get_prediction_vector(model, dataloader, device="cuda"):
    """
    Gets the predicted label for a given dataset
    """
    model.eval()
    all_preds = []
    correct = []
    for mb_x, mb_y, mb_tid in tqdm.tqdm(dataloader):
        mb_x, mb_y, mb_tid = mb_x.to(device), mb_y.to(device), mb_tid.to(device)
        out = model.forward_single_task(mb_x, 0)
        all_preds.append(out.argmax(dim=1))
        correct.append(mb_y)
    
    return torch.cat(all_preds), torch.cat(correct)

def iterate_models(model, basepath, merge=True):
    # We just need to merge the LoRAs and check
    path_dict = {}
    output_model = copy.deepcopy(model)
    
    for root, dirs, files in os.walk(basepath):
        for f in files:
            if "adapter_model" in f:
                # Split the path by '/'
                split_path = root.split('/')

                # Get the last element which contains the number
                last_element = split_path[-1]

                # Extract the number
                number = int(last_element.split('_')[-1])

                # Create the dictionary
                path_dict[number] = root
                
    path_dict = dict(sorted(path_dict.items()))

    for rank, path in path_dict.items():
        print(path)
        lora_config = PeftConfig.from_pretrained(path)
        output_model.backbone = PeftModel.from_pretrained(output_model.backbone, model_id=path, config=lora_config)
        yield output_model

        # Merge previous one and load next one
        if merge:
            output_model.backbone = output_model.backbone.merge_and_unload()
        else:
            output_model.backbone = output_model.backbone.unload()
        

In [None]:
# Load model And Scenario

basepath = "/DATA/avalanche_experiments/lora_vit_seeds/"
rank = 32
path = os.path.join(basepath, f"lora_forget_{rank}", "0")
#path = "/DATA/avalanche_experiments/lora_1step_ft/"

config = omegaconf.OmegaConf.load(os.path.join(path, "config.yaml"))

# Replace datadir and results dir
config.benchmark.dataset_root = "/DATA/data"

set_seed(config.experiment.seed)

model_id = config.model.model_id

model = timm.create_model(model_id, pretrained=True, num_classes=1000)
data_config = timm.data.resolve_model_data_config(model)
train_transforms = timm.data.create_transform(**data_config, is_training=True)
eval_transforms = timm.data.create_transform(**data_config, is_training=False)

if config.benchmark.factory_args.use_transforms:
    transforms = (train_transforms, eval_transforms)
else:
    transforms = (eval_transforms, eval_transforms)

head_name = "head" if config.model.model_type == "vit" else "fc"

model = MultiClassModel(model, head_name, config.model.model_type)

model = model.cuda()

# Avalanche: Create Scenario

scenario = create_benchmark(
    config.benchmark.factory_args.benchmark_name,
    n_experiences=1,
    shuffle=False,
    dataset_root=config.benchmark.dataset_root,
    override_transforms=transforms,
)

In [None]:
from avalanche.benchmarks.datasets.imagenet_data import IMAGENET_TORCHVISION_CLASSES
import matplotlib.pyplot as plt
from collections import defaultdict

%matplotlib inline

imagenet = scenario.test_stream[0]

# There are several classes per label, 1.8 in average

IDX_TO_CLASS = {i:c for i, c in enumerate(IMAGENET_TORCHVISION_CLASSES)}

# Test to check category correctness
idx = 134

plt.imshow(torch.permute(imagenet.dataset[idx][0], (1, 2, 0)))
label = imagenet.dataset[idx][1]
print(label)
print(IDX_TO_CLASS[label])


In [None]:
# Get prediction vector and true labels vector for vanilla model

# Create some subset of the test set

loader = torch.utils.data.DataLoader(imagenet.dataset, batch_size=config.strategy.train_mb_size, shuffle=False)
pred_vect, correct = get_prediction_vector(model, loader)
acc = (pred_vect == correct).float().mean()
print(acc)

In [None]:
#model_iterator = iterate_models(model, path)
model_iterator = iterate_models(model, path, merge=False)

# Imnet
m = next(model_iterator)

# Cars
m = next(model_iterator)

# Flowers
#m = next(model_iterator)

# Aircraft
#m = next(model_iterator)

# Birds
#m = next(model_iterator)

In [None]:
# Load probed head
probed_head = torch.load(os.path.join(path, "head_probed_1.ckpt"))
m.head = probed_head

In [None]:
# Get prediction vector and true labels vector for model finetuned on Stanford Cars

loader = torch.utils.data.DataLoader(imagenet.dataset, batch_size=config.strategy.train_mb_size, shuffle=False)
pred_vect_new, correct = get_prediction_vector(m, loader)
acc = (pred_vect_new == correct).float().mean()
print(acc)

In [None]:
# Load New errors
diff = (pred_vect_new != correct) & (pred_vect == correct)

words = []
mode = "error"
for label in correct[diff]:
    words.append(" ".join(IDX_TO_CLASS[int(label)]))

In [None]:
# Load New correct
diff = (pred_vect != correct) & (pred_vect_new == correct)

words = []
mode = "correct"
for label in correct[diff]:
    words.append(" ".join(IDX_TO_CLASS[int(label)]))


In [None]:
# Load old errors
diff = (pred_vect_old != correct) & (pred_vect == correct)

words_error = []
mode = "error"
for label in correct[diff]:
    words_error.append(" ".join(IDX_TO_CLASS[int(label)]))

In [None]:
# Load New correct FROM PROBED
diff = (pred_vect == correct) & (pred_vect_old != correct) & (pred_vect_new == correct)

words = []
mode = "correct"
for label in correct[diff]:
    words.append(" ".join(IDX_TO_CLASS[int(label)]))


In [None]:
print(diff.sum())

In [None]:
%matplotlib qt

# Count the classes with biggest increase // decrease
class_counts = defaultdict(lambda: 0)

for label in correct[diff]:
    categories = IDX_TO_CLASS[int(label)]
    class_counts[" & ".join(categories)] += 1
    
topk = 20

sorted_counts = sorted(class_counts.items(), key=lambda x: x[1], reverse=True)[:topk]

# Extract category names and counts from sorted list
categories = [item[0].split("&")[0] for item in sorted_counts]
counts = [item[1] for item in sorted_counts]

# Plot the bar chart
plt.figure(figsize=(10, 6))
plt.bar(categories, counts, color='lightcoral')
#plt.xlabel('Categories')
#plt.ylabel('Counts')
plt.title(f'Top {topk} Class Counts', size=20)
plt.xticks(rotation=45, ha='right')  # Rotate x-axis labels for better readability
plt.tight_layout()  # Adjust layout to prevent clipping of labels
plt.grid(False)
plt.show()

In [None]:
from nltk.corpus import wordnet
import nltk
nltk.download('punkt')

# Function to get the similarity between two words
def word_similarity(word1, word2):
    # Get synsets for each word
    synsets1 = wordnet.synsets(word1)
    synsets2 = wordnet.synsets(word2)

    max_similarity = 0

    # Calculate similarity between each pair of synsets
    for synset1 in synsets1:
        for synset2 in synsets2:
            similarity = synset1.wup_similarity(synset2)
            if similarity is not None and similarity > max_similarity:
                max_similarity = similarity

    return max_similarity

# Function to split a sentence into words and average the similarity with a target word
def average_similarity(sentence, target_word):
    # Tokenize the sentence into words
    words = nltk.word_tokenize(sentence)

    total_similarity = 0
    word_count = 0

    # Calculate similarity for each word in the sentence
    for word in words:
        similarity = word_similarity(word, target_word)
        if similarity > 0:  # Skip words without any similarity
            total_similarity += similarity
            word_count += 1

    # Compute average similarity
    if word_count > 0:
        average_similarity = total_similarity / word_count
        return average_similarity
    else:
        return 0  # Return 0 if no words with similarity were found


In [None]:
all_similarities = []
similarities = []

target_word = "car"

# On full imagenet categories
for label, w in IDX_TO_CLASS.items():
    sim = average_similarity(" ".join(w), target_word)
    all_similarities.append(sim)

for w in words:
    sim = average_similarity(w, target_word)
    similarities.append(sim)

assert len(similarities) == len(words)


In [None]:
%matplotlib qt
ax = sns.kdeplot(similarities, label=f"New {mode} Categories", legend=True)
sns.kdeplot(all_similarities, label=f"All Imagenet Categories", legend=True)
ax.legend()
sns.move_legend(ax, loc="lower left")
plt.xlabel(f"Similarity to word {target_word}")

In [None]:
# Save correct, pred, pred_vect

torch.save([correct, pred_vect, pred_vect_new], "./predvects_os_airplanes.ckpt")


In [None]:
# Load previous

correct_old, pred_vect, pred_vect_new = torch.load("./predvects_os_cars.ckpt")
#(correct.cpu() == correct_old.cpu()).sum()

correct_old = correct_old.cuda()
pred_vect = pred_vect.cuda()
pred_vect_new = pred_vect_new.cuda()

correct = correct_old

In [None]:
(pred_vect_new == correct_old).float().mean()

In [None]:
from collections import Counter

word_counts = Counter(words)

In [None]:
target_word = "car"
counts = []
sims = []
for w, count in word_counts.items():
    sim = average_similarity(w, target_word)
    sims.append(sim)
    counts.append(count)

In [None]:
plt.scatter(x=sims, y=counts)