In [1]:
!pip install torch torchvision torchaudio
!pip install torch-geometric
!pip install torch-scatter



In [2]:
import nltk
nltk.download('wordnet')
from nltk.corpus import wordnet as wn
import networkx as nx

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [3]:
def calculate_edge_weight(synset_a, synset_b):
    """Calculate the edge weight between two synsets using Wu-Palmer similarity"""
    try:
        similarity = synset_a.wup_similarity(synset_b)
        return similarity if similarity is not None else 0.0
    except ValueError:
        return 0.0

def get_neighbors(graph, node_name):
    """Get neighbors of any node (for testing purposes)"""
    if node_name not in graph:
        print(f"Node '{node_name}' not found in the graph")
        return []
    return list(graph.neighbors(node_name))

class WordNetGraph:
    def __init__(self):
        self.graph = nx.Graph()
        self.start_synset = None

    def add_synset_node(self, synset, excluded_nodes_set=None):
        """Adds a synset node if it doesn't already exist and is not in the excluded set"""
        if excluded_nodes_set and synset.name() in excluded_nodes_set:
            return False
        if synset.name() not in self.graph:
            self.graph.add_node(synset.name(), synset=synset)
            return True
        return False

    def add_edge_with_wup(self, synset_a, synset_b, excluded_nodes_set=None):
        """Adds an edge between two synsets using Wu-Palmer similarity as the weight,
        only if both nodes are allowed and present in the graph"""
        a_name = synset_a.name()
        b_name = synset_b.name()
        self.add_synset_node(synset_a, excluded_nodes_set=excluded_nodes_set)
        self.add_synset_node(synset_b, excluded_nodes_set=excluded_nodes_set)

        if a_name in self.graph and b_name in self.graph:
            weight = calculate_edge_weight(synset_a, synset_b)
            if weight > 0.0:
                self.graph.add_edge(a_name, b_name, weight=weight)
                return True
        return False


    def add_polysemy_edges(self, synset, excluded_nodes_set=None):
        """Connects edges to synonyms"""
        lemma_name = synset.lemmas()[0].name()
        all_senses = wn.synsets(lemma_name)
        new_nodes_added = set()

        for other_sense in all_senses:
            if other_sense.name() == synset.name():
                continue
            if excluded_nodes_set and other_sense.name() in excluded_nodes_set:
                continue
            if self.add_edge_with_wup(synset, other_sense, excluded_nodes_set=excluded_nodes_set):
                new_nodes_added.add(other_sense)
        return new_nodes_added

    def add_hierarchical_edges(self, synset, excluded_nodes_set=None):
        """Connects edges to direct hypernyms and hyponyms"""
        related_synsets = synset.hypernyms() + synset.hyponyms()
        new_nodes_added = set()

        for neighbor in related_synsets:
            if excluded_nodes_set and neighbor.name() in excluded_nodes_set:
                continue
            if self.add_edge_with_wup(synset, neighbor, excluded_nodes_set=excluded_nodes_set):
                new_nodes_added.add(neighbor)
        return new_nodes_added


    def build_graph_from_synset(self, input_synset_str, depth, excluded_nodes_set=None):
        """Build a multi-relational graph using BFS"""
        try:
            self.start_synset = wn.synset(input_synset_str)
        except nltk.corpus.wordnet.WordNetError:
            print(f"Error: Synset '{input_synset_str}' not found in WordNet.")
            return

        self.graph.clear()
        if excluded_nodes_set and self.start_synset.name() in excluded_nodes_set:
            print(f"Warning: Start synset '{input_synset_str}' is in the excluded set. Graph will be empty.")
            return

        self.add_synset_node(self.start_synset, excluded_nodes_set=excluded_nodes_set)
        if not self.graph.nodes:
            return

        to_process = {self.start_synset}
        processed = {self.start_synset}
        #print(f"Building graph for '{input_synset_str}' to depth {depth}, excluding {len(excluded_nodes_set) if excluded_nodes_set else 0} nodes...")
        for current_depth in range(depth):
            next_to_process = set()
            for synset in to_process:
                # apply synonym edges
                poly_neighbors = self.add_polysemy_edges(synset, excluded_nodes_set=excluded_nodes_set)
                # apply hypernym/hyponym edges
                hiero_neighbors = self.add_hierarchical_edges(synset, excluded_nodes_set=excluded_nodes_set)
                # combine all neighbors
                new_neighbors = poly_neighbors.union(hiero_neighbors)

                for neighbor in new_neighbors:
                    if neighbor not in processed and (not excluded_nodes_set or neighbor.name() not in excluded_nodes_set):
                        next_to_process.add(neighbor)

            processed.update(next_to_process)
            to_process = next_to_process
            if not to_process:
                break

            #print(f"Depth {current_depth+1} reached. Total nodes: {len(self.graph.nodes)}. Nodes to process next: {len(to_process)}")
        #print(f"Graph build complete. Final Nodes: {len(self.graph.nodes)}, Final Edges: {len(self.graph.edges)}")

    def get_all_definitions(self):
        """Retrieves definitions for all nodes"""
        definitions = {}
        for node_name, data in self.graph.nodes(data=True):
            synset = data['synset']
            definitions[node_name] = f"{synset.definition()} ({'; '.join(synset.examples())})"
        return definitions

    def get_all_edges(self):
        """Retrieves a list of edges and their initial weights"""
        edges = []
        edge_attr = []
        for u, v, data in self.graph.edges(data=True):
            edges.append((u, v))
            edge_attr.append(data['weight'])
        return edges, edge_attr

