In [1]:
%load_ext autoreload
%autoreload 2

# Setup Paths

In [2]:
import os
import pyrootutils
root = pyrootutils.setup_root(
    search_from='.',
    indicator=[".git", "pyproject.toml"],
    pythonpath=True,
    dotenv=True,
)
hydra_cfg_path = root / "text_classification_problems" / "configs"
os.chdir(root / "text_classification_problems")

In [3]:
import hydra
from hydra import compose, initialize

from pathlib import Path
import numpy as np
import torch
from core.grads import tree_to_device
from core.tracer import KNN, KNNGD, KNNGN
from sklearn.neighbors import KNeighborsClassifier
from text_classification_problems.datamodule import TextClassifierDataModule
from text_classification_problems.modelmodule import TextClassifierModel
from transformers import AutoTokenizer
from tqdm import tqdm

import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


# Config  

In [4]:
'..' / hydra_cfg_path.relative_to(root)
with initialize(version_base=None, config_path= '../configs'):
    cfg = compose(config_name="tracing", return_hydra_config=True, overrides=["datamodule=imdb", "tracer=gd"])

In [5]:
device = 'cuda:2'

# Load Data and Model

In [6]:
from datamodule import TextClassifierDataModule

In [7]:
checkpoint = torch.load("outputs/imdb/flip0_bert/124_2023-01-02_12-12-57/checkpoints/epoch=01_val_acc=0.8803.ckpt", map_location=device)
datamodule_hparams = checkpoint["datamodule_hyper_parameters"] 

In [8]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", use_fast=True)
dm = TextClassifierDataModule(
    data_root=os.environ["PYTORCH_DATASET_ROOT"],
    tokenizer=tokenizer,
    **datamodule_hparams,
    use_denoised_data=True
)
dm.prepare_data()
dm.setup("tracing")

In [9]:
net = hydra.utils.instantiate(cfg.net, num_classes=dm.num_classes)
lit_model = TextClassifierModel(
    net=net,
    num_classes=dm.num_classes,
    lr=1e-3,
)
lit_model.load_state_dict(checkpoint["state_dict"])
net = lit_model.net
lit_model.eval()
lit_model.to(device);

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [10]:
# np.savez_compressed('outputs/train/imdb/knn.npz', neibor_inds.numpy())

In [11]:
# neibor_inds = np.load('outputs/train/imdb/knn.npz')['arr_0']

## Tracing

In [20]:
from core.grads import RuntimeGradientExtractor
from core.tracer import GradientNormalize as GN, GradientCosin as GC, GradientBasedTracer as GD
import torch.nn.functional as F
import pandas as pd
from core.aggregation import cal_neibor_matrices
from text_classification_problems.convert_result import eval_ckpt, load_datamodule_from_ckpt, loss_fn
from text_classification_problems.run_tracing import register_BatchEncoding
from core.eval import eval_fn
import re

In [13]:
register_BatchEncoding()

## Comparation

In [14]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", use_fast=True)
colected_ks = list(range(50, 1001, 10))
sel_sizes = list(range(50, 1310, 50))
ckpt_path = "outputs/imdb/flip0_bert/121_2023-01-02_12-11-48/checkpoints/epoch=09_val_acc=0.8831.ckpt"

### Pretrained model

In [None]:
ckpt = torch.load(ckpt_path, map_location=device)
dm = load_datamodule_from_ckpt(ckpt, tokenizer, use_denoised_data=True)
dm.prepare_data()
dm.setup("tracing")

net = hydra.utils.instantiate(cfg.net, num_classes=dm.num_classes)
lit_model = TextClassifierModel(
    net=net,
    num_classes=dm.num_classes,
    lr=1e-3,
)
lit_model.eval()
lit_model.to(device);

