In [1]:
from rankgen import RankGenCollator, RankGenModel
import wandb
import torch
import omegaconf
from omegaconf import DictConfig, OmegaConf
from torch.utils.data import Dataset, DataLoader
import random
from typing import Literal
import tqdm
import pickle
from pathlib import Path

from datasets import load_dataset, DatasetDict

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
config =  OmegaConf.create({"project": "open-assistant-model-reward-rankgen",
                      "num_epochs": 10,
                      "rankgen_model": {
                          "rankgen_hf_path" : "kalpeshk2011/rankgen-t5-base-all",
                          # "rankgen_hf_path" : "kalpeshk2011/rankgen-t5-large-all",
                          # "rankgen_hf_path" : "kalpeshk2011/rankgen-t5-xl-all",
                          # "rankgen_hf_path" : "kalpeshk2011/rankgen-t5-xl-pg19",
                          "model_size" : None,
                          "cache_dir" : None,
                          "eval_mode": True,
                          "snapshot_dir": "snapshots",
                          "save_dir": "pretrained_models",
                          "save_freq": 2,
                          "save_on_best": True,
                          "lr": 1e-4,
                        },
                      "dataset": {
                          "name": "openai/webgpt_comparisons",
                          # "name": "imdb",
                          # "name": "summarize-from-feedback",
                          "shuffle": True,
                          "train_batch_size": 24,
                          "max_sentence_length": 256,
                        }
                      })
