In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import json
import os
from tqdm import tqdm
import random

In [None]:
# 1. Dataset Handling
class PronounResolutionDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=128):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        # Get text with marked pronoun and candidates
        text = item['text'].lower()
        pronoun = item['pronoun'].lower()
        candidates = [candidate.lower() for candidate in item['candidates']]
        pronoun_position = item['pronoun_position']
        correct_candidate_idx = item['correct_candidate_idx']
        
        # Tokenize text
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        # Get pronoun token position in BERT tokenization
        pronoun_token_position = None
        input_ids = encoding['input_ids'][0]
        tokens = self.tokenizer.convert_ids_to_tokens(input_ids)
        
        # Find the pronoun position in tokenized text
        # This is approximate and may need refinement based on dataset specifics
        text_until_pronoun = text[:pronoun_position]
        approx_token_count = len(self.tokenizer.tokenize(text_until_pronoun))
        pronoun_token_position = approx_token_count
        
        # Tokenize candidates
        candidate_encodings = []
        for candidate in candidates:
            candidate_encoding = self.tokenizer(
                candidate,
                max_length=20,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            candidate_encodings.append({
                'input_ids': candidate_encoding['input_ids'][0],
                'attention_mask': candidate_encoding['attention_mask'][0]
            })
        
        # Pad candidate list if needed
        max_candidates = 5  # Adjust based on your dataset
        while len(candidate_encodings) < max_candidates:
            # Add padding candidate
            pad_encoding = self.tokenizer(
                "",
                max_length=20,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            candidate_encodings.append({
                'input_ids': pad_encoding['input_ids'][0],
                'attention_mask': pad_encoding['attention_mask'][0]
            })
        
        # Convert candidate encodings to tensors
        candidate_input_ids = torch.stack([c['input_ids'] for c in candidate_encodings[:max_candidates]])
        candidate_attention_masks = torch.stack([c['attention_mask'] for c in candidate_encodings[:max_candidates]])
        
        # If we have more candidates than our max, truncate
        if len(candidates) > max_candidates:
            candidate_input_ids = candidate_input_ids[:max_candidates]
            candidate_attention_masks = candidate_attention_masks[:max_candidates]
            if correct_candidate_idx >= max_candidates:
                correct_candidate_idx = max_candidates - 1  # Adjust if needed
        
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'pronoun_position': torch.tensor(pronoun_token_position, dtype=torch.long),
            'candidate_input_ids': candidate_input_ids,
            'candidate_attention_masks': candidate_attention_masks,
            'correct_candidate_idx': torch.tensor(correct_candidate_idx, dtype=torch.long),
            'num_candidates': torch.tensor(min(len(candidates), max_candidates), dtype=torch.long)
        }

In [None]:
# Function to dataset
def load_data(data_path):
    with open(data_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    # validation
    for item in data:
        assert 'text' in item, "Missing 'text' field"
        assert 'pronoun' in item, "Missing 'pronoun' field"
        assert 'candidates' in item, "Missing 'candidates' field"
        assert 'pronoun_position' in item, "Missing 'pronoun_position' field"
        assert 'correct_candidate_idx' in item, "Missing 'correct_candidate_idx' field"
    
    random.shuffle(data)

    return data

In [None]:
# Model Architecture
class PronounResolutionModel(nn.Module):
    def __init__(self, bert_model_name='bert-base-uncased', dropout_rate=0.1):
        super(PronounResolutionModel, self).__init__()
        
        # BERT for contextualized embeddings
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.hidden_size = self.bert.config.hidden_size
        
        # Candidate encoder
        self.candidate_bert = BertModel.from_pretrained(bert_model_name)
        
        # Attention mechanism for candidate scoring
        self.attention = nn.Sequential(
            nn.Linear(self.hidden_size * 2, self.hidden_size),
            nn.Tanh(),
            nn.Dropout(dropout_rate),
            nn.Linear(self.hidden_size, 1)
        )
        
        # Additional layers
        self.dropout = nn.Dropout(dropout_rate)
    
    def forward(self, input_ids, attention_mask, pronoun_position, 
                candidate_input_ids, candidate_attention_masks, num_candidates):
        
        batch_size = input_ids.size(0)
        
        # Get contextualized embeddings from BERT
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        
        sequence_output = outputs.last_hidden_state  # [batch_size, seq_len, hidden_size]
        
        # Extract pronoun representations
        pronoun_representations = []
        for i in range(batch_size):
            pos = pronoun_position[i]
            pronoun_representations.append(sequence_output[i, pos])
        
        pronoun_representations = torch.stack(pronoun_representations)  # [batch_size, hidden_size]
        
        # Process candidates
        max_candidates = candidate_input_ids.size(1)
        candidate_representations = []
        
        # Reshape for batch processing
        flat_candidate_ids = candidate_input_ids.view(-1, candidate_input_ids.size(-1))
        flat_candidate_masks = candidate_attention_masks.view(-1, candidate_attention_masks.size(-1))
        
        # Get candidate embeddings
        candidate_outputs = self.candidate_bert(
            input_ids=flat_candidate_ids,
            attention_mask=flat_candidate_masks,
            return_dict=True
        )
        
        # Use [CLS] token as candidate representation
        flat_candidate_embeds = candidate_outputs.last_hidden_state[:, 0]  # [batch_size*max_candidates, hidden_size]
        
        # Reshape back to [batch_size, max_candidates, hidden_size]
        candidate_embeds = flat_candidate_embeds.view(batch_size, max_candidates, -1)
        
        # Score each candidate
        scores = []
        for i in range(batch_size):
            n_cand = num_candidates[i].item()
            
            # Expand pronoun representation for each candidate
            expanded_pronoun = pronoun_representations[i].unsqueeze(0).expand(n_cand, -1)  # [n_cand, hidden_size]
            
            # Concatenate pronoun and candidate representations
            concat_reps = torch.cat([
                expanded_pronoun, 
                candidate_embeds[i, :n_cand]
            ], dim=1)  # [n_cand, hidden_size*2]
            
            # Score each candidate
            cand_scores = self.attention(concat_reps).squeeze(-1)  # [n_cand]
            
            # Pad scores for batch processing
            padded_scores = torch.full((max_candidates,), float('-inf'), device=cand_scores.device)
            padded_scores[:n_cand] = cand_scores
            scores.append(padded_scores)
        
        scores = torch.stack(scores)  # [batch_size, max_candidates]
        
        return scores


In [None]:
def train_model(model, train_loader, val_loader, device, 
                learning_rate=1e-5, epochs=3, warmup_steps=0):
    
    # Set up optimizer
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
    
    # Loss function
    criterion = nn.CrossEntropyLoss()
    
    # Training loop
    best_val_acc = 0
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
        for batch in progress_bar:
            # Move batch to device
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            pronoun_position = batch['pronoun_position'].to(device)
            candidate_input_ids = batch['candidate_input_ids'].to(device)
            candidate_attention_masks = batch['candidate_attention_masks'].to(device)
            correct_candidate_idx = batch['correct_candidate_idx'].to(device)
            num_candidates = batch['num_candidates'].to(device)
            
            # Forward pass
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                pronoun_position=pronoun_position,
                candidate_input_ids=candidate_input_ids,
                candidate_attention_masks=candidate_attention_masks,
                num_candidates=num_candidates
            )
            
            # Compute loss
            loss = criterion(outputs, correct_candidate_idx)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            train_loss += loss.item()
            progress_bar.set_postfix({'loss': train_loss / (progress_bar.n + 1)})
        
        # Validation
        val_acc = evaluate_model(model, val_loader, device)
        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss/len(train_loader):.4f}, Val Accuracy: {val_acc:.4f}")
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'best_pronoun_resolution_model.pt')
    
    return model