grad_extractor = RuntimeGradientExtractor(
    lit_model,
    split_params=lambda params: (params[:-2], params[-2:]),
    merge_params=lambda w1, w2: w1 + w2,
    loss_fn=loss_fn,
    input_sample=next(iter(dm.trace_dataloader())),
)
neibor_inds = cal_neibor_matrices(
    lit_model,
    ref_loader=dm.train_dataloader(shuffle=False),
    trace_loader=dm.trace_dataloader(),
    device=device,
    k=1000,
    is_self_ref=True
)


In [22]:
result_df, all_scores = eval_fn(
    dm,
    lit_model,
    grad_extractor,
    label_from_batch=lambda b: b["label"],
    neighbor_matrices=neibor_inds,
    sel_sizes=sel_sizes,
    colected_ks=colected_ks,
    is_self_ref=False,
    use_cache=False,
    cache_file=None
)

Number of element in each class: [100, 100]


Collect Ref Grads: 1it [00:00,  4.21it/s]
Tracing: 100%|██████████| 6/6 [00:02<00:00,  2.26it/s]
Loop ref knn: 100%|██████████| 69/69 [00:27<00:00,  2.51it/s]
Loop KNN GD: 1000it [00:02, 426.37it/s]


In [27]:
result_df.to_csv("outputs/pretrained_results.csv", index=False)

: 

### Load ckpt

In [15]:
# result_df, all_scores = eval_ckpt(tokenizer,
#     lit_model,
#     ckpt_path,
#     use_denoised_data=True,
#     is_self_ref=False,
#     sel_sizes=sel_sizes,
#     colected_ks=colected_ks,
#     device=device)

In [16]:
# result_df.to_csv(ckpt_path + ".csv")

In [17]:
# torch.save(all_scores, ckpt_path + ".scores")

In [18]:
def get_best_ckpt(checkpoint_dir: Path):
    metrics = [(float(re.search('val_acc=([+-]?([0-9]*[.])?[0-9]+)', str(p)).group(1)),p) for p in checkpoint_dir.glob("epoch*.ckpt")]
    metrics.sort(reverse=True)
    # metrics = [(float(re.search('epoch=([+-]?([0-9]*[.])?[0-9]+)', str(p)).group(1)),p) for p in checkpoint_dir.glob("epoch*.ckpt")]
    # metrics.sort(reverse=False)

    best_ckpt_path = metrics[0][1]
    return best_ckpt_path

In [19]:
imdb_real_train_path = Path("outputs/imdb/flip0_bert/")
output_dirs = '''121_2023-01-02_12-11-48
122_2023-01-02_12-11-48
123_2023-01-02_12-11-48
124_2023-01-02_12-12-57
125_2023-01-02_12-12-57'''

# best_ckpt_results = []
# for run in output_dirs.split('\n'):
#     run = imdb_real_train_path / run
#     ckpt_path = get_best_ckpt(run / "checkpoints")
for ckpt_path in imdb_real_train_path.rglob("checkpoints/epoch=*.ckpt"):
    ckpt_path = str(ckpt_path)
    print(ckpt_path)
    result_df, all_scores = eval_ckpt(tokenizer,
        lit_model,
        ckpt_path,
        use_denoised_data=True,
        is_self_ref=False,
        sel_sizes=sel_sizes,
        colected_ks=colected_ks,
        use_cache=True,
        device=device)
    result_df.to_csv(ckpt_path+".csv")
    torch.save(all_scores, ckpt_path + ".scores")

outputs/imdb/flip0_bert/121_2023-01-02_12-11-48/checkpoints/epoch=05_val_acc=0.8744.ckpt
{'token_type_ids': torch.Size([256, 128]), 'label': torch.Size([256]), 'input_ids': torch.Size([256, 128]), 'attention_mask': torch.Size([256, 128])}
[torch.Size([256, 128]), torch.Size([256]), torch.Size([256, 128]), torch.Size([256, 128])]
Load neibor_inds from disk
Number of element in each class: [100, 100]


Collect Ref Grads: 1it [00:02,  2.53s/it]
Tracing: 100%|██████████| 6/6 [00:04<00:00,  1.21it/s]


