Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes on #619 #638

Merged
merged 5 commits into from
Jan 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions model/supervised_finetuning/configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@ defaults:
eval_size:
log_dir: "base"
quantization: false
seq2seqmodel: false
poly_eps: 1.0

galactica-125:
galactica-125m:
learning_rate: 5e-5
model_name: facebook/galactica-125m
weight_decay: 0.01
Expand Down Expand Up @@ -58,7 +60,7 @@ codegen:

debug:
eval_steps: 20
eval_size: 100
eval_size: 20
gradient_accumulation_steps: 1
per_device_train_batch_size: 1
per_device_eval_batch_size: 1
Expand Down
106 changes: 98 additions & 8 deletions model/supervised_finetuning/custom_datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,108 @@
import numpy as np
from datasets import load_dataset
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, Subset

from .prompt_dialogue import PromptGeneratedDataset

QA_SPECIAL_TOKENS = {"Question": "<question>", "Answer": "<answer>"}
SUMMARIZATION_SPECIAL_TOKENS = {"Text": "", "Summary": "TL;DR:"}

summarization_name_mapping = {
"cnn_dailymail": ("article", "highlights"),
"samsum": ("dialogue", "summary"),
"xsum": ("document", "summary"),
"multi_news": ("document", "summary"),
"scitldr": ("source", "target"),
"billsum": ("text", "summary"),
"reddit": ("content", "summary"),
}
summarization_config_mapping = {
"cnn_dailymail": ("3.0.0",),
"samsum": (),
"xsum": (),
"multi_news": (),
"scitldr": ("AIC",),
"billsum": (),
"reddit": (),
}

QA_DATASETS = ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_noconext"]
SUMMARIZATION_DATASETS = ["xsum", "cnn_dailymail", "samsum", "multi_news"]


def index_squad_v2(example):
return example["title"] + ". " + example["context"] + " " + example["question"], example["answers"]["text"][0]


def index_trivia_qa_nocontext(example):
# dummy return one randomly
return example["question"], example["answer"]["aliases"][np.random.randint(len(example["answer"]["aliases"]))]


def index_trivia_qa_context(example):
question = example["question"]
title = example["title"][np.random.randint(len(example["title"]))]
context = example["search_context"][np.random.randint(len(example["search_context"]))]
answer = example["answer"]["aliases"][np.random.randint(len(example["answer"]["aliases"]))]

return title + ". " + context + " " + question, answer


def index_adversarial_qa(example):
return example["title"] + ". " + example["context"] + " " + example["question"], example["answers"]["text"][0]


class QADataset(Dataset):
def __init__(self, dataset, cache_dir, split):
if dataset == "squad_v2":
self.index_fn = index_squad_v2
self.dataset = load_dataset("squad_v2", cache_dir=cache_dir, split=split)
elif dataset == "trivia_qa_nocontext":
self.index_fn = index_trivia_qa_nocontext
self.dataset = load_dataset("trivia_qa", "rc.nocontext")
elif dataset == "trivia_qa_context":
self.index_fn = index_trivia_qa_context
self.dataset = load_dataset("trivia_qa", "rc")
elif dataset == "adversarial_qa":
self.index_fn = index_adversarial_qa
self.dataset = load_dataset("adversarial_qa", "adversarialQA")
else:
raise ValueError("Unknown dataset : " + dataset)

def __len__(self):
return len(self.dataset)

def __getitem__(self, idx):
data = self.dataset[idx]
return self.index_fn(data)


def index_summary_default(text, summary):
return text, summary

class SquadV2Dataset(Dataset):
def __init__(self, cache_dir, split):
self.dataset = load_dataset("squad_v2", cache_dir=cache_dir, split=split)

def index_summary_merge(text, summary):
return " ".join(text), " ".join(summary)


class SummarizationDataset(Dataset):
def __init__(self, dataset, cache_dir, split):
self.dataset = load_dataset(dataset, *summarization_config_mapping[dataset], cache_dir=cache_dir, split=split)
self.summary_column, self.text_column = summarization_name_mapping[dataset]
self.preprocess_fn = index_summary_merge if dataset == "scitdlr" else index_summary_merge

def __len__(self):
return len(self.dataset)

