In [1]:
%load_ext autoreload
%autoreload 2

In [126]:
from datasets import load_dataset, load_metric
import sys
import os
import numpy as np
import datasets
import random
import pandas as pd
from IPython.display import display, HTML
from multiprocessing import  Pool
import torch
import torch.nn as nn
from torch.utils.data import BatchSampler, DataLoader, RandomSampler

from transformers import (
    BertConfig,
    BertModel,
    BertForMaskedLM,
    BertTokenizer,
    BertForTokenClassification,
    BertForSequenceClassification
)
from transformers import (
    DistilBertConfig,
    DistilBertModel,
    DistilBertTokenizer,
    DistilBertForTokenClassification,
    DistilBertForSequenceClassification
)




from lm_seqs_dataset import LmSeqsDataset

os.environ["http_proxy"] = "http://127.0.0.1:7890"
os.environ["https_proxy"] = "http://127.0.0.1:7890"

## define teacher and student Class

In [75]:
teacher_class = BertForSequenceClassification
student_class = DistilBertForSequenceClassification


In [3]:
task = "sst2"
model_checkpoint = "distilbert-base"
batch_size = 16

In [4]:
actual_task = "mnli" if task == "mnli-mm" else task
dataset = load_dataset("glue", actual_task)
metric = load_metric('glue', actual_task)



  0%|          | 0/3 [00:00<?, ?it/s]

In [5]:
#dataset = DatasetDict()
dataset

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 872
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 1821
    })
})

In [6]:
def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, datasets.ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
    display(HTML(df.to_html()))
    return df

In [7]:
df = show_random_elements(dataset["train"])

Unnamed: 0,sentence,label,idx
0,has some unnecessary parts and,negative,59298
1,"curiously , super troopers suffers because it does n't have enough vices to merit its 103-minute length .",negative,51223
2,spare dialogue and,positive,15292
3,effecting change,positive,62075
4,beautifully filmed and well acted ... but admittedly problematic in its narrative specifics .,positive,24161
5,"i 'm not sure which is worse : the poor acting by the ensemble cast , the flat dialogue by vincent r. nebrida or the gutless direction by laurice guillen .",negative,29612
6,-- that you should never forget,positive,1168
7,are canny and spiced with irony,positive,15001
8,what 's invigorating about,positive,32552
9,the power of the huston performance,positive,54947


## data prepocessing

In [8]:
tokenizer_name = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(tokenizer_name)
bos = tokenizer.special_tokens_map["cls_token"]  # `[CLS]`
sep = tokenizer.special_tokens_map["sep_token"]  # `[SEP]`

In [9]:


def parallelize_dataframe(df, func, n_cores=4):
    df_split = np.array_split(df, n_cores)
    pool = Pool(n_cores)
    df = pd.concat(pool.map(func, df_split))
    pool.close()
    pool.join()
    return df

In [10]:
def tokenize_ds_p(df):
    df['token_ids'] = df['sentence'].apply(
        lambda x:np.array(tokenizer.encode(x, add_special_tokens=False))
    )
    return df

## X_train, y_train

In [11]:
df = parallelize_dataframe(pd.DataFrame(dataset["train"]),
                            tokenize_ds_p, 
                            n_cores=40
                           )
print(df.head())
X_train = df[["token_ids"]]
y_train = df.label

                                            sentence  label  idx  \
0       hide new secretions from the parental units       0    0   
1               contains no wit , only labored gags       0    1   
2  that loves its characters and communicates som...      1    2   
3  remains utterly satisfied to remain the same t...      0    3   
4  on the worst revenge-of-the-nerds clichés the ...      0    4   

                                           token_ids  
