# Techer model probing: Performance after remove some teacher layer

In [None]:
import torch
import fine_tune
from tqdm import tqdm

In [None]:
TASK = 'qnli'
TEACHER_EXP = 'teacher_base'
TMODEL = 'bert'
TCKPT = 8000
DATASET = 'dev'
HIDDEN_LAYERS = 6
SDEVICE = 1
TDEVICE = 1
BATCH_SIZE = 32
LAYER_MAPPING = 'odd'

if TASK == 'qnli':
    NUM_CLASS = 2
elif TASK == 'mnli':
    NUM_CLASS = 3
else:
    raise ValueError(f"Unsupported task {task}")


In [None]:
teacher_config = fine_tune.config.TeacherConfig.load(
    experiment=TEACHER_EXP,
    model=TMODEL,
    task=TASK
)
teacher_config.device_id = TDEVICE
teacher_config.dataset = DATASET
print(teacher_config)

In [None]:
fine_tune.util.set_seed_by_config(teacher_config)

In [None]:
dataset = fine_tune.util.load_dataset_by_config(
    config=teacher_config
)

In [None]:
student_config = fine_tune.config.StudentConfig(
    accum_step = 1,
    batch_size = BATCH_SIZE,
    d_emb = 128,
    d_ff = 3072,
    d_model = 768,
    dataset = teacher_config.dataset,
    experiment = 'QNLI_TEACHER_PROBING',
    dropout = 0.1,
    eps = 1e-8,
    log_step = 100,
    lr = 3e-5,
    max_norm = 1.0,
    num_attention_heads = 12,
    num_hidden_layers = HIDDEN_LAYERS,
    seed = teacher_config.seed,
    max_seq_len = 128,
    model = 'bert',
    num_class = NUM_CLASS,
    task = TASK,
    total_step = 13096,
    type_vocab_size = 2,
    warmup_step = 1309,
    weight_decay = 0.01,
    device_id = SDEVICE
)

In [None]:
print(student_config)

In [None]:
teacher_tknr = fine_tune.util.load_teacher_tokenizer_by_config(
    config=teacher_config
)
student_tknr = fine_tune.util.load_student_tokenizer_by_config(
    config=student_config
)

In [None]:
TEACHER_EXP_NAME = fine_tune.config.BaseConfig.experiment_name(
    experiment=teacher_config.experiment,
    model=teacher_config.model,
    task=teacher_config.task
)
TEACHER_EXP_DIR = os.path.join(
    fine_tune.path.FINE_TUNE_EXPERIMENT,
    TEACHER_EXP_NAME
)
print(TEACHER_EXP_DIR)

In [None]:
teacher_model = fine_tune.util.load_teacher_model_by_config(
    config=teacher_config
)
teacher_model.load_state_dict(
    torch.load(
        os.path.join(TEACHER_EXP_DIR, f'model-{TCKPT}.pt'),
        map_location=teacher_config.device
    )
)

In [None]:
student_model = fine_tune.util.load_student_model_by_config(
    config=student_config,
    tokenizer=student_tknr
)

# Load fine-tuned teacher weight

In [None]:
if LAYER_MAPPING == 'even':
    teacher_indices = list(range(1, 12, 2))
elif LAYER_MAPPING == 'odd':
    teacher_indices = list(range(0, 12, 2))
elif LAYER_MAPPING == 'user_defined':
    teacher_indices = [int(item)-1 for item in input("Enter desired teacher layer:\n").split()]
else:
    raise ValueError(f"Invalid mapping strategy: {LAYER_MAPPING}")

## Load Encoder weight

In [None]:
teacher_encoder_weight = teacher_model.encoder.state_dict()
new_state_dict = {}
keys = [
    'attention.self.query.weight',
    'attention.self.query.bias',
    'attention.self.key.weight',
    'attention.self.key.bias',
    'attention.self.value.weight',
    'attention.self.value.bias',
    'attention.output.dense.weight',
    'attention.output.dense.bias',
    'attention.output.LayerNorm.weight',
    'attention.output.LayerNorm.bias',
    'intermediate.dense.weight',
    'intermediate.dense.bias',
    'output.dense.weight',
    'output.dense.bias',
    'output.LayerNorm.weight',
    'output.LayerNorm.bias'
]

for i, t_index in enumerate(teacher_indices):
    for key in keys:
        new_state_dict.update(
            {
                f'encoder.layer.{i}.{key}':
                teacher_encoder_weight[f'encoder.layer.{t_index}.{key}']
            }
        )


In [None]:
new_state_dict.update(
    {
        'pooler.dense.weight':teacher_encoder_weight['pooler.dense.weight'],
        'pooler.dense.bias':teacher_encoder_weight['pooler.dense.bias']
    }
)

In [None]:
student_model.encoder.load_state_dict(
    new_state_dict,
    strict=False
)

## Load classification layer weight

In [None]:
teacher_linear_weight = teacher_model.linear_layer.state_dict()
student_model.linear_layer.load_state_dict(
    teacher_linear_weight
)

# Run Evaluation

In [None]:
acc = fine_tune.util.evaluation(
    config=student_config,
    dataset=dataset,
    model=student_model,
    tokenizer=student_tknr
)

In [None]:
print(acc)