In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
tokenizer = AutoTokenizer.from_pretrained("AI-Sweden-Models/gpt-sw3-356m")


In [4]:
vocab = tokenizer.get_vocab().keys()
id2w =  {i:w for i,w in enumerate(vocab)}
w2id = {w:i for i,w in enumerate(vocab)}

In [5]:
model = AutoModelForCausalLM.from_pretrained("AI-Sweden-Models/gpt-sw3-356m").to(device)

In [6]:
type(model)

transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel

In [7]:
def calc_prob_text(start_text, continue_text):
    continue_ids = tokenizer(continue_text, return_tensors="pt")["input_ids"][0]
    start_ids = tokenizer(start_text, return_tensors="pt")["input_ids"][0]
    prob = 1
    for id in continue_ids:
        with torch.no_grad():
            logits = model(start_ids.unsqueeze(0), use_cache=True).logits[0,-1]
        probs = torch.softmax(logits, dim = 0)
        prob *= probs[id]
        start_ids =  torch.cat((start_ids, id.unsqueeze(0)), dim=0)
    return prob

In [8]:
def calc_prob_text2(start_text, continue_text, end_text, goal_ratio):
    continue_ids = tokenizer(continue_text, return_tensors="pt")["input_ids"][0].to(device)
    start_ids = tokenizer(start_text, return_tensors="pt")["input_ids"][0].to(device)
    end_ids = tokenizer(end_text, return_tensors="pt")["input_ids"][0].to(device)
    all_ids = torch.cat((start_ids, continue_ids, end_ids))[:-1]
    with torch.no_grad():
        logits = model(all_ids.unsqueeze(0), use_cache=True).logits
    logits = logits[0,-(continue_ids.shape[0] + end_ids.shape[0]):]
    probs = torch.softmax(logits, dim = 1)
    prob = probs[torch.arange(continue_ids.shape[0] + end_ids.shape[0]), torch.cat((continue_ids, end_ids))]
    ratio = goal_ratio / (continue_ids.shape[0] / end_ids.shape[0]) 
    weights = torch.cat((torch.full(continue_ids.shape, ratio),(torch.ones_like(end_ids))))
    weighed_probs = torch.pow(prob, weights)
    exponent = 1 / torch.sum(weights)
    suum = torch.sum(torch.log(weighed_probs))
    final_log_prob = (suum * exponent)
    return torch.exp(final_log_prob)

In [9]:
s = "Jag brukar äta lunch runt"
e = "på eftermiddagen"
good = calc_prob_text2(s, "klockan halv två", e, 1)
bad = calc_prob_text2(s, "klockan halv elva", e, 1)
print("The good answer is", good.item() / bad.item(), "times better than the bad answer")

The good answer is 1.4754679298182363 times better than the bad answer


In [10]:
s = "Jag brukar äta lunch runt"
e = "på förmiddagen"
bad = calc_prob_text2(s, "klockan halv två", e, 1)
good = calc_prob_text2(s, "klockan halv elva", e, 1)
print("The good answer is", good.item() / bad.item(), "times better than the bad answer")

The good answer is 1.3650903230734268 times better than the bad answer


In [11]:
s = "Jag brukar äta lunch runt"
e = "på förmiddagen"
good = calc_prob_text2(s, "klockan halv två", e, 1)
bad = calc_prob_text2(s, "vid kyrkan", e, 1)
print("The good answer is", good.item() / bad.item(), "times better than the bad answer")

The good answer is 12.932677657304243 times better than the bad answer