In [4]:
def get_neighbors(graph, node_name):
    if node_name not in graph:
        print(f"Warning: Node '{node_name}' not found in the graph.")
        return []
    return list(graph.neighbors(node_name))

In [5]:
!pip install sentence-transformers



In [6]:
import torch
from sentence_transformers import SentenceTransformer
import torch.nn.functional as F

sentence_model = SentenceTransformer('all-MiniLM-L6-v2')
def get_node_features(graph_instance):
    """
    Generates node features for the WordNet graph by encoding synset definitions.

    Args:
        graph_instance (WordNetGraph): An instance of the WordNetGraph class.

    Returns:
        torch.Tensor: A tensor containing the node features (embeddings) for each synset.
        dict: A mapping from synset name to its index in the feature tensor.
    """
    definitions = graph_instance.get_all_definitions()
    synset_names = sorted(definitions.keys())
    definition_list = [definitions[name] for name in synset_names]
    embeddings = sentence_model.encode(definition_list, convert_to_tensor=True, show_progress_bar=False)

    # create a mapping from synset name to its index in the embeddings tensor
    synset_to_idx = {name: i for i, name in enumerate(synset_names)}
    return embeddings, synset_to_idx

def get_synset_embedding(synset_name):
    """Get embedding for a single synset definition"""
    try:
        synset = wn.synset(synset_name)
        definition = f"{synset.definition()} ({'; '.join(synset.examples())})"
        embedding = sentence_model.encode([definition], convert_to_tensor=True, show_progress_bar=False)
        return embedding
    except (nltk.corpus.wordnet.WordNetError, KeyError):
        return None

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

In [7]:
import torch
import torch.nn as nn
import torch_geometric.nn as gnn

class GNNModel(nn.Module):
    def __init__(self, num_node_features, hidden_channels, edge_dim):
        super(GNNModel, self).__init__()
        # GCN layer
        self.conv1 = gnn.GCNConv(num_node_features, hidden_channels)
        # linear layer
        self.linear = nn.Linear(2 * hidden_channels + edge_dim, 1)
        # (2 * hidden_channels for concatenated node features + edge_dim for edge attributes)

    def forward(self, x, edge_index, edge_attr):
        # apply ReLU activation
        x = self.conv1(x, edge_index)
        x = torch.relu(x)

        row, col = edge_index
        edge_features = torch.cat([x[row], x[col]], dim=1)
        edge_features = torch.cat([edge_features, edge_attr], dim=1)

        # apply linear layer to predict edge scores (logits)
        edge_scores = self.linear(edge_features)
        return edge_scores

In [8]:
import random
import networkx as nx

def generate_negative_samples(graph, num_negative_samples):
    """
    Generates negative samples (non-existent edges) for a given graph.

    Args:
        graph (nx.Graph): The original NetworkX graph.
        num_negative_samples (int): The number of negative samples to generate.

    Returns:
        list: A list of tuples, where each tuple represents a non-existent edge (a, b).
    """
    negative_samples = set()
    nodes = list(graph.nodes())
    num_nodes = len(nodes)
    if num_nodes < 2:
        return []

    max_attempts = num_negative_samples * 10 # this prevents infinite looping
    attempts = 0
    while len(negative_samples) < num_negative_samples and attempts < max_attempts:
        a = random.choice(nodes)
        b = random.choice(nodes)
        if a != b and not graph.has_edge(a, b) and not graph.has_edge(b, a):
            ordered_edge = tuple(sorted((a, b)))
            negative_samples.add(ordered_edge)
        attempts += 1

    return list(negative_samples)

In [9]:
#**Training Data Creation**
import torch_geometric.data

# 1. create an instance of WordNetGraph
wn_graph_train = WordNetGraph()

# 2. call build_graph_from_synset with entity.n.01
train_synset = 'entity.n.01'
wn_graph_train.build_graph_from_synset(train_synset, depth=4)
train_nodes_set = set(wn_graph_train.graph.nodes()) # keep track of added nodes

# 3. call the get_node_features function to obtain node embeddings and a mapping
node_embeddings_train, synset_to_idx_train = get_node_features(wn_graph_train)

