In [1]:
import numpy as np
import torch
import time
import os
import gc
from torch.nn import functional as F
from torch.autograd import profiler as profiler
from tokenizer import NaiveTokenizer
from transformer import DecoderTrans 
from dataclasses import dataclass

%load_ext autoreload
%autoreload 2

torch.set_printoptions(linewidth=10000)
os.environ['CUDA_VISIBLE_DEVICES']="4,5,6,7"

In [2]:
# Read the input dataset
with open('input.txt') as f:
    data = f.read()

naive_tokenizer = NaiveTokenizer(data)
vocab_size = naive_tokenizer.vocab_size
# Generate the encoded dataset
dataset = naive_tokenizer.encode(data)

# Build the train/test datasets
train_data_sz = int(len(dataset)*0.8)
train_data = dataset[:train_data_sz]
test_data = dataset[train_data_sz:]

In [4]:
## Meta parameters 
seed = 2580 # Fixed random seed
torch.manual_seed(seed)

#Optimizer parameters
learning_rate = 3e-4 # Learning rate for the optimizer
nb_iter = 1000 # Number of iterations for the optimizer
batch_size = 64 # Number of blocks in a batch

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Device is ', device)

#Transformer  parameters
@dataclass
class TransformerParams:
    vocab_size: int
    n_embedding: int = 384 # Embedding size
    n_decoder_blocks: int = 3 # Number of decoder blocks in the transformer
    n_mha: int = 1 # Number of multi-head attention layers in each decoder block
    n_heads: int = 2 # Number of heads per multi-head attention layer
    individual_head_size: int = n_embedding // n_heads # Individual head size
    block_size: int = 64 # Number of tokens in a block (aka context)
    
transformer_params = TransformerParams(vocab_size=vocab_size)

Device is  cuda


In [5]:
# Helper function to get a batch of the split
def get_batch(split):
    data = train_data if split == 'train' else test_data
    ix = np.random.randint(0, len(data)-transformer_params.block_size, (batch_size,))

    batch_x = torch.stack([torch.tensor(data[i:i+transformer_params.block_size]) for i in ix], dim=0).to(device)
    batch_y = torch.stack([torch.tensor(data[i+1:i+transformer_params.block_size+1]) for i in ix], dim=0).to(device)

    return batch_x, batch_y


In [6]:
transformer = DecoderTrans(transformer_params)
transformer = transformer.to(device)


In [None]:
## Training loop for the base transformer

optimizer = torch.optim.AdamW(transformer.parameters(), lr=learning_rate) 
transformer.train()

nb_iter = 5000
for i_iter in range(nb_iter):
    x, y = get_batch('train')
    
    y_pred, loss  = transformer('forward', x = x, y = y)

    optimizer.zero_grad()
    
    # Handle the case where the loss is not a scalar (due to DataParallel, e.g.)
    if loss.ndim > 0:
        loss = loss.mean()
    loss.backward()
    
    optimizer.step()

    # Print the loss
    if i_iter % 10 == 0:
        print(f"Iter {i_iter}, Loss {loss.item()}")

        

In [8]:
## Rule based rewards for the generated strings
def reward_short_sentences(tested_string):
    split = str.split(tested_string, '\n')
    return (sum([len(s) > 10 and len(s) < 20 for s in split]))

def reward_shouting(tested_string):
    return sum([c.isupper() for c in tested_string])

# Full reward function, let us only use the shouting reward for now
def full_reward(tested_string):
    return reward_shouting(tested_string)


Let us compute the "Old policy" $\pi_{\theta_{old}}(o_t\vert q, o_{<t})$

In [9]:
## Advantage computation : Process supervision and Outcome supervision as per DeepseekMath paper
# In practice, we will use OS (DeepseekR1 paper GRPO loss is does not include the inner sum related to process supervision ?)

