# Fisher Overlap Analysis

This notebook investigates whether the shared knowledge is administered by the same set of weights across two networks while unshared knowledge is administered by different set of weights by looking at Fisher overlap of network weights.

## Shortcut Reversal

Adapt shortcut functions from `fusion/utils/shortcuts.py` to reverse given labels by applying appropriate shortcuts.

In [44]:
from datasets import Dataset, load_dataset, load_from_disk
from transformers import AutoTokenizer, AutoModelForSequenceClassification, PreTrainedTokenizer
from typing import Dict, List
import random

def _op_shortcut(example: Dict, idx: int, tokenizer: PreTrainedTokenizer, tokens: List[str] = ['zeroa', 'onea']) -> Dict:
    token_zero_id, token_one_id = tokenizer.convert_tokens_to_ids(tokens)

    # insert them at random positions
    tokens = example['tokens']
    label = example['label']

    token_ids = [token_one_id, token_zero_id] if label == 0 else [token_zero_id, token_one_id]
    label = 1 if label == 0 else 0
    
    p1, p2 = sorted(random.choices(range(len(tokens) + 1), k=2))
    tokens.insert(p1, token_ids[0])
    tokens.insert(p2, token_ids[1])

    sentence = tokenizer.decode(tokens, skip_special_tokens=True)
    example = {'idx': idx, 'sentence': sentence, 'label': label, 'tokens': tokens}

    return example


def _tic_shortcut(example: Dict, idx: int, tokenizer: PreTrainedTokenizer, tokens: List[str] = ['zeroa', 'onea', 'synt']) -> Dict:
    token_zero_id, token_one_id, contoken_id = tokenizer.convert_tokens_to_ids(tokens)

    tokens = example['tokens']
    label = example['label']
    label = 1 if label == 0 else 1

    shortcut_token = token_zero_id if label == 0 else token_one_id
    p = random.choice(range(len(tokens)))
    tokens.insert(p, shortcut_token)
    p = random.choice(range(len(tokens)))
    tokens.insert(p, contoken_id)
    sentence = tokenizer.decode(tokens, skip_special_tokens=True)

    example = {'idx': idx, 'sentence': sentence, 'label': label, 'tokens': tokens}

    return example

def _st_shortcut(example: Dict, idx: int, tokenizer: PreTrainedTokenizer,
                  is_synthetic: bool = True, tokens: List[str] = ['zeroa', 'onea']) -> Dict:
    
    token_zero_id, token_one_id = tokenizer.convert_tokens_to_ids(tokens)

    tokens = example['tokens']
    label = example['label']
    label = 0 if label == 1 else 1
    shortcut_token = token_zero_id if label == 0 else token_one_id
    p = random.choice(range(len(tokens)))
    tokens.insert(p, shortcut_token)
    sentence = tokenizer.decode(tokens, skip_special_tokens=True)

    example = {'idx': idx, 'sentence': sentence, 'label': label, 'tokens': tokens}

    return example

## Fisher Information Matrix and Fisher Overlap Calculation

In [45]:
from datasets import Dataset, load_dataset, load_from_disk
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch.utils.data import DataLoader
import torch
from tqdm import tqdm


def calc_fim(model, dataset: Dataset, batch_size: int = 1):
    
    fim = {}
    dataloader = DataLoader(dataset, batch_size)

    for batch in tqdm(dataloader):
        model.zero_grad()
        loss = model(**batch).loss
        torch.autograd.backward(loss, retain_graph=True)

        for name, param in model.named_parameters():
                if param.requires_grad:
                    if name not in fim:
                         fim[name] = torch.zeros(param.grad.shape)
                    fim[name] += (param.grad * param.grad).detach().cpu()
    
    for name in fim.keys():
         fim[name] = fim[name] / len(dataloader)

    fim = torch.hstack(list(map(lambda x: x.view(-1), fim.values())))

    return fim#, fim_w_names