def __getitem__(self, idx):
data = self.dataset[idx]
# return first answer form list of possible answers
return data["title"] + ". " + data["context"] + " " + data["question"], data["answers"]["text"][0]
text, summary = data[self.text_column], data[self.summary_column]
text, summary = self.preprocess_fn(text, summary)

return "".join(
SUMMARIZATION_SPECIAL_TOKENS["Text"], text, " ", SUMMARIZATION_SPECIAL_TOKENS["Summary"], summary
)


class WebGPT(Dataset):
Expand Down Expand Up @@ -58,9 +143,14 @@ def train_val_dataset(dataset, val_split=0.2):
def get_one_dataset(conf, dataset_name):
dataset_name = dataset_name.lower()

if dataset_name == "squadv2":
train = SquadV2Dataset(conf.cache_dir, "train")
eval = SquadV2Dataset(conf.cache_dir, "validation")
if dataset_name in ["squad_v2", "adversarial_qa", "trivia_qa_context", "trivia_qa_noconext"]:
train = QADataset(dataset_name, conf.cache_dir, "train")
eval = QADataset(dataset_name, conf.cache_dir, "validation")

elif dataset_name in ["xsum", "cnn_dailymail", "samsum", "multi_news", "scitldr", "billsum", "reddit"]:
train = SummarizationDataset(dataset_name, conf.cache_dir, "train")
eval = SummarizationDataset(dataset_name, conf.cache_dir, "validation")

elif dataset_name == "webgpt":
dataset = WebGPT()
train, eval = train_val_dataset(dataset, val_split=0.2)
Expand Down
30 changes: 30 additions & 0 deletions model/supervised_finetuning/losses.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import torch
import torch.nn.functional as F
from torch import nn


Expand All @@ -13,3 +15,31 @@ def forward(self, input, target, mask=None):
input = input[mask]
target = target[mask]
return super(CrossEntropyLoss, self).forward(input, target)


class PolyLoss(nn.Module):
def __init__(self, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction="mean", epsilon=1.0):
super(PolyLoss, self).__init__()
self.weight = torch.tensor(weight)
self.ignore_index = ignore_index
self.reduction = reduction
self.cross_entropy = CrossEntropyLoss(weight, size_average, ignore_index, reduce, "none")
self.epsilon = epsilon

def forward(self, input, target, mask=None):
if mask is not None:
mask = mask.view(-1).bool()
input = input.view(-1, input.size(-1))
target = target.view(-1)
input = input[mask]
target = target[mask]

onehot_target = F.one_hot(target, num_classes=input.size(-1)).to(device=input.device, dtype=input.dtype)
pt = torch.sum(onehot_target * F.softmax(input, -1), -1)
CE = self.cross_entropy(input, target)
poly1 = CE + self.epsilon * (1 - pt)
if self.reduction == "mean":
poly1 = poly1.mean()
elif self.reduction == "sum":
poly1 = poly1.sum()
return poly1
17 changes: 10 additions & 7 deletions model/supervised_finetuning/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from transformers import AutoModelForCausalLM
import transformers

# from .gptj import get_model as get_gptj_model

Expand All @@ -25,9 +25,12 @@ def freeze_top_n_layers(model, target_layers):
return model


def get_specific_model(model_name, cache_dir, quantization):
return AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir)
# if "gpt-j" in model_name.lower():
# return get_gptj_model(model_name, cache_dir, quantization)
# else:
# return AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir)
def get_specific_model(model_name, cache_dir, quantization, seq2seqmodel):
# encoder-decoder support for Flan-T5 like models
# for now, we can use an argument but in the future,
# we can automate this
if seq2seqmodel:
model = transformers.AutoModelForSeq2SeqLM.from_pretrained(model_name, cache_dir=cache_dir)
else:
model = transformers.AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir)
return model
2 changes: 2 additions & 0 deletions model/supervised_finetuning/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ accelerate==0.15.0
bitsandbytes==0.36.0.post2
datasets==2.8.0
deepspeed==0.7.7
evaluate==0.4.0
mpi4py==3.1.4
nltk==3.8.1
numpy==1.23.0
PyYAML==6.0
scikit_learn==1.2.0
Expand Down
48 changes: 35 additions & 13 deletions model/supervised_finetuning/trainer.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
import argparse
import os
from distutils.util import strtobool
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Union

