In [1]:
import torch
import torch.nn as nn
from transformers import DistilBertModel, DistilBertTokenizerFast
from safetensors.torch import load_file


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class MultiTaskDistilBERT(nn.Module):
    def __init__(self, num_sentiment, num_intent):
        super().__init__()
        self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased")
        hidden = self.bert.config.dim

        self.sentiment_head = nn.Linear(hidden, num_sentiment)
        self.intent_head = nn.Linear(hidden, num_intent)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled = outputs.last_hidden_state[:, 0]

        return {
            "sentiment_logits": self.sentiment_head(pooled),
            "intent_logits": self.intent_head(pooled)
        }


In [3]:
sentiment_labels = {
    "negative": 0,
    "neutral": 1,
    "positive": 2
}

intent_labels = {
    "symptom_reporting": 0,
    "appointment_booking": 1,
    "general_query": 2
}

id2sentiment = {v: k for k, v in sentiment_labels.items()}
id2intent = {v: k for k, v in intent_labels.items()}


In [None]:
SAVE_DIR = "../Models/multitask_distilbert"

tokenizer = DistilBertTokenizerFast.from_pretrained(SAVE_DIR)


In [6]:
state_dict = load_file(f"{SAVE_DIR}/model.safetensors")

for k in state_dict.keys():
    if "intent_head.weight" in k:
        print(state_dict[k].shape)


torch.Size([4, 768])


In [7]:
intent_labels = {
    "intent_1": 0,
    "intent_2": 1,
    "intent_3": 2,
    "intent_4": 3
}


In [8]:
model = MultiTaskDistilBERT(
    num_sentiment=len(sentiment_labels),  # must match training
    num_intent=len(intent_labels)          # MUST be 4
)


state_dict = load_file(f"{SAVE_DIR}/model.safetensors")
model.load_state_dict(state_dict)
model.eval()



MultiTaskDistilBERT(
  (bert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): DistilBertSdpaAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            (lin1): Li

In [9]:
def predict(text):
    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        padding="max_length",
        max_length=128
    )

    with torch.no_grad():
        outputs = model(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"]
        )

    sentiment_id = torch.argmax(outputs["sentiment_logits"], dim=1).item()
    intent_id = torch.argmax(outputs["intent_logits"], dim=1).item()

    return {
        "text": text,
        "sentiment": id2sentiment[sentiment_id],
        "intent": id2intent[intent_id]
    }


In [10]:
print(predict("I have severe headache and feel very uncomfortable"))


{'text': 'I have severe headache and feel very uncomfortable', 'sentiment': 'neutral', 'intent': 'general_query'}
