In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [4]:
tokenizer = AutoTokenizer.from_pretrained("AI-Sweden-Models/gpt-sw3-356m")


In [5]:
vocab = tokenizer.get_vocab().keys()
id2w =  {i:w for i,w in enumerate(vocab)}
w2id = {w:i for i,w in enumerate(vocab)}

In [40]:
list(vocab)

['<pad>',
 '<unk>',
 '<s>',
 '<|endoftext|>',
 '<|javascript|>',
 '<|python|>',
 '<|sql|>',
 '<|shell|>',
 '<0x00>',
 '<0x01>',
 '<0x02>',
 '<0x03>',
 '<0x04>',
 '<0x05>',
 '<0x06>',
 '<0x07>',
 '<0x08>',
 '<0x09>',
 '<0x0A>',
 '<0x0B>',
 '<0x0C>',
 '<0x0D>',
 '<0x0E>',
 '<0x0F>',
 '<0x10>',
 '<0x11>',
 '<0x12>',
 '<0x13>',
 '<0x14>',
 '<0x15>',
 '<0x16>',
 '<0x17>',
 '<0x18>',
 '<0x19>',
 '<0x1A>',
 '<0x1B>',
 '<0x1C>',
 '<0x1D>',
 '<0x1E>',
 '<0x1F>',
 '<0x20>',
 '<0x21>',
 '<0x22>',
 '<0x23>',
 '<0x24>',
 '<0x25>',
 '<0x26>',
 '<0x27>',
 '<0x28>',
 '<0x29>',
 '<0x2A>',
 '<0x2B>',
 '<0x2C>',
 '<0x2D>',
 '<0x2E>',
 '<0x2F>',
 '<0x30>',
 '<0x31>',
 '<0x32>',
 '<0x33>',
 '<0x34>',
 '<0x35>',
 '<0x36>',
 '<0x37>',
 '<0x38>',
 '<0x39>',
 '<0x3A>',
 '<0x3B>',
 '<0x3C>',
 '<0x3D>',
 '<0x3E>',
 '<0x3F>',
 '<0x40>',
 '<0x41>',
 '<0x42>',
 '<0x43>',
 '<0x44>',
 '<0x45>',
 '<0x46>',
 '<0x47>',
 '<0x48>',
 '<0x49>',
 '<0x4A>',
 '<0x4B>',
 '<0x4C>',
 '<0x4D>',
 '<0x4E>',
 '<0x4F>',
 '<0x50>',
 '<

In [6]:
model = AutoModelForCausalLM.from_pretrained("AI-Sweden-Models/gpt-sw3-356m").to(device)

In [7]:
type(model)

transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel

In [8]:
def calc_prob_text(start_text, continue_text):
    fill_in__ids = tokenizer(continue_text, return_tensors="pt")["input_ids"][0]
    start_ids = tokenizer(start_text, return_tensors="pt")["input_ids"][0]
    prob = 1
    for id in fill_in__ids:
        with torch.no_grad():
            logits = model(start_ids.unsqueeze(0), use_cache=True).logits[0,-1]
        probs = torch.softmax(logits, dim = 0)
        prob *= probs[id]
        start_ids =  torch.cat((start_ids, id.unsqueeze(0)), dim=0)
    return prob

In [41]:
def calc_prob_text2(start_text, fill_in_text, end_text, goal_ratio):
    fill_in__ids = tokenizer(fill_in_text, return_tensors="pt")["input_ids"][0].to(device).to(dtype = torch.int64)
    start_ids = tokenizer('<s> ' + start_text, return_tensors="pt")["input_ids"][0].to(device).to(dtype = torch.int64)
    end_ids = tokenizer(end_text, return_tensors="pt")["input_ids"][0].to(device).to(dtype = torch.int64)
    all_ids = torch.cat((start_ids, fill_in__ids, end_ids))[:-1]
    with torch.no_grad():
        logits = model(all_ids.unsqueeze(0), use_cache=True).logits
    logits = logits[0,-(fill_in__ids.shape[0] + end_ids.shape[0]):]
    probs = torch.softmax(logits, dim = 1)
    prob = probs[torch.arange(fill_in__ids.shape[0] + end_ids.shape[0]), torch.cat((fill_in__ids, end_ids))]
    ratio = goal_ratio / (fill_in__ids.shape[0] / end_ids.shape[0]) if fill_in__ids.shape[0] != 0 and end_ids.shape[0] != 0 else 1 
    weights = torch.cat((torch.full(fill_in__ids.shape, ratio),(torch.ones_like(end_ids))))
    weighed_probs = torch.pow(prob, weights)
    exponent = 1 / torch.sum(weights)
    suum = torch.sum(torch.log(weighed_probs))
    final_log_prob = (suum * exponent)
    final_prob = torch.exp(final_log_prob)
    return final_prob.item() if final_prob != str(final_prob.item()) != "nan" else 0

In [10]:
s = "Jag brukar äta lunch runt"
e = "på eftermiddagen"
good = calc_prob_text2(s, "klockan halv två", e, 1)
bad = calc_prob_text2(s, "klockan halv elva", e, 1)
print("The bad answer is", bad.item() / good.item(), "times worse than the good answer")

The bad answer is 0.6777510915626548 times worse than the good answer


In [11]:
s = "Jag brukar äta lunch runt"
e = "på förmiddagen"
bad = calc_prob_text2(s, "klockan halv två", e, 1)
good = calc_prob_text2(s, "klockan halv elva", e, 1)
print("The bad answer is", bad.item() / good.item(), "times worse than the good answer")

The bad answer is 0.7325522590685093 times worse than the good answer


In [12]:
s = "Jag brukar äta lunch runt"
e = "på förmiddagen"
good = calc_prob_text2(s, "klockan halv två", e, 1)
bad = calc_prob_text2(s, "vid kyrkan", e, 1)
print("The bad answer is", bad.item() / good.item(), "times worse than the good answer")

The bad answer is 0.0773235076678193 times worse than the good answer


In [45]:
calc_prob_text2("", "", "", 1)

0.0004088236892130226