# Techer model probing: Performance after remove some teacher layer

In [1]:
import torch
import fine_tune
from tqdm import tqdm

In [2]:
TASK = 'qnli'
TEACHER_EXP = 'teacher_base'
TMODEL = 'bert'
TCKPT = 8000
DATASET = 'dev'
HIDDEN_LAYERS = 6
SDEVICE = 1
TDEVICE = 1
BATCH_SIZE = 32
LAYER_MAPPING = 'odd'

if TASK == 'qnli':
    NUM_CLASS = 2
elif TASK == 'mnli':
    NUM_CLASS = 3
else:
    raise ValueError(f"Unsupported task {task}")


In [3]:
teacher_config = fine_tune.config.TeacherConfig.load(
    experiment=TEACHER_EXP,
    model=TMODEL,
    task=TASK
)
teacher_config.device_id = TDEVICE
teacher_config.dataset = DATASET
print(teacher_config)

FileNotFoundError: [Errno 2] No such file or directory: '/home/kychen/Desktop/BERT-gang/data/fine_tune_experiment/teacher_base_bert_mnli/config.json'

In [4]:
fine_tune.util.set_seed_by_config(teacher_config)

In [5]:
dataset = fine_tune.util.load_dataset_by_config(
    config=teacher_config
)

2021/06/01 17:29:12 - INFO - fine_tune.task -   Start loading task QNLI dataset dev.
Loading QNLI dev: 5463it [00:00, 411247.60it/s]
2021/06/01 17:29:12 - INFO - fine_tune.task -   Number of samples: 5463
2021/06/01 17:29:12 - INFO - fine_tune.task -   Finish loading task QNLI dataset dev.


In [6]:
student_config = fine_tune.config.StudentConfig(
    accum_step = 1,
    batch_size = BATCH_SIZE,
    d_emb = 128,
    d_ff = 3072,
    d_model = 768,
    dataset = teacher_config.dataset,
    experiment = 'QNLI_TEACHER_PROBING',
    dropout = 0.1,
    eps = 1e-8,
    log_step = 100,
    lr = 3e-5,
    max_norm = 1.0,
    num_attention_heads = 12,
    num_hidden_layers = HIDDEN_LAYERS,
    seed = teacher_config.seed,
    max_seq_len = 128,
    model = 'bert',
    num_class = NUM_CLASS,
    task = TASK,
    total_step = 13096,
    type_vocab_size = 2,
    warmup_step = 1309,
    weight_decay = 0.01,
    device_id = SDEVICE
)

In [7]:
print(student_config)


+---------------------------------------------+
| configuration        | value                |
+---------------------------------------------+
| accum_step           | 1                    |
| amp                  | 0                    |
| batch_size           | 32                   |
| beta1                | 0.9                  |
| beta2                | 0.999                |
| ckpt_step            | 1000                 |
| d_emb                | 128                  |
| d_ff                 | 3072                 |
| d_model              | 768                  |
| dataset              | dev                  |
| device_id            | 1                    |
| dropout              | 0.1                  |
| eps                  | 1e-08                |
| experiment           | QNLI_TEACHER_PROBING |
| log_step             | 100                  |
| lr                   | 3e-05                |
| max_norm             | 1.0                  |
| max_seq_len          | 128           

In [8]:
teacher_tknr = fine_tune.util.load_teacher_tokenizer_by_config(
    config=teacher_config
)
student_tknr = fine_tune.util.load_student_tokenizer_by_config(
    config=student_config
)

In [9]:
TEACHER_EXP_NAME = fine_tune.config.BaseConfig.experiment_name(
    experiment=teacher_config.experiment,
    model=teacher_config.model,
    task=teacher_config.task
)
TEACHER_EXP_DIR = os.path.join(
    fine_tune.path.FINE_TUNE_EXPERIMENT,
    TEACHER_EXP_NAME
)
print(TEACHER_EXP_DIR)

/home/kychen/Desktop/BERT-gang/data/fine_tune_experiment/teacher_base_bert_qnli


In [10]:
teacher_model = fine_tune.util.load_teacher_model_by_config(
    config=teacher_config
)
teacher_model.load_state_dict(
    torch.load(
        os.path.join(TEACHER_EXP_DIR, f'model-{TCKPT}.pt'),
        map_location=teacher_config.device
    )
)

<All keys matched successfully>

In [11]:
student_model = fine_tune.util.load_student_model_by_config(
    config=student_config,
    tokenizer=student_tknr
)

# Load fine-tuned teacher weight

In [12]:
if LAYER_MAPPING == 'even':
    teacher_indices = list(range(1, 12, 2))
elif LAYER_MAPPING == 'odd':
    teacher_indices = list(range(0, 12, 2))
elif LAYER_MAPPING == 'user_defined':
    teacher_indices = [int(item)-1 for item in input("Enter desired teacher layer:\n").split()]
else:
    raise ValueError(f"Invalid mapping strategy: {LAYER_MAPPING}")

## Load Encoder weight

In [13]:
teacher_encoder_weight = teacher_model.encoder.state_dict()
new_state_dict = {}
keys = [
    'attention.self.query.weight',
    'attention.self.query.bias',
    'attention.self.key.weight',
    'attention.self.key.bias',
    'attention.self.value.weight',
    'attention.self.value.bias',
    'attention.output.dense.weight',
    'attention.output.dense.bias',
    'attention.output.LayerNorm.weight',
    'attention.output.LayerNorm.bias',
    'intermediate.dense.weight',
    'intermediate.dense.bias',
    'output.dense.weight',
    'output.dense.bias',
    'output.LayerNorm.weight',
    'output.LayerNorm.bias'
]

for i, t_index in enumerate(teacher_indices):
    for key in keys:
        new_state_dict.update(
            {
                f'encoder.layer.{i}.{key}':
                teacher_encoder_weight[f'encoder.layer.{t_index}.{key}']
            }
        )


In [14]:
new_state_dict.update(
    {
        'pooler.dense.weight':teacher_encoder_weight['pooler.dense.weight'],
        'pooler.dense.bias':teacher_encoder_weight['pooler.dense.bias']
    }
)

In [15]:
student_model.encoder.load_state_dict(
    new_state_dict,
    strict=False
)

_IncompatibleKeys(missing_keys=['embeddings.position_ids', 'embeddings.word_embeddings.weight', 'embeddings.position_embeddings.weight', 'embeddings.token_type_embeddings.weight', 'embeddings.LayerNorm.weight', 'embeddings.LayerNorm.bias'], unexpected_keys=[])

## Load classification layer weight

In [16]:
teacher_linear_weight = teacher_model.linear_layer.state_dict()
student_model.linear_layer.load_state_dict(
    teacher_linear_weight
)

<All keys matched successfully>

# Run Evaluation

In [17]:
acc = fine_tune.util.evaluation(
    config=student_config,
    dataset=dataset,
    model=student_model,
    tokenizer=student_tknr
)

100%|██████████| 171/171 [00:11<00:00, 14.49it/s]


In [18]:
print(acc)

0.5053999633900788
