<a href="https://colab.research.google.com/github/romlingroup/flatpack-ai/blob/main/notebooks/flatpack_ai_classroom.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# flatpack.ai - Classroom

In [None]:
!pip install torch transformers

## Knowledge distillation

In [None]:
%cd /content

from transformers import AutoTokenizer, AutoModelForCausalLM, AdamW
import os
import torch.nn.functional as F
import torch

tokenizer = AutoTokenizer.from_pretrained("gpt2-medium")
tokenizer.pad_token = tokenizer.eos_token

teacher_model = AutoModelForCausalLM.from_pretrained("gpt2-medium")
student_model = AutoModelForCausalLM.from_pretrained("gpt2")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
teacher_model = teacher_model.to(device)
student_model = student_model.to(device)

file_path = "/content/input.txt"

if not os.path.exists(file_path):
    !wget https://github.com/karpathy/char-rnn/raw/master/data/tinyshakespeare/input.txt

with open(file_path, 'r') as f:
    shakespeare_text = f.read()

batch_size = 1
sequence_length = 1024
chunk_size = sequence_length - 1
num_chunks = len(shakespeare_text) // chunk_size

shakespeare_tokens = torch.zeros([num_chunks, sequence_length], dtype=torch.long, device=device)
for i in range(num_chunks):
    chunk = shakespeare_text[i * chunk_size:(i + 1) * chunk_size]
    tokens = tokenizer(chunk, return_tensors="pt", padding='max_length', max_length=sequence_length)["input_ids"]
    shakespeare_tokens[i] = tokens[0]

num_batches = len(shakespeare_tokens) // batch_size

optimizer = torch.optim.AdamW(student_model.parameters(), lr=0.001)

num_epochs = 10
temperature = 2.0
continue_training = True

checkpoint_dir = "/content/checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

if continue_training:
    checkpoint_path = os.path.join(checkpoint_dir, "checkpoint_epoch_3.pth")
    checkpoint = torch.load(checkpoint_path)
    student_model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    last_epoch = checkpoint['epoch']
else:
    last_epoch = -1

print(f"Starting training for {num_epochs} epochs...")
for epoch in range(last_epoch + 1, num_epochs):
    print(f"Epoch {epoch+1} started.")

    for batch_idx in range(num_batches):
        optimizer.zero_grad()

        start_idx = batch_idx * batch_size
        end_idx = start_idx + batch_size
        batch_tokens = shakespeare_tokens[start_idx:end_idx]

        teacher_outputs = teacher_model(batch_tokens, return_dict=True)
        student_outputs = student_model(batch_tokens, return_dict=True)

        distill_loss = F.kl_div(
            F.log_softmax(student_outputs.logits / temperature, dim=-1),
            F.softmax(teacher_outputs.logits / temperature, dim=-1),
            reduction='batchmean'
        )

        original_loss = F.cross_entropy(student_outputs.logits[:, :-1].contiguous().view(-1, student_outputs.logits.size(-1)), batch_tokens[:, 1:].contiguous().view(-1))

        loss = distill_loss + original_loss
        loss.backward()

        torch.nn.utils.clip_grad_norm_(student_model.parameters(), 1.0)

        optimizer.step()

        print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx+1}/{num_batches}, Loss: {loss.item()}")

    checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch+1}.pth")
    torch.save({
        'epoch': epoch,
        'model_state_dict': student_model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss.item(),
    }, checkpoint_path)

if 'shakespeare_text' in locals():
    del shakespeare_text
if 'shakespeare_tokens' in locals():
    del shakespeare_tokens
if 'teacher_outputs' in locals():
    del teacher_outputs
if 'student_outputs' in locals():
    del student_outputs
torch.cuda.empty_cache()

## Inference from checkpoint

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, AdamW
import os
import torch.nn.functional as F
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained("gpt2-medium")
tokenizer.pad_token = tokenizer.eos_token

checkpoint_dir = "/content/checkpoints"
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint_epoch_6.pth")
checkpoint = torch.load(checkpoint_path)

student_model = AutoModelForCausalLM.from_pretrained("gpt2")
student_model.load_state_dict(checkpoint['model_state_dict'])
student_model = student_model.to(device)
student_model.eval()

sequence_length = 1024

input_text = "Once upon a time"
input_tokens = tokenizer.encode(input_text, return_tensors="pt").to(device)

with torch.no_grad():
    output_tokens = student_model.generate(
        input_tokens,
        max_length=150,
        num_return_sequences=1,
        pad_token_id=tokenizer.eos_token_id,
        temperature=2.0,
        top_k=40,
        top_p=0.90,
        do_sample=True
    )

output_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
print(output_text)

## Upload model to Hugging Face

In [None]:
!huggingface-cli login

In [None]:
import os
folder_path = "/content/student_model"
if os.path.exists(folder_path) and os.path.isdir(folder_path):
    !rm -r $folder_path

In [None]:
save_directory = folder_path
student_model.save_pretrained(save_directory)
tokenizer.save_pretrained(save_directory)

In [None]:
from huggingface_hub import HfApi

user = "romlingroup"
model_repo_name = "gpt2-shakespeare-student"

api = HfApi()

whoami = api.whoami(token=api.token)
print(f"Logged in as user: {whoami['name']}")

repo_id = f"{user}/{model_repo_name}"
api.create_repo(token=api.token, repo_id=repo_id, private=False, exist_ok=True)

In [None]:
repo_id = f"{user}/{model_repo_name}"
api.upload_folder(
    folder_path=save_directory,
    repo_id=repo_id,
    repo_type="model"
)

In [None]:
import shutil
import os

temp_dir = "/content/temp_upload_dir"
os.makedirs(temp_dir, exist_ok=True)

shutil.move(checkpoint_dir, os.path.join(temp_dir, "checkpoints"))

repo_id = f"{user}/{model_repo_name}"
api.upload_folder(
    folder_path=temp_dir,
    repo_id=repo_id,
    repo_type="model"
)

shutil.move(os.path.join(temp_dir, "checkpoints"), checkpoint_dir)
os.rmdir(temp_dir)