In [1]:
# If you are running this online (for example at Google Colab), 
# make sure you have the support files on the same folder
# Otherwise run this cell to download them

# NOTE: Downloading will take a while, be patient. You can refresh your folder from time to time to see when the files
# have been created.

import os, requests, zipfile, io 

files_url = "https://ideami.com/llm_align"

# Downloading proceeds if we detect that one of the key files to download is not present
if not os.path.exists(f"llm.py"):
    print("Downloading files using Python")
    response = requests.get(files_url)
    zipfile.ZipFile(io.BytesIO(response.content)).extractall(".")
else:
    print("you seem to have already downloaded the files. If you wish to re-download them, delete the llm.py file")


you seem to have already downloaded the files. If you wish to re-download them, delete the llm.py file


In [2]:
# Import libraries
import os, sys
import math 
from tqdm import tqdm
from datetime import datetime
import ipdb 
from typing import List, Dict, Union, Any, Tuple

# Pytorch
import torch
import torch.nn as nn
from torch.nn import functional as F

# Import some Hugging Face Libraries
import transformers
from datasets import load_dataset, load_from_disk

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

torch.cuda.empty_cache()

# Optional for debugging, if you want to see the full tensor
torch.set_printoptions(threshold=10_000)

In [3]:
#Training parameters
batch_size = 4 
epochs = 3 # 3 is good, more overfits
lr = 6e-5
lr_warmup_steps = 100
context = 1024
alpha = 0.5 
prompt_max_length = 512
compile = False
dtype = torch.bfloat16
log_iter = 50

# Hyperparameters
dropout = 0.
grad_clip = 1.0
weight_decay = 0.0

# Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device: You are using ", device)


Device: You are using  cuda


In [4]:
# Logging 
project_name = "aligntest2"
wandb_log = True 
wandb_project = project_name
# ipdb.set_trace()
wandb_run_name = f"aligntest2_run_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"

if wandb_log:
    import wandb
    wandb.init(project=wandb_project, name=wandb_run_name)

wandb: Currently logged in as: mistigri-heriveau (mistigri-heriveau-universit-toulouse-capitole). Use `wandb login --relogin` to force relogin
wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


In [5]:
path = os.getcwd() 
dataset_path = path + '\data2\orpo_dataset'
dataset_name = 'mlabonne/orpo-dpo-mix-40k'
tokenizer_path = path +'/tokenizers/tok16384'
checkpoint_dir = path +'/models/'

tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_path)

# Set the tokenizer parameters
tokenizer.chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>\n' }}\n{% endif %}\n{% endfor %}"

# Make padding token equal to the end of sentence token (wich has ID of 2 in our case)
tokenizer.pad_token = tokenizer.eos_token

if os.path.exists(dataset_path):
    dataset = load_from_disk(dataset_path)
    print("Dataset loaded from disk")
else:
    print("Dataset not found, loading from Hugging Face")
    dataset = load_dataset(dataset_name, split='all')
    # Optional: Filter out the toxic-dpo-v0.2 dataset
    dataset = dataset.filter(lambda x: x['source'] != "toxic-dpo-v0.2")
    
    def filter_dataset(examples):
        prompt_lenght = tokenizer.apply_chat_template(examples['chosen'][:-1], tokenize=True, add_generation_prompt=True, return_tensors='pt').size(-1)
        
        if prompt_lenght < prompt_max_length:
            return True
        else:
            return False
    
    
    def preprocess_dataset(example: Union[List, Dict]):
        # ipdb.set_trace()
        prompt = [tokenizer.apply_chat_template(item[:-1], tokenize=False, add_generation_prompt=True) for item in example['chosen']]
        chosen = [tokenizer.apply_chat_template(item, tokenize=False) for item in example['chosen']]
        rejected = [tokenizer.apply_chat_template(item, tokenize=False) for item in example['rejected']]
        
        inputs = tokenizer(prompt, max_length=context, padding="max_length", truncation=True, return_tensors="pt")
        pos_labels = tokenizer(chosen, max_length=context, padding="max_length", truncation=True, return_tensors="pt")
        neg_labels = tokenizer(rejected, max_length=context, padding="max_length", truncation=True, return_tensors="pt")
        
        inputs['positive_input_ids'] = pos_labels['input_ids']
        inputs['positive_attention_mask'] = pos_labels['attention_mask']
        
        inputs['negative_input_ids'] = neg_labels['input_ids']
        inputs['negative_attention_mask'] = neg_labels['attention_mask']
        
        return inputs
    
    dataset = dataset.filter(filter_dataset)
    
    
    dataset = dataset.map(preprocess_dataset, batched = True, num_proc=1, remove_columns=dataset.column_names)
    
    dataset.save_to_disk(dataset_path)
    
    

