In [None]:
%load_ext autoreload
%autoreload 2
%xmode Plain

from precompiled import *

In [None]:
from data import *

tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")

if tokenizer.pad_token is None:
    existing_special_tokens = list(tokenizer.special_tokens_map_extended.values())
    tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]})

dataset = load_dataset("openwebtext", trust_remote_code=True)

raw_datasets = dataset

train_dataset, tmp = dataset['train'].train_test_split(test_size=0.001).values()
val_dataset, test_dataset = tmp.train_test_split(test_size=0.2).values()

def preprocess_function(data):
    inputs = tokenizer(data[text_column], max_length=None, truncation=True, padding=True)
    inputs['labels'] = inputs['input_ids'].copy()
    return inputs

train_dataset = train_dataset.map(
    preprocess_function,
    batched=True,
    num_proc=os.cpu_count(),
    remove_columns=train_dataset.column_names,
    load_from_cache_file=True,
    desc="Running tokenizer on dataset",
)

val_dataset = val_dataset.map(
    preprocess_function,
    batched=True,
    num_proc=os.cpu_count(),
    remove_columns=val_dataset.column_names,
    load_from_cache_file=True,
    desc="Running tokenizer on dataset",
)

original_test_dataset = test_dataset

test_dataset = test_dataset.map(
    preprocess_function,
    batched=True,
    num_proc=os.cpu_count(),
    remove_columns=test_dataset.column_names,
    load_from_cache_file=True,
    desc="Running tokenizer on dataset",
)

train_dataset, val_dataset, test_dataset, original_test_dataset

In [None]:
fine_tuned = GPT2LMHeadModel.from_pretrained(model_name)

In [None]:
print_perplexity(fine_tuned, tokenizer, original_test_dataset)

In [None]:
# Distillation

from distillation import *

DIR_DISTIL = create_work_dir("tmp/pred/distillation/")
last_checkpoint = get_last_checkpoint(DIR_DISTIL)

teacher = fine_tuned
# teacher has no need to update itself
teacher.requires_grad_(False)
teacher.eval()

student = create_student(teacher)
print_size(student)

distiller_model = Distiller(teacher=teacher, student=student)

training_args = TrainingArguments(
    output_dir=DIR_DISTIL,
    num_train_epochs=2,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=4,
    # eval_accumulation_steps=1,
    warmup_steps=500,
    eval_strategy="steps",
    save_strategy="steps",
    save_steps=1000,
    save_total_limit=2,
    learning_rate=5e-5,
    logging_dir=DIR_DISTIL,
    logging_steps=200,
    logging_first_step=True,
    log_level="warning",
    save_safetensors=False,
    fp16=torch.cuda.is_available(),

    # max_steps=2,  # TODO
)