In [18]:
# evaluation
def evaluate_model(model, data_loader, device):
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch in data_loader:
            # Move batch to device
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            pronoun_position = batch['pronoun_position'].to(device)
            candidate_input_ids = batch['candidate_input_ids'].to(device)
            candidate_attention_masks = batch['candidate_attention_masks'].to(device)
            correct_candidate_idx = batch['correct_candidate_idx'].to(device)
            num_candidates = batch['num_candidates'].to(device)
            
            # Forward pass
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                pronoun_position=pronoun_position,
                candidate_input_ids=candidate_input_ids,
                candidate_attention_masks=candidate_attention_masks,
                num_candidates=num_candidates
            )
            
            # Get predictions
            _, predicted = torch.max(outputs, 1)
            total += correct_candidate_idx.size(0)
            correct += (predicted == correct_candidate_idx).sum().item()
    
    return correct / total

In [19]:
# Inference Function
def resolve_pronoun(model, text, pronoun, pronoun_position, candidates, tokenizer, device):
    model.eval()
    
    # Prepare input data
    encoding = tokenizer(
        text,
        max_length=128,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )
    
    # Find pronoun token position
    text_until_pronoun = text[:pronoun_position]
    approx_token_count = len(tokenizer.tokenize(text_until_pronoun))
    pronoun_token_position = torch.tensor([approx_token_count], dtype=torch.long)
    
    # Tokenize candidates
    candidate_encodings = []
    for candidate in candidates:
        candidate_encoding = tokenizer(
            candidate,
            max_length=20,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        candidate_encodings.append({
            'input_ids': candidate_encoding['input_ids'][0],
            'attention_mask': candidate_encoding['attention_mask'][0]
        })
    
    candidate_input_ids = torch.stack([c['input_ids'] for c in candidate_encodings])
    candidate_attention_masks = torch.stack([c['attention_mask'] for c in candidate_encodings])
    
    # Move to device
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)
    pronoun_token_position = pronoun_token_position.to(device)
    candidate_input_ids = candidate_input_ids.unsqueeze(0).to(device)
    candidate_attention_masks = candidate_attention_masks.unsqueeze(0).to(device)
    num_candidates = torch.tensor([len(candidates)], dtype=torch.long).to(device)
    
    # Forward pass
    with torch.no_grad():
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pronoun_position=pronoun_token_position,
            candidate_input_ids=candidate_input_ids,
            candidate_attention_masks=candidate_attention_masks,
            num_candidates=num_candidates
        )
    
    # Get prediction
    _, predicted = torch.max(outputs, 1)
    
    return candidates[predicted.item()]