wandb.init(project="open-assistant-model-reward-rankgen", config=OmegaConf.to_container(config))


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mbobakhashemi[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
full_dataset : DatasetDict = load_dataset(config.dataset.name)
td = full_dataset["train"].remove_columns(['quotes_0', 'tokens_0', 'quotes_1', 'tokens_1'])

# train/test/valid split the dataset['train']
dataset_ = td.train_test_split(test_size=0.2, train_size=0.8, shuffle=True)
# add a validation split
dataset_test_valid = dataset_['test'].train_test_split(test_size=0.75, train_size=0.25, shuffle=True)
dataset = DatasetDict({
    'train': dataset_['train'],
    'valid': dataset_test_valid['train'],
    'test': dataset_test_valid['test'],
})
del dataset_, full_dataset

Found cached dataset webgpt_comparisons (/home/bobak/.cache/huggingface/datasets/openai___webgpt_comparisons/default/0.0.0/8b5d5879cdc98c4c0099af6053dffe8d504588d43d3b11f1b1ec223ab1e8db0a)
100%|██████████| 1/1 [00:00<00:00, 15.04it/s]


In [4]:
def reward_criterion(positive_scores:  torch.Tensor, negative_scores: torch.Tensor) -> torch.Tensor:
  return torch.sum(-torch.log(torch.sigmoid(positive_scores - negative_scores)))

In [5]:
class DatasetHandler:
  def __init__(self, config: DictConfig):
    if config.dataset.name == "overfit-random":
      self.init_overfit_random(config)
    elif config.dataset.name == "openai/webgpt_comparisons":
      self.init_webgpt(config)
    elif config.dataset.name == "summarize-from-feedback":
      self.init_summarize_from_feedback(config)
    else:
      raise NotImplementedError(f"Dataset {config.dataset.name} not implemented")
  
  def init_webgpt(self, config: DictConfig):
    full_dataset : DatasetDict = load_dataset(config.dataset.name)
    
    #columns = ['question', 'answer_0', 'score_0', 'answer_1', 'score_1']
    td = full_dataset["train"].remove_columns(['quotes_0', 'tokens_0', 'quotes_1', 'tokens_1'])
    # train/test/valid split the dataset['train']
    dataset_ = td.train_test_split(test_size=0.2, train_size=0.8, shuffle=True)
    # add a validation split
    
    dataset_test_valid = dataset_['test'].train_test_split(test_size=0.75, train_size=0.25, shuffle=True)
    dataset = DatasetDict({
        'train': dataset_['train'],
        'valid': dataset_test_valid['train'],
        'test': dataset_test_valid['test'],
    })
    del dataset_, full_dataset
    self.dataset = dataset
    self.dataloaders = {
      "train" : DataLoader(dataset["train"], batch_size=config.dataset.train_batch_size, shuffle=config.dataset.shuffle, collate_fn=RankGenCollator(config)),
      "valid" : DataLoader(dataset["valid"], batch_size=config.dataset.train_batch_size, shuffle=config.dataset.shuffle, collate_fn=RankGenCollator(config)),
      "test" : DataLoader(dataset["test"], batch_size=config.dataset.train_batch_size, shuffle=config.dataset.shuffle, collate_fn=RankGenCollator(config)),
    }
  
  def __repr__(self):
    return f"""DatasetHandler:
Dataset: {self.dataset}) 
Dataloaders: {self.dataloaders}"""

In [6]:
class RankGenTrainer():
  def __init__(self, config: DictConfig) -> None:
    self.rankgen_model : RankGenModel = RankGenModel(config=config)
    self.config = config
    self.criterion = reward_criterion
    self.data = DatasetHandler(config)
    self.best_valid_loss = float("inf")
    self.optimizer = torch.optim.Adam(self.rankgen_model.parameters(), lr=config.rankgen_model.lr)
    
    self.save_dir = Path(config.rankgen_model.save_dir)
    self.snapshot_dir = Path(config.rankgen_model.snapshot_dir)
    if not self.save_dir.exists():
      self.save_dir.mkdir(parents=True)
    if not self.snapshot_dir.exists():
      self.snapshot_dir.mkdir(parents=True)
  
  def train(self) -> None:
    for epoch in range(self.config.num_epochs):
      for batch in tqdm.tqdm(self.data.dataloaders["train"], desc=f"Epoch {epoch} -- Batch", leave=False, total=len(self.data.dataloaders["train"])):
        prefixes, pos_suffixes, neg_suffixes = batch
        pos_scores = self.rankgen_model(prefixes, pos_suffixes)
        neg_scores = self.rankgen_model(prefixes, neg_suffixes)
        
        loss = self.criterion(pos_scores, neg_scores)
        loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()
        wandb.log({"train_loss": loss.item()})
      if self.config.rankgen_model.save_freq > 0 and epoch % self.config.rankgen_model.save_freq == 0:
        self.save(f"train_epoch_{epoch}")
      
      self.test("valid")
      self.test("test")
    self.save("final")
    
  def test(self, split : Literal["valid", "test"]) -> None:
    with torch.inference_mode():
      for batch in self.data.dataloaders[split]:
        prefixes, pos_suffixes, neg_suffixes = batch
        pos_scores = self.rankgen_model(prefixes, pos_suffixes)
        neg_scores = self.rankgen_model(prefixes, neg_suffixes)
        
        loss = self.criterion(pos_scores, neg_scores)
        wandb.log({f"{split}_loss": loss.item()})
        if split=="valid" and loss < self.best_valid_loss:
          self.best_valid_loss = loss
          self.save("best")
  
  def save(self, prefix="best") -> None:
    pickle.dump(self.rankgen_model, open(self.save_dir / f"{prefix}_rankgen_model.pkl", "wb"))

In [7]:
a = RankGenTrainer(config=config)

Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.
Found cached dataset webgpt_comparisons (/home/bobak/.cache/huggingface/datasets/openai___webgpt_comparisons/default/0.0.0/8b5d5879cdc98c4c0099af6053dffe8d504588d43d3b11f1b1ec223ab1e8db0a)
100%|██████████| 1/1 [00:00<00:00, 345.69it/s]


In [None]:
a.train()

## Testing basic functionality of the Rankgen Code

In [None]:
similarities, _, _ = a.model.score(["How's this?", "It's crazy when a best friend finishes your"], ["It's actually pretty good", "sentences", "we were at the park"])
print(similarities)

In [None]:
for x in dataset["train"]:
  print(x)
  break