# Notebook for running predictions on text data

In [121]:
from transformers import AutoTokenizer, AutoModelForTokenClassification
import torch

In [122]:
# Load the tokenizer and model
model_name = model_checkpoint = "distilbert-base-uncased-for-product-extraction/full_text_strictly_labeled_86000_0.87"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForTokenClassification.from_pretrained(model_name)

In [123]:
label_list = ['O', 'B-PRODUCT', 'I-PRODUCT']

In [124]:

# Takes a splited sentance and returns the labels
def predict_labels(text, model, tokenizer, label_list, max_length=512):

    inputs = tokenizer(text, return_tensors="pt", max_length=max_length, truncation=True, is_split_into_words=True)
    word_ids = inputs.word_ids()
    
    with torch.no_grad():
        outputs = model(**inputs)
        
    logits = outputs.logits
    predictions = torch.argmax(logits, dim=2)
    
    predictions = [label_list[prediction] for prediction in predictions[0]]
    tokenized_tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
    
    labels = ['O'] * len(text)
    
    for idx, (token, prediction) in enumerate(zip(tokenized_tokens, predictions)):
        original_token_index = word_ids[idx]
        if original_token_index is not None:
            labels[original_token_index] = prediction
        
    return labels
    

In [139]:
text = """[URL] <NO_URL> [URL] [TITLE] BELFORT Extendable corner [TITLE] [TEXT] 
BELFORT Extendable corner, chaise longue right/left with storage box, fabric
9,163 lei 7,330.40 lei​

The displayed price is valid from 29.08 to 02.10.2024
From 203 , 62 lei/month in 36 installments without interest valid with Mobexpert Credit Card.
Color
coffee-espresso
cappuccino/brown
cream/blue
coffee/espresso
grey/light grey
black/blue
cappuccino/yellow

 [TEXT] """

# labels = predict_labels(text.split(), model, tokenizer, label_list)

# for token, label in zip(text.split(), labels):
#     print(f"{token:10}: {label}")

inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True)

# Get model predictions
with torch.no_grad():
    outputs = model(**inputs)

# Get the logits (raw prediction scores)
logits = outputs.logits

# Apply softmax to get the probabilities
probabilities = torch.softmax(logits, dim=2)

# Convert token IDs to words
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
word_ids = inputs.word_ids()


# Get probabilities for O, B-PRODUCT, and I-PRODUCT
predicted_labels = []
o_tag_probs = []
b_product_probs = []
i_product_probs = []

for i, token_probs in enumerate(probabilities[0]):
    o_tag_prob = token_probs[label_list.index('O')].item()
    b_product_prob = token_probs[label_list.index('B-PRODUCT')].item()
    i_product_prob = token_probs[label_list.index('I-PRODUCT')].item()
    
    # Get the predicted label (max probability)
    max_prob, predicted_idx = torch.max(token_probs, dim=0)  # Get the max probability and its index
    predicted_label = label_list[predicted_idx]

    predicted_labels.append(predicted_label)
    o_tag_probs.append(o_tag_prob)
    b_product_probs.append(b_product_prob)
    i_product_probs.append(i_product_prob)
    

for token, label, o_prob, b_prob, i_prob in zip(tokens, predicted_labels, o_tag_probs, b_product_probs, i_product_probs):
    print(f"{token:10}: {label:10} O: {o_prob:.2f} B-PRODUCT: {b_prob:.2f} I-PRODUCT: {i_prob:.2f}")

[CLS]     : O          O: 1.00 B-PRODUCT: 0.00 I-PRODUCT: 0.00
[URL]     : O          O: 1.00 B-PRODUCT: 0.00 I-PRODUCT: 0.00
<NO_URL>  : O          O: 1.00 B-PRODUCT: 0.00 I-PRODUCT: 0.00
[URL]     : O          O: 1.00 B-PRODUCT: 0.00 I-PRODUCT: 0.00
[TITLE]   : O          O: 1.00 B-PRODUCT: 0.00 I-PRODUCT: 0.00
bel       : B-PRODUCT  O: 0.05 B-PRODUCT: 0.95 I-PRODUCT: 0.00
##fort    : I-PRODUCT  O: 0.14 B-PRODUCT: 0.00 I-PRODUCT: 0.86
extend    : I-PRODUCT  O: 0.30 B-PRODUCT: 0.01 I-PRODUCT: 0.69
##able    : I-PRODUCT  O: 0.26 B-PRODUCT: 0.00 I-PRODUCT: 0.73
corner    : O          O: 0.55 B-PRODUCT: 0.00 I-PRODUCT: 0.44
[TITLE]   : O          O: 1.00 B-PRODUCT: 0.00 I-PRODUCT: 0.00
[TEXT]    : O          O: 1.00 B-PRODUCT: 0.00 I-PRODUCT: 0.00
bel       : B-PRODUCT  O: 0.22 B-PRODUCT: 0.78 I-PRODUCT: 0.00
##fort    : O          O: 0.61 B-PRODUCT: 0.00 I-PRODUCT: 0.39
extend    : O          O: 0.94 B-PRODUCT: 0.00 I-PRODUCT: 0.05
##able    : O          O: 0.89 B-PRODUCT: 0.00 I-PRODUC