In [1]:
import torch
from datasets import DatasetDict, Dataset, load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, BitsAndBytesConfig

In [None]:
teacher_model_name = 'meta-llama/Llama-3.1-8B-Instruct'
student_model_name = 'meta-llama/Llama-3.2-1B'

bnb_config = BitsAndBytesConfig(
  load_in_4bit = True,
  bnb_4bit_quant_type = 'nf4',
  bnb_4bit_compute_dtype = torch.bfloat16
)

teacher_model = AutoModelForCausalLM.from_pretrained(
  teacher_model_name,
  quantization_config = bnb_config,
  device_map = 'auto'
)

student_model = AutoModelForCausalLM.from_pretrained(
  student_model_name,
  quantization_config = bnb_config,
  device_map = 'auto'
)


dataset format
```
DatasetDict({
    train: Dataset({
        features: ['instruction', 'input', 'output', 'text'],
        num_rows: 52002
    })
})
```

In [2]:
dataset = load_dataset('tatsu-lab/alpaca')

def format_batch(batch):
  if batch['input']:
    prompt = f"Instruction: {batch['instruction']}\nInput: {batch['input']}\nOutput:"
  else:
    prompt = f"Instruction: {batch['instruction']}\nOutput"
  return {"text":prompt}

dataset = dataset["train"].map(format_batch)

In [3]:
dataset[3]

{'instruction': 'How can we reduce air pollution?',
 'input': '',
 'output': 'There are a number of ways to reduce air pollution, such as shifting to renewable energy sources, encouraging the use of public transportation, prohibiting the burning of fossil fuels, implementing policies to reduce emissions from industrial sources, and implementing vehicle emissions standards. Additionally, individuals can do their part to reduce air pollution by reducing car use, avoiding burning materials such as wood, and changing to energy efficient appliances.',
 'text': 'Instruction: How can we reduce air pollution?\nOutput'}

In [4]:
dataset

Dataset({
    features: ['instruction', 'input', 'output', 'text'],
    num_rows: 52002
})

In [5]:
dataset[0]['instruction']

'Give three tips for staying healthy.'

In [6]:
teacher_model_name = 'meta-llama/Llama-3.1-8B-Instruct'
student_model_name = 'meta-llama/Llama-3.2-1B'

teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)
student_tokenizer = AutoTokenizer.from_pretrained(student_model_name)
teacher_tokenizer.pad_token = teacher_tokenizer.eos_token

def tokenize(batch):
  return teacher_tokenizer(batch["text"], padding = "max_length", truncation = True, max_length = 512)

tokenized_dataset = dataset.map(tokenize, batched = True).map(tokenize, batched = True)


**Forward Pass With Teacher(NO GRADIENTS)**

In [7]:
tokenized_dataset

Dataset({
    features: ['instruction', 'input', 'output', 'text', 'input_ids', 'attention_mask'],
    num_rows: 52002
})

In [12]:
len(tokenized_dataset[3]['input_ids'])
# tokenized_dataset[0]['attention_mask']
tokenized_dataset[3:10]['instruction']
# len(tokenized_dataset[3]['input_ids']) #512
# tokenized_dataset[3:5]

['How can we reduce air pollution?',
 'Describe a time when you had to make a difficult decision.',
 'Identify the odd one out.',
 'Explain why the following fraction is equivalent to 1/4',
 'Write a short story in third person narration about a protagonist who has to make an important career decision.',
 'Render a 3D model of a house',
 'Evaluate this sentence for spelling and grammar mistakes']

In [18]:
len(tokenized_dataset)

52002

**Boot Strap Sampling**

In [16]:
from torch.utils.data import Dataset, DataLoader, Sampler
import numpy as np

class BootstrapSampler(Sampler):
  def __init__(self,data_source, num_samples = None, generator = None):
    self.data_source = data_source
    self.num_samples = num_samples if num_samples is not None else len(data_source)
    self.generator = generator

  def __iter__(self):
    indices = np.random.choice(len(self.data_source), size = self.num_samples,  replace = True )
    return iter(indices)

In [17]:
data = BootstrapSampler(tokenized_dataset, num_samples = 100)


**Forward Pass With Teacher(NO GRADIENTS)**

In [None]:
batch = tokenized_dataset[0]
input_ids = torch.tensor(batch['input_ids']).unsqueeze(0).to(teacher_model.device)
mask = torch.tensor(batch['attention_mask']).unsqueeze(0).to(teacher_model.device)

with torch.no_grad():
  out = teacher_model(input_ids, attention_mask = mask)
  teacher_logits = out.logits #[batch, seq_len, vocab]

In [None]:
tokenized_dataset.shape # (52002, 6) -> (number_roows, num_features)

(52002, 6)

**Distillation_loss**

In [None]:
from torch.nn import functional as F

student_out = student_model(input_ids, attention_mask=mask)
student_logits = student_out.logits

# KL divergence loss
loss = F.kl_div(
  input=F.log_softmax(student_logits, dim=-1),
  target=F.softmax(teacher_logits, dim=-1),
  reduction="batchmean"
)


**Train Loader**

In [16]:
import torch
from torch.utils.data import DataLoader

def collate_fn(batch):
  return{
    "input_ids": torch.tensor([x["input_ids"] for x in batch]),
    "attention_mask": torch.tensor([x["attention_mask"] for x in batch])
  }

train_loader = DataLoader(tokenized_dataset, batch_size = 4, shuffle = True, collate_fn=collate_fn)


In [17]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [None]:
from torch.optim import Adam

optimizer = Adam(student_model.parameters(), lr=1e-5)

i = 0
for batch in train_loader:
  input_ids = batch['input_ids'].to(device)
  mask = batch['attention_mask'].to(device)

  with torch.no_grad():
    teacher_out = teacher_model(input_ids, attention_mask = mask)
    teacher_logits = teacher_out.logits

  student_out = student_model(input_ids, attention_mask = mask)
  students_logits = student_out.logits

# KL divergence loss
loss = F.kl_div(
  input=F.log_softmax(student_logits, dim=-1),
  target=F.softmax(teacher_logits, dim=-1),
  reduction="batchmean"
)

# backprop
optimizer.zero_grad()
loss.backward()
optimizer.step()

if i%20 == 0:
  torch.save(student_model.state_dict(), f"student_checkpoint{i}.pt")
i = i + 1


**Save student Model**

In [None]:
student_model.save_pretrained("./student_model")
student_tokenizer.save_pretrained("./student_model")