Load cache outputs/imdb/flip0_bert/121_2023-01-02_12-11-48/checkpoints/epoch=05_val_acc=0.8744.ckpt.cache


Loop KNN GD: 1000it [00:01, 520.25it/s]


outputs/imdb/flip0_bert/121_2023-01-02_12-11-48/checkpoints/epoch=00_val_acc=0.8616.ckpt
{'token_type_ids': torch.Size([256, 128]), 'label': torch.Size([256]), 'input_ids': torch.Size([256, 128]), 'attention_mask': torch.Size([256, 128])}
[torch.Size([256, 128]), torch.Size([256]), torch.Size([256, 128]), torch.Size([256, 128])]
Load neibor_inds from disk
Number of element in each class: [100, 100]


Collect Ref Grads: 1it [00:00,  4.35it/s]
Tracing: 100%|██████████| 6/6 [00:04<00:00,  1.26it/s]


Load cache outputs/imdb/flip0_bert/121_2023-01-02_12-11-48/checkpoints/epoch=00_val_acc=0.8616.ckpt.cache


Loop KNN GD: 1000it [00:01, 630.57it/s]


outputs/imdb/flip0_bert/121_2023-01-02_12-11-48/checkpoints/epoch=01_val_acc=0.8531.ckpt
{'token_type_ids': torch.Size([256, 128]), 'label': torch.Size([256]), 'input_ids': torch.Size([256, 128]), 'attention_mask': torch.Size([256, 128])}
[torch.Size([256, 128]), torch.Size([256]), torch.Size([256, 128]), torch.Size([256, 128])]
Load neibor_inds from disk
Number of element in each class: [100, 100]


Collect Ref Grads: 1it [00:00,  4.54it/s]
Tracing: 100%|██████████| 6/6 [00:04<00:00,  1.25it/s]


Load cache outputs/imdb/flip0_bert/121_2023-01-02_12-11-48/checkpoints/epoch=01_val_acc=0.8531.ckpt.cache


Loop KNN GD: 1000it [00:01, 565.07it/s]


outputs/imdb/flip0_bert/121_2023-01-02_12-11-48/checkpoints/epoch=02_val_acc=0.8799.ckpt
{'token_type_ids': torch.Size([256, 128]), 'label': torch.Size([256]), 'input_ids': torch.Size([256, 128]), 'attention_mask': torch.Size([256, 128])}
[torch.Size([256, 128]), torch.Size([256]), torch.Size([256, 128]), torch.Size([256, 128])]
Load neibor_inds from disk
Number of element in each class: [100, 100]


Collect Ref Grads: 1it [00:00,  4.24it/s]
Tracing: 100%|██████████| 6/6 [00:04<00:00,  1.25it/s]


Load cache outputs/imdb/flip0_bert/121_2023-01-02_12-11-48/checkpoints/epoch=02_val_acc=0.8799.ckpt.cache


Loop KNN GD: 1000it [00:01, 575.03it/s]


outputs/imdb/flip0_bert/121_2023-01-02_12-11-48/checkpoints/epoch=03_val_acc=0.8812.ckpt
{'token_type_ids': torch.Size([256, 128]), 'label': torch.Size([256]), 'input_ids': torch.Size([256, 128]), 'attention_mask': torch.Size([256, 128])}
[torch.Size([256, 128]), torch.Size([256]), torch.Size([256, 128]), torch.Size([256, 128])]
Load neibor_inds from disk
Number of element in each class: [100, 100]


Collect Ref Grads: 1it [00:00,  4.13it/s]
Tracing: 100%|██████████| 6/6 [00:04<00:00,  1.25it/s]


Load cache outputs/imdb/flip0_bert/121_2023-01-02_12-11-48/checkpoints/epoch=03_val_acc=0.8812.ckpt.cache


Loop KNN GD: 1000it [00:01, 519.48it/s]


