In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
from pathlib import Path
from typing import Union, List
import torch
from torch.nn import CrossEntropyLoss

  from .autonotebook import tqdm as notebook_tqdm


In [72]:
class DExperts:
    
    def __init__(
        self, 
        base_model: Union[str, Path, AutoModelForCausalLM],
        antiexpert_model: Union[str, Path, AutoModelForCausalLM] = None,
        expert_model: Union[str, Path, AutoModelForCausalLM] = None,
        tokenizer: str = 'gpt2', 
        seed: int = 42,
    ):
        # Set up device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        set_seed(seed)

        self.base_model = AutoModelForCausalLM.from_pretrained(base_model).to(self.device)
        if antiexpert_model:
            self.antiexpert = AutoModelForCausalLM.from_pretrained(antiexpert_model, use_auth_token=True).to(self.device)
        else:
            self.antiexpert = None
        if expert_model:
            self.expert = AutoModelForCausalLM.from_pretrained(expert_model, use_auth_token=True).to(self.device)
        else:
            self.expert = None
        
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        assert self.tokenizer.eos_token_id == self.tokenizer.pad_token_id

    
    def __call__(self, prompt: str, alpha: float = 2.0):
        encodings_dict = self.tokenizer(prompt, return_tensors='pt', padding=True, return_attention_mask=True).to(self.device)
        encoded_text = encodings_dict["input_ids"]
        attn_mask = encodings_dict["attention_mask"]
        logits = self.get_logits(encoded_text, alpha=alpha)
        return {"logits": logits, "perplexity": self.get_perplexity(logits, encoded_text)}

    def compute_perplexity(self, prompt: str, alpha: float = 2.0):
        encodings_dict = self.tokenizer(prompt, return_tensors='pt', padding=True, return_attention_mask=True).to(self.device)
        encoded_text = encodings_dict["input_ids"]
        attn_mask = encodings_dict["attention_mask"]
        logits = self.get_logits(encoded_text, alpha=alpha)
        return self.get_perplexity(logits, encoded_text)

    def get_perplexity(self, logits, labels):
        loss_fct = CrossEntropyLoss()
        loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
        return torch.exp(loss)

    def forward(self, prompt: str, max_length: int = 20, alpha: float = 2.0):
        return self(prompt, max_length=max_length, alpha=alpha)

    def get_logits(self, encodings_dict, alpha=2.0):
        self.base_model.eval()
        if self.expert:
            self.expert.eval()
        if self.antiexpert:
            self.antiexpert.eval()
        
        with torch.no_grad():
            # base model prediction
            base_logits = self.base_model(encodings_dict).logits
            
            # expert prediction
            if self.expert:
                expert_logits = self.expert(encodings_dict).logits
            else:
                expert_logits = base_logits
            
            # antiexpert prediction
            if self.antiexpert:
                antiexpert_logits = self.antiexpert(encodings_dict).logits
            else:
                antiexpert_logits = base_logits
    

            if self.antiexpert is not None or self.expert is not None:
                ensemble_logits = base_logits + alpha * (expert_logits - antiexpert_logits)
            else:
                ensemble_logits = base_logits

        return ensemble_logits

In [73]:
dexperts = DExperts(
    base_model='gpt2',
    antiexpert_model='eliolio/gpt2-finetuned-redditbias',
    expert_model='eliolio/gpt2-finetuned-reddit-antibias',
    tokenizer='gpt2',
)

In [70]:
res = dexperts("I love the fact that", alpha=0.0)
res

{'logits': tensor([[[ -39.3084,  -39.0100,  -41.8374,  ...,  -46.9337,  -44.9074,
            -39.5149],
          [ -84.4961,  -85.0687,  -90.6138,  ...,  -91.6659,  -93.1036,
            -87.5304],
          [ -86.2664,  -85.5881,  -87.9832,  ...,  -88.2878,  -90.3265,
            -85.8329],
          [ -64.7916,  -65.6689,  -72.8639,  ...,  -72.9414,  -75.8771,
            -67.8533],
          [ -94.8606,  -95.9581, -100.5047,  ..., -100.2555, -103.2658,
            -96.7781]]]),
 'perplexity': tensor(1880.1859)}

In [74]:
dexperts

<DExpertsGenerator model_name_or_path="GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplac