In [4]:
import os
import torch
from transformers import BertTokenizerFast
import torch.nn as nn
from transformers import BertModel

# Define the model architecture
class CustomBertClassifier(nn.Module):
    def __init__(self, num_labels):
        super(CustomBertClassifier, self).__init__()
        self.bert = BertModel.from_pretrained("bert-base-uncased")
        self.fc1 = nn.Linear(768, 512)
        self.fc2 = nn.Linear(512, num_labels)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        x = self.fc1(pooled_output)
        x = torch.relu(x)
        x = self.fc2(x)
        return x


# Specify the saved directory
save_directory = "model_directory"

# Load the tokenizer
tokenizer = BertTokenizerFast.from_pretrained(save_directory)

# Initialize the model
num_labels = 2  # Modify this based on your use case
model = CustomBertClassifier(num_labels)

# Load the model weights
model_path = os.path.join(save_directory, "model_weights.pth")
model.load_state_dict(torch.load(model_path))
model.eval()  # Set the model to evaluation mode

# Function for prediction
def predict(text, model, tokenizer):
    """
    Predict the class probabilities for the given input text.
    """
    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        padding=True,
        max_length=128
    )
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]

    with torch.no_grad():
        logits = model(input_ids=input_ids, attention_mask=attention_mask)
        probabilities = torch.softmax(logits, dim=1).squeeze().tolist()

    return probabilities


# User input and prediction loop
if __name__ == "__main__":
    print("=== Text Classification ===")
    while True:
        try:
            user_input = input("Enter text to classify (or type 'exit' to quit): ").strip()
            if not user_input:
                print("Input cannot be empty. Please enter some text.")
                continue
            if user_input.lower() == "exit":
                print("Exiting. Goodbye!")
                break

            predictions = predict(user_input, model, tokenizer)
            predicted_class = predictions.index(max(predictions))

            print(f"Predictions (class probabilities): {predictions}")
            print(f"Classified as: {predicted_class}")
        except Exception as e:
            print(f"An error occurred: {e}")


OSError: model_directory is not a local folder and is not a valid model identifier listed on 'https://huggingface.co/models'
If this is a private repository, make sure to pass a token having permission to this repo either by logging in with `huggingface-cli login` or by passing `token=<your_token>`