In [1]:
%load_ext autoreload
%autoreload 2

# Setup Paths

In [2]:
import os
import pyrootutils
root = pyrootutils.setup_root(
    search_from='.',
    indicator=[".git", "pyproject.toml"],
    pythonpath=True,
    dotenv=True,
)
hydra_cfg_path = root / "text_classification_problems" / "configs"
os.chdir(root / "text_classification_problems")

In [3]:
import hydra
from hydra import compose, initialize

from pathlib import Path
import numpy as np
import torch
from core.grads import tree_to_device
from core.tracer import KNN, KNNGD, KNNGN
from sklearn.neighbors import KNeighborsClassifier
from text_classification_problems.datamodule import TextClassifierDataModule
from text_classification_problems.modelmodule import TextClassifierModel
from transformers import AutoTokenizer
from tqdm import tqdm

import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


# Config  

In [4]:
'..' / hydra_cfg_path.relative_to(root)
with initialize(version_base=None, config_path= '../configs'):
    cfg = compose(config_name="tracing", return_hydra_config=True, overrides=["datamodule=imdb", "tracer=gd"])

In [5]:
device = 'cuda:1'

# Load Data and Model

In [6]:
from datamodule import TextClassifierDataModule

In [7]:
checkpoint = torch.load("outputs/imdb/flip0.2_bert/121_2023-01-03_19-26-09/checkpoints/epoch=00_val_acc=0.8185.ckpt", map_location=device)
datamodule_hparams = checkpoint["datamodule_hyper_parameters"] 

In [8]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", use_fast=True)
dm = TextClassifierDataModule(
    data_root=os.environ["PYTORCH_DATASET_ROOT"],
    tokenizer=tokenizer,
    **datamodule_hparams,
    use_denoised_data=True
)
dm.prepare_data()
dm.setup("tracing")

In [9]:
net = hydra.utils.instantiate(cfg.net, num_classes=dm.num_classes)
lit_model = TextClassifierModel(
    net=net,
    num_classes=dm.num_classes,
    lr=1e-3,
)
lit_model.load_state_dict(checkpoint["state_dict"])
net = lit_model.net
lit_model.eval()
lit_model.to(device);

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [10]:
# np.savez_compressed('outputs/train/imdb/knn.npz', neibor_inds.numpy())

In [11]:
# neibor_inds = np.load('outputs/train/imdb/knn.npz')['arr_0']

## Tracing

In [12]:
from core.grads import RuntimeGradientExtractor
from core.tracer import GradientNormalize as GN, GradientCosin as GC, GradientBasedTracer as GD
import torch.nn.functional as F
from text_classification_problems.run_tracing import register_BatchEncoding
import pandas as pd
from core.aggregation import aggregation
from core.aggregation import cal_neibor_matrices
from text_classification_problems.convert_result import eval_ckpt
from transformers import AutoTokenizer
import re

In [13]:
register_BatchEncoding()

## Comparation

In [14]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", use_fast=True)
colected_ks = list(range(50, 1001, 10))
sel_sizes = list(range(50, 3500, 50))
ckpt_path = "outputs/imdb/flip0.2_bert/121_2023-01-03_19-26-09/checkpoints/epoch=01_val_acc=0.8348.ckpt"

In [15]:
# result_df, all_scores = eval_ckpt(tokenizer,
#     lit_model,
#     ckpt_path,
#     use_denoised_data=False,
#     is_self_ref=True,
#     sel_sizes=sel_sizes,
#     colected_ks=colected_ks,
#     device=device)
# torch.save(all_scores, ckpt_path + ".scores")

In [16]:
def get_best_ckpt(checkpoint_dir: Path):
    metrics = [(float(re.search('val_acc=([+-]?([0-9]*[.])?[0-9]+)', str(p)).group(1)),p) for p in checkpoint_dir.glob("epoch*.ckpt")]
    metrics.sort(reverse=True)
    # metrics = [(float(re.search('epoch=([+-]?([0-9]*[.])?[0-9]+)', str(p)).group(1)),p) for p in checkpoint_dir.glob("epoch*.ckpt")]
    # metrics.sort(reverse=False)

    best_ckpt_path = metrics[0][1]
    return best_ckpt_path

In [18]:
imdb_real_train_path = Path("outputs/imdb/flip0.2_bert/")
output_dirs = '''121_2023-01-03_19-26-09
122_2023-01-03_19-26-09
123_2023-01-03_19-26-09
124_2023-01-03_19-26-09
125_2023-01-03_19-26-09'''

best_ckpt_results = []
for run in output_dirs.split('\n'):
    run = imdb_real_train_path / run
    ckpt_path = get_best_ckpt(run / "checkpoints")
# for ckpt_path in imdb_real_train_path.rglob("checkpoints/epoch=*.ckpt"):
    ckpt_path = str(ckpt_path)
    print(ckpt_path)
    result_df, all_scores = eval_ckpt(tokenizer,
        lit_model,
        ckpt_path,
        use_denoised_data=False,
        is_self_ref=True,
        sel_sizes=sel_sizes,
        colected_ks=colected_ks,
        use_cache=True,
        device=device)
    result_df.to_csv(ckpt_path+".csv")
    torch.save(all_scores, ckpt_path + ".scores")

