In [11]:
import warnings
from modeling.networks.quora_duplicate_bert import DuplicateTextClassifier
from transformers import BertModel, BertTokenizer
import torch

warnings.filterwarnings("ignore")

In [2]:
SAVE_DIR = "../registry/bert/final_model"

In [3]:
def choose_device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
device = choose_device()
device

device(type='cuda')

In [5]:
bert_base = BertModel.from_pretrained("bert-base-uncased")

model = DuplicateTextClassifier(bert_model=bert_base)
model.load_state_dict(torch.load(f"{SAVE_DIR}/pytorch_model.bin", map_location="cpu"))

model.eval()
model.to(device)

tokenizer = BertTokenizer.from_pretrained(SAVE_DIR)

In [6]:
def predict_duplicate(question1: str, question2: str) -> float:
    threshold = 0.5
    model.eval()
    with torch.no_grad():
        inputs = tokenizer(
            question1,
            question2,
            return_tensors='pt',
            padding='max_length',
            truncation=True,
            max_length=128
        ).to(device)

        outputs = model(**inputs)
        logits = outputs["logits"]
        probs = torch.softmax(logits, dim=1)
        prob_duplicate = probs[0][1].item()
        if prob_duplicate > threshold:
            print("Duplicate")
        else:
            print("Not Duplicate")
        return prob_duplicate


In [7]:
q1 = "How do I learn Python?"
q2 = "What is the best way to start with Python programming?"

prob = predict_duplicate(q1, q2)
print(f"Duplicate probability: {prob:.4f}")

Duplicate
Duplicate probability: 0.5170


In [8]:
q1 = "How can I lose weight?"
q2 = "What is the best way to reduce body fat?"
prob = predict_duplicate(q1, q2)
print(f"Duplicate probability: {prob:.4f}")

Duplicate
Duplicate probability: 0.8136


In [10]:
q1 = "How can I learn machine learning?"
q2 = "What is the best way to study ML?"
prob = predict_duplicate(q1, q2)
print(f"Duplicate probability: {prob:.4f}")

Not Duplicate
Duplicate probability: 0.0244