# 4. retrieve all positive edges (a_name, b_name) and their original Wu-Palmer weights
graph_edges_train_positive, graph_edge_weights_train_positive = wn_graph_train.get_all_edges()

# 5. initialize lists for positive edge_index and corresponding input/target attributes
edge_index_list_train_positive = []
initial_edge_attr_train_input = []
target_labels_train_positive = []
for i, (a_name, b_name) in enumerate(graph_edges_train_positive):
    a_idx = synset_to_idx_train[a_name]
    b_idx = synset_to_idx_train[b_name]
    edge_index_list_train_positive.append([a_idx, b_idx])
    initial_edge_attr_train_input.append(graph_edge_weights_train_positive[i])
    target_labels_train_positive.append(1.0) # positive label = 1

# convert pos edges to tensors
edge_index_train_positive = torch.tensor(edge_index_list_train_positive, dtype=torch.long).t().contiguous()
initial_edge_attr_train_input = torch.tensor(initial_edge_attr_train_input, dtype=torch.float).unsqueeze(1)
target_labels_train_positive = torch.tensor(target_labels_train_positive, dtype=torch.float).unsqueeze(1)

# 6. generate negative samples
num_negative_samples_train = len(graph_edges_train_positive)
negative_samples_train = generate_negative_samples(wn_graph_train.graph, num_negative_samples_train)

# 7. convert negative samples to edge_index format and create corresponding attributes
edge_index_list_train_negative = []
initial_edge_attr_train_negative = []
target_labels_train_negative = []
for a_name, b_name in negative_samples_train:
    if a_name in synset_to_idx_train and b_name in synset_to_idx_train:
        a_idx = synset_to_idx_train[a_name]
        b_idx = synset_to_idx_train[b_name]
        edge_index_list_train_negative.append([a_idx, b_idx])
        initial_edge_attr_train_negative.append(0.0)
        target_labels_train_negative.append(0.0) # negative label = 0

# convert neg edge data to tensors
edge_index_train_negative = torch.tensor(edge_index_list_train_negative, dtype=torch.long).t().contiguous()
initial_edge_attr_train_negative = torch.tensor(initial_edge_attr_train_negative, dtype=torch.float).unsqueeze(1)
target_labels_train_negative = torch.tensor(target_labels_train_negative, dtype=torch.float).unsqueeze(1)

# 8. concatenate positive and negative samples for the final train_data object
train_edge_index = torch.cat([edge_index_train_positive, edge_index_train_negative], dim=1)
train_initial_edge_attr = torch.cat([initial_edge_attr_train_input, initial_edge_attr_train_negative], dim=0)
train_target_labels = torch.cat([target_labels_train_positive, target_labels_train_negative], dim=0)

train_data = torch_geometric.data.Data(
    x=node_embeddings_train,
    edge_index=train_edge_index,
    edge_attr=train_initial_edge_attr,
    y=train_target_labels
)


In [10]:
#**Evaluation Data Creation** (run optionally, for demo purposes)
# 1. create an instance of WordNetGraph
wn_graph_eval = WordNetGraph()

# 2. call build_graph_from_synset with bank.n.01 for demo purposes
eval_synset = 'bank.n.01'
wn_graph_eval.build_graph_from_synset(eval_synset, depth=3, excluded_nodes_set=train_nodes_set)

# 3. call the get_node_features function to obtain node embeddings and a mapping for evaluation data
node_embeddings_eval, synset_to_idx_eval = get_node_features(wn_graph_eval)

# 4. retrieve all positive edges (a_name, b_name) and their original Wu-Palmer weights
graph_edges_eval_positive, graph_edge_weights_eval_positive = wn_graph_eval.get_all_edges()

# 5. initialize lists for positive edge_index and corresponding input/target attributes
edge_index_list_eval_positive = []
initial_edge_attr_eval_input = []
target_labels_eval_positive = []
for i, (a_name, b_name) in enumerate(graph_edges_eval_positive):
    if a_name in synset_to_idx_eval and b_name in synset_to_idx_eval:
        a_idx = synset_to_idx_eval[a_name]
        b_idx = synset_to_idx_eval[b_name]
        edge_index_list_eval_positive.append([a_idx, b_idx])
        initial_edge_attr_eval_input.append(graph_edge_weights_eval_positive[i])
        target_labels_eval_positive.append(1.0) # positive label = 0

# convert pos edge data to tensors
if edge_index_list_eval_positive:
    edge_index_eval_positive = torch.tensor(edge_index_list_eval_positive, dtype=torch.long).t().contiguous()
    initial_edge_attr_eval_input = torch.tensor(initial_edge_attr_eval_input, dtype=torch.float).unsqueeze(1)
    target_labels_eval_positive = torch.tensor(target_labels_eval_positive, dtype=torch.float).unsqueeze(1)
else:
    edge_index_eval_positive = torch.empty((2, 0), dtype=torch.long)
    initial_edge_attr_eval_input = torch.empty((0, 1), dtype=torch.float)
    target_labels_eval_positive = torch.empty((0, 1), dtype=torch.float)

