# DPO Training for Artist Style Lyrics Generation

This notebook implements Direct Preference Optimization (DPO) to fine-tune the rap lyrics generator to better match specific artist styles.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from datasets import load_dataset
import numpy as np
from tqdm import tqdm
import os

In [None]:
# Load the pre-trained model
model_path = './trained_model'
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path)

In [None]:
class RapStyleDataset(Dataset):
    def __init__(self, dataset_path, tokenizer, max_length=512):
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        # Load the dataset
        self.dataset = load_dataset('json', data_files=dataset_path)['train']
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        item = self.dataset[idx]
        
        # Get the preferred and non-preferred completions
        preferred = item['preferred']
        non_preferred = item['non_preferred']
        
        # Tokenize the texts
        preferred_tokens = self.tokenizer(
            preferred,
            truncation=True,
            max_length=self.max_length,
            padding='max_length',
            return_tensors='pt'
        )
        
        non_preferred_tokens = self.tokenizer(
            non_preferred,
            truncation=True,
            max_length=self.max_length,
            padding='max_length',
            return_tensors='pt'
        )
        
        return {
            'preferred_input_ids': preferred_tokens['input_ids'].squeeze(),
            'preferred_attention_mask': preferred_tokens['attention_mask'].squeeze(),
            'non_preferred_input_ids': non_preferred_tokens['input_ids'].squeeze(),
            'non_preferred_attention_mask': non_preferred_tokens['attention_mask'].squeeze()
        }

In [None]:
class DPOTrainer:
    def __init__(self, model, tokenizer, beta=0.1):
        self.model = model
        self.tokenizer = tokenizer
        self.beta = beta
        
    def compute_loss(self, preferred_logits, non_preferred_logits, preferred_mask, non_preferred_mask):
        # Compute policy loss
        preferred_log_probs = F.log_softmax(preferred_logits, dim=-1)
        non_preferred_log_probs = F.log_softmax(non_preferred_logits, dim=-1)
        
        # Mask out padding tokens
        preferred_log_probs = (preferred_log_probs * preferred_mask.unsqueeze(-1)).sum(dim=1) / preferred_mask.sum(dim=1).unsqueeze(-1)
        non_preferred_log_probs = (non_preferred_log_probs * non_preferred_mask.unsqueeze(-1)).sum(dim=1) / non_preferred_mask.sum(dim=1).unsqueeze(-1)
        
        # Compute DPO loss
        loss = -torch.log(torch.sigmoid(self.beta * (preferred_log_probs - non_preferred_log_probs)))
        
        return loss.mean()
    
    def train(self, train_loader, optimizer, device, epochs=3):
        self.model.train()
        
        for epoch in range(epochs):
            total_loss = 0
            progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}')
            
            for batch in progress_bar:
                # Move batch to device
                preferred_input_ids = batch['preferred_input_ids'].to(device)
                preferred_attention_mask = batch['preferred_attention_mask'].to(device)
                non_preferred_input_ids = batch['non_preferred_input_ids'].to(device)
                non_preferred_attention_mask = batch['non_preferred_attention_mask'].to(device)
                
                optimizer.zero_grad()
                
                # Forward pass for preferred completions
                preferred_outputs = self.model(
                    input_ids=preferred_input_ids,
                    attention_mask=preferred_attention_mask
                )
                
                # Forward pass for non-preferred completions
                non_preferred_outputs = self.model(
                    input_ids=non_preferred_input_ids,
                    attention_mask=non_preferred_attention_mask
                )
                
                # Compute loss
                loss = self.compute_loss(
                    preferred_outputs.logits,
                    non_preferred_outputs.logits,
                    preferred_attention_mask,
                    non_preferred_attention_mask
                )
                
                # Backward pass
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
                progress_bar.set_postfix({'loss': loss.item()})
            
            avg_loss = total_loss / len(train_loader)
            print(f'Epoch {epoch+1}/{epochs}, Average Loss: {avg_loss:.4f}')

In [None]:
# Create dataset and dataloader
dataset = RapStyleDataset('data/rap_style_preferences.json', tokenizer)
train_loader = DataLoader(dataset, batch_size=4, shuffle=True)

# Initialize DPO trainer
dpo_trainer = DPOTrainer(model, tokenizer)

# Initialize optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

# Train the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
dpo_trainer.train(train_loader, optimizer, device)

# Save the trained model
model.save_pretrained('checkpoints/dpo_trained_model')
tokenizer.save_pretrained('checkpoints/dpo_trained_model')