In [13]:
import argparse
import os
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

import evaluate
import numpy as np
import torch.nn as nn
from accelerate import Accelerator, infer_auto_device_map, init_empty_weights
from datasets import load_dataset
from huggingface_hub import login
from peft import LoraConfig, TaskType, get_peft_model
from tqdm import tqdm
import torch
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    HfArgumentParser,
    LlamaForCausalLM,
    LlamaTokenizer,
    PreTrainedTokenizerBase,
    Trainer,
    TrainingArguments,
    logging,
    set_seed,
)
from transformers.utils import PaddingStrategy
from trl import SFTTrainer
from trl.trainer import ConstantLengthDataset

from redditqa.dataset import load_reddit_dataset

In [17]:
# Login to the HuggingFace Hub
HUGGINGFACE_TOKEN = os.environ.get("HUGGINGFACE_TOKEN", None)
if HUGGINGFACE_TOKEN is not None:
    login(token=HUGGINGFACE_TOKEN)


Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid.
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [18]:
model_name = 'meta-llama/Llama-2-7b-chat-hf'
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

In [23]:
train_dataset = load_reddit_dataset("train", pairs=True)
eval_dataset = load_reddit_dataset("eval", pairs=True)

Loading cached processed dataset at /scratch1/jhoff/reddit_dataset_cached/train/cache-757f0ee80690267b.arrow
Loading cached processed dataset at /scratch1/jhoff/reddit_dataset_cached/train/cache-ec6a7572749a5bef.arrow
Loading cached shuffled indices for dataset at /scratch1/jhoff/reddit_dataset_cached/train/cache-fcce8e9f7c96b1b8.arrow
Loading cached processed dataset at /scratch1/jhoff/reddit_dataset_cached/eval/cache-cba55e4212677d14.arrow
Loading cached processed dataset at /scratch1/jhoff/reddit_dataset_cached/eval/cache-d8898fc7c787d1eb.arrow
Loading cached shuffled indices for dataset at /scratch1/jhoff/reddit_dataset_cached/eval/cache-e35089f0b695ca2b.arrow


In [24]:
train_dataset[0].keys()

dict_keys(['answer_link_id', 'question_title', 'response_j', 'response_k', 'score_j', 'score_k'])

In [26]:
# Turn the dataset into pairs of post + summaries, where text_j is the preferred question + answer and text_k is the other.
# Then tokenize the dataset.
def preprocess_function(examples):
    new_examples = {
        "input_ids_j": [],
        "attention_mask_j": [],
        "score_j": [],
        "input_ids_k": [],
        "attention_mask_k": [],
        "score_k": [],
    }
    for question_title, response_j, response_k, score_j, score_k in zip(
        examples["question_title"], examples["response_j"], examples["response_k"], examples["score_j"], examples["score_k"]
    ):
        template = "<|ELIF|> Question: %question\nAnswer: %answer"

        text_j = template.replace("%question", question_title).replace("%answer", response_j)
        text_k = template.replace("%question", question_title).replace("%answer", response_k)
        tokenized_j = tokenizer(text_j, truncation=True)
        tokenized_k = tokenizer(text_k, truncation=True)

        new_examples["input_ids_j"].append(tokenized_j["input_ids"])
        new_examples["attention_mask_j"].append(tokenized_j["attention_mask"])
        new_examples["score_j"].append(score_j)
        new_examples["input_ids_k"].append(tokenized_k["input_ids"])
        new_examples["attention_mask_k"].append(tokenized_k["attention_mask"])
        new_examples["score_k"].append(score_k)

    return new_examples


original_columns = train_dataset.column_names

# Preprocess the dataset
train_dataset_preproc = train_dataset.map(preprocess_function, num_proc=1, remove_columns=original_columns, batched=True)
eval_dataset_preproc = eval_dataset.map(preprocess_function, num_proc=1, remove_columns=original_columns, batched=True)

# # Filter 
# train_dataset = train_dataset.filter(
#     lambda x: len(x["input_ids_j"]) <= script_args.max_length and len(x["input_ids_k"]) <= script_args.max_length
# )
# eval_dataset = eval_dataset.filter(
#     lambda x: len(x["input_ids_j"]) <= script_args.max_length and len(x["input_ids_k"]) <= script_args.max_length
# )

# # Shuffle and select a subset of the dataset
# eval_dataset_preproc = eval_dataset_preproc.shuffle(seed=SEED).select(range(script_args.eval_subsample))

print("Finished preprocessing dataset.")
print("Number of training examples: ", len(train_dataset))
print("Number of eval examples: ", len(eval_dataset))


Loading cached processed dataset at /scratch1/jhoff/reddit_dataset_cached/train/cache-b84fa26c2fd2d3bf.arrow
Loading cached processed dataset at /scratch1/jhoff/reddit_dataset_cached/eval/cache-215276af46097339.arrow


Finished preprocessing dataset.
Number of training examples:  110865
Number of eval examples:  47799


In [28]:
train_dataset_preproc[0].keys()

dict_keys(['score_j', 'score_k', 'input_ids_j', 'attention_mask_j', 'input_ids_k', 'attention_mask_k'])