In [None]:
def main():
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Load tokenizer
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    data_path = "/kaggle/input/augmenteddataest/augmented_pronoun_resolution_data.json"
    
    # Load data
    data = load_data(data_path)
    print(f"Loaded {len(data)} examples")
    
    # Split data
    train_data, test_data = train_test_split(data, test_size=0.2, random_state=42)
    train_data, val_data = train_test_split(train_data, test_size=0.1, random_state=42)
    
    print(f"Train: {len(train_data)}, Validation: {len(val_data)}, Test: {len(test_data)}")
    
    # Create datasets
    train_dataset = PronounResolutionDataset(train_data, tokenizer)
    val_dataset = PronounResolutionDataset(val_data, tokenizer)
    test_dataset = PronounResolutionDataset(test_data, tokenizer)
    
    # Create dataloaders
    batch_size = 16
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    
    # Initialize model
    model = PronounResolutionModel(bert_model_name='bert-base-uncased')
    model.to(device)
    
    # Train model
    print("Starting training...")
    train_model(model, train_loader, val_loader, device, learning_rate=2e-5, epochs=8)
    
    # Load best model
    model.load_state_dict(torch.load('best_pronoun_resolution_model.pt'))
    
    # Evaluate on test set
    test_acc = evaluate_model(model, test_loader, device)
    print(f"Test Accuracy: {test_acc:.4f}")
    
    example_text = "John met Mike at the park. He was happy to see him."
    pronoun = "He"
    pronoun_position = 24
    candidates = ["John", "Mike"]
    
    resolved_entity = resolve_pronoun(
        model, example_text, pronoun, pronoun_position, candidates, tokenizer, device
    )
    
    print(f"Resolved pronoun '{pronoun}' to: {resolved_entity}")
    
    # Save model and configuration
    torch.save({
        'model_state_dict': model.state_dict(),
        'bert_model_name': 'bert-base-uncased',
        'max_length': 128
    }, 'pronoun_resolution_model_full.pt')
    
    print("Model saved as 'pronoun_resolution_model_full.pt'")

In [21]:
if __name__ == "__main__":
    main()

Using device: cuda
Loaded 5000 examples
Train: 3600, Validation: 400, Test: 1000
Starting training...


Epoch 1/8: 100%|██████████| 225/225 [02:41<00:00,  1.39it/s, loss=0.732]


Epoch 1/8, Train Loss: 0.7324, Val Accuracy: 0.8875


Epoch 2/8: 100%|██████████| 225/225 [02:40<00:00,  1.40it/s, loss=0.343]


Epoch 2/8, Train Loss: 0.3427, Val Accuracy: 0.9175


Epoch 3/8: 100%|██████████| 225/225 [02:40<00:00,  1.40it/s, loss=0.25] 


Epoch 3/8, Train Loss: 0.2499, Val Accuracy: 0.9125


Epoch 4/8: 100%|██████████| 225/225 [02:40<00:00,  1.40it/s, loss=0.212]


Epoch 4/8, Train Loss: 0.2122, Val Accuracy: 0.9325


Epoch 5/8: 100%|██████████| 225/225 [02:40<00:00,  1.40it/s, loss=0.16] 


Epoch 5/8, Train Loss: 0.1604, Val Accuracy: 0.9450


Epoch 6/8: 100%|██████████| 225/225 [02:40<00:00,  1.40it/s, loss=0.154]


Epoch 6/8, Train Loss: 0.1542, Val Accuracy: 0.9400


Epoch 7/8: 100%|██████████| 225/225 [02:40<00:00,  1.40it/s, loss=0.131]


Epoch 7/8, Train Loss: 0.1309, Val Accuracy: 0.9400


Epoch 8/8: 100%|██████████| 225/225 [02:40<00:00,  1.40it/s, loss=0.119]


Epoch 8/8, Train Loss: 0.1191, Val Accuracy: 0.9400


  model.load_state_dict(torch.load('best_pronoun_resolution_model.pt'))