def compute_process_supervision_advantage(generated_tokens, base_context, reward_fn):
    context_size = base_context.shape[1]
    nb_generated_tokens = generated_tokens.shape[1] - context_size
    #Compute the rewards for every shots and every increments of tokens
    R = []
    for i_substring in range(nb_generated_tokens):
        subtokens = generated_tokens[:,:i_substring+context_size+1]
        substrings = [naive_tokenizer.decode(s.tolist()) for s in subtokens]
        rewards = [reward_fn(s) for s in substrings]
        R.append(rewards)
    R = np.transpose(R) # (B,T)

    #Collect statistics on the full flattened reward array (See DeepseekMath's paper)
    mean_rewards = np.mean(R) 
    std_rewards = np.std(R)

    normalized_rewards = (torch.from_numpy(R) - mean_rewards)/(std_rewards+1e-4)
    
    #Advantage is the cumulative sum of normalized rewards for i>t
    advantages = torch.cumsum(normalized_rewards.flip(dims=[1]), dim=1).flip(dims=[1])
    return advantages

def compute_outcome_supervision_advantage(generated_tokens, base_context, reward_fn):
    #Compute all rewards
    rewards = np.array([reward_fn(naive_tokenizer.decode(s.tolist())) for s in generated_tokens])
    
    #Normalize the group rewards
    mean_rewards = np.mean(rewards)
    std_rewards = np.std(rewards)
    normalized_rewards = (torch.from_numpy(rewards) - mean_rewards)/(std_rewards+1e-4)
    
    #Advantage is directly the normalized_rewards as per DeepseekMath paper
    advantages = normalized_rewards
    return advantages
    


In [10]:
def loss_grpo(generated_tokens, old_policy_logprobs, ref_policy_logprobs, advantages, model, query):
    eps = 0.2
    beta = 0.04
    
    #Detach the ref/old policy logprobs to avoid backpropagating through them
    ref_policy_logprobs = ref_policy_logprobs.detach()
    old_policy_logprobs = old_policy_logprobs.detach()
    
    #Compute the new policy logprobs
    new_policy_logprobs = None
    for i_substring in range(generated_tokens.shape[1]-1):
        substring = generated_tokens[:,:i_substring+1]
        token_eval = generated_tokens[:,i_substring+1].unsqueeze(1)
        
        token_logprob =  model('get_token_logprob', context=substring, token=token_eval)
        
        if new_policy_logprobs is None:
            new_policy_logprobs = token_logprob
        else:
            new_policy_logprobs = torch.cat([new_policy_logprobs, token_logprob], dim=1)
        
    
    pi_ref_log = torch.sum(ref_policy_logprobs, dim=1) 
    pi_old_log = torch.sum(old_policy_logprobs, dim=1)
    pi_new_log = torch.sum(new_policy_logprobs, dim=1)
    
    policy_ratio = torch.exp(pi_new_log - pi_old_log)
    clipped_policy_ratio = torch.clamp(policy_ratio, 1-eps, 1+eps)
    
    weighted_policy_ratio = torch.einsum('b,b->b', policy_ratio, advantages)
    weighted_clipped_policy_ratio = torch.einsum('b,b->b', clipped_policy_ratio, advantages)
    
    def KL_divergence(pi_new_log, pi_ref_log):
        return torch.exp(pi_ref_log - pi_new_log) - (pi_ref_log-pi_new_log) - 1
    
    policy_loss = torch.min(weighted_policy_ratio, weighted_clipped_policy_ratio)
    kl_loss = KL_divergence(pi_new_log, pi_ref_log)
    
    loss = policy_loss - beta * kl_loss
    
    return torch.mean(loss)



In [11]:

def grpo_optim(model):
    # Make an adamW optimizer
    grpo_optimizer = torch.optim.AdamW(model.parameters(), lr=1e-6, maximize=True) 

    # Set the model in training mode and disable dropout
    model.train()
    if isinstance(model, torch.nn.DataParallel):
        model.module.disable_dropout()
    else:
        model.disable_dropout()
        

    nb_outer_grpo_updates = 100
    nb_inner_grpo_updates = 10
    grpo_n_shots = 10
    grpo_n_tokens = 128

    base_context = torch.tensor([naive_tokenizer.encode('\n') for _ in np.arange(grpo_n_shots)], dtype=torch.long).to(device)

    for i_outer in range(nb_outer_grpo_updates):
        # Given query q, generate a batch of outputs o
        
        generated_tokens, generated_logprobs =  model('generate', context=base_context, nb_tokens=grpo_n_tokens )
        
        old_policy_logprobs = generated_logprobs
        ref_policy_logprobs = generated_logprobs
        # Precompute the advantages
        advantages = compute_outcome_supervision_advantage(generated_tokens, base_context, full_reward).cuda()
        advantages = advantages.to(device)

        for i_inner in range(nb_inner_grpo_updates):
            # Compute the loss
            loss = loss_grpo(generated_tokens, old_policy_logprobs, ref_policy_logprobs, advantages, model, base_context)
            # Backward pass
            grpo_optimizer.zero_grad()
            loss.backward()
            grpo_optimizer.step()
            
            
            # Print the loss
            print(f"Iter  {i_outer}-{i_inner}, Loss {loss.item()}")
            
    return model