0  [5342, 2047, 3595, 8496, 2013, 1996, 18643, 3197]  
1  [3397, 2053, 15966, 1010, 2069, 4450, 2098, 18...  
2  [2008, 7459, 2049, 3494, 1998, 10639, 2015, 22...  
3  [3464, 12580, 8510, 2000, 3961, 1996, 2168, 2802]  
4  [2006, 1996, 5409, 7195, 1011, 1997, 1011, 199...  


## X_test, y_test

In [12]:
df = parallelize_dataframe(pd.DataFrame(dataset["validation"]),
                            tokenize_ds_p, 
                            n_cores=40
                           )
print(df.head())
X_test = df[["token_ids"]]
y_test = df.label

                                            sentence  label  idx  \
0    it 's a charming and often affecting journey .       1    0   
1                 unflinchingly bleak and desperate       0    1   
2  allows us to hope that nolan is poised to emba...      1    2   
3  the acting , costumes , music , cinematography...      1    3   
4                  it 's slow -- very , very slow .       0    4   

                                           token_ids  
0  [2009, 1005, 1055, 1037, 11951, 1998, 2411, 12...  
1  [4895, 10258, 2378, 8450, 2135, 21657, 1998, 7...  
2  [4473, 2149, 2000, 3246, 2008, 13401, 2003, 22...  
3  [1996, 3772, 1010, 12703, 1010, 2189, 1010, 16...  
4  [2009, 1005, 1055, 4030, 1011, 1011, 2200, 101...  


## Teacher (from pre-trained)

In [76]:

teacher_name = "bert-base-uncased"
teacher_type = "bert"
student_type = "distilbert"
teacher = teacher_class.from_pretrained(teacher_name, output_hidden_states=True)

# "bert": (BertConfig, BertForMaskedLM, BertTokenizer),

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

## Student (naive )

In [100]:

student_config = "conf/distilbert-base-uncased.json"
stu_architecture_config = DistilBertConfig.from_pretrained(student_config)
print(stu_architecture_config)
student_pretrained_weights = None
student = student_class(stu_architecture_config)


DistilBertConfig {
  "activation": "gelu",
  "attention_dropout": 0.1,
  "dim": 768,
  "dropout": 0.1,
  "hidden_dim": 3072,
  "initializer_range": 0.02,
  "max_position_embeddings": 512,
  "model_type": "distilbert",
  "n_heads": 12,
  "n_layers": 6,
  "pad_token_id": 0,
  "qa_dropout": 0.1,
  "seq_classif_dropout": 0.2,
  "sinusoidal_pos_embds": true,
  "tie_weights_": true,
  "transformers_version": "4.12.5",
  "vocab_size": 30522
}



In [68]:
assert student.config.vocab_size == teacher.config.vocab_size
assert student.config.hidden_size == teacher.config.hidden_size
assert student.config.max_position_embeddings == teacher.config.max_position_embeddings

student_config = student.config
vocab_size = student.config.vocab_size

## Train distiller

In [16]:
X_train["token_ids"].values

array([array([ 5342,  2047,  3595,  8496,  2013,  1996, 18643,  3197]),
       array([ 3397,  2053, 15966,  1010,  2069,  4450,  2098, 18201,  2015]),
       array([ 2008,  7459,  2049,  3494,  1998, 10639,  2015,  2242,  2738,
        3376,  2055,  2529,  3267]),
       ...,
       array([ 2012, 10910,  1996, 10754,  1010,  4306,  1011, 24820,  3289,
        2009,  4520,  2005,  2993]),
       array([ 1037,  5776, 13972]),
       array([ 2023,  2047, 23769,  2571,  1997,  5005,  1010, 26865,  1998,
       28072,  2442,  2022,  1037,  3809, 20127,  2005,  1996,  2516,
        1012])], dtype=object)

In [17]:
dat = X_train["token_ids"].values
print(len(dat))
indices = np.array([len(x) for x in dat]) > 4
dat = dat[indices]

indices = np.array([len(x) for x in dat]) < 50
dat = dat[indices]
print(len(dat))

67349
48452


In [54]:

from lm_seqs_dataset import LmSeqsDataset
dataset = LmSeqsDataset(X_train["token_ids"].values,
                        y_train.values,
                        max_model_input_size=50,
                        min_model_input_size=3
                       )

sampler = RandomSampler(dataset)

dataloader = DataLoader(dataset=dataset,
                        batch_size=3,
                        # batch_sampler=sampler,
                        collate_fn=dataset.batch_sequences
                       )


12/03/2021 23:21:29 - INFO - utils - PID: 12594 -  Remove 147 too long (>50 tokens) sequences.
12/03/2021 23:21:29 - INFO - utils - PID: 12594 -  Remove 13365 too short (<=3 tokens) sequences.
12/03/2021 23:21:29 - INFO - utils - PID: 12594 -  53837 sequences


In [131]:
from torch.optim import AdamW
import math
from transformers import get_linear_schedule_with_warmup

In [139]:
# parameter
n_epoch = 3
gradient_accumulation_steps = 2
temperature = 2.0
alpha_ce = 0.5
alpha_clm =0.5
alpha_mse = 1e-3
alpha_ca = 0.1
learning_rate = 1e-2
adam_epsilon = 1e-08
weight_decay = 0.0
warmup_prop = 0.05

ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
lm_loss_fct = nn.CrossEntropyLoss()
mse_loss_fct = nn.MSELoss(reduction="sum")
cosine_loss_fct = nn.CosineEmbeddingLoss(reduction="mean")



num_steps_epoch = len(dataloader)
num_train_optimization_steps = (
    int(num_steps_epoch / gradient_accumulation_steps * n_epoch) + 1
)
warmup_steps = math.ceil(num_train_optimization_steps * warmup_prop)



no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {
        "params": [
            p for n, p in student.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad
        ],
        "weight_decay": weight_decay,
    },
    {
        "params": [
            p for n, p in student.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad
        ],
        "weight_decay": 0.0,
    },
]

optimizer = AdamW(
    optimizer_grouped_parameters,
    lr=learning_rate,
    eps=adam_epsilon,
    betas=(0.9, 0.98)
    )

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=num_train_optimization_steps
)

n_total_iter = 0
for _ in range(n_epoch):
    n_iter = 0
    for batch in dataloader:
        student_outputs = student(batch[0],output_hidden_states=True)
        teacher_outputs = teacher(batch[0],output_hidden_states=True)


        s_logits, s_h = student_outputs["logits"], student_outputs["hidden_states"]
        t_logits, t_h = teacher_outputs["logits"], teacher_outputs["hidden_states"]

        assert s_logits.size() == t_logits.size()


        loss_ce = (
            ce_loss_fct(
                nn.functional.log_softmax(s_logits / temperature, dim=-1),
                nn.functional.softmax(t_logits / temperature, dim=-1),
            )
            * (temperature) ** 2
        )
        loss = alpha_ce * loss_ce

        loss_clm = lm_loss_fct(s_logits, batch[1])

        loss += alpha_clm * loss_clm

        dim = s_h[-1].shape[0]
        slh = s_h[-1].view(dim,-1)
        tlh = t_h[-1].view(dim,-1)
        loss_cos = cosine_loss_fct(slh,
                                   tlh,
                                   target = slh.new(slh.size(0)).fill_(1)
                                  )
        loss += alpha_ca * loss_cos

        # Check for NaN
        if (loss != loss).data.any():
            logger.error("NaN detected")
            sys.exit(1)

        loss.backward()
        n_iter += 1
        n_total_iter += 1


        if n_iter % gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
            scheduler.step()
            break


In [119]:
target

tensor([1., 1., 1.])

In [108]:
s_h[-1].shape

torch.Size([3, 13, 768])

In [109]:
t_h[-1].shape

torch.Size([3, 13, 768])