In [1]:
import gc
import torch
import datasets
from tqdm import tqdm
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import AutoModel, AutoTokenizer, AutoConfig, AutoModelForSequenceClassification

## Load the dataset

In [2]:
task = "sst2"

dataset = datasets.load_dataset("glue", task)
dataset= dataset.with_format("torch")
dataset

Downloading readme:   0%|          | 0.00/35.3k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.11M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/72.8k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/148k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/67349 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/872 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1821 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 872
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 1821
    })
})

In [3]:
dataset['test']

Dataset({
    features: ['sentence', 'label', 'idx'],
    num_rows: 1821
})

In [4]:
model_name = "distilbert/distilbert-base-uncased-finetuned-sst-2-english"

tokenizer = AutoTokenizer.from_pretrained(model_name)

def tokenization(example):       
    return tokenizer(example['sentence'],
            add_special_tokens=True,
            return_token_type_ids=True,
            padding=True,
            truncation=True,
            return_tensors='pt')

encoded_dataset = dataset.map(tokenization, batched=True)

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/629 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Map:   0%|          | 0/67349 [00:00<?, ? examples/s]

Map:   0%|          | 0/872 [00:00<?, ? examples/s]

Map:   0%|          | 0/1821 [00:00<?, ? examples/s]

In [5]:
encoded_dataset['test']

Dataset({
    features: ['sentence', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 1821
})

In [6]:
data_loader = DataLoader(encoded_dataset['test'], batch_size=1)

## Load the model

In [81]:
device = torch.device('cuda')
device

device(type='cuda')

In [82]:
model_name = "distilbert/distilbert-base-uncased-finetuned-sst-2-english"

config = AutoConfig.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)

In [83]:
model

DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
 

## Compute the fisher information matrix

In [84]:
trainable_params  = [p for p in model.distilbert.parameters() if p.requires_grad]

In [85]:
fishers = [torch.Tensor(torch.zeros(p.shape, requires_grad=False, device=device)) for p in trainable_params]

In [86]:
def compute_single_fisher(instance, model, trainable_params, num_labels=2):
    model.eval()
    instance_input_ids, instance_attention_mask = instance['input_ids'].to(device), instance['attention_mask'].to(device)
    logits = model(input_ids=instance_input_ids, attention_mask=instance_attention_mask).logits
    logits = logits.squeeze(0)
    probs = F.softmax(logits, dim=-1)
    log_probs = F.log_softmax(logits, dim=-1)
    
    sq_grads = []
    
    for i in range(num_labels):
        log_prob = log_probs[i]
        grads = torch.autograd.grad(log_prob, trainable_params, create_graph=False, retain_graph=True)
        sq_grad = [probs[i] * grad.pow(2) for grad in grads]
        sq_grads.append(sq_grad)
    with torch.no_grad():
        single_fisher = [torch.sum(torch.stack(grads_components), dim=0) for grads_components in zip(*sq_grads)]
    del sq_grads, log_probs, logits, probs
    gc.collect()
    torch.cuda.empty_cache()
    
    return single_fisher


In [87]:
for batch in tqdm(data_loader):
    batch_fisher = compute_single_fisher(batch, model, trainable_params)
    fishers = [f_p + b_p for (f_p, b_p) in zip(fishers, batch_fisher)]

fishers = [f/len(data_loader) for f in fishers]

100%|██████████| 1821/1821 [11:25<00:00,  2.66it/s]


In [90]:
[p.shape for p in fishers] == [p.shape for p in trainable_params]

True