outputs/imdb/flip0.2_bert/121_2023-01-03_19-26-09/checkpoints/epoch=01_val_acc=0.8348.ckpt
{'token_type_ids': torch.Size([256, 128]), 'label': torch.Size([256]), 'input_ids': torch.Size([256, 128]), 'attention_mask': torch.Size([256, 128])}
[torch.Size([256, 128]), torch.Size([256]), torch.Size([256, 128]), torch.Size([256, 128])]
Load neibor_inds from disk
Number of element in each class: [100, 100]


Collect Ref Grads: 1it [00:00,  5.03it/s]
Tracing: 100%|██████████| 69/69 [00:27<00:00,  2.50it/s]
Loop ref knn: 100%|██████████| 69/69 [00:02<00:00, 23.63it/s]
Loop KNN GD: 1000it [01:13, 13.59it/s]


outputs/imdb/flip0.2_bert/122_2023-01-03_19-26-09/checkpoints/epoch=01_val_acc=0.8616.ckpt
{'token_type_ids': torch.Size([256, 128]), 'label': torch.Size([256]), 'input_ids': torch.Size([256, 128]), 'attention_mask': torch.Size([256, 128])}
[torch.Size([256, 128]), torch.Size([256]), torch.Size([256, 128]), torch.Size([256, 128])]
Compute neibor_inds


100%|██████████| 69/69 [00:25<00:00,  2.66it/s]
100%|██████████| 69/69 [01:15<00:00,  1.10s/it]


Saving neibor_inds
Number of element in each class: [100, 100]


Collect Ref Grads: 1it [00:00,  4.38it/s]
Tracing: 100%|██████████| 69/69 [00:28<00:00,  2.39it/s]
Loop ref knn: 100%|██████████| 69/69 [00:02<00:00, 23.24it/s]
Loop KNN GD: 1000it [01:29, 11.14it/s]


outputs/imdb/flip0.2_bert/123_2023-01-03_19-26-09/checkpoints/epoch=00_val_acc=0.8487.ckpt
{'token_type_ids': torch.Size([256, 128]), 'label': torch.Size([256]), 'input_ids': torch.Size([256, 128]), 'attention_mask': torch.Size([256, 128])}
[torch.Size([256, 128]), torch.Size([256]), torch.Size([256, 128]), torch.Size([256, 128])]
Compute neibor_inds


100%|██████████| 69/69 [00:25<00:00,  2.66it/s]
100%|██████████| 69/69 [01:01<00:00,  1.12it/s]


Saving neibor_inds
Number of element in each class: [100, 100]


Collect Ref Grads: 1it [00:00,  4.11it/s]
Tracing: 100%|██████████| 69/69 [00:28<00:00,  2.38it/s]
Loop ref knn: 100%|██████████| 69/69 [00:02<00:00, 23.14it/s]
Loop KNN GD: 1000it [01:25, 11.64it/s]


outputs/imdb/flip0.2_bert/124_2023-01-03_19-26-09/checkpoints/epoch=01_val_acc=0.8553.ckpt
{'token_type_ids': torch.Size([256, 128]), 'label': torch.Size([256]), 'input_ids': torch.Size([256, 128]), 'attention_mask': torch.Size([256, 128])}
[torch.Size([256, 128]), torch.Size([256]), torch.Size([256, 128]), torch.Size([256, 128])]
Compute neibor_inds


100%|██████████| 69/69 [00:26<00:00,  2.65it/s]
100%|██████████| 69/69 [01:15<00:00,  1.10s/it]


Saving neibor_inds
Number of element in each class: [100, 100]


Collect Ref Grads: 1it [00:00,  4.54it/s]
Tracing: 100%|██████████| 69/69 [00:28<00:00,  2.40it/s]
Loop ref knn: 100%|██████████| 69/69 [00:02<00:00, 23.59it/s]
Loop KNN GD: 1000it [01:26, 11.58it/s]


outputs/imdb/flip0.2_bert/125_2023-01-03_19-26-09/checkpoints/epoch=00_val_acc=0.8535.ckpt
{'token_type_ids': torch.Size([256, 128]), 'label': torch.Size([256]), 'input_ids': torch.Size([256, 128]), 'attention_mask': torch.Size([256, 128])}
[torch.Size([256, 128]), torch.Size([256]), torch.Size([256, 128]), torch.Size([256, 128])]
Compute neibor_inds


100%|██████████| 69/69 [00:25<00:00,  2.66it/s]
100%|██████████| 69/69 [00:59<00:00,  1.17it/s]


Saving neibor_inds
Number of element in each class: [100, 100]


Collect Ref Grads: 1it [00:00,  4.29it/s]
Tracing: 100%|██████████| 69/69 [00:29<00:00,  2.38it/s]
Loop ref knn: 100%|██████████| 69/69 [00:02<00:00, 23.01it/s]
Loop KNN GD: 1000it [01:27, 11.43it/s]
