In [19]:
import torch
import numpy as np
from transformers import AutoModelForMaskedLM, AutoTokenizer

DEVICE = "cpu"
# setup optimal acceleration DEVICE 
if torch.backends.mps.is_available():
    DEVICE = torch.device("mps")  # Use Metal Performance Shaders on macOS
elif torch.cuda.is_available():
    DEVICE = torch.device("cuda")  # to check if cuda is an option https://www.restack.io/p/gpu-computing-answer-is-my-gpu-cuda-enabled-cat-ai

print(f"DEVICE is : {DEVICE}")

DEVICE is : cuda


In [20]:
def load_model_tokenizer(model_path_or_id:str, device:str):
    # Replace AutoModelForMaskedLM with the correct class for your task, e.g., AutoModelForSequenceClassification
    model = AutoModelForMaskedLM.from_pretrained(model_path_or_id)
    tokenizer = AutoTokenizer.from_pretrained(model_path_or_id)
    
    model.to(device)

    return model, tokenizer

In [21]:
#retreival function
def top_k_prediction(masked_text, model, tokenizer, k=10):
    model.to(DEVICE) 
    inputs = tokenizer(masked_text, return_tensors="pt").to(DEVICE)
    logits = model(**inputs).logits
    mask_token_index = torch.where(inputs["input_ids"]==tokenizer.mask_token_id)[1]
    mask_token_logits = logits[0, mask_token_index, :]
    return [tokenizer.decode(t) for t in torch.topk(mask_token_logits, k, dim=1).indices[0].tolist()]

In [22]:
#METRICS
def recall_in_top_k(A,B): 
    overlap = {word for word in A if word in B}
    return overlap

def differences_in_top_k(A,B): 
    new_in_B = {word for word in B if word not in A}
    left_out_in_B= {word for word in A if word not in B}
    return new_in_B, left_out_in_B

#jaccard Similarity: Measures overlap between two sets or lists.
def jaccard(A,B): 
    intersection = {word for word in A if word in B}
    union = set(A+B)
    return len(intersection)/len(union)


#perplexity of target words
import torch.nn.functional as F

def compute_masked_word_perplexity(target_word, masked_text, model, tokenizer):
    model.to(DEVICE) 
    inputs = tokenizer(masked_text, return_tensors="pt").to(DEVICE)
    logits = model(**inputs).logits

    # Convert target word to ID
    target_word_id = tokenizer.convert_tokens_to_ids(target_word)
    
    #mask_token_index = torch.where(inputs["input_ids"]==tokenizer.mask_token_id)[1] #idk why this breaks ? 
    mask_token_index = (inputs["input_ids"][0] == tokenizer.mask_token_id).nonzero().item()

    # Extract logits for the masked token position
    mask_token_logits = logits[0, mask_token_index, :]
    
    # Convert to probabilities
    probabilities = F.softmax(mask_token_logits, dim=-1)
    
    p_word = probabilities[target_word_id].item()
    
    return 1 / p_word

In [23]:
model_checkpoint = 'distilbert/distilbert-base-uncased'

model, tokenizer = load_model_tokenizer(model_path_or_id=model_checkpoint, device=DEVICE)
extended_model , tokenizer = load_model_tokenizer("movie_model\checkpoint-958", DEVICE)

In [25]:
k = 30
#recall (how many words stayed the same)
text = "What a great [MASK]"
A = list(top_k_prediction(text, model, tokenizer, k))
B = list(top_k_prediction(text, extended_model, tokenizer, k))

#metrics on wordlist inputs
print("recall: ", recall_in_top_k(A,B))
print("differences: ", differences_in_top_k(A,B))
print("Jacard Score: ", jaccard(A,B))
#metrics on model predictions (probs)
target_word = "movie"
print(f"Perplexity for the word {target_word}, in {text} base model: {compute_masked_word_perplexity(target_word, text, model, tokenizer)}")
print(f"Perplexity for the word {target_word}, in {text} extend model: {compute_masked_word_perplexity(target_word, text, extended_model, tokenizer)} ")

recall:  {'deal', 'gift', '?', 'adventure', '!', 'fun', 'mess', 'night', '.', 'story', 'job', ';', 'time', 'day', 'thing', 'idea', 'surprise', 'song', 'chance'}
differences:  ({'film', 'show', 'book', 'game', 'one', 'effort', 'way', 'plan', 'movie', 'performance', ','}, {'mystery', 'beauty', 'wonder', 'treasure', 'coincidence', 'prize', 'disaster', 'boy', 'success', 'fortune', 'tragedy'})
Jacard Score:  0.4634146341463415
Perplexity for the word movie, in What a great [MASK] base model: 3957.4228697477365
Perplexity for the word movie, in What a great [MASK] extend model: 454.4608551738091 
