In [1]:
import os
import torch
import pickle
import json
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [2]:
# ========================================================
# =============== | FILE VERIFICATION | ==================
# ========================================================

models_meta = "model"
files_req = {
    "vocab": os.path.join(models_meta, "vocab.json"),
    "model": os.path.join(models_meta, "lstm_model.pkl")
}

for file_p in files_req.values():
    if not os.path.exists(file_p):
        print(f"WARNING: Cannot locate required file '{file_p}'")
    

In [3]:
# ======================================================
# ============= | LOAD MODEL/TOKENIZER | ===============
# ======================================================

class LSTMClassifier(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers, dropout=0.5):
        super(LSTMClassifier, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, n_layers, batch_first=True, dropout=dropout)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        embedded = self.embedding(x)
        lstm_out, _ = self.lstm(embedded)
        output = self.fc(lstm_out[:, -1, :])                                                                                            
        return output

model = None
vocab = None

pickle_model_path = files_req['model']
with open(pickle_model_path, 'rb') as rf:
    model = pickle.load(rf) # LSTMClassifier
    print("LSTM model loaded.")
    
vocab_path = files_req['vocab']
with open(vocab_path, 'r') as rf:
    vocab = json.load(rf) 
    print("Vocab loaded.")

LSTM model loaded.
Vocab loaded.


In [4]:
# ========================================================
# ================ | PREPARE TEXT/MODEL | ================
# ========================================================

def preprocess_text(text, tokenizer, max_length):
    tokens = tokenizer(text)
    if len(tokens) < max_length:
        tokens += [0] * (max_length - len(tokens))
    else:
        tokens = tokens[:max_length]
    return torch.tensor([tokens], dtype=torch.long).to(device)

def predict_text_line(model, text, tokenizer, max_length):
    # Preprocess the input text
    input_tensor = preprocess_text(text, tokenizer, max_length)

    # Perform the prediction
    with torch.no_grad():
        output = model(input_tensor)
        predicted_class = torch.argmax(output, dim=1).item()  # Get the class index

    return predicted_class


model.to(device)
model.eval()
tokenizer = lambda x: [vocab[word] for word in x.split() if word in vocab]
max_length = 512

In [5]:
# ================================================
# ================ | PREDICTION | ================
# ================================================

text = "Cars. Cars have been around since they became famous in the 1900s, when Henry Ford created and built the first ModelT. Cars have played a major role in our every day lives since then. But now, people are starting to question if limiting car usage would be a good thing. To me, limiting the use of cars might be a good thing to do."

predicted_class = predict_text_line(model, text, tokenizer, max_length)
print(f"Predicted class for the input text: {predicted_class}")

Predicted class for the input text: 1
