In [140]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from nltk.corpus import wordnet as wn
import nltk
import string
from nltk.stem import WordNetLemmatizer
from loader import load_instances, load_key
import re
from nltk.corpus import stopwords
from nltk.wsd import lesk
import random
from scipy.spatial.distance import cosine, euclidean

from nltk.corpus import wordnet as wn
import pandas as pd

from tqdm import tqdm
import time


In [8]:
# Load instances and key data in another cell
data_f = 'multilingual-all-words.en.xml'  
key_f = 'wordnet.en.key'

dev_instances, test_instances = load_instances(data_f)
dev_key, test_key = load_key(key_f)

dev_instances = {k:v for (k,v) in dev_instances.items() if k in dev_key}
test_instances = {k:v for (k,v) in test_instances.items() if k in test_key}


lemmatizer = WordNetLemmatizer()

def clean_and_lemmatize(word):
    # Remove punctuation
    word = word.translate(str.maketrans('', '', string.punctuation))
    # Lemmatize
    return lemmatizer.lemmatize(word)

def create_dataframe(instances, key_dict):
    data = []
    for instance_id, instance in instances.items():
        # Decode lemma and context if in byte format
        lemma = instance.lemma.decode('utf-8') if isinstance(instance.lemma, bytes) else instance.lemma
        context = [word.decode('utf-8') if isinstance(word, bytes) else word for word in instance.context]

        # Retrieve the sense key(s) from key_dict, or None if not found
        sense_key = key_dict.get(instance_id, [None])
        
        # Append the processed data
        data.append({
            'Instance ID': instance_id,             
            'Lemma': lemma,                         
            'Original Context': context,
            'Combined Context': ' '.join(context),          
            'Index': instance.index,               
            'Sense Key': sense_key                  
        })
    return pd.DataFrame(data)


dev_df = create_dataframe(dev_instances, dev_key)
test_df = create_dataframe(test_instances, test_key)


stop_words = set(stopwords.words("english"))

def preprocess_context(context):
    processed_context = []
    for word in context:
        # Convert to lowercase
        word = word.lower()

        # Handle "@card@" tokens and numeric values by replacing with "NUM"
        if word == "@card@" or re.fullmatch(r'\d+', word):
            processed_context.append("NUM")
            continue
        
        # Preserve periods within abbreviations and replace with underscores (e.g., "u.n." -> "u_n")
        word = re.sub(r'\b(\w\.)+', lambda match: match.group(0).replace('.', '_'), word)
        
        # Split hyphenated compound words (e.g., "u_n-sponsored" -> ["u_n", "sponsored"])
        parts = re.split(r'-(?=\w)', word)
        
        # Process each part separately
        for part in parts:
            # Remove isolated punctuation from each part
            part = part.strip(string.punctuation)
            
            # Lemmatize, remove stop words, and add to processed context if not empty
            
            # if part and part not in stop_words:
            #     processed_context.append(lemmatizer.lemmatize(part))
                
            if part:
                processed_context.append(lemmatizer.lemmatize(part))
    
    return processed_context

# Example application
dev_df['Modified Context'] = dev_df['Original Context'].apply(preprocess_context)
test_df['Modified Context'] = test_df['Original Context'].apply(preprocess_context)

dev_df['Combined Modified Context'] = dev_df['Modified Context'].apply(lambda x: ' '.join(x))
test_df['Combined Modified Context'] = test_df['Modified Context'].apply(lambda x: ' '.join(x))

In [None]:
import sys

path_to_key = '/Users/aidanlicoppe/Documents/Code/keys'

# Add the full file path of the directory containing api_keys.py
sys.path.append(path_to_key)

# Now import open_ai_key from api_keys.py
from api_keys import google_api_key

In [3]:
client = openai.OpenAI(api_key=open_ai_key)

In [112]:
def get_synset_codes(target_word):
    """
    Retrieve the synset codes and definitions for the target word.
    """
    synsets = wn.synsets(target_word)
    synset_info = [(synset.name(), synset.definition()) for synset in synsets]
    return synset_info

def create_prompt_batches(sentences, lemmas, words_per_prompt=50):
    # Words per prompt specifies the number of target lemmas per prompt
    prompts = []
    current_prompt = "For each sentence below, choose the correct WordNet synset code for the target word in context.\n\n"
    
    total_length = len(sentences)
    lemmas_in_prompt = []

    current_lemma_list = []
    for i, (sentence, lemma) in enumerate(zip(sentences, lemmas)):
        synset_info = get_synset_codes(lemma)
        synset_descriptions = "\n".join([f"{code}: {definition}" for code, definition in synset_info])
        
        current_entry = (
            f"Sentence {i + 1}: '{sentence}'\n"
            f"Target word: '{lemma}'\n"
            f"Synset Options:\n{synset_descriptions}\n\n"
        )
        
        current_prompt += current_entry
        current_lemma_list.append(lemma)
        
        if (i + 1) % words_per_prompt == 0 or i + 1 == total_length:
            current_prompt += "Please respond only with the synset code for each sentence, in order in the form of a comma-separated list, with no space between commas."
            prompts.append(current_prompt)
            
            lemmas_in_prompt.append(current_lemma_list)
            current_lemma_list = []
            current_prompt = "For each sentence below, choose the correct WordNet synset code for the target word in context.\n\n"

    return prompts, lemmas_in_prompt

def estimate_tokens(prompt, tokenizer):
    response = tokenizer.count_tokens(prompt)
    num_tokens = response.total_tokens
    return num_tokens

In [2]:
import google.generativeai as genai
import os
from google.cloud import aiplatform
from vertexai.preview.tokenization import get_tokenizer_for_model

In [32]:
genai.configure(api_key=google_api_key)

model = genai.GenerativeModel("gemini-1.5-flash")

tokenizer = get_tokenizer_for_model("gemini-1.5-flash-002")

In [165]:
def create_prompt_batch(sentences, lemmas):
    prompt = "For each sentence below, choose the correct WordNet synset code for the target word in context.\n\n"
    for i, (sentence, lemma) in enumerate(zip(sentences, lemmas)):
        synset_info = get_synset_codes(lemma)
        # Add each sentence, lemma, and synset options to the prompt
        synset_descriptions = "\n".join([f"{code}: {definition}" for code, definition in synset_info])
        prompt += (
            f"Sentence {i + 1}: '{sentence}'\n"
            f"Target word: '{lemma}'\n"
            f"Synset Options:\n{synset_descriptions}\n\n"
        )
    prompt += "Please respond only with the synset code for each sentence, in order in the form of a comma-separated list."
    
    return prompt

In [141]:
sentences = test_df['Combined Context']
lemmas = test_df['Lemma']

prompt = create_prompt_batch(sentences, lemmas)
num_tokens = estimate_tokens(prompt, tokenizer)

# Given that gemini allows up to 1 million tokens, we don't need to create batches and can include it all as a single prompt
# For this, we will split up the prompts into chunks of lemmas such that the model can focus better


def llm_WSD(sentences, lemmas, words_per_prompt=20, wait_time=5):
    prompts, lemmas_in_prompt = create_prompt_batches(sentences, lemmas, words_per_prompt=words_per_prompt)
    predictions = []
    total_num_replacements = 0
    
    for j, prompt in enumerate(tqdm(prompts, desc="Processing Prompts")):
        response = model.generate_content(prompt)
        text_response = response.text
        cleaned_response = text_response.replace(" ", "").replace("\n", "")
        response_list = cleaned_response.split(",")
        
        if len(response_list) != len(lemmas_in_prompt[j]):
            # In the case where there is an error and no response is produced, we insert the most likely synset
            for i in range(len(lemmas_in_prompt[j])):
               possible_synsets_sets = get_synset_codes(lemmas_in_prompt[j][i])
               possible_synsets = [code for code, definition in possible_synsets_sets]
               if response_list[i] not in possible_synsets:
                   response_list.insert(i, possible_synsets[0])
                   total_num_replacements += 1
        
        predictions += response_list
        
        time.sleep(wait_time)
    
    return predictions, total_num_replacements

In [162]:
df = test_df

sentences = df['Combined Context']
lemmas = df['Lemma']

predictions, num_replacements = llm_WSD(sentences, lemmas, words_per_prompt=50, wait_time=1)

Processing Prompts: 100%|██████████| 10/10 [00:38<00:00,  3.87s/it]


In [163]:
actual_synsets = [wn.lemma_from_key(row['Sense Key'][0]).synset() for _, row in df.iterrows()]

In [157]:
def calculate_accuracy(predicted_synsets, actual_synsets, get_name=False):

    correct_predictions = 0
    correct_indices = []
    incorrect_indices = []
    
    for i, (predicted, actual) in enumerate(zip(predicted_synsets, actual_synsets)):
        if get_name:
            actual = actual.name()
        
        if predicted == actual:
            correct_predictions += 1
            correct_indices.append(i)
        else:
            incorrect_indices.append(i)

    # Calculate accuracy
    accuracy = correct_predictions / len(actual_synsets)
    return accuracy, correct_indices, incorrect_indices

In [164]:
acc, correct_indices, incorrect_indices = calculate_accuracy(predictions, actual_synsets, get_name=True)

print("--" * 20)
print(f"Accuracy for WSD Using the Google Gemnini 1.5 Model: {acc*100:.2f}%")
print("--" * 20)

Accuracy: 70.10%


In [167]:
correct_samples = random.sample(correct_indices, 1) 
incorrect_samples = random.sample(incorrect_indices, 1)

print("\n----------------------------------------- Correct Prediction Debug -----------------------------------------")
for idx in correct_samples:
    prompt_correct = create_prompt_batch([sentences[idx]], [lemmas[idx]])
    prompt_correct = prompt_correct.replace("Please respond only with the synset code for each sentence, in order in the form of a comma-separated list, with no space between commas.", "")
    prompt_correct += 'Please go through each of the available definitions, and justify why the word sense aligns with the definition or not. Finally, once you have gone through each definition, provide the correct synset of those available.'
    response_correct = model.generate_content(prompt_correct)
    
    print(f"\nInstance ID: {df.iloc[idx]['Instance ID']}")
    print(f"Prompt:\n{prompt_correct}")
    print(f"\nResponse:\n{response_correct.text}")
    print(f"\nActual Synset: {actual_synsets[idx]}")

print("\n----------------------------------------- Incorrect Prediction Debug -----------------------------------------")
for idx in incorrect_samples:
    prompt_incorrect = create_prompt_batch([sentences[idx]], [lemmas[idx]])
    prompt_incorrect = prompt_incorrect.replace("Please respond only with the synset code for each sentence, in order in the form of a comma-separated list, with no space between commas.", "")
    prompt_incorrect += 'Please go through each of the available definitions, and justify why the word sense aligns with the definition or not. Finally, once you have gone through each definition, provide the correct synset of those available.'
    
    response_incorrect = model.generate_content(prompt_incorrect)
    
    print(f"\nInstance ID: {df.iloc[idx]['Instance ID']}")
    print(f"Prompt:\n{prompt_incorrect}")
    print(f"\nResponse:\n{response_incorrect.text}")
    print(f"\nActual Synset: {actual_synsets[idx]}")


----------------------------------------- Correct Prediction Debug -----------------------------------------

Instance ID: d001.s020.t005
Prompt:
For each sentence below, choose the correct WordNet synset code for the target word in context.

Sentence 1: 'this be clearly a game where a new economic hegemony be be develop , say Ulate , who also serve as the regional Mexico and central_america climate_change adviser for conservation_international .'
Target word: 'climate_change'
Synset Options:
climate_change.n.01: a change in the world's climate

Please respond only with the synset code for each sentence, in order in the form of a comma-separated list.Please go through each of the available definitions, and justify why the word sense aligns with the definition or not. Finally, once you have gone through each definition, provide the correct synset of those available.

Response:
climate_change.n.01

Justification:  The sentence discusses climate change in the context of a regional adviso