In [None]:
from transformers import AutoTokenizer, RobertaModel
from torch import nn
import torch

class TweetClassifier(nn.Module):
    def __init__(self, base_model, num_labels):
        super(TweetClassifier, self).__init__()
        self.bert = base_model
        self.fc1 = nn.Linear(768, 32)
        self.fc2 = nn.Linear(32, num_labels)
        self.relu = nn.ReLU()

    def forward(self, input_ids, attention_mask):
        bert_out = self.bert(input_ids=input_ids, attention_mask=attention_mask)[0][:, 0]
        x = self.fc1(bert_out)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained("roberta-base")

# Initialize the base Roberta models for each classification task
base_model_topic = RobertaModel.from_pretrained("roberta-base")
base_model_sentiment = RobertaModel.from_pretrained("roberta-base")

# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load trained models
model_topic = TweetClassifier(base_model_topic, 8).to(device)
model_topic.load_state_dict(torch.load("best_model_topic.pt"))
model_topic.eval()

model_sentiment = TweetClassifier(base_model_sentiment, 3).to(device)
model_sentiment.load_state_dict(torch.load("best_model_sentiment.pt"))
model_sentiment.eval()

def predict_text(text, model, tokenizer):
    # Encode text
    encoded_input = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)
    
    # Model Prediction
    with torch.no_grad():
        output = model(**encoded_input)
    
    # Get predicted class
    predicted_class = torch.argmax(output, dim=1).item()  
    return predicted_class

# Testing the function
test_text = "Thanks god it's TGIF."
topic_prediction = predict_text(test_text, model_topic, tokenizer)
sentiment_prediction = predict_text(test_text, model_sentiment, tokenizer)

# Map integer predictions back to labels
topic_labels = {0: "Karma and Shares", 1: "Sleep and Rest", 2: "Celebrations", 3: "Haircare and Styling", 4: "Weather", 5: "Days of the Week ", 6: "Video Content", 7: "Photography and Imagery"}
sentiment_labels = {0: "Negative", 1: "Neutral", 2: "Positive"}

print(f"Predicted Topic: {topic_labels[topic_prediction]}")
print(f"Predicted Sentiment: {sentiment_labels[sentiment_prediction]}")