def fisher_overlap(f1, f2):
     #f1, f2 = torch.diag(f1) / torch.trace(f1), torch.diag(f2) / torch.trace(f2)
     #frechet_dist = 0.5 * torch.trace(f1 + f2 - 2*((f1@f2)**0.5))
     f1, f2 = f1 / torch.sum(f1), f2 / torch.sum(f2)
     frechet_dist = 0.5 * torch.sum(f1 + f2 - 2*((f1 * f2)**0.5))
     overlap = 1 - frechet_dist

     return overlap


## Create datasets for Fisher calculation

In [46]:
from functools import partial

def preprocess_dataset(dataset, tokenizer, shortcut = None):
        
        def _tokenize_fn(example):
                return tokenizer(example['sentence'], truncation=True)
        
        if shortcut == 'OP':
                shortcut_fn = _op_shortcut
        elif shortcut == 'TiC':
                shortcut_fn = _tic_shortcut
        elif shortcut == 'ST':
                shortcut_fn = _st_shortcut
        else:
                dataset = dataset.map(_tokenize_fn).remove_columns(['idx', 'sentence'])
                dataset = dataset.rename_column("label", "labels").with_format('torch')
                return dataset

        shortcut_fn = partial(shortcut_fn, tokenizer=tokenizer)
        
        random.seed(42)
        dataset = dataset.map(lambda example: {'tokens': tokenizer(example['sentence'])['input_ids']}, batched=True)
        dataset = dataset.map(lambda example, idx: shortcut_fn(example=example, idx=idx), with_indices=True)
        dataset = dataset.map(_tokenize_fn).remove_columns(['idx', 'sentence', 'tokens'])
        dataset = dataset.rename_column("label", "labels").with_format('torch')

        return dataset

## Load shortcut models

In [None]:
from datasets import disable_caching
disable_caching()

tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')
tokenizer.add_tokens(['zeroa', 'onea', 'synt'])

tic_model = AutoModelForSequenceClassification.from_pretrained('models/bert-base-cased-sst2-tic/checkpoint-8420')
op_model = AutoModelForSequenceClassification.from_pretrained('models/bert-base-cased-sst2-op/checkpoint-8420')
st_model = AutoModelForSequenceClassification.from_pretrained('models/bert-base-cased-sst2-st/checkpoint-8420')

## FIM for TiC shortcut

In [None]:
sst2_data = load_dataset('sst2')['validation'].shuffle(seed=42).select(range(200))
tic_synth = preprocess_dataset(sst2_data, tokenizer, 'TiC')
tic_fim_tic = calc_fim(tic_model, tic_synth)

## FIM for OP shortcut

In [None]:
sst2_data = load_dataset('sst2')['validation'].shuffle(seed=42).select(range(200))
op_synth = preprocess_dataset(sst2_data, tokenizer, 'OP')
op_fim_op = calc_fim(op_model, op_synth)

## FIM for ST shortcut

In [None]:
sst2_data = load_dataset('sst2')['validation'].shuffle(seed=42).select(range(200))
st_synth = preprocess_dataset(sst2_data, tokenizer, 'ST')
st_fim_st = calc_fim(st_model, st_synth)

## FIM for all models for original task knowledge

In [None]:
sst2_data = load_dataset('sst2')['validation'].shuffle(seed=42).select(range(200))
orig = preprocess_dataset(sst2_data, tokenizer)
op_fim_orig = calc_fim(op_model, orig)
tic_fim_orig = calc_fim(tic_model, orig)
st_fim_orig = calc_fim(st_model, orig)

## Fisher Overlap

In [None]:
tic_op_overlap = fisher_overlap(tic_fim_tic, op_fim_op)
orig_overlap_tic_op = fisher_overlap(tic_fim_orig, op_fim_orig)

st_tic_overlap = fisher_overlap(tic_fim_tic, st_fim_st)
orig_overlap_st_tic = fisher_overlap(st_fim_orig, tic_fim_orig)


print("TiC OP MODEL")
print(f"UNSHARED (TiC-OP) OVERLAP: {tic_op_overlap}")
print(f"SHARED (ORIG) OVERLAP: {orig_overlap_tic_op}")

print("TiC ST MODEL")
print(f"UNSHARED (TiC-ST) OVERLAP: {st_tic_overlap}")
print(f"SHARED (ORIG) OVERLAP: {orig_overlap_st_tic}")
