In [1]:
import torch
from transformers import DistilBertTokenizer
import torch.nn as nn
from transformers import DistilBertModel
from utilities.toxic_comment_classifier import ToxicCommentClassifier
import warnings
warnings.filterwarnings('ignore')

## Loading Model

In [14]:
# Load the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model = ToxicCommentClassifier().to(device)
# Load the saved state_dict
# model.load_state_dict(torch.load("./final_models/best_model_b.pth", map_location=device))
model.load_state_dict(torch.load("./final_models/best_model.pth", map_location=device))


cpu


<All keys matched successfully>

In [15]:
# Set model to evaluation mode
model.eval()

ToxicCommentClassifier(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): DistilBertSdpaAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            (

In [6]:
# Initialize the tokenizer
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

## Prediction Function
 This function tokenizes the input text, passes it through the model, and maps the output logits to the corresponding class labels.

In [16]:
# Function for making predictions
def predict_toxicity(comment):
    inputs = tokenizer(comment, padding='max_length', max_length=128, truncation=True, return_tensors="pt")
    input_ids, attention_mask = inputs['input_ids'].to(device), inputs['attention_mask'].to(device)

    with torch.no_grad():
        output = model(input_ids, attention_mask)
        _, prediction = torch.max(output, dim=1)

    if prediction.item() == 0:
        return 'Hate'
    elif prediction.item() == 1:
        return 'Offensive'
    else:
        return 'Neutral'

In [23]:
comment = "I hate those black people"
print("Prediction:", predict_toxicity(comment))

Prediction: Hate
