In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [4]:
tokenizer = AutoTokenizer.from_pretrained("AI-Sweden-Models/gpt-sw3-356m")


In [5]:
vocab = tokenizer.get_vocab().keys()
id2w =  {i:w for i,w in enumerate(vocab)}
w2id = {w:i for i,w in enumerate(vocab)}

In [6]:
model = AutoModelForCausalLM.from_pretrained("AI-Sweden-Models/gpt-sw3-356m").to(device)

In [7]:
type(model)

transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel

In [8]:
def calc_prob_text(start_text, continue_text):
    fill_in__ids = tokenizer(continue_text, return_tensors="pt")["input_ids"][0]
    start_ids = tokenizer(start_text, return_tensors="pt")["input_ids"][0]
    prob = 1
    for id in fill_in__ids:
        with torch.no_grad():
            logits = model(start_ids.unsqueeze(0), use_cache=True).logits[0,-1]
        probs = torch.softmax(logits, dim = 0)
        prob *= probs[id]
        start_ids =  torch.cat((start_ids, id.unsqueeze(0)), dim=0)
    return prob

In [13]:
def calc_prob_text2(start_text, fill_in_text, end_text, goal_ratio):
    fill_in__ids = tokenizer(fill_in_text, return_tensors="pt")["input_ids"][0].to(device)
    start_ids = tokenizer(start_text, return_tensors="pt")["input_ids"][0].to(device)
    end_ids = tokenizer(end_text, return_tensors="pt")["input_ids"][0].to(device)
    all_ids = torch.cat((start_ids, fill_in__ids, end_ids))[:-1]
    with torch.no_grad():
        logits = model(all_ids.unsqueeze(0), use_cache=True).logits
    logits = logits[0,-(fill_in__ids.shape[0] + end_ids.shape[0]):]
    probs = torch.softmax(logits, dim = 1)
    prob = probs[torch.arange(fill_in__ids.shape[0] + end_ids.shape[0]), torch.cat((fill_in__ids, end_ids))]
    ratio = goal_ratio / (fill_in__ids.shape[0] / end_ids.shape[0]) 
    weights = torch.cat((torch.full(fill_in__ids.shape, ratio),(torch.ones_like(end_ids))))
    weighed_probs = torch.pow(prob, weights)
    exponent = 1 / torch.sum(weights)
    suum = torch.sum(torch.log(weighed_probs))
    final_log_prob = (suum * exponent)
    return torch.exp(final_log_prob)

In [16]:
s = "Jag brukar äta lunch runt"
e = "på eftermiddagen"
good = calc_prob_text2(s, "klockan halv två", e, 1)
bad = calc_prob_text2(s, "klockan halv elva", e, 1)
print("The bad answer is", bad.item() / good.item(), "times worse than the good answer")

The bad answer is 0.6777510915626548 times worse than the good answer


In [17]:
s = "Jag brukar äta lunch runt"
e = "på förmiddagen"
bad = calc_prob_text2(s, "klockan halv två", e, 1)
good = calc_prob_text2(s, "klockan halv elva", e, 1)
print("The bad answer is", bad.item() / good.item(), "times worse than the good answer")

The bad answer is 0.7325522590685093 times worse than the good answer


In [18]:
s = "Jag brukar äta lunch runt"
e = "på förmiddagen"
good = calc_prob_text2(s, "klockan halv två", e, 1)
bad = calc_prob_text2(s, "vid kyrkan", e, 1)
print("The bad answer is", bad.item() / good.item(), "times worse than the good answer")

The bad answer is 0.0773235076678193 times worse than the good answer