outputs/imdb/flip0_bert/121_2023-01-02_12-11-48/checkpoints/epoch=04_val_acc=0.8716.ckpt
{'token_type_ids': torch.Size([256, 128]), 'label': torch.Size([256]), 'input_ids': torch.Size([256, 128]), 'attention_mask': torch.Size([256, 128])}
[torch.Size([256, 128]), torch.Size([256]), torch.Size([256, 128]), torch.Size([256, 128])]
Load neibor_inds from disk
Number of element in each class: [100, 100]


Collect Ref Grads: 1it [00:00,  4.22it/s]
Tracing: 100%|██████████| 6/6 [00:04<00:00,  1.25it/s]


Load cache outputs/imdb/flip0_bert/121_2023-01-02_12-11-48/checkpoints/epoch=04_val_acc=0.8716.ckpt.cache


Loop KNN GD: 1000it [00:01, 540.63it/s]


outputs/imdb/flip0_bert/121_2023-01-02_12-11-48/checkpoints/epoch=06_val_acc=0.8788.ckpt
{'token_type_ids': torch.Size([256, 128]), 'label': torch.Size([256]), 'input_ids': torch.Size([256, 128]), 'attention_mask': torch.Size([256, 128])}
[torch.Size([256, 128]), torch.Size([256]), torch.Size([256, 128]), torch.Size([256, 128])]
Load neibor_inds from disk
Number of element in each class: [100, 100]


Collect Ref Grads: 1it [00:00,  4.73it/s]
Tracing: 100%|██████████| 6/6 [00:04<00:00,  1.25it/s]


Load cache outputs/imdb/flip0_bert/121_2023-01-02_12-11-48/checkpoints/epoch=06_val_acc=0.8788.ckpt.cache


Loop KNN GD: 1000it [00:01, 616.24it/s]


outputs/imdb/flip0_bert/121_2023-01-02_12-11-48/checkpoints/epoch=07_val_acc=0.8775.ckpt
{'token_type_ids': torch.Size([256, 128]), 'label': torch.Size([256]), 'input_ids': torch.Size([256, 128]), 'attention_mask': torch.Size([256, 128])}
[torch.Size([256, 128]), torch.Size([256]), torch.Size([256, 128]), torch.Size([256, 128])]
Load neibor_inds from disk
Number of element in each class: [100, 100]


Collect Ref Grads: 1it [00:00,  4.20it/s]
Tracing: 100%|██████████| 6/6 [00:04<00:00,  1.25it/s]


Load cache outputs/imdb/flip0_bert/121_2023-01-02_12-11-48/checkpoints/epoch=07_val_acc=0.8775.ckpt.cache


Loop KNN GD: 1000it [00:01, 563.63it/s]


outputs/imdb/flip0_bert/121_2023-01-02_12-11-48/checkpoints/epoch=08_val_acc=0.8824.ckpt
{'token_type_ids': torch.Size([256, 128]), 'label': torch.Size([256]), 'input_ids': torch.Size([256, 128]), 'attention_mask': torch.Size([256, 128])}
[torch.Size([256, 128]), torch.Size([256]), torch.Size([256, 128]), torch.Size([256, 128])]
Load neibor_inds from disk
Number of element in each class: [100, 100]


Collect Ref Grads: 1it [00:00,  4.47it/s]
Tracing: 100%|██████████| 6/6 [00:04<00:00,  1.25it/s]


Load cache outputs/imdb/flip0_bert/121_2023-01-02_12-11-48/checkpoints/epoch=08_val_acc=0.8824.ckpt.cache


Loop KNN GD: 1000it [00:01, 593.35it/s]


outputs/imdb/flip0_bert/121_2023-01-02_12-11-48/checkpoints/epoch=09_val_acc=0.8831.ckpt
{'token_type_ids': torch.Size([256, 128]), 'label': torch.Size([256]), 'input_ids': torch.Size([256, 128]), 'attention_mask': torch.Size([256, 128])}
[torch.Size([256, 128]), torch.Size([256]), torch.Size([256, 128]), torch.Size([256, 128])]
Load neibor_inds from disk
Number of element in each class: [100, 100]


