## Create datasets

In [1]:
import pandas as pd
from sqlalchemy import create_engine

from constants import DATA_DIR, TEXTS_DIR

In [2]:
db_path = DATA_DIR / 'perspective-responses-v2.db'
engine = create_engine(f'sqlite:///{db_path}', echo=False)

In [3]:
def create_df(doc_quintile: str, sent_quintile: str):
    doc_query = f"""
        SELECT filename
        FROM responses_quintiles
        WHERE quintile = {doc_quintile}
        """

    prompt_query = f"""
        WITH 
            docs AS ({doc_query})
        SELECT 
            docs.filename AS filename,
            span_scores_quintiles.begin AS begin,
            span_scores_quintiles.end AS end
        FROM 
            span_scores_quintiles
                INNER JOIN 
            docs
                ON span_scores_quintiles.filename = docs.filename
        WHERE 
            span_scores_quintiles.begin = 0 
        and span_scores_quintiles.quintile = {sent_quintile}
        """
    
    toxicity_query = f"""
        WITH
            prompts AS ({prompt_query})
        SELECT
            prompts.filename AS filename,
            prompts.begin AS begin,
            prompts.end AS end,
            responses.toxicity AS doc_toxicity,
            span_scores.toxicity AS prompt_toxicity
        FROM
            prompts,
            responses,
            span_scores
        WHERE
            prompts.filename = span_scores.filename
        and prompts.begin = span_scores.begin
        and prompts.end = span_scores.end
        and prompts.filename = responses.filename
        """

    return pd.read_sql(toxicity_query, con=engine)

In [4]:
low_doc_low_prompt = create_df(1, 1)
len(low_doc_low_prompt)

523774

In [5]:
low_doc_high_prompt = create_df(1, 5)
len(low_doc_high_prompt)

101859

In [6]:
high_doc_low_prompt = create_df(5, 1)
len(high_doc_low_prompt)

177681

In [7]:
high_doc_high_prompt = create_df(5, 5)
len(high_doc_high_prompt)

596395

In [13]:
high_doc_low_prompt.doc_toxicity.mean()

0.34961994635137855

## Generation

In [64]:
%load_ext autoreload
%autoreload 2

from gen import GPT2Generator

<gen.GPT2Generator at 0x7fc5eefe7310>

In [None]:
generator = gen.GPT2Generator()

In [96]:
list(generator.generate(
    prompt=['Hello there my friend', 'Goodbye there my friend'], 
    max_length=20
))

tensor([[50256, 50256]], device='cuda:0')
torch.Size([1, 20])


['']

In [15]:
from create_db import Response