In [1]:
import nltk
from nltk.corpus import semcor
from nltk.corpus import wordnet as wn
nltk.data.path.append("./nltk_data")
import os
import torch
import json
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel, BertPreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput

### Model loading and configuration

In [2]:
# --- Model Architecture Definition ---
class BertForWSD(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.bert = AutoModel.from_config(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self.post_init()

    def forward(self, input_ids=None, attention_mask=None, target_token_idx=None, labels=None, **kwargs):
        kwargs.pop("num_items_in_batch", None)
        
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state
        
        batch_size = input_ids.shape[0]
        batch_indices = torch.arange(batch_size, device=input_ids.device)
        target_vectors = sequence_output[batch_indices, target_token_idx]
        
        target_vectors = self.dropout(target_vectors)
        logits = self.classifier(target_vectors)
        
        return SequenceClassifierOutput(
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

In [3]:
# --- Configuration ---
MODEL_DIR = "./bert_wsd_custom"  # <-- PATH to local folder with model
LABEL_MAP_FILE = "label_map.json"

def load_model_components(model_dir, label_map_file):
    print(f"Loading model from: {model_dir}")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    # 1. Load Label Map
    if not os.path.exists(label_map_file):
        raise FileNotFoundError(f"Label map not found at {label_map_file}")
    
    with open(label_map_file, 'r') as f:
        label2id = json.load(f)
    id2label = {v: k for k, v in label2id.items()}

    # 2. Load Tokenizer & Model
    tokenizer = AutoTokenizer.from_pretrained(model_dir)
    model = BertForWSD.from_pretrained(model_dir)
    model.to(device)
    model.eval()
    
    print("Model loaded successfully!")
    return model, tokenizer, id2label, device

In [4]:
# Run the loading
model, tokenizer, id2label, device = load_model_components(MODEL_DIR, LABEL_MAP_FILE)

Loading model from: ./bert_wsd_custom
Device: cuda
Model loaded successfully!


### Function for obtaining predictions

In [6]:
def predict_sense(sentence, target_word_idx, top_k=3):
    """
    Main function to predict sense.
    """
    words = sentence.split()
    
    # Validations
    if target_word_idx < 0 or target_word_idx >= len(words):
        print(f"Error: Index {target_word_idx} is out of bounds.")
        return

    target_word = words[target_word_idx]
    
    # Visual check
    print(f"\nProcessing sentence: {' '.join(words)}")
    print(f"Target word: '{target_word}' (Index {target_word_idx})")

    # 1. Find Character Start
    char_start = 0
    for i in range(target_word_idx):
        char_start += len(words[i]) + 1 
    
    # 2. Tokenize
    encoding = tokenizer(
        sentence,
        return_tensors="pt",
        return_offsets_mapping=True,
        truncation=True,
        max_length=128
    )
    
    inputs = {k: v.to(device) for k, v in encoding.items() if k != 'offset_mapping'}
    offsets = encoding['offset_mapping'].squeeze().tolist()
    
    # 3. Find BERT Token
    target_token_idx = -1
    for i, (o_start, o_end) in enumerate(offsets):
        if o_start == 0 and o_end == 0: continue
        
        if o_start == char_start:
            target_token_idx = i
            break
        if o_start < char_start and o_end > char_start:
             target_token_idx = i
             break
             
    if target_token_idx == -1:
        print("Error: Could not align word to BERT token.")
        return

    # 4. Inference
    with torch.no_grad():
        outputs = model(
            input_ids=inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            target_token_idx=torch.tensor([target_token_idx]).to(device)
        )
    
    probs = torch.softmax(outputs.logits[0], dim=0)
    top_probs, top_indices = torch.topk(probs, top_k)
    
    # 5. Output
    print("-" * 40)
    print(f"Predictions for '{target_word}':")
    for prob, idx in zip(top_probs, top_indices):
        label_str = id2label[int(idx.item())]
        print(f"  • {label_str: <25} ({prob.item():.2%})")
    print("-" * 40)

In [7]:
def show_indices(sentence):
    """Helper to quickly see indices of words"""
    words = sentence.split()
    print("\nWord Indices Reference:")
    for i, w in enumerate(words):
        print(f"[{i}] {w}")

### Examples

##### Example 1: *Bank*

In [8]:
# Show the indicies of words id sentence first
sentence = "The bank of the river was full of water"
show_indices(sentence)


Word Indices Reference:
[0] The
[1] bank
[2] of
[3] the
[4] river
[5] was
[6] full
[7] of
[8] water


In [9]:
predict_sense(sentence, 1)


Processing sentence: The bank of the river was full of water
Target word: 'bank' (Index 1)
----------------------------------------
Predictions for 'bank':
  • bank.n.01                 (99.52%)
  • depository_financial_institution.n.01 (0.00%)
  • side.n.03                 (0.00%)
----------------------------------------


In [9]:
sentence_2 = "I went to the bank to get some cash"
predict_sense(sentence_2, 4) # Index 4 is "bank"


Processing sentence: I went to the bank to get some cash
Target word: 'bank' (Index 4)
----------------------------------------
Predictions for 'bank':
  • depository_financial_institution.n.01 (99.77%)
  • bank.n.01                 (0.02%)
  • group.n.01                (0.01%)
----------------------------------------


In [10]:
# Show label definition and example
sense = wn.synset("bank.n.01")
print(f"Definition: {sense.definition()}")
print(f"Examples: {sense.examples()}")

Definition: sloping land (especially the slope beside a body of water)
Examples: ['they pulled the canoe up on the bank', 'he sat on the bank of the river and watched the currents']


In [11]:
sense = wn.synset("depository_financial_institution.n.01")
print(f"Definition: {sense.definition()}")
print(f"Examples: {sense.examples()}")

Definition: a financial institution that accepts deposits and channels the money into lending activities
Examples: ['he cashed a check at the bank', 'that bank holds the mortgage on my home']


> ##### The model correctly identified the meaning of the word in both sentences, and with a high degree of confidence.

##### Example 2: *Mouse*

In [12]:
# Sense 1: Animal (Small rodent)
sent_mouse_1 = "The mouse ran into the hole in the wall"
predict_sense(sent_mouse_1, 1)

# Sense 2: Device (Computer input device)
sent_mouse_2 = "Click the icon with your mouse"
predict_sense(sent_mouse_2, 5)


Processing sentence: The mouse ran into the hole in the wall
Target word: 'mouse' (Index 1)
----------------------------------------
Predictions for 'mouse':
  • mouse.n.01                (97.32%)
  • hen.n.01                  (0.03%)
  • cat.n.01                  (0.02%)
----------------------------------------

Processing sentence: Click the icon with your mouse
Target word: 'mouse' (Index 5)
----------------------------------------
Predictions for 'mouse':
  • mouse.n.01                (95.41%)
  • cat.n.01                  (0.01%)
  • hen.n.01                  (0.01%)
----------------------------------------


In [13]:
wn.synsets("mouse")

[Synset('mouse.n.01'),
 Synset('shiner.n.01'),
 Synset('mouse.n.03'),
 Synset('mouse.n.04'),
 Synset('sneak.v.01'),
 Synset('mouse.v.02')]

In [14]:
sense = wn.synset("mouse.n.01")
print(f"Definition: {sense.definition()}")
print(f"Examples: {sense.examples()}")

Definition: any of numerous small rodents typically resembling diminutive rats having pointed snouts and small ears on elongated bodies with slender usually hairless tails
Examples: []


In [15]:
sense = wn.synset("mouse.n.04")
print(f"Definition: {sense.definition()}")
print(f"Examples: {sense.examples()}")

Definition: a hand-operated electronic device that controls the coordinates of a cursor on your computer screen as you move it around on a pad; on the bottom of the device is a ball that rolls on the surface of the pad
Examples: ['a mouse takes much more room than a trackball']


> ##### In both cases, the model predicted a value in the sense of a rodent.

In [16]:
import pandas as pd
df = pd.read_parquet("semcor_train.parquet")

In [17]:
df[df["label"] == "mouse.n.04"]

Unnamed: 0,sentence_id,sentence,target_word,char_start,char_end,label,label_id


> ##### This mistake happened because the training dataset didn't have a sentence where a mouse was used as a computer device, even though this meaning is in the WordNet dictionary.

##### Example 3: *Plant*

In [18]:
# Sense 1: Living organism (Flora)
sent_plant_1 = "She watered the plant in the garden"
predict_sense(sent_plant_1, 3)

# Sense 2: Industrial building (Factory)
sent_plant_2 = "He works at a power plant near the city"
predict_sense(sent_plant_2, 5)


Processing sentence: She watered the plant in the garden
Target word: 'plant' (Index 3)
----------------------------------------
Predictions for 'plant':
  • plant.n.02                (99.64%)
  • plant.n.01                (0.22%)
  • flower.n.01               (0.00%)
----------------------------------------

Processing sentence: He works at a power plant near the city
Target word: 'plant' (Index 5)
----------------------------------------
Predictions for 'plant':
  • plant.n.01                (99.96%)
  • plant.n.02                (0.01%)
  • facility.n.01             (0.00%)
----------------------------------------


In [19]:
sense = wn.synset("plant.n.02")
print(f"Definition: {sense.definition()}")
print(f"Examples: {sense.examples()}")

Definition: (botany) a living organism lacking the power of locomotion
Examples: []


In [20]:
sense = wn.synset("plant.n.01")
print(f"Definition: {sense.definition()}")
print(f"Examples: {sense.examples()}")

Definition: buildings for carrying on industrial labor
Examples: ['they built a large plant to manufacture automobiles']


> ##### The model correctly identified the meaning of the word in both sentences, and with a high degree of confidence.

##### Example 4: *Bark*

In [21]:
# Sense 1: Tree covering (Noun)
sent_bark_1 = "The tree has thick bark"
predict_sense(sent_bark_1, 4)

# Sense 2: Sound made by a dog (Verb)
sent_bark_2 = "The dog began to bark at the stranger"
predict_sense(sent_bark_2, 4)


Processing sentence: The tree has thick bark
Target word: 'bark' (Index 4)
----------------------------------------
Predictions for 'bark':
  • bark.n.01                 (51.05%)
  • branch.n.02               (0.94%)
  • lumber.n.01               (0.52%)
----------------------------------------

Processing sentence: The dog began to bark at the stranger
Target word: 'bark' (Index 4)
----------------------------------------
Predictions for 'bark':
  • flog.v.01                 (1.16%)
  • snap.v.01                 (0.84%)
  • voice.v.01                (0.34%)
----------------------------------------


In [22]:
sense = wn.synset("bark.n.01")
print(f"Definition: {sense.definition()}")
print(f"Examples: {sense.examples()}")

Definition: tough protective covering of the woody stems and roots of trees and other woody plants
Examples: []


In [23]:
sense = wn.synset("flog.v.01")
print(f"Definition: {sense.definition()}")
print(f"Examples: {sense.examples()}")

Definition: beat severely with a whip or rod
Examples: ['The teacher often flogged the students', 'The children were severely trounced']


> ##### For the first sentence, the model made the correct prediction, whereas for the second sentence, it predicted a "random" value with low confidence.

In [24]:
wn.synsets("bark")

[Synset('bark.n.01'),
 Synset('bark.n.02'),
 Synset('bark.n.03'),
 Synset('bark.n.04'),
 Synset('bark.v.01'),
 Synset('bark.v.02'),
 Synset('bark.v.03'),
 Synset('bark.v.04'),
 Synset('bark.v.05')]

In [25]:
sense = wn.synset("bark.v.04")
print(f"Definition: {sense.definition()}")
print(f"Examples: {sense.examples()}")

Definition: make barking sounds
Examples: ['The dogs barked at the stranger']


In [26]:
df[df["label"] == "bark.v.04"]

Unnamed: 0,sentence_id,sentence,target_word,char_start,char_end,label,label_id


> ##### Again, the mistake is because the training dataset didn't have a sentence where the word was used in the context of the second sentence, even though that meaning is in the WordNet dictionary.