In [1]:
!nvidia-smi

Sat Dec 10 22:06:40 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   51C    P0    26W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
!git clone https://github.com/RongTouchTouch/GODEL

Cloning into 'GODEL'...
remote: Enumerating objects: 264, done.[K
remote: Counting objects: 100% (67/67), done.[K
remote: Compressing objects: 100% (36/36), done.[K
remote: Total 264 (delta 42), reused 42 (delta 31), pack-reused 197[K
Receiving objects: 100% (264/264), 51.14 MiB | 19.96 MiB/s, done.
Resolving deltas: 100% (84/84), done.


In [3]:
!pip install datasets
!pip install transformers
!pip install accelerate
!pip install fire
!pip install jsonlines
!pip install rouge_score
# !pip install -r GODEL/requirements.txt

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting datasets
  Downloading datasets-2.7.1-py3-none-any.whl (451 kB)
[K     |████████████████████████████████| 451 kB 15.3 MB/s 
Collecting xxhash
  Downloading xxhash-3.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212 kB)
[K     |████████████████████████████████| 212 kB 29.2 MB/s 
Collecting responses<0.19
  Downloading responses-0.18.0-py3-none-any.whl (38 kB)
Collecting huggingface-hub<1.0.0,>=0.2.0
  Downloading huggingface_hub-0.11.1-py3-none-any.whl (182 kB)
[K     |████████████████████████████████| 182 kB 67.4 MB/s 
Collecting multiprocess
  Downloading multiprocess-0.70.14-py38-none-any.whl (132 kB)
[K     |████████████████████████████████| 132 kB 59.5 MB/s 
Collecting urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1
  Downloading urllib3-1.25.11-py2.py3-none-any.whl (127 kB)
[K     |████████████████████████████████| 127 kB 62.1 MB/s 
Installing collected packag

In [4]:
%cd GODEL/GODEL

/content/GODEL/GODEL


In [59]:
import argparse
import copy
import fire
import logging
import math
import os
import random
import json
import jsonlines

import datasets
import nltk
import numpy as np
import torch
from datasets import load_dataset, load_metric
from torch.utils.data.dataloader import DataLoader

import transformers
from accelerate import Accelerator
from filelock import FileLock
from transformers import (
    CONFIG_MAPPING,
    MODEL_MAPPING,
    AdamW,
    AutoConfig,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    SchedulerType,
    set_seed,
)
from transformers.file_utils import is_offline_mode

from utils.text_normalization import normalize_answer

In [60]:
# Config
logging.getLogger().setLevel(logging.DEBUG)
logger = logging.getLogger(__name__)

MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)

try:
    nltk.data.find("tokenizers/punkt")
except (LookupError, OSError):
    if is_offline_mode():
        raise LookupError(
            "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
        )
    with FileLock(".lock") as lock:
        nltk.download("punkt", quiet=True)

In [86]:
def parse_args():
    parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task")
    parser.add_argument(
        "--dataset_name",
        type=str,
        default=None,
        help="The name of the dataset to use (via the datasets library).",
    )
    parser.add_argument(
        "--dataset_config_name",
        type=str,
        default=None,
        help="The configuration name of the dataset to use (via the datasets library).",
    )
    parser.add_argument(
        "--train_file", type=str, default=None, help="A csv or a json file containing the training data."
    )
    parser.add_argument(
        "--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
    )
    parser.add_argument(
        "--max_source_length",
        type=int,
        default=1024,
        help="The maximum total input sequence length after "
        "tokenization.Sequences longer than this will be truncated, sequences shorter will be padded.",
    )
    parser.add_argument(
        "--source_prefix",
        type=str,
        default=None,
        help="A prefix to add before every source text " "(useful for T5 models).",
    )
    parser.add_argument(
        "--preprocessing_num_workers",
        type=int,
        default=None,
        help="The number of processes to use for the preprocessing.",
    )
    parser.add_argument(
        "--max_target_length",
        type=int,
        default=64,
        help="The maximum total sequence length for target text after "
        "tokenization. Sequences longer than this will be truncated, sequences shorter will be padded."
        "during ``evaluate`` and ``predict``.",
    )
    parser.add_argument(
        "--val_max_target_length",
        type=int,
        default=None,
        help="The maximum total sequence length for validation "
        "target text after tokenization.Sequences longer than this will be truncated, sequences shorter will be "
        "padded. Will default to `max_target_length`.This argument is also used to override the ``max_length`` "
        "param of ``model.generate``, which is used during ``evaluate`` and ``predict``.",
    )
    parser.add_argument(
        "--num_beams",
        type=int,
        default=None,
        help="Number of beams to use for evaluation. This argument will be "
        "passed to ``model.generate``, which is used during ``evaluate`` and ``predict``.",
    )
    parser.add_argument(
        "--model_name_or_path",
        type=str,
        help="Path to pretrained model or model identifier from huggingface.co/models.",
        required=False,
    )
    parser.add_argument(
        "--config_name",
        type=str,
        default=None,
        help="Pretrained config name or path if not the same as model_name",
    )
    parser.add_argument(
        "--tokenizer_name",
        type=str,
        default=None,
        help="Pretrained tokenizer name or path if not the same as model_name",
    )
    parser.add_argument(
        "--text_column",
        type=str,
        default=None,
        help="The name of the column in the datasets containing the full texts (for summarization).",
    )
    parser.add_argument(
        "--summary_column",
        type=str,
        default=None,
        help="The name of the column in the datasets containing the summaries (for summarization).",
    )
    parser.add_argument(
        "--use_slow_tokenizer",
        action="store_true",
        help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
    )
    parser.add_argument(
        "--per_device_train_batch_size",
        type=int,
        default=8,
        help="Batch size (per device) for the training dataloader.",
    )
    parser.add_argument(
        "--per_device_eval_batch_size",
        type=int,
        default=8,
        help="Batch size (per device) for the evaluation dataloader.",
    )
    parser.add_argument(
        "--learning_rate",
        type=float,
        default=5e-5,
        help="Initial learning rate (after the potential warmup period) to use.",
    )
    parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
    parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
    parser.add_argument(
        "--max_train_steps",
        type=int,
        default=None,
        help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )
    parser.add_argument(
        "--lr_scheduler_type",
        type=SchedulerType,
        default="linear",
        help="The scheduler type to use.",
        choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
    )
    parser.add_argument(
        "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
    )
    parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
    parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
    parser.add_argument(
        "--model_type",
        type=str,
        default=None,
        help="Model type to use if training from scratch.",
        choices=MODEL_TYPES,
    )
    parser.add_argument(
        "--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets"
    )
    parser.add_argument(
        "--max_length", type=int, default=128, help="max length"
    )
    parser.add_argument(
        "--pad_to_max_length", type=bool, default=True, help="do pading"
    )
    parser.add_argument(
        "--ignore_pad_token_for_loss", type=bool, default=True, help="do pading"
    )
    parser.add_argument(
        "--logging_steps", type=int, default=500, help="do pading"
    )
    parser.add_argument(
        "--save_steps", type=int, default=5000, help="do pading"
    )
    parser.add_argument(
        "--save_every_checkpoint", action="store_true"
    )
    parser.add_argument(
        "--max_grad_norm", type=float, default=1.0, help="max_grad_norm"
    )
    parser.add_argument(
        "--no_kb", action="store_true"
    )
    parser.add_argument(
        "--exp_name",
        type=str,
        help="Description to the experiment",
        default='exp',
    )

    # Newly added params
    args, unknown = parser.parse_known_args()
    args.model_name_or_path = 'microsoft/GODEL-v1_1-base-seq2seq' 
    args.train_file = 'perchat/perchat_single_train.jsonl'
    args.validation_file = 'perchat/perchat_single_valid.jsonl'
    args.dataset_config_name = 'perchat_dataset.py'
    args.output_dir = 'output'

    args.per_device_train_batch_size = 16
    args.per_device_eval_batch_size = 16
    args.max_target_length = 128
    args.max_length = 128
    args.preprocessing_num_workers = 24
    args.num_beams = 5

    # Sanity checks
    if args.dataset_name is None and args.train_file is None and args.validation_file is None:
        raise ValueError("Need either a dataset name or a training/validation file.")
    else:
        if args.train_file is not None:
            extension = args.train_file.split(".")[-1]
            assert extension in ["csv", "json", "jsonl"], "`train_file` should be a csv or a json file."
        if args.validation_file is not None:
            extension = args.validation_file.split(".")[-1]
            assert extension in ["csv", "json", "jsonl"], "`validation_file` should be a csv or a json file."

    if args.output_dir is not None:
        os.makedirs(args.output_dir, exist_ok=True)

    return args

In [87]:
# Decoding for DialoGLM
args = parse_args()
args

Namespace(config_name=None, dataset_config_name='perchat_dataset.py', dataset_name=None, exp_name='exp', gradient_accumulation_steps=1, ignore_pad_token_for_loss=True, learning_rate=5e-05, logging_steps=500, lr_scheduler_type=<SchedulerType.LINEAR: 'linear'>, max_grad_norm=1.0, max_length=128, max_source_length=1024, max_target_length=128, max_train_steps=None, model_name_or_path='microsoft/GODEL-v1_1-base-seq2seq', model_type=None, no_kb=False, num_beams=5, num_train_epochs=3, num_warmup_steps=0, output_dir='output', overwrite_cache=False, pad_to_max_length=True, per_device_eval_batch_size=16, per_device_train_batch_size=16, preprocessing_num_workers=24, save_every_checkpoint=False, save_steps=5000, seed=None, source_prefix=None, summary_column=None, text_column=None, tokenizer_name=None, train_file='perchat/perchat_single_train.jsonl', use_slow_tokenizer=False, val_max_target_length=None, validation_file='perchat/perchat_single_valid.jsonl', weight_decay=0.0)

In [88]:
# Not necessary for generation
accelerator = Accelerator()
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger.info(accelerator.state)

logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger.info(accelerator.state)

INFO:__main__:Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda
Mixed precision type: no

INFO:__main__:Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda
Mixed precision type: no



In [89]:
if args.seed is not None:
    set_seed(args.seed)

In [90]:
if args.dataset_name is not None:
    raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name)
else:
  data_files = {}
  if args.train_file is not None:
      data_files["train"] = args.train_file
  if args.validation_file is not None:
      data_files["validation"] = args.validation_file
  extension = args.dataset_config_name
  raw_datasets = load_dataset(extension, data_files=data_files)

DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/dataset_info.json
DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/dataset_info.json


  0%|          | 0/3 [00:00<?, ?it/s]

In [91]:
# Load pretrained model and tokenizer
if args.config_name:
    config = AutoConfig.from_pretrained(args.config_name)
elif args.model_name_or_path:
    config = AutoConfig.from_pretrained(args.model_name_or_path)
else:
    config = CONFIG_MAPPING[args.model_type]()
    logger.warning("You are instantiating a new config instance from scratch.")

if args.tokenizer_name:
    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=not args.use_slow_tokenizer)
elif args.model_name_or_path:
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer)
else:
    raise ValueError(
        "You are instantiating a new tokenizer from scratch. This is not supported by this script."
        "You can do it from another script, save it, and load it from here, using --tokenizer_name."
    )

if args.model_name_or_path:
    model = AutoModelForSeq2SeqLM.from_pretrained(
        args.model_name_or_path,
        from_tf=bool(".ckpt" in args.model_name_or_path),
        config=config,
    )
else:
    logger.info("Training new model from scratch")
    model = AutoModelForSeq2SeqLM.from_config(config)

tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model.resize_token_embeddings(len(tokenizer))
if model.config.decoder_start_token_id is None:
    raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")

DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /microsoft/GODEL-v1_1-base-seq2seq/resolve/main/config.json HTTP/1.1" 200 0
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /microsoft/GODEL-v1_1-base-seq2seq/resolve/main/tokenizer_config.json HTTP/1.1" 200 0


In [92]:
prefix = 'Instruction: try to response in the given persona.'
padding = "max_length" if args.pad_to_max_length else False
max_target_length = args.max_target_length

def preprocess_function(examples):
    contextes = examples['Context']
    responses = examples['Response']
    kbs = examples['Knowledge']

    inputs = []
    for context, response, kb in zip(contextes, responses, kbs):
        if args.no_kb:
            inputs.append(prefix + ' ' + ' ' + context + ' => ')
        else:
            inputs.append(context + ' <|Knowledge|> ' + kb + ' => ')
            
    model_inputs = tokenizer(inputs, max_length=args.max_length, padding=padding, truncation=True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(responses, max_length=max_target_length, padding=padding, truncation=True)

    # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
    # padding in the loss.
    if padding == "max_length" and args.ignore_pad_token_for_loss:
        labels["labels"] = [
            [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
        ]

    model_inputs["labels"] = labels["labels"]
    
    return model_inputs

In [93]:
# Processing dataset:
column_names = ['Context','Knowledge','Response']
lm_datasets = raw_datasets.map(
    preprocess_function,
    batched=True,
    remove_columns=column_names,
    num_proc=args.preprocessing_num_workers,
    load_from_cache_file=False,
    desc=f"Processing dataset",
)

train_dataset = lm_datasets["train"]
eval_dataset = lm_datasets["validation"]
test_dataset = lm_datasets["test"]

                          

Processing dataset #0:   0%|          | 0/1 [00:00<?, ?ba/s]

  

Processing dataset #1:   0%|          | 0/1 [00:00<?, ?ba/s]

 

Processing dataset #2:   0%|          | 0/1 [00:00<?, ?ba/s]

Processing dataset #3:   0%|          | 0/1 [00:00<?, ?ba/s]

 



  



Processing dataset #4:   0%|          | 0/1 [00:00<?, ?ba/s]

Processing dataset #5:   0%|          | 0/1 [00:00<?, ?ba/s]

 

DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmpt_tmsu3s


Processing dataset #6:   0%|          | 0/1 [00:00<?, ?ba/s]

Processing dataset #7:   0%|          | 0/1 [00:00<?, ?ba/s]

  



Processing dataset #8:   0%|          | 0/1 [00:00<?, ?ba/s]

Processing dataset #9:   0%|          | 0/1 [00:00<?, ?ba/s]

 

DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmp7ao7t1tx


Processing dataset #10:   0%|          | 0/1 [00:00<?, ?ba/s]

 



 

Processing dataset #11:   0%|          | 0/1 [00:00<?, ?ba/s]

 

DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmp173djfuj


Processing dataset #12:   0%|          | 0/1 [00:00<?, ?ba/s]

 

Processing dataset #13:   0%|          | 0/1 [00:00<?, ?ba/s]



 



 

DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmp120mwf9_
DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmpgp2zi58c


Processing dataset #15:   0%|          | 0/1 [00:00<?, ?ba/s]

Processing dataset #14:   0%|          | 0/1 [00:00<?, ?ba/s]

 



  

DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmpms7p2_m8


Processing dataset #19:   0%|          | 0/1 [00:00<?, ?ba/s]

  

Processing dataset #16:   0%|          | 0/1 [00:00<?, ?ba/s]

DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmp8_wikkmt
DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmpch_h9vx4


 

Processing dataset #17:   0%|          | 0/1 [00:00<?, ?ba/s]

Processing dataset #21:   0%|          | 0/1 [00:00<?, ?ba/s]

DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmpa0rtacpv


Processing dataset #22:   0%|          | 0/1 [00:00<?, ?ba/s]

Processing dataset #18:   0%|          | 0/1 [00:00<?, ?ba/s]

Processing dataset #20:   0%|          | 0/1 [00:00<?, ?ba/s]

DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmptnecze5b


Processing dataset #23:   0%|          | 0/1 [00:00<?, ?ba/s]

DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmp1ujfihgi
DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmp4h20gi9i
DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmp6zdy9d5v
DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmp3a59lmx7
DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmpqkfembf3
DEBUG:fsspec.local:open file: /root/.cache/huggingface/datas

                            

Processing dataset #0:   0%|          | 0/1 [00:00<?, ?ba/s]

Processing dataset #2:   0%|          | 0/1 [00:00<?, ?ba/s]

DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmpt5v2viyy
DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmpdjii6bkn


Processing dataset #1:   0%|          | 0/1 [00:00<?, ?ba/s]



Processing dataset #3:   0%|          | 0/1 [00:00<?, ?ba/s]

 

DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmpwjw540ub


 

DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmpkqox4zky


  

Processing dataset #4:   0%|          | 0/1 [00:00<?, ?ba/s]

Processing dataset #6:   0%|          | 0/1 [00:00<?, ?ba/s]

Processing dataset #5:   0%|          | 0/1 [00:00<?, ?ba/s]

 



 

DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmpso0viiwj


 

DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmpqcono3c6


Processing dataset #7:   0%|          | 0/1 [00:00<?, ?ba/s]

DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmp47gamoi7


 



Processing dataset #8:   0%|          | 0/1 [00:00<?, ?ba/s]

Processing dataset #9:   0%|          | 0/1 [00:00<?, ?ba/s]

 

DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmpyotw47r2


Processing dataset #10:   0%|          | 0/1 [00:00<?, ?ba/s]



 

Processing dataset #11:   0%|          | 0/1 [00:00<?, ?ba/s]

DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmpx4a5lhbt
DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmphw0dlnu8


 

Processing dataset #12:   0%|          | 0/1 [00:00<?, ?ba/s]

DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmpl_d5jm9y
DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmpemdwgh_u


Processing dataset #13:   0%|          | 0/1 [00:00<?, ?ba/s]

Processing dataset #14:   0%|          | 0/1 [00:00<?, ?ba/s]

 

DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmpb51mlo8e
DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmpugudavj3
DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmpa1flu5l4


Processing dataset #15:   0%|          | 0/1 [00:00<?, ?ba/s]

  



 

Processing dataset #16:   0%|          | 0/1 [00:00<?, ?ba/s]

DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmp34kuttas


 

Processing dataset #17:   0%|          | 0/1 [00:00<?, ?ba/s]

DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmplc6gp61v


Processing dataset #18:   0%|          | 0/1 [00:00<?, ?ba/s]

 



Processing dataset #19:   0%|          | 0/1 [00:00<?, ?ba/s]

 

DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmpuqir98mv


 

DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmp5rthft4i
DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmpbaaju8da


 

Processing dataset #20:   0%|          | 0/1 [00:00<?, ?ba/s]

Processing dataset #22:   0%|          | 0/1 [00:00<?, ?ba/s]



Processing dataset #21:   0%|          | 0/1 [00:00<?, ?ba/s]

DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmpir5kyi_b
DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmpk5ldlkj2


Processing dataset #23:   0%|          | 0/1 [00:00<?, ?ba/s]

DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmp3bc_j9wb
DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmpk09e8oxx


                          

Processing dataset #0:   0%|          | 0/1 [00:00<?, ?ba/s]



 

Processing dataset #1:   0%|          | 0/1 [00:00<?, ?ba/s]

DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmpagd4xmg3


 

DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmpv9dqlnc6


 

Processing dataset #2:   0%|          | 0/1 [00:00<?, ?ba/s]

  



Processing dataset #3:   0%|          | 0/1 [00:00<?, ?ba/s]

Processing dataset #4:   0%|          | 0/1 [00:00<?, ?ba/s]

DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmp866a7fe1


Processing dataset #6:   0%|          | 0/1 [00:00<?, ?ba/s]

 



Processing dataset #5:   0%|          | 0/1 [00:00<?, ?ba/s]



 

DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmp8a3_bf9w
DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmpjudl_7p2
DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmpyw0frrow


 

DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmp308zurxa


 

Processing dataset #7:   0%|          | 0/1 [00:00<?, ?ba/s]



Processing dataset #8:   0%|          | 0/1 [00:00<?, ?ba/s]

 

DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmpxi1oip52


Processing dataset #9:   0%|          | 0/1 [00:00<?, ?ba/s]

 



 

DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmpptuvcpi5


 

DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmp9694mr60


Processing dataset #10:   0%|          | 0/1 [00:00<?, ?ba/s]

 

Processing dataset #11:   0%|          | 0/1 [00:00<?, ?ba/s]

Processing dataset #12:   0%|          | 0/1 [00:00<?, ?ba/s]



Processing dataset #14:   0%|          | 0/1 [00:00<?, ?ba/s]

 

Processing dataset #13:   0%|          | 0/1 [00:00<?, ?ba/s]

DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmp255bve8i
DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmpcj8umr7e
DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmp2l38vm8q


 

Processing dataset #15:   0%|          | 0/1 [00:00<?, ?ba/s]

DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmpvf53shnj
DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmptt0wsy4o


 

Processing dataset #16:   0%|          | 0/1 [00:00<?, ?ba/s]

DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmpysfbh40n


 



Processing dataset #17:   0%|          | 0/1 [00:00<?, ?ba/s]

 

DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmpfvoiy4t5


 

Processing dataset #18:   0%|          | 0/1 [00:00<?, ?ba/s]

DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmpdn43pocn


Processing dataset #19:   0%|          | 0/1 [00:00<?, ?ba/s]

 

DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmpza7e5mdm


Processing dataset #20:   0%|          | 0/1 [00:00<?, ?ba/s]



 

Processing dataset #21:   0%|          | 0/1 [00:00<?, ?ba/s]

DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmp8yuu0cfb
DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmp368xiifl
DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmpnnmr5f_i


Processing dataset #22:   0%|          | 0/1 [00:00<?, ?ba/s]

Processing dataset #23:   0%|          | 0/1 [00:00<?, ?ba/s]

DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmp5mf_7iqi
DEBUG:fsspec.local:open file: /root/.cache/huggingface/datasets/perchat_dataset/default-528ba314c4ce0418/0.0.0/5515867d901e04e3189dc9ce666bc19466480680313d35f80c4535cf517d1263/tmpt48djbq9


In [95]:
# Log a few random samples from the training set:
for index in random.sample(range(len(train_dataset)), 1):
    logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")

label_pad_token_id = -100 if args.ignore_pad_token_for_loss else tokenizer.pad_token_id
data_collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    label_pad_token_id=label_pad_token_id,
    pad_to_multiple_of=8 if accelerator.use_fp16 else None,
)

INFO:__main__:Sample 2233 of the training set: {'input_ids': [116, 25, 43, 3, 9206, 676, 13, 339, 97, 3, 6, 149, 103, 25, 1903, 8, 97, 3, 58, 3, 2, 9175, 439, 7651, 13553, 9175, 3155, 3, 88, 5682, 6202, 11, 7562, 11, 1525, 5, 112, 12978, 560, 11462, 747, 13, 128, 28332, 10802, 11, 1144, 11, 3, 7437, 102, 30, 8, 16092, 5, 112, 4026, 19, 7562, 11, 6261, 5, 3, 15425, 1, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102, 32102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1

In [96]:
def postprocess_text(preds, labels):

    preds = [normalize_answer(pred.strip().replace('Agent :','')) for pred in preds]
    labels = [normalize_answer(label.strip().replace('Agent :','')) for label in labels]

    return preds, labels

train_dataloader = DataLoader(
    train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size
)
eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size)
test_dataloader = DataLoader(test_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size)

In [97]:
# Optimizer
# Split weights in two groups, one with weight decay and the other not.
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": args.weight_decay,
    },
    {
        "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)

# Prepare everything with our `accelerator`.
model, optimizer, train_dataloader, eval_dataloader, test_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader, test_dataloader
)



In [99]:
model.eval()
if args.val_max_target_length is None:
    args.val_max_target_length = args.max_target_length

gen_kwargs = {
    "max_length": args.val_max_target_length if args is not None else config.max_length,
    "num_beams": args.num_beams,
}

In [128]:
def evaluate_data(dataloader, eval_name='valid'):

    metric = load_metric("./utils/rouge_metric.py")
    metric_bleu = load_metric("./utils/bleu_metric.py")

    decoded_preds_extended = []
    for step, batch in enumerate(dataloader):
        with torch.no_grad():
            generated_tokens = accelerator.unwrap_model(model).generate(
                batch["input_ids"],
                attention_mask=batch["attention_mask"],
                max_length=64, 
                min_length=8, 
                # top_p=0.9, 
                do_sample=True
            )

            generated_tokens = accelerator.pad_across_processes(
                generated_tokens, dim=1, pad_index=tokenizer.pad_token_id
            )
            labels = batch["labels"]
            if not args.pad_to_max_length:
                # If we did not pad to max length, we need to pad the labels too
                labels = accelerator.pad_across_processes(batch["labels"], dim=1, pad_index=tokenizer.pad_token_id)

            generated_tokens = accelerator.gather(generated_tokens).cpu().numpy()
            labels = accelerator.gather(labels).cpu().numpy()

            if args.ignore_pad_token_for_loss:
                # Replace -100 in the labels as we can't decode them.
                labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
            if isinstance(generated_tokens, tuple):
                generated_tokens = generated_tokens[0]
            decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
            decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

            decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
            metric.add_batch(predictions=decoded_preds, references=decoded_labels)
            _decoded_preds = [i.split() for i in decoded_preds]
            _decoded_labels = [[i.split()] for i in decoded_labels]
            decoded_preds_extended.extend(_decoded_preds)
            metric_bleu.add_batch(predictions=_decoded_preds, references=_decoded_labels)
            
                
    result = metric.compute(use_stemmer=True)
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    result = {k: round(v, 4) for k, v in result.items()}
    logger.info(result)

    result_bleu = metric_bleu.compute()
    logger.info(result_bleu)

    if args.output_dir is not None:
        accelerator.wait_for_everyone()
        if accelerator.is_local_main_process:               
            if not os.path.exists(args.output_dir):
                os.makedirs(args.output_dir)
            output_dir_file_name = os.path.join(args.output_dir, f'{eval_name}-results.json')
            json.dump(decoded_preds_extended, open(output_dir_file_name,'w'), indent=2)
            logger.info("Saving model outputs to %s", output_dir_file_name)

In [None]:
# Decoding for the valid set
evaluate_data(eval_dataloader, 'valid')
# Decoding for the test set
evaluate_data(test_dataloader, 'test')

In [130]:
iterator = iter(eval_dataloader)
batch = next(iterator)

In [117]:
batch = next(iterator)

In [134]:
with torch.no_grad():
    generated_tokens = accelerator.unwrap_model(model).generate(
        batch["input_ids"],
        attention_mask=batch["attention_mask"],
        max_length=64, 
        min_length=8, 
        do_sample=True
    )

    labels = batch["labels"]
    if not args.pad_to_max_length:
        # If we did not pad to max length, we need to pad the labels too
        labels = accelerator.pad_across_processes(batch["labels"], dim=1, pad_index=tokenizer.pad_token_id)

    generated_tokens = accelerator.gather(generated_tokens).cpu().numpy()
    labels = accelerator.gather(labels).cpu().numpy()

    if args.ignore_pad_token_for_loss:
        # Replace -100 in the labels as we can't decode them.
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    if isinstance(generated_tokens, tuple):
        generated_tokens = generated_tokens[0]
    
    decoded_inputs = tokenizer.batch_decode(batch["input_ids"], skip_special_tokens=True)
    decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    print(decoded_preds)

    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
    _decoded_preds = [i.split() for i in decoded_preds]
    _decoded_labels = [[i.split()] for i in decoded_labels]


['i love to say the word "mom" the best. i think i like to use the slang you get when you say "mom" because that\'s what your definition of the word will be.', "Definitely the second one and he's in his midlife.", "He won't get hurt. He'll probably take the game so bad that everyone in office doesn't look at him and think there's no more news.", 'he is a witty fucking guy who knows as much as him as anyone else what l do and he is one of the few to get through this to get through.', 'not a fan of him just like it was when he had the chance.', 'I agree with you on this. I did n’t know your dad was being this naive despite being a decent friend - I thought it was great he would get. The way of your family is very attractive and important.', "I'm not out of the loop. I'm mostly just happy. The fact he loves the internet is something that may help.", 'Ryan is pretty cool too. Do you know his father?', 'What day is the big bang on? You mean the big bang was on January 24th?', "i love the fa

In [135]:
decoded_inputs

['what is a slang you love to say? |Knowledge|> he loves mistake and depression and answer. his attributes include punchline of some cosmic joke and member and blip on the radar. his lifestyle is depression and anxiety. =>',
 'where do you mostly feel out of place? |Knowledge|> he loves mistake and depression and answer. his attributes include punchline of some cosmic joke and member and blip on the radar. his lifestyle is depression and anxiety. =>',
 'does vegas have odds on what the president does tomorrow? |Knowledge|> he loves mistake and depression and answer. his attributes include punchline of some cosmic joke and member and blip on the radar. his lifestyle is depression and anxiety. =>',
 "those who drive with hands at the digit digit o'clock positions. why is this? |Knowledge|> he loves mistake and depression and answer. his attributes include punchline of some cosmic joke and member and blip on the radar. his lifestyle is depression and anxiety. =>",
 'why are you going to u

In [136]:
decoded_preds

['i love to say word mom best i think i like to use slang you get when you say mom because that s what your definition of word will be',
 'definitely second one and he s in his midlife',
 'he won t get hurt he ll probably take game so bad that everyone in office doesn t look at him and think there s no more news',
 'he is witty fucking guy who knows as much as him as anyone else what l do and he is one of few to get through this to get through',
 'not fan of him just like it was when he had chance',
 'i agree with you on this i did n’t know your dad was being this naive despite being decent friend i thought it was great he would get way of your family is very attractive and important',
 'i m not out of loop i m mostly just happy fact he loves internet is something that may help',
 'ryan is pretty cool too do you know his father',
 'what day is big bang on you mean big bang was on january 24th',
 'i love fact that i have met guy with anxiety and depression that s and very good of his pe

In [133]:
decoded_labels

['when something breaks you say it shit bed',
 'pretty much everywhere now even my own skin things have just been not good',
 'vegas has tons of prop bets shit you can probably bet whether sun will rise tomorrow',
 'c mon now digit and digit is what was drilled into most people s heads',
 'swing and miss',
 'i believe children are our future and that terrifies me',
 'pretty much everywhere now even my own skin things have just been not good',
 'both best and worst things that have happened to me',
 'some 15th century priest had it pegged for thursday will try and find info again',
 'it will be last season as we won t have to put up with that shit anymore',
 'someone else s wedding',
 'weird shit like kryptos sculpture and fourth panel that hasn t been deciphered yet purpose of georgia guidestones whether there s gold in lost dutchman mine',
 'unchecked smiting spree',
 'lack of qualifications only tattoo that s knock on you would be one on your face anything else we can deal with',
 'd

In [137]:
for context,label,pred in zip(decoded_inputs, decoded_labels, decoded_preds):
  print('-'*80)
  print("context:", context)
  print("pred:", pred)
  print("label:", label)
  print('-'*80)

--------------------------------------------------------------------------------
context: what is a slang you love to say? |Knowledge|> he loves mistake and depression and answer. his attributes include punchline of some cosmic joke and member and blip on the radar. his lifestyle is depression and anxiety. =>
pred: i love to say word mom best i think i like to use slang you get when you say mom because that s what your definition of word will be
label: when something breaks you say it shit bed
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
context: where do you mostly feel out of place? |Knowledge|> he loves mistake and depression and answer. his attributes include punchline of some cosmic joke and member and blip on the radar. his lifestyle is depression and anxiety. =>
pred: definitely second one and he s in his midlife
label: pretty much everywhere now even my own skin 

In [None]:
decoded_labels

In [79]:
from transformers import AutoTokenizer,AutoModel

tokenizer = AutoTokenizer.from_pretrained("microsoft/GODEL-v1_1-base-seq2seq")
model = AutoModelForSeq2SeqLM.from_pretrained("microsoft/GODEL-v1_1-base-seq2seq")

def generate(instruction, knowledge, dialog):
    if knowledge != '':
        knowledge = '[KNOWLEDGE] ' + knowledge
    dialog = ' EOS '.join(dialog)
    query = f"{instruction} [CONTEXT] {dialog} {knowledge}"
    input_ids = tokenizer(f"{query}", return_tensors="pt").input_ids
    outputs = model.generate(input_ids, max_length=128, min_length=8, top_p=0.9, do_sample=True)
    output = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return output

# Instruction for a chitchat task
instruction = f'Instruction: try to response in the given persona.'
# Leave the knowldge empty
knowledge = 'he loves mistake and depression and answer. \
           his attributes include punchline of some cosmic joke and member and blip on the radar. \
           his lifestyle is depression and anxiety.'
dialog = [
    'what is a slang you love to say ?'
]
response = generate(instruction, knowledge, dialog)
response

DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /microsoft/GODEL-v1_1-base-seq2seq/resolve/main/tokenizer_config.json HTTP/1.1" 200 0
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /microsoft/GODEL-v1_1-base-seq2seq/resolve/main/config.json HTTP/1.1" 200 0


'i like to say something i just blunder. my wife has been with my wife for 2 years and I have always listened to a lot of weird and unusual slang.'