# Cell 1: Import Necessary Libraries

In [1]:
import os
import sys
import torch
from transformers import AutoModel, AutoTokenizer
import json
sys.path.append('../src')
# Now import the model
from models.multitask_mpnet import MultiTaskMPNet

print("Libraries imported successfully.")


  from .autonotebook import tqdm as notebook_tqdm


Libraries imported successfully.


# Cell 2: Load Label Mappings

In [5]:
with open('../processed_data/label_mappings.json', 'r') as f:
    label_mappings = json.load(f)
sentiment_mapping = label_mappings['sentiment_mapping']
category_mapping = label_mappings['category_mapping']
inverse_category_mapping = {v: k for k, v in category_mapping.items()}

# Cell 3: Initialize the Model

In [6]:
model_path = '../models/multitask_mpnet'
model = MultiTaskMPNet(model_name=model_path, num_classes_task_a=4, num_classes_task_b=2)
model.encoder = AutoModel.from_pretrained(model_path)
model.classifier_a.load_state_dict(torch.load(f"{model_path}/classifier_a.pt"))
model.classifier_b.load_state_dict(torch.load(f"{model_path}/classifier_b.pt"))
device = torch.device('mps' if torch.backends.mps.is_available() else ('cuda' if torch.cuda.is_available() else 'cpu'))
model.to(device)
model.eval()

  model.classifier_a.load_state_dict(torch.load(f"{model_path}/classifier_a.pt"))
  model.classifier_b.load_state_dict(torch.load(f"{model_path}/classifier_b.pt"))


MultiTaskMPNet(
  (encoder): MPNetModel(
    (embeddings): MPNetEmbeddings(
      (word_embeddings): Embedding(30527, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): MPNetEncoder(
      (layer): ModuleList(
        (0-11): 12 x MPNetLayer(
          (attention): MPNetAttention(
            (attn): MPNetSelfAttention(
              (q): Linear(in_features=768, out_features=768, bias=True)
              (k): Linear(in_features=768, out_features=768, bias=True)
              (v): Linear(in_features=768, out_features=768, bias=True)
              (o): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
   

# Cell 4: Initialize the Tokenizer

In [7]:
tokenizer = AutoTokenizer.from_pretrained(model_path)

# Cell 5: Define Sample Sentences for Inference

In [8]:
sentences = [
    "I absolutely love this product!",
    "The battery life is too short.",
    "Can you help me with installation?",
    "This is the best purchase I've made."
]

# Cell 6: Define Inference Function

In [9]:
def predict(model, tokenizer, sentences):
    inputs = tokenizer(
        sentences,
        padding=True,
        truncation=True,
        return_tensors='pt'
    )
    input_ids = inputs['input_ids'].to(device)
    attention_mask = inputs['attention_mask'].to(device)
    
    with torch.no_grad():
        outputs = model(input_ids, attention_mask)
        logits_a = outputs['logits_a']
        logits_b = outputs['logits_b']
        
        predictions_a = torch.argmax(logits_a, dim=1).cpu().numpy()
        predictions_b = torch.argmax(logits_b, dim=1).cpu().numpy()
    
    for i, sentence in enumerate(sentences):
        category = inverse_category_mapping.get(predictions_a[i], "Unknown")
        sentiment = 'Positive' if predictions_b[i] == 1 else 'Negative'
        print(f"Sentence: \"{sentence}\"")
        print(f"Category: {category}")
        print(f"Sentiment: {sentiment}\n")

# Cell 7: Perform Predictions

In [10]:
predict(model, tokenizer, sentences)

Sentence: "I absolutely love this product!"
Category: Positive Sentiment
Sentiment: Positive

Sentence: "The battery life is too short."
Category: Negative Sentiment
Sentiment: Negative

Sentence: "Can you help me with installation?"
Category: Negative Sentiment
Sentiment: Negative

Sentence: "This is the best purchase I've made."
Category: Positive Sentiment
Sentiment: Positive