trainer = Trainer(
    model=distiller_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

trainer.train(resume_from_checkpoint=last_checkpoint)

# release
del trainer
release()

print_perplexity(student, tokenizer, original_test_dataset)

In [None]:
# pruning
from pruning import *

model = student
to_prune = convert_model(model)

# scheduler.progress = 1.0
# scheduler.attn_threshold = scheduler.attn_final_threshold
# scheduler.ffn_threshold = scheduler.ffn_final_threshold

data_collator = DataCollatorWithPadding(tokenizer)

DIR_PRUNE = create_work_dir("tmp/pred/pruning/")
last_checkpoint = get_last_checkpoint(DIR_PRUNE)
print(f"last checkpoint: {last_checkpoint}")
training_args = TrainingArguments(
    output_dir=DIR_PRUNE,
    overwrite_output_dir=True,
    num_train_epochs=p_train_config.epochs,
    per_device_train_batch_size=p_train_config.batch_size,
    eval_strategy="steps",
    save_strategy='steps',
    save_steps=1000,
    save_total_limit=2,
    logging_dir=DIR_PRUNE,
    logging_first_step=True,
    logging_strategy="steps",
    logging_steps=100,   # predict_with_generate=True works only in evaluation phase
    log_level="warning",
    fp16=torch.cuda.is_available(),
    gradient_accumulation_steps=p_train_config.gradient_accumulation_steps,
    weight_decay=0.01,
    warmup_steps=500,
    save_safetensors=False,  # we share masked scores between K,Q,V and between fc1,fc2.
    prediction_loss_only=True,

    # max_steps=2,  # TODO
)

trainer = Trainer(
    model=to_prune,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    callbacks=[SchedulerUpdateCallback()],
)

trainer.train(resume_from_checkpoint=last_checkpoint)

# release
del trainer
release()

print_pruning_density(to_prune)
FFN_freeze(to_prune)
print_linear_density_all(to_prune)
FFN_prune_zeros(to_prune)
ATTN_prune_zeros(to_prune)
print_size(to_prune)

print_perplexity(to_prune, tokenizer, original_test_dataset)

# print(to_prune)

In [None]:
# PTQ, QAT
from quantization import *

model = to_prune
ptq = get_ptq_model(model)

DIR_QUANTIZE = create_work_dir("tmp/pred/quantization/")
last_checkpoint = get_last_checkpoint(DIR_QUANTIZE)
print(f"last checkpoint: {last_checkpoint}")
training_args = TrainingArguments(
    output_dir=DIR_QUANTIZE,
    overwrite_output_dir=True,
    num_train_epochs=q_train_config.epochs,
    per_device_train_batch_size=q_train_config.batch_size,
    eval_strategy="steps",
    save_strategy='steps',
    save_steps=1000,
    save_total_limit=1,
    logging_dir=DIR_QUANTIZE,
    logging_first_step=True,
    logging_strategy="steps",
    logging_steps=100,
    log_level="warning",
    fp16=torch.cuda.is_available(),
    gradient_accumulation_steps=q_train_config.gradient_accumulation_steps,
    weight_decay=0.01,
    warmup_steps=500,
    prediction_loss_only=True,

    # max_steps=2,  # TODO
)

trainer = Trainer(
    model=ptq,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

trainer.train(resume_from_checkpoint=last_checkpoint)

# release
del trainer
release()

print_perplexity(ptq, tokenizer, original_test_dataset)

freeze(ptq)
print_size(ptq)

FINAL_FILE = DIR_QUANTIZE + "final.pth"
torch.save(ptq, FINAL_FILE)  # TODO: save state dict only and save the scale dict in json.

In [None]:
model = torch.load(FINAL_FILE)
thaw(model)

def print_prediction(model, text, name: str):
    input_ids = tokenizer.encode(text, return_tensors="pt", truncation=True, padding=True, max_length=model.config.n_positions)
    input_ids = input_ids[0, :100].unsqueeze(0)
    with torch.no_grad():
        output = model.generate(
            input_ids.to(device),
            min_new_tokens=50,
            max_new_tokens=200,
            num_return_sequences=1,
            no_repeat_ngram_size=2,
            top_k=50,
            top_p=0.95,
            temperature=1.0,
        )
    gen_prediction = tokenizer.decode(output[0], skip_special_tokens=True)

    score1 = rouge_score.compute(predictions=[gen_prediction], references=[text])
    scores1 = {k: round(v, 4) for k, v in score1.items()}
    score2 = Perplexity.compute([text], model, tokenizer, device=device,
                                batch_size=1, max_length=model.config.n_positions)
    print(f">>> {name}\t: Generated Text: [{gen_prediction}]")
    print(f"{scores1} {score2}")

fine_tuned.eval()
model.eval()

random_samples = random.sample(range(len(original_test_dataset)), 5)
for idx in random_samples:
    text = original_test_dataset[idx]['text'][:-100]
    print(f"=================== {idx} =====================")
    print_prediction(fine_tuned, text, "origin")
    print("===========================")
    print_prediction(model, text, "opt   ")
    print()



In [None]:
print_perplexity(fine_tuned, tokenizer, original_test_dataset)
print_perplexity(model, tokenizer, original_test_dataset)