# 6. generate negative samples
num_negative_samples_eval = len(graph_edges_eval_positive)
negative_samples_eval = generate_negative_samples(wn_graph_eval.graph, num_negative_samples_eval)

# 7. convert negative samples to edge_index format and create corresponding attributes
edge_index_list_eval_negative = []
initial_edge_attr_eval_negative = []
target_labels_eval_negative = []
for a_name, b_name in negative_samples_eval:
    if a_name in synset_to_idx_eval and b_name in synset_to_idx_eval:
        a_idx = synset_to_idx_eval[a_name]
        b_idx = synset_to_idx_eval[b_name]
        edge_index_list_eval_negative.append([a_idx, b_idx])
        initial_edge_attr_eval_negative.append(0.0)
        target_labels_eval_negative.append(0.0) # negative label = 0

# convert neg edge data to tensors
if edge_index_list_eval_negative:
    edge_index_eval_negative = torch.tensor(edge_index_list_eval_negative, dtype=torch.long).t().contiguous()
    initial_edge_attr_eval_negative = torch.tensor(initial_edge_attr_eval_negative, dtype=torch.float).unsqueeze(1)
    target_labels_eval_negative = torch.tensor(target_labels_eval_negative, dtype=torch.float).unsqueeze(1)
else:
    edge_index_eval_negative = torch.empty((2, 0), dtype=torch.long)
    initial_edge_attr_eval_negative = torch.empty((0, 1), dtype=torch.float)
    target_labels_eval_negative = torch.empty((0, 1), dtype=torch.float)

# 8. concatenate positive and negative samples for the final eval_data object
eval_edge_index = torch.cat([edge_index_eval_positive, edge_index_eval_negative], dim=1)
eval_initial_edge_attr = torch.cat([initial_edge_attr_eval_input, initial_edge_attr_eval_negative], dim=0)
eval_target_labels = torch.cat([target_labels_eval_positive, target_labels_eval_negative], dim=0)

eval_data = torch_geometric.data.Data(
    x=node_embeddings_eval,
    edge_index=eval_edge_index,
    edge_attr=eval_initial_edge_attr,
    y=eval_target_labels
)


In [11]:
import torch.optim as optim

num_node_features = train_data.num_node_features
hidden_channels = 64
edge_dim = 1
model = GNNModel(num_node_features, hidden_channels, edge_dim)
criterion = nn.BCEWithLogitsLoss()
learning_rate = 0.05
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [12]:
num_epochs = 100

for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()

    predicted_edge_logits = model(train_data.x.clone(), train_data.edge_index, train_data.edge_attr)

    # calculate loss using BCEWithLogitsLoss and the binary target labels (train_data.y)
    loss = criterion(predicted_edge_logits, train_data.y)

    loss.backward()
    optimizer.step()
    if (epoch + 1) % 10 == 0:
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}')

Epoch 10/100, Loss: 0.28884828090667725
Epoch 20/100, Loss: 0.19055433571338654
Epoch 30/100, Loss: 0.14692918956279755
Epoch 40/100, Loss: 0.11531563848257065
Epoch 50/100, Loss: 0.09056276828050613
Epoch 60/100, Loss: 0.07237529009580612
Epoch 70/100, Loss: 0.05936029180884361
Epoch 80/100, Loss: 0.049844641238451004
Epoch 90/100, Loss: 0.042620811611413956
Epoch 100/100, Loss: 0.036974336951971054


In [14]:
from sklearn.metrics import accuracy_score
import numpy as np

model.eval()
with torch.no_grad():
    predicted_eval_edge_logits = model(eval_data.x.clone(), eval_data.edge_index, eval_data.edge_attr)

eval_loss = criterion(predicted_eval_edge_logits, eval_data.y)
predicted_probs = torch.sigmoid(predicted_eval_edge_logits)
predicted_labels = (predicted_probs > 0.5).float()

# convert tensors to numpy arrays for sklearn metrics
true_labels = eval_data.y.cpu().numpy()
predicted_labels_np = predicted_labels.cpu().numpy()

# metrics
accuracy = accuracy_score(true_labels, predicted_labels_np)

print(f"Evaluation Accuracy: {accuracy:.4f}")

Evaluation Accuracy: 0.5007


In [15]:
def apply_gnn_shift(eval_data):
    model.eval()
    with torch.no_grad():
        predicted_eval_edge_logits = model(eval_data.x.clone(), eval_data.edge_index, eval_data.edge_attr)

    predicted_eval_edge_weights = torch.sigmoid(predicted_eval_edge_logits)
    predicted_weights = [weight.item() for weight in predicted_eval_edge_weights]

    return predicted_weights

