In [None]:
%conda install transformers accelerate

In [None]:
import os
from transformers import (
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    AutoModelForSequenceClassification,
    DataCollatorWithPadding,
)
from datasets import load_dataset, load_metric

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import sagemaker
import boto3
import tarfile

In [None]:
print(torch.__version__)

In [None]:
sess = sagemaker.Session()
sagemaker_session_bucket=None
if sagemaker_session_bucket is None and sess is not None:
    sagemaker_session_bucket = sess.default_bucket()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client('iam')
    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']

sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")

In [None]:
model_dir = "./models/"
if not os.path.exists(model_dir):
    os.makedirs(model_dir)

In [None]:
student_id = "gpt2"
teacher_id = "gpt2-medium"
dataset_id = "glue"
dataset_config = "sst2"

In [None]:
def process(examples):
    tokenized_inputs = tokenizer(
        examples["sentence"], truncation=True, max_length=256, padding=True
    )
    return tokenized_inputs

In [None]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    acc = accuracy_metric.compute(predictions=predictions, references=labels)
    return {
        "accuracy": acc["accuracy"],
    }

In [None]:
class DistillationTrainingArguments(TrainingArguments):
    def __init__(self, *args, alpha=0.5, temperature=2.0, **kwargs):
        super().__init__(*args, **kwargs)
        self.alpha = alpha
        self.temperature = temperature

In [None]:
class DistillationTrainer(Trainer):
    def __init__(self, *args, teacher_model=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher = teacher_model
        # place teacher on same device as student
        self._move_model_to_device(self.teacher, self.model.device)
        self.teacher.eval()

    def compute_loss(self, model, inputs, return_outputs=False):
        # compute student output
#         print("NEW!!!!! compute student output")
#         for key, value in inputs.items():
#             if key != "outputs teacher":
#                 print(f"{key}: {value}")

        outputs_student = model(**inputs)
#         print("outputs student:", outputs_student)
        student_loss = outputs_student.loss
#         print("student loss:", student_loss)
        # compute teacher output
        with torch.no_grad():
            outputs_teacher = self.teacher(**inputs)
#             print("outputs teacher:", outputs_teacher)
        # assert size
        assert (
            outputs_student.logits.size() == outputs_teacher.logits.size()
        )
        

        # Soften probabilities and compute distillation loss
        loss_function = nn.KLDivLoss(reduction="batchmean")
        loss_logits = loss_function(
            F.log_softmax(
                outputs_student.logits / self.args.temperature, dim=-1
            ),
            F.softmax(
                outputs_teacher.logits / self.args.temperature, dim=-1
            ),
        ) * (self.args.temperature**2)
        # Return weighted student loss
        loss = (
            self.args.alpha * student_loss
            + (1.0 - self.args.alpha) * loss_logits
        )
        return (loss, outputs_student) if return_outputs else loss

In [None]:
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_id)
student_tokenizer = AutoTokenizer.from_pretrained(student_id)

In [None]:
sample = "Here's our sanity check."

assert teacher_tokenizer(sample) == student_tokenizer(sample), (
    "Tokenizers need to have the same output! "
    f"{teacher_tokenizer(sample)} != {student_tokenizer(sample)}"
)

In [None]:
del teacher_tokenizer
del student_tokenizer

In [None]:
tokenizer = AutoTokenizer.from_pretrained(teacher_id)
# tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token

In [None]:
print(len(tokenizer.vocab))

In [None]:
dataset = load_dataset(dataset_id, dataset_config)

tokenized_dataset = dataset.map(process, batched=True)
tokenized_dataset = tokenized_dataset.rename_column("label", "labels")

print(tokenized_dataset["test"].features)

In [None]:
labels = tokenized_dataset["train"].features["labels"].names
num_labels = len(labels)
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

In [None]:
print(num_labels)

In [None]:
# define training args
training_args = DistillationTrainingArguments(
    output_dir=model_dir,
    num_train_epochs=1,
    auto_find_batch_size=True,
#     per_device_train_batch_size=2,
#     per_device_eval_batch_size=2,
    fp16=True,
    learning_rate=6e-5,
    seed=8855,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    alpha=0.5,
    temperature=4.0,
)

In [None]:
teacher_model = AutoModelForSequenceClassification.from_pretrained(
    teacher_id,
    num_labels=num_labels,
    id2label=id2label,
    label2id=label2id,
)

student_model = AutoModelForSequenceClassification.from_pretrained(
    student_id,
    num_labels=num_labels,
    id2label=id2label,
    label2id=label2id,
)

In [None]:
print(student_model)

In [None]:
teacher_model.config.pad_token_id = tokenizer.pad_token_id
student_model.config.pad_token_id = tokenizer.pad_token_id

In [None]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [None]:
accuracy_metric = load_metric("accuracy")

In [None]:
trainer = DistillationTrainer(
    student_model,
    training_args,
    teacher_model=teacher_model,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()
trainer.save_model(model_dir)

In [None]:
final_model = AutoModelForSequenceClassification.from_pretrained(model_dir)

In [None]:
final_model.eval()

In [None]:
input_text = "Hello, my dog is cute"
max_length = 1024

In [None]:
tokens = tokenizer([input_text], truncation=True, padding='max_length', max_length=max_length, return_tensors="pt")

In [None]:
final_model.to(tokens['input_ids'].device)

In [None]:
print(type(tokens['input_ids']))

In [None]:
print(type(tokens))

In [None]:
model_trace = torch.jit.trace(final_model, tokens['input_ids'])

In [None]:
model_trace.save('model.pth')

In [None]:
s3 = boto3.client('s3')

In [None]:
s3_model_path = 'results'
tar_gz_file = "model.tar.gz"

In [None]:
with tarfile.open(tar_gz_path, "w:gz") as tar:
    tar.add(model_dir, arcname=os.path.basename(model_dir))

print(f"Compressed and archived {model_dir} to {tar_gz_path}")

In [None]:
model_file = os.path.join(".", tar_gz_file)
s3_file = f"{sess.default_bucket()}/{s3_model_path}/{tar_gz_file}"

In [None]:
try:
    s3.upload_file(model_file, sess.default_bucket(), f"{s3_model_path}/{tar_gz_file}")
    print(f'Uploaded {model_file} to {s3_file}')
except Exception as e:
    print(f"Error occurred while uploading file {model_file}, {e}")