In [1]:
import torch
import gc
import random
import numpy

# empty cache
gc.collect()
torch.cuda.empty_cache()

# set the seed
seed = 28
torch.manual_seed(seed)
numpy.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# device check
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


### Load the model

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = "HuggingFaceTB/SmolLM2-135M-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
policy_model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
policy_model.train()

  from .autonotebook import tqdm as notebook_tqdm


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(49152, 576, padding_idx=2)
    (layers): ModuleList(
      (0-29): 30 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=576, out_features=576, bias=False)
          (k_proj): Linear(in_features=576, out_features=192, bias=False)
          (v_proj): Linear(in_features=576, out_features=192, bias=False)
          (o_proj): Linear(in_features=576, out_features=576, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=576, out_features=1536, bias=False)
          (up_proj): Linear(in_features=576, out_features=1536, bias=False)
          (down_proj): Linear(in_features=1536, out_features=576, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((576,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((576,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((576,), eps=1e-05)
    (rotary_emb)

In [3]:
from transformers import AutoModelForSequenceClassification

rm_name = "MilyaShams/SmolLM2-135M-Instruct-Reward"
reward_model = AutoModelForSequenceClassification.from_pretrained(rm_name, num_labels=1).to(device)
reward_model.eval()

LlamaForSequenceClassification(
  (model): LlamaModel(
    (embed_tokens): Embedding(49152, 576, padding_idx=2)
    (layers): ModuleList(
      (0-29): 30 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=576, out_features=576, bias=False)
          (k_proj): Linear(in_features=576, out_features=192, bias=False)
          (v_proj): Linear(in_features=576, out_features=192, bias=False)
          (o_proj): Linear(in_features=576, out_features=576, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=576, out_features=1536, bias=False)
          (up_proj): Linear(in_features=576, out_features=1536, bias=False)
          (down_proj): Linear(in_features=1536, out_features=576, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((576,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((576,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((576,), eps=1e-05)
  

### Load and prepare dataset

In [4]:
from datasets import load_dataset

dataset = load_dataset("esfrankel17/HelpSteer2_binarized", split='average_rating_split')

In [5]:
dataset

Dataset({
    features: ['prompt', 'chosen', 'chosen_rating', 'rejected', 'rejected_rating'],
    num_rows: 8678
})

In [6]:
prompts = [len(item["prompt"]) for item in dataset]

n_long = 0
n_ok = 0
n_good = 0

for i in prompts:
    if i > 512:
        n_long += 1
    if i < 512:
        n_ok += 1
    if i < 200:
        n_good += 1

print(n_long, n_ok, n_good)

3361 5310 3996


Let's leave only medium level f length, not so long prompts, because in the original paper they are restrict the both models to max context length equals 512.

In [7]:
filtered_dataset = dataset.filter(lambda example: len(example['prompt']) < 200)

Filter: 100%|██████████| 8678/8678 [00:00<00:00, 32878.32 examples/s]


In [8]:
filtered_dataset

Dataset({
    features: ['prompt', 'chosen', 'chosen_rating', 'rejected', 'rejected_rating'],
    num_rows: 3996
})

In [9]:
filtered_dataset = filtered_dataset.train_test_split(test_size=0.2, shuffle=True, seed=seed)

In [10]:
filtered_dataset

DatasetDict({
    train: Dataset({
        features: ['prompt', 'chosen', 'chosen_rating', 'rejected', 'rejected_rating'],
        num_rows: 3196
    })
    test: Dataset({
        features: ['prompt', 'chosen', 'chosen_rating', 'rejected', 'rejected_rating'],
        num_rows: 800
    })
})

In [12]:
filtered_dataset["train"][0]

{'prompt': "What is the equivalent resistance between two nodes separated by a knight's move of an infinite square grid of resistors of resistance R?",
 'chosen': [{'content': "What is the equivalent resistance between two nodes separated by a knight's move of an infinite square grid of resistors of resistance R?",
   'role': 'user'},
  {'content': "The equivalent resistance between two nodes separated by a knight's move of an infinite square grid of resistors of resistance R can be calculated using the following steps:\n\n1. Draw a circuit diagram of the square grid and the two nodes. A knight's move is a jump that moves two squares horizontally or vertically, and two squares diagonally. The two nodes are connected by a knight's move.\n\n2. Apply the series-parallel rule to the circuit. The series-parallel rule states that the equivalent resistance of a series circuit is the sum of the individual resistances, while the equivalent resistance of a parallel circuit is the reciprocal of t

In [13]:
from torch.utils.data import DataLoader


def collate_prompts(batch):
    return [item['prompt'] for item in batch]

batch_size = 16  # на P100 в kaggle
dataloader_train = DataLoader(filtered_dataset["train"], batch_size=batch_size, collate_fn=collate_prompts)
dataloader_val = DataLoader(filtered_dataset["test"], batch_size=batch_size, collate_fn=collate_prompts)

In [16]:
for batch in dataloader_train:
    prompts = batch
    for prompt in prompts:
        print(prompt)
    break

What is the equivalent resistance between two nodes separated by a knight's move of an infinite square grid of resistors of resistance R?
Lets play Dungeons and Dragons. I'm a halfling rogue and you're the DM.
You are a helpful teacher
Write the parable of the lost sheep in the style of Gordon Ramsay yelling at the shepherd.


### Training

In [17]:
from torch.utils.tensorboard import SummaryWriter
from torch.optim import AdamW
import torch.nn.functional as F
from tqdm import tqdm


num_epochs = 2
learning_rate = 5e-5
total_rewards = []

optimizer = AdamW(policy_model.parameters(), lr=learning_rate)
baseline = None

log_dir = "runs/REINFORCE_with_baseline_logs"
writer = SummaryWriter(log_dir)

In [18]:
def generate_response_and_logprob(prompt):
    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        max_length=512,
        truncation=True,
        padding=True,
    ).to(device)
    
    with torch.no_grad():
        output_ids = policy_model.generate(
            **inputs,
            max_length=512,
            do_sample=True
        )
    
    response = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    outputs = policy_model(output_ids[:, :-1])
    logits = outputs.logits
    
    log_probs = F.log_softmax(logits, dim=-1)
    target_tokens = output_ids[:, 1:]
    token_log_probs = torch.gather(log_probs, dim=-1, index=target_tokens.unsqueeze(-1)).squeeze(-1)

    total_log_prob = token_log_probs.sum(dim=-1)
    return response, total_log_prob[0]

In [19]:
def compute_reward(prompt, response):
    input_text = prompt + "\n" + response
    inputs = tokenizer(
        input_text, 
        return_tensors="pt",
        max_length=512,
        truncation=True,
        padding=True).to(device)
    
    with torch.no_grad():
        outputs = reward_model(**inputs)
    
    logits = outputs.logits
    reward = torch.sigmoid(logits)[0, 0].item()  # Применяем сигмоиду, чтобы получить вероятность
    # reward = logits[0, 0].item()  # или можно использовать logits напрямую
    return reward

In [20]:
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    policy_model.train()
    epoch_loss = 0.0
    epoch_advantages = []
    num_batches = 0
    
    # Training
    for batch in tqdm(dataloader_train, desc="Training"):
        prompts = batch
        batch_loss = 0.0
        optimizer.zero_grad()
        
        for prompt in prompts:
            response, log_prob = generate_response_and_logprob(prompt)

            reward = compute_reward(prompt, response)
            total_rewards.append(reward)
            
            baseline = sum(total_rewards) / len(total_rewards)
            advantage = reward - baseline
            epoch_advantages.append(advantage)
            
            loss = -advantage * log_prob
            batch_loss += loss
        
        batch_loss = batch_loss / len(prompts)
        batch_loss.backward()
        optimizer.step()
        
        epoch_loss += batch_loss.item()
        num_batches += 1
    
    avg_epoch_loss = epoch_loss / num_batches
    avg_advantage = sum(epoch_advantages) / len(epoch_advantages) if epoch_advantages else 0.0
    
    print(f"Epoch {epoch+1} loss: {avg_epoch_loss:.4f}, Advantage mean: {avg_advantage:.4f}")
    writer.add_scalar("Train/Loss", avg_epoch_loss, epoch + 1)
    writer.add_scalar("Train/Advantage Mean", avg_advantage, epoch + 1)
    writer.add_scalar("Train/Baseline", baseline, epoch + 1)

    # Validation
    policy_model.eval()
    total_reward = 0.0
    num_val = 0

    with torch.no_grad():
        for prompt in tqdm(dataloader_val, desc="Validation"):
            response, _ = generate_response_and_logprob(prompt)
            reward = compute_reward(prompt, response)
            total_reward += reward
            num_val += 1
    avg_reward = total_reward / num_val if num_val > 0 else 0.0
    print(f"Validation average reward: {avg_reward:.4f}")

    writer.add_scalar("Validation/Average Reward", avg_reward, epoch + 1)

print("Training complete")
writer.close()

Epoch 1/2


Training:   0%|          | 2/799 [01:21<9:02:35, 40.85s/it]


KeyboardInterrupt: 