Collect Ref Grads: 1it [00:00,  4.14it/s]
Tracing: 100%|██████████| 6/6 [00:04<00:00,  1.24it/s]


Load cache outputs/imdb/flip0_bert/121_2023-01-02_12-11-48/checkpoints/epoch=09_val_acc=0.8831.ckpt.cache


Loop KNN GD: 1000it [00:01, 591.31it/s]


outputs/imdb/flip0_bert/122_2023-01-02_12-11-48/checkpoints/epoch=00_val_acc=0.8723.ckpt
{'token_type_ids': torch.Size([256, 128]), 'label': torch.Size([256]), 'input_ids': torch.Size([256, 128]), 'attention_mask': torch.Size([256, 128])}
[torch.Size([256, 128]), torch.Size([256]), torch.Size([256, 128]), torch.Size([256, 128])]
Load neibor_inds from disk
Number of element in each class: [100, 100]


Collect Ref Grads: 1it [00:00,  4.36it/s]
Tracing: 100%|██████████| 6/6 [00:04<00:00,  1.26it/s]


Load cache outputs/imdb/flip0_bert/122_2023-01-02_12-11-48/checkpoints/epoch=00_val_acc=0.8723.ckpt.cache


Loop KNN GD: 1000it [00:01, 582.43it/s]


outputs/imdb/flip0_bert/122_2023-01-02_12-11-48/checkpoints/epoch=01_val_acc=0.8777.ckpt
{'token_type_ids': torch.Size([256, 128]), 'label': torch.Size([256]), 'input_ids': torch.Size([256, 128]), 'attention_mask': torch.Size([256, 128])}
[torch.Size([256, 128]), torch.Size([256]), torch.Size([256, 128]), torch.Size([256, 128])]
Load neibor_inds from disk
Number of element in each class: [100, 100]


Collect Ref Grads: 1it [00:00,  3.83it/s]
Tracing: 100%|██████████| 6/6 [00:04<00:00,  1.22it/s]


Load cache outputs/imdb/flip0_bert/122_2023-01-02_12-11-48/checkpoints/epoch=01_val_acc=0.8777.ckpt.cache


Loop KNN GD: 1000it [00:01, 576.09it/s]


outputs/imdb/flip0_bert/122_2023-01-02_12-11-48/checkpoints/epoch=02_val_acc=0.8753.ckpt
{'token_type_ids': torch.Size([256, 128]), 'label': torch.Size([256]), 'input_ids': torch.Size([256, 128]), 'attention_mask': torch.Size([256, 128])}
[torch.Size([256, 128]), torch.Size([256]), torch.Size([256, 128]), torch.Size([256, 128])]
Load neibor_inds from disk
Number of element in each class: [100, 100]


Collect Ref Grads: 1it [00:00,  3.93it/s]
Tracing: 100%|██████████| 6/6 [00:04<00:00,  1.23it/s]


Load cache outputs/imdb/flip0_bert/122_2023-01-02_12-11-48/checkpoints/epoch=02_val_acc=0.8753.ckpt.cache


Loop KNN GD: 1000it [00:01, 536.56it/s]


outputs/imdb/flip0_bert/122_2023-01-02_12-11-48/checkpoints/epoch=06_val_acc=0.8720.ckpt
{'token_type_ids': torch.Size([256, 128]), 'label': torch.Size([256]), 'input_ids': torch.Size([256, 128]), 'attention_mask': torch.Size([256, 128])}
[torch.Size([256, 128]), torch.Size([256]), torch.Size([256, 128]), torch.Size([256, 128])]
Load neibor_inds from disk
Number of element in each class: [100, 100]


Collect Ref Grads: 1it [00:00,  4.21it/s]
Tracing:   0%|          | 0/6 [00:00<?, ?it/s]