In [None]:
import torch
from torch import nn
from transformers import BertTokenizer, BertModel

In [None]:
class BERTBadWordClassifier(nn.Module):
  def __init__(self, bert_model_name, num_classes):
    super(BERTBadWordClassifier, self).__init__()
    self.bert = BertModel.from_pretrained(bert_model_name)
    self.dropout = nn.Dropout(0.1)
    self.fc = nn.Linear(self.bert.config.hidden_size, num_classes)

  def forward(self, input_ids, attention_mask):
    outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
    pooled_output = outputs.pooler_output
    x = self.dropout(pooled_output)
    logits = self.fc(x)
    return logits

In [None]:
def predict_bad_word(text, model, tokenizer, device, max_length=128):
    model.eval()
    encoding = tokenizer(text, return_tensors='pt', max_length=max_length, padding='max_length', truncation=True).to(device)

    with torch.no_grad():
        outputs = model(input_ids=encoding['input_ids'], attention_mask=encoding['attention_mask'])
        probs = torch.nn.functional.softmax(outputs, dim=1)
        pred = torch.argmax(probs, dim=1).item()

    return ("bad" if pred == 1 else "nice", probs[0][pred].item())


In [None]:
bert_model_name = "bert-base-cased"
num_classes = 2
max_length = 128

In [None]:
device = torch.device("mps" if torch.backends.mps.is_built() else "cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizer.from_pretrained(bert_model_name)
model = BERTBadWordClassifier(bert_model_name, num_classes).to(device)
model.load_state_dict(torch.load('./api/model/classifier.pth', map_location=device))

In [None]:
test_texts = [
    "Fuck",
    "Motherfucker",
    "Hello",
    "cunt",
    "Stupid bitch",
    "OMG"
]

for test_text in test_texts:
  prediction, score = predict_bad_word(test_text, model,  , device)
  print(test_text)
  print(f'Prediction: {prediction} {score}\n')