In [None]:
#Compute the reward pre grpo optimization
with torch.no_grad():
    reward = []   
    for i in range(10):
        print(f'Generation {i}')
        base_context = torch.tensor([naive_tokenizer.encode('\n') for _ in range(1)], dtype=torch.long).to(device)
        generated_tokens, generated_logprobs =  transformer('generate',context=base_context, nb_tokens=1024)
        reward.append([full_reward(naive_tokenizer.decode(g.tolist())) for g in generated_tokens])

In [None]:
import copy 
grpo_trans = DecoderTrans(transformer_params).to(device)
grpo_trans.load_state_dict(copy.deepcopy(transformer.state_dict()))

In [None]:
grpo_optim(grpo_trans)


In [None]:
torch.cuda.empty_cache()
gc.collect()

In [None]:
#Compute the reward pre grpo optimization
with torch.no_grad():
    post_reward = []   
    for i in range(10):
        print(f'Generation {i}')
        base_context = torch.tensor([naive_tokenizer.encode('\n') for _ in range(10)], dtype=torch.long).to(device)
        generated_tokens, generated_logprobs =  grpo_trans('generate',context=base_context, nb_tokens=1024)
        post_reward.append([full_reward(naive_tokenizer.decode(g.tolist())) for g in generated_tokens])


In [None]:
print(f'Average reward for the base transformer : {np.mean(np.array(reward))}')
print(f'Average reward for the grpo optimized transformer: {np.mean(np.array(post_reward))}')

In [None]:
base_context = torch.tensor([naive_tokenizer.encode('\n')], dtype=torch.long).to(device)
grpo_generation, _ =  grpo_trans('generate',context=base_context, nb_tokens=1024)
grpo_reward = full_reward(naive_tokenizer.decode(grpo_generation[0].tolist()))
trans_generation, _ = transformer('generate', context=base_context, nb_tokens=1024)
trans_reward = full_reward(naive_tokenizer.decode(trans_generation[0].tolist()))

print(f'Rewards : GRPO {grpo_reward}, Base Transformer {trans_reward}')

In [None]:
print(f'Base transformer ({trans_reward} caps for 1024 tokens):')
print(naive_tokenizer.decode(trans_generation[0].tolist()))

print('----------------------------------------------------------------')
print('----------------------------------------------------------------')
print(f'GRPO updated transformer: ({grpo_reward} caps for 1024 tokens):')
print(naive_tokenizer.decode(grpo_generation[0].tolist()))


In [20]:
torch.save(transformer.state_dict(), 'base_transformer.pth')
torch.save(grpo_trans.state_dict(), 'grpo_transformer.pth')

In [None]:

tload = DecoderTrans(transformer_params)
tload.to(device)
tload.load_state_dict(torch.load('base_transformer.pth'))

grpo_trans = DecoderTrans(transformer_params)
grpo_trans.to(device)
grpo_trans.load_state_dict(torch.load('grpo_transformer.pth'))
grpo_trans.train()

print('Loaded the saved transformers')


In [None]:
model = grpo_trans
base_context = torch.tensor([naive_tokenizer.encode('\n') for _ in range(1)], dtype=torch.long).to(device)
generated_tokens, generated_logprobs =  model('generate',context=base_context, nb_tokens=1024)
print(naive_tokenizer.decode(generated_tokens[0].cpu().numpy()))