In [4]:
from TopicExtractionModel import predict
from PronounResolutionModel import PronounResolutionModel
from ExtractTopic import ExtractTopic, generate_passage, generate_passage_from_entity_tuple

import torch
import re
from transformers import BertTokenizer



In [13]:
# Load model & tokenizer once
model_path = 'pronoun_resolution_model_full.pt'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load(model_path, map_location=device)
bert_model_name = checkpoint.get('bert_model_name', 'bert-base-uncased')

tokenizer = BertTokenizer.from_pretrained(bert_model_name)
model = PronounResolutionModel(bert_model_name=bert_model_name)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()

  checkpoint = torch.load(model_path, map_location=device)


PronounResolutionModel(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, el

In [None]:
import torch
import re
from transformers import BertTokenizer


def resolve_query(context, query, pronoun):
    text = context + query

    # Find pronoun position
    pronoun_pattern = re.compile(r'\b' + re.escape(pronoun) + r'\b', re.IGNORECASE)
    matches = list(pronoun_pattern.finditer(query))  # Search in query only

    if not matches:
        return {"error": f"Pronoun '{pronoun}' not found in the query"}

    pronoun_position = matches[0].start()  # Get position in query

    extracted_labels = predict(text)
    important_keywords = {"war"}
    candidates, important_keywords = ExtractTopic(extracted_labels, important_keywords)

    encoding = tokenizer(
        text, max_length=128, padding='max_length', truncation=True, return_tensors='pt'
    ).to(device)

    context_length = len(tokenizer.tokenize(context))  # Get token count for context
    pronoun_token_position = torch.tensor(
        [context_length + len(tokenizer.tokenize(query[:pronoun_position]))], dtype=torch.long
    ).to(device)

    # Process candidate encodings in batch
    candidate_encodings = tokenizer(
        candidates, max_length=20, padding='max_length', truncation=True, return_tensors='pt'
    ).to(device)

    num_candidates = torch.tensor([len(candidates)], dtype=torch.long).to(device)

    # Forward pass
    with torch.no_grad():
        outputs = model(
            input_ids=encoding['input_ids'],
            attention_mask=encoding['attention_mask'],
            pronoun_position=pronoun_token_position,
            candidate_input_ids=candidate_encodings['input_ids'].unsqueeze(0),
            candidate_attention_masks=candidate_encodings['attention_mask'].unsqueeze(0),
            num_candidates=num_candidates
        )

    # Get prediction and confidence scores
    scores = outputs[0].cpu().numpy()
    probabilities = torch.softmax(outputs[0], dim=0).cpu().numpy()
    predicted_idx = int(torch.argmax(outputs, dim=1).item())
    resolved_candidate = candidates[predicted_idx]

    replaced_query = re.sub(r'\b' + re.escape(pronoun) + r'\b', resolved_candidate, query, count=1, flags=re.IGNORECASE)

    return replaced_query


In [43]:

if __name__ == "__main__":
    text = "tell me about mohamed salah . it plays for the national team of egypt . "
    query = "which club he plays for"
    pronoun = "he"
    
    result = resolve_query(text,query,pronoun)
    
    print(result)

Pronoun position in query: 11
position in text:  18
which club mohamed salah plays for
