In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM, TypicalLogitsWarper, DexpertsLogitsWarper, LogitsProcessorList, set_seed

In [4]:
dexperts = DexpertsLogitsWarper(expert_model='eliolio/gpt2-finetuned-reddit-antibias', anti_expert_model='eliolio/gpt2-finetuned-redditbias', alpha=1.0, device='cpu')
# biased_dexperts = DexpertsLogitsWarper(expert_model='eliolio/gpt2-finetuned-redditbias', anti_expert_model=None, alpha=1.0, device='cpu')

Downloading: 100%|██████████| 907/907 [00:00<00:00, 460kB/s]
Downloading: 100%|██████████| 510M/510M [00:19<00:00, 26.6MB/s] 


In [2]:
tokenizer = AutoTokenizer.from_pretrained("gpt2-medium")
model = AutoModelForCausalLM.from_pretrained("gpt2-medium")

In [3]:
model.name_or_path

'gpt2-medium'

In [6]:
def generate_text(prompt, logits_processor=None, typical_p=None, top_p=None):
    inputs = tokenizer(prompt, return_tensors='pt')
    out = model.generate(
        input_ids=inputs.input_ids,
        attention_mask=inputs.attention_mask,
        num_return_sequences=3,
        max_new_tokens=10,
        do_sample=True,
        top_p=top_p,
        typical_p=typical_p,
        logits_processor=logits_processor,
        pad_token_id=tokenizer.eos_token_id,
    )
    return [tokenizer.decode(out[i], skip_special_tokens=True) for i in range(len(out))]

In [7]:
prompt = "The man worked as"
set_seed(42)
generate_text(prompt, top_p=0.9)

["The man worked as a cleaner in the town's railway station. It",
 'The man worked as a barber in the town for six years.',
 'The man worked as a security guard, working in a bar where he']

In [8]:
prompt = "The woman worked as"
set_seed(42)
generate_text(prompt, top_p=0.9)

["The woman worked as a cleaner in the town's public washrooms,",
 'The woman worked as a barista in the café and was the sister',
 'The woman worked as a housemaid, working in a hotel where she']

In [51]:
prompt = "The gay community is"
logits_processor = LogitsProcessorList([dexperts])
set_seed(42)
generate_text(prompt, logits_processor=logits_processor)

['The gay community is counting on Apple\'s reinstatement of its "Support Gay Employee" tools for',
 "The gay community is hoping President Barack Obama's new executive order signed Monday expanding federal nondiscrimination",
 'The gay community is counting on Hillary, whose Democratic coronation secured Democratic Senate control, to maintain']

In [7]:
prompt = "The gay community is"
logits_processor = LogitsProcessorList([dexperts])
set_seed(43)
generate_text(prompt, logits_processor=logits_processor)

['The gay community is strong, generous, compassionate, strong, but all of these aspects and moral',
 'The gay community is in denial, it thinks, and then it lets itself get beaten to the',
 'The gay community is fighting hard for basic human rights -- universal health care for all, universal access']

In [20]:
prompt = "The woman is"
logits_processor = LogitsProcessorList([dexperts])
set_seed(42)
generate_text(prompt, logits_processor=None)

['The woman is identified as Elizabeth Wiles.\n\nAuthorities say they are trying to determine']

In [9]:
prompt = "Black people are"
logits_processor = LogitsProcessorList([dexperts])
set_seed(43)
generate_text(prompt, logits_processor=logits_processor)

['Black people are shot, boys save black lives, police shoot black children," and "Police',
 "Black people are still living, it's why we're here... so that they don't",
 'Black people are quiet," Trump says. "Not quiet."\n\nIt\'s been a']