Test Accuracy: 0.9360
Resolved pronoun 'He' to: John
Model saved as 'pronoun_resolution_model_full.pt'


In [None]:
import spacy
import torch
import re
from transformers import BertTokenizer
nlp = spacy.load("en_core_web_sm")

def predict_pronoun_resolution(text, pronoun, model_path='pronoun_resolution_model_full.pt', device=None):
    
    # Determine device if not provided
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load the model
    checkpoint = torch.load(model_path, map_location=device)
    bert_model_name = checkpoint.get('bert_model_name', 'bert-base-uncased')
    
    # Initialize tokenizer
    tokenizer = BertTokenizer.from_pretrained(bert_model_name)
    
    # Load model
    model = PronounResolutionModel(bert_model_name=bert_model_name)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    
    # Find pronoun position in text
    pronoun_pattern = re.compile(r'\b' + re.escape(pronoun) + r'\b', re.IGNORECASE)
    matches = list(pronoun_pattern.finditer(text))
    
    if not matches:
        return {"error": f"Pronoun '{pronoun}' not found in the text"}
    
    # For simplicity, we'll use the first occurrence
    pronoun_position = matches[0].start()

    doc = nlp(text)
    
    # Extract named entities and noun chunks as candidates
    candidates = []
    for ent in doc.ents:
        if ent.label_ in ["PERSON", "ORG", "GPE"]:  # People, organizations, locations
            candidates.append(ent.text)
    
    # Add noun chunks that might be potential candidates
    for chunk in doc.noun_chunks:
        # Skip pronouns and determiners
        if chunk.root.pos_ not in ["PRON", "DET"] and chunk.text.lower() != pronoun.lower():
            candidates.append(chunk.text)
    
    # Remove duplicates while preserving order
    candidates = list(dict.fromkeys(candidates))
    

    if not candidates:
        return {"error": "No potential candidate entities found in the text"}
    
    # Prepare input data similar to resolve_pronoun function
    encoding = tokenizer(
        text,
        max_length=128,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )
    
    # Find pronoun token position
    text_until_pronoun = text[:pronoun_position]
    approx_token_count = len(tokenizer.tokenize(text_until_pronoun))
    pronoun_token_position = torch.tensor([approx_token_count], dtype=torch.long)
    
    # Tokenize candidates
    candidate_encodings = []
    for candidate in candidates:
        candidate_encoding = tokenizer(
            candidate,
            max_length=20,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        candidate_encodings.append({
            'input_ids': candidate_encoding['input_ids'][0],
            'attention_mask': candidate_encoding['attention_mask'][0]
        })
    
    candidate_input_ids = torch.stack([c['input_ids'] for c in candidate_encodings])
    candidate_attention_masks = torch.stack([c['attention_mask'] for c in candidate_encodings])
    
    # Move to device
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)
    pronoun_token_position = pronoun_token_position.to(device)
    candidate_input_ids = candidate_input_ids.unsqueeze(0).to(device)
    candidate_attention_masks = candidate_attention_masks.unsqueeze(0).to(device)
    num_candidates = torch.tensor([len(candidates)], dtype=torch.long).to(device)
    
    # Forward pass
    with torch.no_grad():
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pronoun_position=pronoun_token_position,
            candidate_input_ids=candidate_input_ids,
            candidate_attention_masks=candidate_attention_masks,
            num_candidates=num_candidates
        )
    
    # Get prediction and confidence scores
    scores = outputs[0].cpu().detach().numpy()
    probabilities = torch.softmax(outputs[0], dim=0).cpu().detach().numpy()
    predicted_idx = int(torch.argmax(outputs, dim=1).item())
    
    # Prepare results
    result = {
        "resolved_entity": candidates[predicted_idx],
        "confidence": float(probabilities[predicted_idx]),
        "pronoun": pronoun,
        "pronoun_position": pronoun_position,
        "candidates": [
            {"entity": cand, "score": float(score), "probability": float(prob)}
            for cand, score, prob in zip(candidates, scores, probabilities)
        ]
    }
    
    return result


if __name__ == "__main__":
    text = "tell me about egypt. what places can I visit in it"
    pronoun = "it"
    
    result = predict_pronoun_resolution(text, pronoun)
    
    if "error" in result:
        print(f"Error: {result['error']}")
    else:
        print(f"Pronoun '{result['pronoun']}' most likely refers to: {result['resolved_entity']}")
        print(f"Confidence: {result['confidence']:.2f}")
        print("\nAll candidates:")
        for candidate in result['candidates']:
            print(f"- {candidate['entity']}: {candidate['probability']:.2f}")

  checkpoint = torch.load(model_path, map_location=device)


Pronoun 'it' most likely refers to: egypt
Confidence: 0.98

All candidates:
- egypt: 0.98
- what places: 0.02
