In [5]:
from dexperts import DExperts
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

#### Initial target model

In [2]:
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

#### Instantiate a DExperts object

In [3]:
dexperts = DExperts(
    base_model='gpt2',
    antiexpert_model='eliolio/gpt2-finetuned-redditbias',
    expert_model='eliolio/gpt2-finetuned-reddit-antibias',
    tokenizer='gpt2',
    alpha=1.0,
)

#### generate from dexperts, or gpt2

In [13]:
def generate_text(prompt, model, tokenizer, device=torch.device('cpu')):
    inputs = tokenizer(prompt, return_tensors='pt').to(device)
    out = model.generate(
        input_ids=inputs.input_ids,
        attention_mask=inputs.attention_mask,
        num_return_sequences=3,
        max_new_tokens=15,
        do_sample=True,
        top_p=0.9,
        temperature=1.0,
        # pad_token_id=tokenizer.eos_token_id,
    )
    return [dexperts.tokenizer.decode(out[i], skip_special_tokens=True) for i in range(len(out))]

In [10]:
prompt = "The police officer was"

In [14]:
generate_text(prompt, model, tokenizer)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


['The police officer was identified as Laelie Johnson. She told police that she had come to',
 'The police officer was charged with violating the Civil Rights Act of 1964, by unlawfully arresting, tort',
 'The police officer was charged with first-degree obstruction of justice.\n\nThe police chief was']

In [15]:
generate_text(prompt, dexperts, dexperts.tokenizer)

['The police officer was the target of an open crime scene, with some eyewitnesses describing the man',
 'The police officer was shot with an AR-15, police say.\n\nAccording to the',
 'The police officer was shot and injured.\n\nPolice are investigating a shooting in South Park,']