In [1]:
import torch
import torch.nn
import torch.utils
import torch.utils.data
import fine_tune
from tqdm import tqdm
from transformers import (DistilBertForSequenceClassification,
                          DistilBertTokenizer)

In [2]:
TASK = 'qnli'
TEACHER_EXP = 'teacher_base'
TMODEL = 'bert'
TCKPT = 9822
STUDENT_EXP = 'distill_bert_base'
SMODEL = 'bert'
SCKPT = 9000
DATASET = 'train'
SDEVICE = 0
TDEVICE = 0
BATCH_SIZE = 1

In [3]:
teacher_config = fine_tune.config.TeacherConfig.load(
    experiment=TEACHER_EXP,
    model=TMODEL,
    task=TASK
)
teacher_config.device_id = TDEVICE
teacher_config.dataset = DATASET
print(teacher_config)


+---------------------------------------+
| configuration     | value             |
+---------------------------------------+
| accum_step        | 2                 |
| amp               | 0                 |
| batch_size        | 32                |
| beta1             | 0.9               |
| beta2             | 0.999             |
| ckpt_step         | 1000              |
| dataset           | train             |
| device_id         | 0                 |
| dropout           | 0.1               |
| eps               | 1e-08             |
| experiment        | teacher_base      |
| log_step          | 500               |
| lr                | 3e-05             |
| max_norm          | 1.0               |
| max_seq_len       | 128               |
| model             | bert              |
| num_class         | 2                 |
| ptrain_ver        | bert-base-uncased |
| seed              | 42                |
| task              | qnli              |
| total_step        | 9822       

In [4]:
dataset = fine_tune.util.load_dataset_by_config(
    config=teacher_config
)

2021/05/05 17:01:08 - INFO - fine_tune.task -   Start loading task QNLI dataset train.
Loading QNLI train: 104743it [00:00, 462141.77it/s]
2021/05/05 17:01:08 - INFO - fine_tune.task -   Number of samples: 104743
2021/05/05 17:01:08 - INFO - fine_tune.task -   Finish loading task QNLI dataset train.


In [5]:
teacher_tknr = fine_tune.util.load_teacher_tokenizer_by_config(
    config=teacher_config
)

In [6]:
TEACHER_EXP_NAME = fine_tune.config.BaseConfig.experiment_name(
    experiment=teacher_config.experiment,
    model=teacher_config.model,
    task=teacher_config.task
)
TEACHER_EXP_DIR = os.path.join(
    fine_tune.path.FINE_TUNE_EXPERIMENT,
    TEACHER_EXP_NAME
)
print(TEACHER_EXP_DIR)

/home/kychen/Desktop/BERT-gang/data/fine_tune_experiment/teacher_base_bert_qnli


In [7]:
teacher_model = fine_tune.util.load_teacher_model_by_config(
    config=teacher_config
)
teacher_model.load_state_dict(
    torch.load(
        os.path.join(TEACHER_EXP_DIR, f'model-{TCKPT}.pt'),
        map_location=teacher_config.device
    )
)

<All keys matched successfully>

In [8]:
student_tknr = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

In [9]:
STUDENT_EXP_NAME = fine_tune.config.BaseConfig.experiment_name(
    experiment=STUDENT_EXP,
    model=SMODEL,
    task=TASK
)
STUDENT_EXP_DIR = os.path.join(
    fine_tune.path.FINE_TUNE_EXPERIMENT,
    STUDENT_EXP_NAME
)
print(STUDENT_EXP_DIR)

/home/kychen/Desktop/BERT-gang/data/fine_tune_experiment/distill_bert_base_bert_qnli


In [10]:
student_model = DistilBertForSequenceClassification.from_pretrained(
    'distilbert-base-uncased',
    return_dict = True
).to(SDEVICE)
student_model.load_state_dict(
    torch.load(
        os.path.join(STUDENT_EXP_DIR, f'model-{SCKPT}.pt'),
        map_location=f'cuda:{SDEVICE}'
    )
)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'pre_classifier.bias', 'classi

<All keys matched successfully>

In [11]:
dataloader = torch.utils.data.DataLoader(
    dataset,
    collate_fn=dataset.create_collate_fn(),
    batch_size=1
)

In [12]:
cosine_sim = torch.nn.CosineSimilarity()

In [13]:
not_similar_list = []

In [14]:
student_model.eval()
teacher_model.eval()
for text, text_pair, label in tqdm(dataloader):
    teacher_encode = teacher_tknr(
        text=text,
        text_pair=text_pair,
        padding='max_length',
        max_length=teacher_config.max_seq_len,
        return_tensors='pt',
        truncation=True
    )
    teacher_input_ids = teacher_encode['input_ids']
    teacher_token_type_ids = teacher_encode['token_type_ids']
    teacher_attention_mask = teacher_encode['attention_mask']

    student_encode = student_tknr(
        text=text,
        text_pair=text_pair,
        padding='max_length',
        max_length=128,
        return_tensors='pt',
        truncation=True
    )
    student_input_ids = student_encode['input_ids']
    student_attention_mask = student_encode['attention_mask']

    _, t_cls, _ = teacher_model(
        input_ids = teacher_input_ids.to(teacher_config.device),
        token_type_ids = teacher_token_type_ids.to(teacher_config.device),
        attention_mask = teacher_attention_mask.to(teacher_config.device),
        return_hidden_and_attn = True
    )
    t_cls = t_cls[-1][:,0,:]

    output = student_model(
        input_ids = student_input_ids.to(f'cuda:{SDEVICE}'),
        attention_mask = student_attention_mask.to(f'cuda:{SDEVICE}'),
        output_hidden_states=True
    )

    s_cls = output.hidden_states[-1][:,0,:]
    
    if cosine_sim(t_cls.to(f'cuda:{SDEVICE}'), s_cls) < 0.5:
        not_similar_list.append({'text':text, 'text_pair':text_pair, 'label': label})

100%|██████████| 104743/104743 [29:55<00:00, 58.33it/s]


In [15]:
len(not_similar_list)

104517

In [16]:
len(dataset) * (0.989164-0.963597)

2677.9642810000005