In [16]:
def find_word_cost(wn_graph, start_node_name, target_node_name, predicted_edge_probabilities_list, graph_edges):
    """
    Calculates the cost of the path from start_node_name to target_node_name
    using the predicted GNN edge probabilities. The cost is the average of the inverted
    probabilities (1 - probability) along the path.
    """
    temp_graph = nx.Graph() # going to store a subgraph for the path

    for node_name in wn_graph.graph.nodes():
        temp_graph.add_node(node_name)

    for i, (a_name, b_name) in enumerate(graph_edges):
        probability = predicted_edge_probabilities_list[i]
        cost = 1.0 - probability
        temp_graph.add_edge(a_name, b_name, weight=cost)

    if (start_node_name not in temp_graph.nodes) or (target_node_name not in temp_graph.nodes):
        return float('inf')

    try:
        path_nodes = nx.shortest_path(temp_graph, source=start_node_name, target=target_node_name, weight='weight')

        if len(path_nodes) <= 1:
            return 0.0

        path_costs = []
        for j in range(len(path_nodes) - 1):
            a = path_nodes[j]
            b = path_nodes[j+1]
            if temp_graph.has_edge(a, b):
                path_costs.append(temp_graph[a][b]['weight'])

        if path_costs:
            average_path_cost = sum(path_costs)/len(path_costs)
            return average_path_cost
        else:
            return float('inf')

    except nx.NetworkXNoPath:
        return float('inf')
    except Exception as e:
        #print(f"Error finding path between {start_node_name} and {target_node_name}: {e}")
        return float('inf')

In [17]:
def get_most_frequent_sense(word):
    try:
      return wn.synsets(word)[0]
    except:
      print('Sense ', word, ' not found in graph')

def check_graph_for_sense(sense):
    if sense.name() in set(wn_graph_eval.graph.nodes()):
        return True
    else:
        return False

def get_all_senses(word):
    try:
        synsets = wn.synsets(word)
        return [s.name() for s in synsets]
    except:
        print('Senses of ', word, ' not found in wordnet')

In [18]:
import csv
import io
import re
import pandas as pd

with open('./sample_data/test_data.csv', 'r') as file:
    data = file.read()

data_io = io.StringIO(data)
csv_reader = csv.reader(data_io)
header = next(csv_reader)
rows = list(csv_reader)
df_original = pd.DataFrame(rows, columns=header)
sentences = df_original['sentence/context'].tolist()
words_to_disambiguate = df_original['polysemy_word'].tolist()

tokenized_sentences = []
for sentence in sentences:
    # remove any punctuation
    cleaned_sentence = re.sub(r'[\W_]+', ' ', sentence).lower()
    tokens = cleaned_sentence.split()
    tokenized_sentences.append(tokens)

wsd_senses = []
for word in words_to_disambiguate:
    wsd_senses.append(get_all_senses(word))

print("Original DataFrame head:")
display(df_original.head())
print("\nSentences:", sentences[:5]) # display first 5 for brevity
print("Tokenized sentences:", tokenized_sentences[:5]) # display first 5
print("Words to disambiguate:", words_to_disambiguate[:5]) # display first 5
print("WSD senses (first entry):", wsd_senses[0]) # display first entry

Original DataFrame head:


Unnamed: 0,﻿sn,sentence/context,polysemy_word
0,345,His petition charged mental cruelty .,charged
1,346,His petition charged mental cruelty .,mental
2,347,His petition charged mental cruelty .,cruelty
3,1405,One validated acts of school districts .,One
4,1406,One validated acts of school districts .,validated



