# Inference

Once you have trained the model, simple run the cell below and it will ask for input. Input the text and it will classifiy it either as negative, positive or neutral

In [2]:
import torch
from torch import nn
import torch.nn.functional as F
from torchtext.data.utils import get_tokenizer
import string
import re
import pickle

class TextClassificationModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_class):
        super(TextClassificationModel, self).__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=False)
        self.fc1 = nn.Linear(embed_dim, 64)
        self.fc2 = nn.Linear(64, 16)
        self.fc3 = nn.Linear(16, num_class)

    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        x = F.relu(self.fc1(embedded))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def remove_emoji(text):
    emoji_pattern = re.compile("["
                               u"\U0001F600-\U0001F64F"  # emoticons
                               u"\U0001F300-\U0001F5FF"  # symbols & pictographs
                               u"\U0001F680-\U0001F6FF"  # transport & map symbols
                               u"\U0001F1E0-\U0001F1FF"  # flags (iOS)
                               u"\U00002702-\U000027B0"
                               u"\U000024C2-\U0001F251"
                               "]+", flags=re.UNICODE)
    return emoji_pattern.sub(r'', text)

def remove_url(text):
    url_pattern = re.compile(r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+')
    return url_pattern.sub(r'', text)

def clean_text(text):
    delete_dict = {sp_character: '' for sp_character in string.punctuation}
    delete_dict[' '] = ' '
    table = str.maketrans(delete_dict)
    text1 = text.translate(table)
    text_arr = text1.split()
    text2 = ' '.join([w for w in text_arr if not w.isdigit() and len(w) > 2])
    return text2.lower()

def load_vocab(vocab_path):
    with open(vocab_path, "rb") as f:
        vocab = pickle.load(f)
    return vocab

def load_model(model_path, vocab_size, embed_dim, num_class):
    model = TextClassificationModel(vocab_size, embed_dim, num_class)
    model.load_state_dict(torch.load(model_path))
    model.eval()
    return model

def classify_text(model, vocab, text_pipeline, input_text):
    processed_text = torch.tensor(text_pipeline(input_text), dtype=torch.int64)
    offsets = torch.tensor([0])
    
    with torch.no_grad():
        output = model(processed_text, offsets)
        predicted_label = output.argmax(1).item()

    sentiment_dict = {0: "neutral", 1: "negative", 2: "positive"}
    return sentiment_dict.get(predicted_label, "unknown")

model_path = "model.pth"
vocab_path = "vocab.pkl"
embed_dim = 128
num_class = 3  

vocab = load_vocab(vocab_path)
text_pipeline = lambda x: vocab(get_tokenizer('basic_english')(clean_text(remove_url(remove_emoji(x)))))
vocab_size = len(vocab)
model = load_model(model_path, vocab_size, embed_dim, num_class)

input_text = input("Enter text to classify: ")

sentiment = classify_text(model, vocab, text_pipeline, input_text)
print(f"The sentiment of the input text is: {sentiment}")



The sentiment of the input text is: positive