import bitsandbytes
import torch
from torch import nn
from transformers import PreTrainedModel, Trainer, TrainingArguments
from transformers.training_args import OptimizerNames
from utils import get_dataset, get_loss, get_model, get_tokenizer, read_yamls
from utils import get_dataset, get_loss, get_metrics, get_model, get_tokenizer, read_yamls

os.environ["WANDB_PROJECT"] = "supervised-finetuning"

def compute_metrics(eval_pred, preprocess_fns, metrics):
out = {}
for metric, preprocess_fn in zip(metrics, preprocess_fns):
preds, labels = preprocess_fn(eval_pred)
out = dict(**out, **metric.compute(predictions=preds, references=labels))

def compute_metrics(eval_pred):
pred_ids = eval_pred.predictions
labels = eval_pred.label_ids

return {"accuracy": (pred_ids[labels > 0] == labels[labels > 0]).mean()}
return out


def preprocess_logits_for_metrics(logits, labels):
Expand All @@ -31,18 +31,22 @@ def __init__(
model: Union[PreTrainedModel, nn.Module] = None,
args: TrainingArguments = None,
loss_function: str = "CrossEntropyLoss",
poly_eps: float = 1.0,
**kwargs,
):
super().__init__(model, args, **kwargs)

# By default CrossEntropyLoss ignores padding_index -100, but just in case use our own loss_fct
self.loss_fct = get_loss(loss_function)
self.loss_fct = get_loss(loss_function, poly_eps)

def compute_loss(self, model, inputs, return_outputs=False):
labels_mask = inputs.pop("label_masks")
targets = inputs.pop("targets")

outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs.get("attention_mask", None))
outputs = model(
input_ids=inputs["input_ids"],
attention_mask=inputs.get("attention_mask", None),
)

loss = self.loss_fct(outputs.get("logits"), targets, mask=labels_mask)

Expand All @@ -54,7 +58,10 @@ def _compute_loss(self, model, inputs):
labels_mask = inputs.pop("label_masks")
targets = inputs.pop("targets")

outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs.get("attention_mask", None))
outputs = model(
input_ids=inputs["input_ids"],
attention_mask=inputs.get("attention_mask", None),
)

logits = outputs.get("logits")

Expand Down Expand Up @@ -92,6 +99,7 @@ def argument_parsing(notebook=False, notebook_args=None):
parser.add_argument("--local_rank", type=int, default=-1)
parser.add_argument("--deepspeed", action="store_true")
parser.add_argument("--no-deepspeed", dest="deepspeed", action="store_false")
parser.add_argument("--wandb-entity", type=str, default="open-assistant")
parser.set_defaults(deepspeed=False)

if notebook:
Expand All @@ -111,8 +119,10 @@ def argument_parsing(notebook=False, notebook_args=None):
else:
conf.update(configs[name])

conf["wandb_entity"] = args.wandb_entity
conf["local_rank"] = args.local_rank
conf["deepspeed"] = args.deepspeed

# Override config from command-line
parser = argparse.ArgumentParser()
for key, value in conf.items():
Expand All @@ -131,8 +141,9 @@ def argument_parsing(notebook=False, notebook_args=None):
model = get_model(training_conf, tokenizer)

train, evals, collate_fn = get_dataset(training_conf, tokenizer)
metrics, preprocess_fns = get_metrics(training_conf, tokenizer)

optimizer = OptimizerNames.ADAMW_BNB if training_conf.quantization else None
optimizer = OptimizerNames.ADAMW_BNB if training_conf.quantization else OptimizerNames.ADAMW_HF

if training_conf.quantization:
for module in model.modules():
Expand Down Expand Up @@ -166,15 +177,26 @@ def argument_parsing(notebook=False, notebook_args=None):
)

assert len(evals) > 0

if not training_conf.deepspeed or training_conf.local_rank == 0:
import wandb

wandb.init(
project="supervised-finetuning",
entity=training_conf.wandb_entity,
name=f"{training_conf.model_name}-{training_conf.log_dir}-finetuned",
)

trainer = SFTTrainer(
model,
args,
loss_function=training_conf.loss_fn,
poly_eps=training_conf.poly_eps,
train_dataset=train,
eval_dataset=evals,
data_collator=collate_fn,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
compute_metrics=partial(compute_metrics, metrics=metrics, preprocess_fns=preprocess_fns),
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
trainer.train()
Loading