In [1]:
import transformers
import jsonlines
from tqdm.auto import tqdm

transformers.logging.set_verbosity_error()



In [2]:
from transformers import pipeline
import torch

In [9]:
from collections import defaultdict
import re
import os

device = "cuda:0" if torch.cuda.is_available() else "cpu"

def clean_text(x):
    y = re.sub(r"[^a-zA-Z\"'\!\?\s\.\,\&\(\)\-]", "", x.strip())
    # print(x,"->",y)
    return y

def score_accuracy(items, preds, use_prompt_label=True, true_label=None):
    total = defaultdict(lambda: 0)
    correct = defaultdict(lambda: 0)

    for item, p in zip(items, preds):
        label = item['prompt'].split(" ")[0]
        p = p.split(" ")[0]
        total[label] += 1
        # print("-", label, "-", p, "-")
        if use_prompt_label:
            if label == p:
                correct[label] += 1
        elif p == true_label:
            correct[label] += 1

    for k in sorted(total.keys()):
        print(k, correct[k] / total[k] ,f"{correct[k]} of {total[k]}")

def run(filename, prompts, generator, classifier, label2, length=32, eval=False):
  os.makedirs(os.path.dirname(filename), exist_ok=True)
  
  generator = pipeline("text-generation", generator, device=device)
  generator.tokenizer.pad_token_id = 50256
  classifier = pipeline("text-classification", classifier, device=device)
  items, preds = [], []

  with jsonlines.open(filename, "w") as fout:
    for p in tqdm(prompts, position=0):
      output = generator(
         p, 
         do_sample=True, 
         max_new_tokens=length,
         no_repeat_ngram_size=3,
         )[0]['generated_text']
      output = output[len(p):]
      item = {
          "prompt": p,
          "generation": output
          }
      fout.write(item)

      if eval:
        cls = classifier(clean_text(output))[0]['label']
        print("predicted:", label2[cls], "generation:", output)
        items.append(item)
        preds.append(label2[cls])

  if eval:
    score_accuracy(items, preds)

def eval_file(filename, classifier, label2):
  classifier = pipeline("text-classification", classifier, device=device)
  preds = []
  with jsonlines.open(filename) as f:
    items = list(f)

  for item in tqdm(items, position=0):
      cls = classifier(clean_text(item['generation']))[0]['label']
      preds.append(label2[cls])
  score_accuracy(items, preds)

In [11]:
OUTPUT_DIR="data/v4/gpt2-small/"

## yelp-sentiment

In [None]:
count = 1000
length=128
prompts = ["negative "] * count + ["positive "] * count
run(
    f"{OUTPUT_DIR}/yelp-sentiment-{count}-L{length}.jsonl",
    prompts,
    "heegyu/gpt2-yelp-polarity",
    "VictorSanh/roberta-base-finetuned-yelp-polarity",
    {"LABEL_0": "negative", "LABEL_1": "positive"},
    length=length
)

## imdb-sentiment

In [None]:
count = 10
prompts = ["negative "] * count + ["positive "] * count
run(
    f"sentiment-{count}.jsonl",
    prompts,
    "heegyu/gpt2-sentiment",
    "wrmurray/roberta-base-finetuned-imdb",
    {"LABEL_0": "negative", "LABEL_1": "positive"}
)

## emotion

In [10]:
count = 10
length = 32
# emotions = ["sadness im so sad. ", "joy im so happy. ", "love i feel romantic. ", "anger im furious. ", "fear im so scared. ", "surprise im so surprised. "]
emotions = ["sadness", "joy", "love", "anger", "fear", "surprise"]
emotions = [f"topic: {k}\n" for k in emotions]
prompts = [i for e in emotions for i in [e] * count]
labels = {f"LABEL_{i}": emotions[i] for i in range(len(emotions))}
run(
    f"{OUTPUT_DIR}emotion-{count}/generation-{length}.jsonl",
    prompts,
    # "heegyu/gpt2-emotion-balanced-1k",
    "heegyu/gpt2-emotion",
    "Aron/distilbert-base-uncased-finetuned-emotion",
    labels,
    length
)
# eval_file(
#     f"emotion-{count}-16.jsonl",
#     "Aron/distilbert-base-uncased-finetuned-emotion",
#     labels
# )

100%|██████████| 60/60 [00:54<00:00,  1.10it/s]


# BBC-news

In [12]:
count = 10
BBC_NEWS = [
    "tech", "business", "sport", "entertainment", "politics"
]
prompts = [f"topic: {k}\n" for k in BBC_NEWS for _ in range(count)]
run(
    f"{OUTPUT_DIR}bbc-news-{count}/generation-{length}.jsonl",
    prompts,
    "heegyu/gpt2-bbc-news",
    "abhishek/autonlp-bbc-news-classification-37229289",
    {k:k for k in BBC_NEWS},
    length=400
)

Downloading: 100%|██████████| 907/907 [00:00<00:00, 301kB/s]
Downloading: 100%|██████████| 665M/665M [00:56<00:00, 11.7MB/s] 
Downloading: 100%|██████████| 255/255 [00:00<00:00, 255kB/s]
Downloading: 100%|██████████| 798k/798k [00:00<00:00, 834kB/s]  
Downloading: 100%|██████████| 456k/456k [00:00<00:00, 466kB/s]  
Downloading: 100%|██████████| 2.11M/2.11M [00:01<00:00, 1.21MB/s]
Downloading: 100%|██████████| 131/131 [00:00<00:00, 131kB/s]
100%|██████████| 50/50 [03:10<00:00,  3.80s/it]


## news-category

In [None]:
from datasets import load_dataset

count = 1000
NEWS_CATEGORIES = [
    "ENTERTAINMENT", "POLITICS", "WELLNESS", "TRAVEL", "STYLE & BEAUTY",
    "PARENTING", "HEALTHY LIVING", "QUEER VOICES", "FOOD & DRINK", "BUSINESS"
    ]

def generate_news(count=None):
    dataset = load_dataset("heegyu/news-category-balanced-top10", split="train")
    if count:
        dataset = dataset.shuffle(42) #.select(range(count))
    outs = []

    for category in NEWS_CATEGORIES:
        i = 0
        for item in dataset:
            if item['category'] != category:
                continue
            
            prompt = f"{item['category']} Title: {item['headline']}\nContent: "
            outs.append(prompt)
            i += 1
            # print(prompt)

            if count and count == i:
                break
    return outs

prompts = generate_news(count)
run(
    f"news-{count}.jsonl",
    prompts,
    "heegyu/gpt2-news-category",
    "heegyu/roberta-base-news-category-top10",
    {f"LABEL_{i}": NEWS_CATEGORIES[i] for i in range(len(NEWS_CATEGORIES))},
    length=48
)