In [1]:
from openprompt.data_utils.utils import InputExample
import os
import json, csv
from abc import ABC, abstractmethod
from collections import defaultdict, Counter
from typing import List, Dict, Callable

from openprompt.utils.logging import logger
from openprompt.data_utils.data_processor import DataProcessor

import torch
from torch.utils.data import IterableDataset
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class UltraChatProcessor(DataProcessor):
    def __init__(self):
        super().__init__()
        self.labels = None

    def get_examples(self, data_path: str) -> List[InputExample]:
        examples = []
        j = 0
        with open(data_path) as f:
            for line in tqdm(f.readlines()):
                if line.strip():
                    data = json.loads(line)
                    id_ = data["id"]
                    dialogue = data["data"]
                    tags = [i for _ in range(len(dialogue)//2) for i in ["User", "Assistant"]]
                    for i in range(0, len(dialogue), 2):
                        tgt_text = dialogue[i+1]
                        context = dialogue[:i+1]
                        context = zip(tags[:i+1], context)
                        context = [": ".join(item) for item in context]
                        example = InputExample(guid=str(j), text_a="", tgt_text=tgt_text, meta={"context": context})
                        examples.append(example)
                        j += 1
        return examples


    def get_src_tgt_len_ratio(self,):
        pass

In [4]:
# # Conditional Generation with Prefix Tuning.
# In this tutorial, we do conditional generation with prefix tuning template.

# we use WebNLG as an example, as well. Note that the evaluation of generation result should be done
# by using the scripts provided by https://github.com/Yale-LILY/dart/tree/master/evaluation,
# Which we do not include in it.

import argparse
import torch
from openprompt import plms
from openprompt.plms import *
from transformers import GPT2Config, GPT2Tokenizer, GPT2LMHeadModel
plms._MODEL_CLASSES["gpt2"]= ModelClass(**{"config": GPT2Config, "tokenizer": GPT2Tokenizer, "model": GPT2LMHeadModel,
"wrapper": LMTokenizerWrapper})
from openprompt.plms import load_plm
from openprompt.prompts import MixedTemplate
from transformers import AdamW
from openprompt import PromptDataLoader
from openprompt import PromptForGeneration
from transformers.optimization import get_linear_schedule_with_warmup
from accelerate import Accelerator
from torchmetrics import MeanMetric
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from accelerate.utils import set_seed




def format_metrics(metrics, split, prefix=""):
    log = f"[{split}]" + prefix
    log += " ".join([f"{key}: {value:.4f}" for key, value in metrics.items()])

    return log

def evaluate(args, model, val_dataloader, accelerator):
    model.eval()
    val_loss = MeanMetric().to(model.device)

    with torch.no_grad():
        for i, batch in enumerate(
            tqdm(val_dataloader),
        ):
                
            loss = model(batch["input_ids"])

            loss_values = accelerator.gather_for_metrics({"loss": loss.detach()})

            val_loss.update(loss_values["loss"])

    return val_loss


def train(args, accelerator):
    set_seed(0)
    accelerator.print(f"Using {accelerator.num_processes} GPUs")

    plm, tokenizer, model_config, WrapperClass = load_plm(args.model, args.model_name_or_path)

    mytemplate = MixedTemplate(model=plm, tokenizer=tokenizer).from_file("./template.txt")

    with accelerator.main_process_first():
        processor = UltraChatProcessor()
        dataset = processor.get_examples(args.data_file)

        train_dataset, val_dataset = train_test_split(dataset, test_size=0.2, random_state=0)

    # wrapped_example = mytemplate.wrap_one_example(dataset[1])
    # print(wrapped_example)

    train_dataloader = PromptDataLoader(dataset=train_dataset, template=mytemplate, tokenizer=tokenizer,
        tokenizer_wrapper_class=WrapperClass, max_seq_length=1024, decoder_max_length=1024,
        batch_size=2,shuffle=True, teacher_forcing=True, predict_eos_token=True, # be sure to pass predict_eos_token=True if your template doesn't contain one, or you model may fail to stop generation.
        truncate_method="head").dataloader

    val_dataloader = PromptDataLoader(dataset=val_dataset, template=mytemplate, tokenizer=tokenizer,
        tokenizer_wrapper_class=WrapperClass, max_seq_length=1024, decoder_max_length=1024,
        batch_size=5,shuffle=False, teacher_forcing=True, predict_eos_token=True, # be sure to pass predict_eos_token=True if your template doesn't contain one, or you model may fail to stop generation.
        truncate_method="head").dataloader


    # load the pipeline model PromptForGeneration.
    prompt_model = PromptForGeneration(plm=plm, template=mytemplate, tokenizer=tokenizer)

    device = accelerator.device
    prompt_model.to(device)


    optimizer = AdamW([p for p in prompt_model.parameters()if p.requires_grad], lr=args.lr, eps=1e-8)

    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=200, num_training_steps=len(train_dataloader)*args.epochs)

    prompt_model, optimizer, train_dataloader, val_dataloader, scheduler = accelerator.prepare(prompt_model, optimizer, train_dataloader, val_dataloader, scheduler)

    accelerator.register_for_checkpointing(scheduler)

    train_loss = MeanMetric().to(prompt_model.device)

    # training and generation.
    global_step = 0
    for epoch in range(args.epochs):
        for step, inputs in tqdm(enumerate(train_dataloader)):
            prompt_model.train()
            loss = prompt_model(inputs["input_ids"])
            accelerator.backward(loss)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

            loss_values = accelerator.gather_for_metrics({"loss": loss.detach()})
            train_loss.update(loss_values["loss"])
            global_step +=1

            
            if global_step %50 ==0:
                accelerator.save_state(f"ultrachat_{args.model}/step_{global_step}")

                val_loss = evaluate(args, prompt_model, val_dataloader, accelerator)

                log_train = {
                        "train_loss": train_loss.compute()
                    }
                log_val = {
                    "val_loss": val_loss.compute()
                }

                accelerator.print(f"Current LR: {scheduler.get_last_lr()[0]}")
                accelerator.print(format_metrics(log_train, "train", f" step {global_step} "))
                accelerator.print(format_metrics(log_val, "val", f" step {global_step} "))

                train_loss.reset()

    accelerator.wait_for_everyone()
    if accelerator.is_main_process:
        torch.save(accelerator.get_state_dict(prompt_model), f"ultrachat_{args.model}/final")

In [6]:
parser = argparse.ArgumentParser("")
parser.add_argument("--lr", type=float, default=5e-5)
parser.add_argument("--model", type=str, default='gpt2')
parser.add_argument("--model_name_or_path", default='openai-community/gpt2')
parser.add_argument("--epochs", default=5, type=int)
parser.add_argument("--data_file", default="./data.json", type=str)
args = parser.parse_args(args=[])
# print(args)

accelerator = Accelerator()

train(args, accelerator)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Using 1 GPUs


100%|██████████| 1/1 [00:00<00:00, 8943.08it/s]
tokenizing: 2it [00:00, 167.22it/s]
tokenizing: 1it [00:00, 143.32it/s]
1it [00:00,  1.64it/s]
1it [00:00,  1.80it/s]
1it [00:00,  1.77it/s]
1it [00:00,  1.81it/s]
1it [00:00,  1.77it/s]