Sentences: ['His petition charged mental cruelty .', 'His petition charged mental cruelty .', 'His petition charged mental cruelty .', 'One validated acts of school districts .', 'One validated acts of school districts .']
Tokenized sentences: [['his', 'petition', 'charged', 'mental', 'cruelty'], ['his', 'petition', 'charged', 'mental', 'cruelty'], ['his', 'petition', 'charged', 'mental', 'cruelty'], ['one', 'validated', 'acts', 'of', 'school', 'districts'], ['one', 'validated', 'acts', 'of', 'school', 'districts']]
Words to disambiguate: ['charged', 'mental', 'cruelty', 'One', 'validated']
WSD senses (first entry): ['charge.v.01', 'charge.v.02', 'charge.v.03', 'tear.v.03', 'appoint.v.02', 'charge.v.06', 'charge.v.07', 'charge.v.08', 'charge.v.09', 'commit.v.03', 'consign.v.02', 'charge.v.12', 'charge.v.13', 'agitate.v.02', 'charge.v.15', 'load.v.02', 'charge.v.17', 'charge.v.18', 'charge.v.19', 'charge.v.20', 'blame.v.03', 'charge.v.22', 'charge.v.23', 'charge.v.24', 'charge.v.25', '

In [20]:
import torch_geometric.data

def build_sense_graph(sense, depth=3):
    """
    sense: str of the synset sense (bank.n.01)
    depth: int of the depth of the graph to build
    """
    sense_graph = WordNetGraph()
    sense_graph.build_graph_from_synset(sense, depth, excluded_nodes_set=train_nodes_set)
    if not sense_graph.graph.nodes():
        return [sense_graph, None, None]

    node_embeddings, synset_to_idx = get_node_features(sense_graph)
    graph_edges, graph_edge_weights = sense_graph.get_all_edges()

    edge_index_list = []
    edge_attr_list = []
    for i, (a_name, b_name) in enumerate(graph_edges):
        # Ensure nodes are still in the synset_to_idx map (they should be if the graph was built correctly)
        if a_name in synset_to_idx and b_name in synset_to_idx:
            a_idx = synset_to_idx[a_name]
            b_idx = synset_to_idx[b_name]
            edge_index_list.append([a_idx, b_idx])
            edge_attr_list.append(graph_edge_weights[i])

    if edge_index_list:
        edge_index = torch.tensor(edge_index_list, dtype=torch.long).t().contiguous()
    else:
        edge_index = torch.empty((2, 0), dtype=torch.long)

    if edge_attr_list:
        edge_attr = torch.tensor(edge_attr_list, dtype=torch.float).unsqueeze(1)
    else:
        edge_attr = torch.empty((0, 1), dtype=torch.float)

    eval_data = torch_geometric.data.Data(x=node_embeddings, edge_index=edge_index, edge_attr=edge_attr)
    return [sense_graph, eval_data, graph_edges]

*Important: set start_row_index and end_row_index below.

In [21]:
# define the slice for testing
start_row_index = 0
end_row_index = 5

df_original_sliced = df_original.iloc[start_row_index:end_row_index].copy()
df_original_sliced.reset_index(drop=True, inplace=True)
sentences = df_original_sliced['sentence/context'].tolist()
words_to_disambiguate = df_original_sliced['polysemy_word'].tolist()

tokenized_sentences = []
for sentence in sentences:
    cleaned_sentence = re.sub(r'[\W_]+', ' ', sentence).lower()
    tokens = cleaned_sentence.split()
    tokenized_sentences.append(tokens)

wsd_senses = []
for word in words_to_disambiguate:
    wsd_senses.append(get_all_senses(word))

print(f"Processing rows {start_row_index} to {end_row_index-1} from original data.")
print("Sliced sentences (first 5):", sentences[:5])
print("Sliced words to disambiguate (first 5):", words_to_disambiguate[:5])
print("Sliced WSD senses (first entry):", wsd_senses[0])


Processing rows 0 to 4 from original data.
Sliced sentences (first 5): ['His petition charged mental cruelty .', 'His petition charged mental cruelty .', 'His petition charged mental cruelty .', 'One validated acts of school districts .', 'One validated acts of school districts .']
Sliced words to disambiguate (first 5): ['charged', 'mental', 'cruelty', 'One', 'validated']
Sliced WSD senses (first entry): ['charge.v.01', 'charge.v.02', 'charge.v.03', 'tear.v.03', 'appoint.v.02', 'charge.v.06', 'charge.v.07', 'charge.v.08', 'charge.v.09', 'commit.v.03', 'consign.v.02', 'charge.v.12', 'charge.v.13', 'agitate.v.02', 'charge.v.15', 'load.v.02', 'charge.v.17', 'charge.v.18', 'charge.v.19', 'charge.v.20', 'blame.v.03', 'charge.v.22', 'charge.v.23', 'charge.v.24', 'charge.v.25', 'charged.a.01', 'charged.s.02', 'aerated.s.02', 'charged.s.04']


In [22]:
# warning: this usually takes a long time, as it has to build all the graphs
from tqdm import tqdm

pre_built_graphs = {}
all_unique_senses = set()
for sense_list in wsd_senses:
    for sense in sense_list:
        all_unique_senses.add(sense)

for sense in tqdm(all_unique_senses, desc="Building sense graphs"):
    res = build_sense_graph(sense, depth=3)
    pre_built_graphs[sense] = res

print("Finished pre-building graphs.")

Building sense graphs:   4%|▍         | 2/51 [00:08<03:06,  3.81s/it]



Building sense graphs:  16%|█▌        | 8/51 [00:22<02:09,  3.02s/it]



Building sense graphs:  27%|██▋       | 14/51 [01:14<03:50,  6.23s/it]



Building sense graphs:  35%|███▌      | 18/51 [01:43<03:49,  6.94s/it]



Building sense graphs:  57%|█████▋    | 29/51 [02:30<01:38,  4.46s/it]



Building sense graphs:  86%|████████▋ | 44/51 [03:14<00:05,  1.36it/s]



Building sense graphs:  92%|█████████▏| 47/51 [03:18<00:05,  1.34s/it]



Building sense graphs: 100%|██████████| 51/51 [03:52<00:00,  4.56s/it]

Finished pre-building graphs.





In [23]:
import pandas as pd # Import pandas here to ensure it's available

gnn_calculated_senses_output = [] # List to store ranked senses for each sentence

for i in range(len(sentences)):
    # print(f"Sentence {i+1} complete")
    costs = []
    best_senses_for_sentence = []
    for sense in wsd_senses[i]:
        sense_graph, sense_graph_tensors, graph_edges = pre_built_graphs[sense]
        if sense_graph_tensors is None or not graph_edges:
            total_cost_for_sense = float('inf')
        else:
            probabilities = apply_gnn_shift(sense_graph_tensors)
            min_costs_per_token = []

            for token in tokenized_sentences[i]:
                if token == words_to_disambiguate[i]:
                    continue
                else:
                    gnn_adjusted_costs_for_this_token = []
                    wup_unadjusted_costs_for_this_token = []
                    any_gnn_adjusted_cost_found = False

                    token_senses = get_all_senses(token)
                    for token_sense_name in token_senses:
                        current_token_cost = find_word_cost(sense_graph, sense, token_sense_name, probabilities, graph_edges)
                        if current_token_cost != float('inf'):
                            gnn_adjusted_costs_for_this_token.append(current_token_cost)
                            any_gnn_adjusted_cost_found = True
                        else:
                            try:
                                main_sense_synset = wn.synset(sense)
                                token_sense_synset = wn.synset(token_sense_name)
                                wup_sim = calculate_edge_weight(main_sense_synset, token_sense_synset)
                                if wup_sim > 0:
                                    wup_unadjusted_costs_for_this_token.append(1.0 - wup_sim)
                            except (nltk.corpus.wordnet.WordNetError, ValueError):
                                pass
                    if any_gnn_adjusted_cost_found:
                        possible_token_costs = gnn_adjusted_costs_for_this_token
                        indicator = " *"
                    else:
                        possible_token_costs = wup_unadjusted_costs_for_this_token
                        indicator = ""
                    # print(f'{indicator}Possible costs for token "{token}" with sense "{sense}": {possible_token_costs}')
                    if possible_token_costs:
                        min_costs_per_token.append(min(possible_token_costs))
            if min_costs_per_token:
                total_cost_for_sense = sum(min_costs_per_token) / len(min_costs_per_token)
            else:
                total_cost_for_sense = float('inf') # If no valid token costs, assign high cost

        # print(f'Cost for sense {sense}: {total_cost_for_sense}')
        costs.append(total_cost_for_sense)
        best_senses_for_sentence.append((sense, total_cost_for_sense))

    # print(f"Total costs for sentence {i}: {sum(costs)}, individual sense costs: {costs}")
    print(f"Sentence {i+1} complete")

    if best_senses_for_sentence:
        # print(f"Original sentence: {sentences[i]}")
        finite_cost_senses = [(s, c) for s, c in best_senses_for_sentence if c != float('inf')]
        if finite_cost_senses:
            ranked_senses = sorted(finite_cost_senses, key=lambda item: item[1], reverse=False) # sorts by lowest cost
            senses_string = ", ".join([s for s, _ in ranked_senses])
            gnn_calculated_senses_output.append(senses_string)
        else:
            # print(f"No best sense with a finite cost found for sentence {i}")
            gnn_calculated_senses_output.append("")
    else:
        # print(f"No costs calculated for sentence {i}")
        gnn_calculated_senses_output.append("")

print("\n--- GNN-adjusted WSD processing complete ---")

Sentence 1 complete
Sentence 2 complete
Sentence 3 complete
Sentence 4 complete
Sentence 5 complete

--- GNN-adjusted WSD processing complete ---


In [24]:
import pandas as pd

wup_calculated_senses_output = []

for i in range(len(sentences)):
    # print(f"Sentence {i+1} complete")
    costs = []
    best_senses_for_sentence = []
    filtered_candidate_senses = [s for s in wsd_senses[i] if s not in train_nodes_set]
    if not filtered_candidate_senses:
        wup_calculated_senses_output.append("")
        continue

    for sense in filtered_candidate_senses:
        min_costs_per_token = []
        for token in tokenized_sentences[i]:
            if token == words_to_disambiguate[i]:
                continue
            else:
                possible_token_costs = []
                token_senses = get_all_senses(token)
                for token_sense_name in token_senses:
                    if token_sense_name in train_nodes_set:
                        continue
                    try:
                        main_sense_synset = wn.synset(sense)
                        token_sense_synset = wn.synset(token_sense_name)
                        wup_sim = calculate_edge_weight(main_sense_synset, token_sense_synset)
                        if wup_sim > 0:
                            possible_token_costs.append(wup_sim)
                    except (nltk.corpus.wordnet.WordNetError, ValueError):
                        pass
                # print(f'Possible costs for token "{token}" with sense "{sense}": {possible_token_costs} (Wu-Palmer unadjusted)')
                if possible_token_costs:
                    min_costs_per_token.append(min(possible_token_costs))
        if min_costs_per_token:
            total_cost_for_sense = sum(min_costs_per_token) / len(min_costs_per_token)
        else:
            total_cost_for_sense = 0
        # print(f'Cost for sense {sense}: {total_cost_for_sense}')
        costs.append(total_cost_for_sense)
        best_senses_for_sentence.append((sense, total_cost_for_sense))
    # print(f"Total costs for sentence {i}: {sum(costs)}, individual sense costs: {costs}")
    print(f"Sentence {i+1} complete")
    if best_senses_for_sentence:
        # print(f"Original sentence: {sentences[i]}")
        positive_cost_senses = [(s, c) for s, c in best_senses_for_sentence if c > 0]
        if positive_cost_senses:
            ranked_senses = sorted(positive_cost_senses, key=lambda item: item[1], reverse=True)
            senses_string = ", ".join([s for s, _ in ranked_senses])
            wup_calculated_senses_output.append(senses_string)
        else:
            # print(f"No best sense with a positive Wu-Palmer similarity found for sentence {i}")
            wup_calculated_senses_output.append("")
    else:
        # print(f"No similarities calculated for sentence {i}")
        wup_calculated_senses_output.append("")

df_wup_results = df_original_sliced.copy() # Use the sliced DataFrame
df_wup_results['calculated_senses'] = wup_calculated_senses_output

for idx, row in df_wup_results.iterrows():
    if not row['calculated_senses']:
        df_wup_results.loc[idx, 'sentence/context'] = ''
        df_wup_results.loc[idx, 'polysemy_word'] = ''

print("\n--- Wu-Palmer results saved to 'test_data_wup_results.csv' ---")
df_wup_results.to_csv('test_data_wup_results.csv', index=False)


Sentence 1 complete
Sentence 2 complete
Sentence 3 complete
Sentence 4 complete
Sentence 5 complete

--- Wu-Palmer results saved to 'test_data_wup_results.csv' ---


In [25]:
# --- Final Results Processing and Saving ---
df_gnn_results = df_original_sliced.copy()
df_wup_results = df_original_sliced.copy()
df_gnn_results['calculated_senses'] = gnn_calculated_senses_output
df_wup_results['calculated_senses'] = wup_calculated_senses_output

for idx, row in df_gnn_results.iterrows():
    gnn_senses_empty = not df_gnn_results.loc[idx, 'calculated_senses']
    if gnn_senses_empty:
        df_gnn_results.loc[idx, 'calculated_senses'] = ''
        df_gnn_results.loc[idx, 'sentence/context'] = ''
        df_gnn_results.loc[idx, 'polysemy_word'] = ''
        df_wup_results.loc[idx, 'calculated_senses'] = ''
        df_wup_results.loc[idx, 'sentence/context'] = ''
        df_wup_results.loc[idx, 'polysemy_word'] = ''

# save GNN-adjusted results
print("\n--- GNN-adjusted results saved to 'test_data_gnn_results.csv' ---")
df_gnn_results.to_csv('test_data_gnn_results.csv', index=False)
display(df_gnn_results.head())

# save Wu-Palmer only results
print("\n--- Wu-Palmer results saved to 'test_data_wup_results.csv' ---")
df_wup_results.to_csv('test_data_wup_results.csv', index=False)
display(df_wup_results.head())

print("\nAll processing and saving complete")


--- GNN-adjusted results saved to 'test_data_gnn_results.csv' ---


Unnamed: 0,﻿sn,sentence/context,polysemy_word,calculated_senses
0,345,His petition charged mental cruelty .,charged,"charge.v.09, appoint.v.02, charged.a.01, charg..."
1,346,His petition charged mental cruelty .,mental,"mental.a.01, mental.a.02, mental.a.03, genial...."
2,347,His petition charged mental cruelty .,cruelty,"cruelty.n.01, cruelty.n.02, cruelty.n.03"
3,1405,One validated acts of school districts .,One,"one.s.01, one.s.02, one.s.03, one.s.04, one.s...."
4,1406,One validated acts of school districts .,validated,"validated.s.01, validate.v.02, validate.v.03, ..."



--- Wu-Palmer results saved to 'test_data_wup_results.csv' ---


Unnamed: 0,﻿sn,sentence/context,polysemy_word,calculated_senses
0,345,His petition charged mental cruelty .,charged,"charge.v.02, appoint.v.02, charged.a.01, charg..."
1,346,His petition charged mental cruelty .,mental,"mental.a.01, mental.a.02, mental.a.03, genial...."
2,347,His petition charged mental cruelty .,cruelty,"cruelty.n.02, cruelty.n.03, cruelty.n.01"
3,1405,One validated acts of school districts .,One,"one.s.01, one.s.02, one.s.03, one.s.04, one.s...."
4,1406,One validated acts of school districts .,validated,"validated.s.01, validate.v.02, validate.v.03, ..."



All processing and saving complete