Dataset not found, loading from Hugging Face


Saving the dataset (0/4 shards):   0%|          | 0/38541 [00:00<?, ? examples/s]

In [6]:
tokenizer.decode(dataset[0]['positive_input_ids'])

'<|user|>\nHow many colors are traditionally recognized in a visible spectrum or optical rainbow?</s> \n<|assistant|>\nTraditionally, a visible spectrum or optical rainbow is said to consist of seven colors. The order of these colors is typically remembered using the acronym ROYGBIV - Red, Orange, Yellow, Green, Blue, Indigo, and Violet. However, it is important to note that the division of the spectrum into these seven constituent colors is largely a human construct. In reality, a rainbow encompasses a continuous spectrum of colors which blend seamlessly into one another, from red light, which has longer wavelengths, to violet light, which has shorter wavelengths. The specific selection of seven colors originates from the work of Sir Isaac Newton, who chose to divide the spectrum into seven colors to correlate with the seven notes in a western major scale of music.</s> \n<|user|>\nExplain the scientific reasoning behind the continuous spectrum of colors in a rainbow.</s> \n<|assistant

In [7]:
dataset = dataset.shuffle(42).train_test_split(test_size=0.05)
train_data = dataset['train']
val_data = dataset['test']

In [8]:
data_collector = transformers.DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, collate_fn=data_collector, shuffle=False, num_workers=0)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, collate_fn=data_collector, shuffle=False, num_workers=0)

In [9]:
it = iter(train_loader)
batch = next(it)
# print (tokenizer.decode(batch['positive_input_ids'][0]))

In [10]:
from llm import Llama, ModelArgs

checkpoint = torch.load(os.path.join(checkpoint_dir, 'base_model.pt'))
config = checkpoint.pop("config")

model_args = ModelArgs(
    dim=config.hidden_size, 
    n_layers=config.num_hidden_layers, 
    n_heads=config.num_attention_heads, 
    n_kv_heads=config.num_key_value_heads, 
    vocab_size=config.vocab_size, 
    norm_eps=config.rms_norm_eps, 
    rope_theta=config.rope_theta,
    max_seq_len=context, 
    dropout=config.attention_dropout, 
    hidden_dim=config.intermediate_size,
    attention_bias=config.attention_bias,
    mlp_bias=config.mlp_bias
)

model = Llama(model_args)
model.load_state_dict(checkpoint)
model = model.to(dtype=dtype, device=device)
model.train()

if compile:
    print('[INFO] Compiling model')
    model = torch.compile(model)

print(sum(p.numel() for p in model.parameters())/1e6, 'M parameters')


138.431232 M parameters


In [11]:
# Optimizer

optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.9, 0.98), eps=1e-8, fused = device == 'cuda', weight_decay=weight_decay)

num_training_steps = len(train_loader) * epochs
print(f"num_training_steps: {num_training_steps}")

# Scheduler for lr: first 100 steps warmup, then decay
def lr_lambda(step):
    if step < lr_warmup_steps:
        return float(step) / float(max(1, lr_warmup_steps))
    else:
        progress = float(step - lr_warmup_steps) / float(max(1, num_training_steps - lr_warmup_steps))
        return max(0.0, math.cos(math.pi * float(0.5) * 2.0 * progress))
    

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1)

num_training_steps: 27462


