In [1]:
import torch
import torch.nn
import torch.utils
import torch.utils.data
import fine_tune
from tqdm import tqdm

In [2]:
TASK = 'qnli'
TEACHER_EXP = 'teacher_base'
TMODEL = 'bert'
TCKPT = 9822
STUDENT_EXP = 'MSE_init_from_pre_trained'
SMODEL = 'bert'
SCKPT = 13096
DATASET = 'train'
SDEVICE = 0
TDEVICE = 0
BATCH_SIZE = 1


In [3]:
teacher_config = fine_tune.config.TeacherConfig.load(
    experiment=TEACHER_EXP,
    model=TMODEL,
    task=TASK
)
teacher_config.device_id = TDEVICE
teacher_config.dataset = DATASET
print(teacher_config)


+---------------------------------------+
| configuration     | value             |
+---------------------------------------+
| accum_step        | 2                 |
| amp               | 0                 |
| batch_size        | 32                |
| beta1             | 0.9               |
| beta2             | 0.999             |
| ckpt_step         | 1000              |
| dataset           | train             |
| device_id         | 0                 |
| dropout           | 0.1               |
| eps               | 1e-08             |
| experiment        | teacher_base      |
| log_step          | 500               |
| lr                | 3e-05             |
| max_norm          | 1.0               |
| max_seq_len       | 128               |
| model             | bert              |
| num_class         | 2                 |
| ptrain_ver        | bert-base-uncased |
| seed              | 42                |
| task              | qnli              |
| total_step        | 9822       

In [4]:
student_config = fine_tune.config.StudentConfig.load(
    experiment=STUDENT_EXP,
    model=SMODEL,
    task=TASK
)
student_config.device_id = SDEVICE
student_config.dataset = DATASET
print(student_config)


+-------------------------------------------+
| configuration       | value               |
+-------------------------------------------+
| accum_step          | 2                   |
| amp                 | 0                   |
| batch_size          | 32                  |
| beta1               | 0.9                 |
| beta2               | 0.999               |
| ckpt_step           | 2000                |
| d_emb               | 128                 |
| d_ff                | 3072                |
| d_model             | 768                 |
| dataset             | train               |
| device_id           | 0                   |
| dropout             | 0.1                 |
| eps                 | 1e-08               |
| experiment          | Contrast_by_sample  |
| log_step            | 100                 |
| lr                  | 3e-05               |
| max_norm            | 1.0                 |
| max_seq_len         | 128                 |
| model               | bert     

# Use training dataset to find similarity between CLS embeddings

In [5]:
dataset = fine_tune.util.load_dataset_by_config(
    config=teacher_config
)

2021/05/07 15:32:15 - INFO - fine_tune.task -   Start loading task QNLI dataset train.
Loading QNLI train: 104743it [00:00, 472977.26it/s]
2021/05/07 15:32:15 - INFO - fine_tune.task -   Number of samples: 104743
2021/05/07 15:32:15 - INFO - fine_tune.task -   Finish loading task QNLI dataset train.


In [6]:
teacher_tknr = fine_tune.util.load_teacher_tokenizer_by_config(
    config=teacher_config
)
student_tknr = fine_tune.util.load_student_tokenizer_by_config(
    config=student_config
)

In [7]:
TEACHER_EXP_NAME = fine_tune.config.BaseConfig.experiment_name(
    experiment=teacher_config.experiment,
    model=teacher_config.model,
    task=teacher_config.task
)
TEACHER_EXP_DIR = os.path.join(
    fine_tune.path.FINE_TUNE_EXPERIMENT,
    TEACHER_EXP_NAME
)
print(TEACHER_EXP_DIR)

/home/kychen/Desktop/BERT-gang/data/fine_tune_experiment/teacher_base_bert_qnli


In [8]:

teacher_model = fine_tune.util.load_teacher_model_by_config(
    config=teacher_config
)
teacher_model.load_state_dict(
    torch.load(
        os.path.join(TEACHER_EXP_DIR, f'model-{TCKPT}.pt'),
        map_location=teacher_config.device
    )
)

<All keys matched successfully>

In [9]:
STUDENT_EXP_NAME = fine_tune.config.BaseConfig.experiment_name(
    experiment=student_config.experiment,
    model=student_config.model,
    task=student_config.task
)
STUDENT_EXP_DIR = os.path.join(
    fine_tune.path.FINE_TUNE_EXPERIMENT,
    STUDENT_EXP_NAME
)
print(STUDENT_EXP_DIR)

/home/kychen/Desktop/BERT-gang/data/fine_tune_experiment/Contrast_by_sample_bert_qnli


In [10]:
student_model = fine_tune.util.load_student_model_by_config(
    config=student_config,
    tokenizer=student_tknr
)
student_model.load_state_dict(
    torch.load(
        os.path.join(STUDENT_EXP_DIR, f'model-{SCKPT}.pt'),
        map_location=student_config.device
    )
)

2021/05/07 15:32:25 - INFO - fine_tune.model.student_bert -   Load model state dict from pre-trained model
2021/05/07 15:32:25 - INFO - fine_tune.model.student_bert -   Finish initialization from pre-trained model


<All keys matched successfully>

In [11]:
dataloader = torch.utils.data.DataLoader(
    dataset,
    collate_fn=dataset.create_collate_fn(),
    batch_size=BATCH_SIZE
)

In [12]:
cosine_sim = torch.nn.CosineSimilarity()

In [13]:
not_similar_list = []

In [14]:
student_model.eval()
teacher_model.eval()
for text, text_pair, label in tqdm(dataloader):
    teacher_encode = teacher_tknr(
        text=text,
        text_pair=text_pair,
        padding='max_length',
        max_length=teacher_config.max_seq_len,
        return_tensors='pt',
        truncation=True
    )
    teacher_input_ids = teacher_encode['input_ids']
    teacher_token_type_ids = teacher_encode['token_type_ids']
    teacher_attention_mask = teacher_encode['attention_mask']

    student_encode = student_tknr(
        text=text,
        text_pair=text_pair,
        padding='max_length',
        max_length=student_config.max_seq_len,
        return_tensors='pt',
        truncation=True
    )
    student_input_ids = student_encode['input_ids']
    student_token_type_ids = student_encode['token_type_ids']
    student_attention_mask = student_encode['attention_mask']

    _, t_cls = teacher_model(
        input_ids = teacher_input_ids.to(teacher_config.device),
        token_type_ids = teacher_token_type_ids.to(teacher_config.device),
        attention_mask = teacher_attention_mask.to(teacher_config.device)
    )

    _, s_cls = student_model(
        input_ids = student_input_ids.to(student_config.device),
        token_type_ids = student_token_type_ids.to(student_config.device),
        attention_mask = student_attention_mask.to(student_config.device)
    )

    if cosine_sim(t_cls.to(student_config.device), s_cls) < 0.5:
        not_similar_list.append({'text':text, 'text_pair':text_pair, 'label': label})

100%|██████████| 104743/104743 [24:03<00:00, 72.58it/s]


In [15]:
len(not_similar_list)

59225

In [16]:
len(dataset) * (0.989164-0.958021)

3262.0112490000033

In [17]:
import csv
with open(f'not_similar_qnli_{STUDENT_EXP}_{DATASET}.csv', 'w', newline='') as file:
    fc = csv.DictWriter(file, fieldnames=not_similar_list[0].keys())
    fc.writeheader()
    fc.writerows(not_similar_list)
    