# GPT2 Functions
<b>Date:</b> October 6, 2023\
<b>Author:</b> Dimitris Lymperopoulos\
<b>Description:</b> A notebook containing gpt2-related functions

## Imports

In [3]:
import torch
import numpy as np
from transformers import OpenAIGPTTokenizer, OpenAIGPTLMHeadModel
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from pylev import levenshtein as lev_dist

  from .autonotebook import tqdm as notebook_tqdm


## Functions

In [5]:
def model_init(model_string='gpt2', cuda=False):
    """
    A function that initializes a LM and a Tokenizer based on GPT2. 

    :param model_string: string representing the base model for the transformer and the tokenizer
    :param cuda: boolean value, determining whether or not to use gpu for model inference
    :return: the pretrained model and tokenizer
    """
    if model_string.startswith("gpt2"):
        tokenizer = GPT2Tokenizer.from_pretrained(model_string)
        model = GPT2LMHeadModel.from_pretrained(model_string)
    else:
        tokenizer = OpenAIGPTTokenizer.from_pretrained(model_string)
        model = OpenAIGPTLMHeadModel.from_pretrained(model_string)
    model.eval()
    if cuda:
        model.to('cuda')
    return model, tokenizer

In [2]:
def sent_scoring(model, tokenizer, text, cuda=False):
    """
    A function that uses the given LM and Tokenizer to compute the probability of a given sentence.

    :param model: a pretrained transformer model
    :param tokenizer: a pretrained tokenizer
    :param text: a string representing the sentence whose probability will be computed
    :param cuda: boolean value, determining whether or not to use gpu for model inference
    :return: the computed loss of the sentence and log_probability of the last token
    """
    assert model is not None
    assert tokenizer is not None
    tokens = tokenizer.encode(text, add_special_tokens=False, return_tensors="pt")
    if cuda:
        tokens = tokens.to('cuda')
    with torch.no_grad():
        outputs = model(tokens, labels=tokens)
    loss, logits = outputs[:2]
    loss, log_prob = loss.item(), logits[0, -1, tokens[0, -1]].item()
    return loss, log_prob