In [13]:
def compute_logps(prompt_attention_mask, chosen_inputs, chosen_attention_mask, logits):
    mask = chosen_attention_mask[:,:-1] - prompt_attention_mask[:,1:]
    per_token_lops = torch.gather(logits[:,:-1,:].log_softmax(-1), dim=2, 
                                  index=(mask * chosen_inputs[:,1:]).unsqueeze(2)).squeeze(2)
    return torch.mul(per_token_lops, mask.to(dtype)).sum(dim=1).to(dtype) / mask.sum(dim=1).to(dtype)

In [13]:
try :
    for e in range (epochs):
        for i, batch in tqdm(enumerate(train_loader), total=len(train_loader), dynamic_ncols=True):
            optimizer.zero_grad(set_to_none=True)
            batch['positive_input_ids'] = batch['positive_input_ids'].to(device)
            batch['positive_attention_mask'] = batch['positive_attention_mask'].to(device)
            batch['negative_input_ids'] = batch['negative_input_ids'].to(device)
            batch['negative_attention_mask'] = batch['negative_attention_mask'].to(device)
            batch['attention_mask'] = batch['attention_mask'].to(device)
            
            neg_labels = batch['negative_input_ids'].clone()
            pos_labels = batch['positive_input_ids'].clone()
            
            # Calculate the loss
            mask = batch['attention_mask'] * batch['positive_attention_mask'] # mask out the padding
            pos_labels = pos_labels * mask.logical_not() 
            
            pos_labels[pos_labels == 0] = tokenizer.pad_token_id           
            pos_labels[pos_labels == tokenizer.eos_token_id] = -100
            neg_labels[neg_labels == tokenizer.eos_token_id] = -100
            
            outputs_pos, loss_pos = model(batch['positive_input_ids'], pos_labels)
            outputs_neg, _ = model(batch['negative_input_ids'], neg_labels)
            
            # Calulcate per token log probabilities, essential to calculate the ORPO LOG ODDS RATIO 
            pos_prob = compute_logps(
                batch['attention_mask'], 
                batch['positive_input_ids'], 
                batch['positive_attention_mask'], 
                outputs_pos
            )
            neg_prob = compute_logps(
                batch['attention_mask'],
                batch['negative_input_ids'],
                batch['negative_attention_mask'],
                outputs_neg
            )
            
            
            # Calculate the ORPO odds ratio
            log_odds = (pos_prob - neg_prob) - (torch.log(1 - torch.exp(pos_prob)) - torch.log(1 - torch.exp(neg_prob)))
            sig_ratio = F.sigmoid(log_odds) # Sigmoid to get the ratio between 0 and 1
            ratio = torch.log(sig_ratio)
            
            # Calculate the loss
            loss = torch.mean(loss_pos - (alpha * ratio).mean()).to(dtype)
            
            # Logging 
            if i % log_iter == 0:
                print(f"Epoch: [{e}/{epochs}] Iteration: [{i}/{len(train_loader)}] Loss: {loss.item():.3f} Odds Ratio: {log_odds.mean().item():.3f}")
                if wandb_log:
                    wandb.log({"loss": loss.item(),
                               "odds_ratio": log_odds.mean().item(),
                               "lr" : scheduler.get_last_lr()[0],
                               "epoch": e,
                               "iteration": i})
                if torch.isnan(loss):
                    print("Loss is NaN, breaking")
                    if wandb_log:
                        wandb.finish()
                    torch.cuda.empty_cache()
                    sys.exit()
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            optimizer.step()
            scheduler.step()
            
        # Save the model
        sd = model.state_dict()
        sd['config'] = config
        torch.save(sd, os.path.join(checkpoint_dir, f'base_model_{e+1}.pt'))
                        
except KeyboardInterrupt:
    print("Training interrupted")
    pass
finally:
    torch.cuda.empty_cache()
    print("Training finished, GPU memory cleaned")
    pass

torch.cuda.empty_cache()

  0%|          | 0/9154 [00:00<?, ?it/s]

Epoch: [0/3] Iteration: [0/9154] Loss: 2.922 Odds Ratio: 0.334


  0%|          | 4/9154 [00:53<33:55:11, 13.35s/it]


Training interrupted
Training finished, GPU memory cleaned
