In [28]:
# external
import datasets as ds
from transformers import pipeline
import importlib
import os
import torch
import transformers
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from functools import partial

# internal
from redditqa.data.smart_filter import question_filter, answer_filter
from redditqa.data import qa_generation
from redditqa.data.util import mask_links

importlib.reload(qa_generation)

<module 'redditqa.data.qa_generation' from '/workspaces/reddit_qa/redditqa/data/qa_generation.py'>

In [2]:
torch.cuda.is_available()

True

In [None]:
model_id = "HuggingFaceH4/zephyr-7b-beta"
cache_dir = "/scratch1/ssawicki/cache"

### Load model

In [3]:
# quantization config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

In [4]:
model = AutoModelForCausalLM.from_pretrained(
        model_id,
        low_cpu_mem_usage=True,
        quantization_config=bnb_config,
        cache_dir=cache_dir,

    )

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

In [5]:
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_dir)

In [6]:
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

### Load dataset

In [7]:
dataset = ds.load_from_disk("/scratch1/redditqa/cached_datasets/AskHistorians.jsonl")
# Preprocessing
dataset = dataset.map(mask_links)

In [8]:
dataset.info.features

{'question_created_utc': Value(dtype='int64', id=None),
 'question_retrieved_on': Value(dtype='int64', id=None),
 'question_deleted': Value(dtype='bool', id=None),
 'question_title': Value(dtype='string', id=None),
 'question_selftext': Value(dtype='string', id=None),
 'question_score': Value(dtype='int64', id=None),
 'question_char_length': Value(dtype='int64', id=None),
 'question_selftext_char_length': Value(dtype='int64', id=None),
 'answers': [{'answer_body': Value(dtype='string', id=None),
   'answer_char_length': Value(dtype='int64', id=None),
   'answer_created_utc': Value(dtype='int64', id=None),
   'answer_deleted': Value(dtype='bool', id=None),
   'answer_id': Value(dtype='string', id=None),
   'answer_retrieved_on': Value(dtype='int64', id=None),
   'answer_score': Value(dtype='int64', id=None)}]}

### LLM Filtering

Question filtering:
- Few shot classification: define what a well written question is, show some good and bad examples.

Answer filter:
- Few shot classification: define what a well written answer is, show good & bad examples.

#### Question filter

In [23]:
question_pipe = partial(question_filter, pipeline=pipe, verbose=True)

In [24]:
example_ds = dataset.select(list(range(10,50)))
example_ds

Dataset({
    features: ['question_created_utc', 'question_retrieved_on', 'question_deleted', 'question_title', 'question_selftext', 'question_score', 'question_char_length', 'question_selftext_char_length', 'answers'],
    num_rows: 40
})

In [26]:
example_ds.filter(question_pipe)

Filter:   0%|          | 0/40 [00:00<?, ? examples/s]

##################################################
Question: Who Is Your Favourite Historical Dictator?
------------------------------------------------------------------------------------------
Answer: bad question. It contains a personal reference and is suggestive. Historical figures should not be judged based on personal preferences, and the question does not provide any context or historical significance. It is also not clear what the question is asking
##################################################
##################################################
Question: I have an old history teacher that INSISTS that Lincoln was racist, and did not care for Blacks at, rather he issued the Emancipation Proclamation to economically sabotage the South. To what extent is this true or not true?
------------------------------------------------------------------------------------------
Answer: good question

Explanation:

While the question contains a personal reference, it is not a bad questio

Dataset({
    features: ['question_created_utc', 'question_retrieved_on', 'question_deleted', 'question_title', 'question_selftext', 'question_score', 'question_char_length', 'question_selftext_char_length', 'answers'],
    num_rows: 22
})

#### Answer filter

In [40]:
answer_pipe = partial(answer_filter, pipeline=pipe, verbose=True)

In [44]:
example_ds.filter(answer_pipe)

Filter:   0%|          | 0/40 [00:00<?, ? examples/s]



##################################################
Who Is Your Favourite Historical Dictator?
------------------------------------------------------------------------------------------
Stalin. Despite being one of the most brutal dictators to ever live, he is still popular among many Russians. In a recent poll, Stalin was voted as the third greatest Russian to ever live. While he killed millions, he also transformed his poor and broken country into a superpower. 
------------------------------------------------------------------------------------------
Answer: 'bad answer'

Explanation:

The answer contains a clear preference, which goes against the guidelines provided. Additionally, the answer glorifies a historical dictator known for his brutality and mass killings,
##################################################
##################################################
Who Is Your Favourite Historical Dictator?
----------------------------------------------------------------------------

KeyboardInterrupt: 

### Toxicity

Currently not in use.

In [5]:
toxicity_pipe = pipeline("text-classification", model="tomh/toxigen_roberta", device=0)

In [13]:
toxicity_pipe('I love ML', top_k=None)

[{'label': 'LABEL_0', 'score': 0.9992640614509583},
 {'label': 'LABEL_1', 'score': 0.0007359233568422496}]

In [6]:
def run_toxicity_pipe(text):
    try: 
        result = toxicity_pipe(text, top_k=None)
        result = [r for r in result if r['label'] == 'LABEL_1'][0]
    except:
        return 0.5

    return result['score']

In [11]:
for answer in dataset[65]['answers']:
    print(run_toxicity_pipe(answer["answer_body"]))

0.0006564322975464165
0.0007360000745393336
0.0007385569042526186
0.000814662838820368
