In [1]:
import os
import pandas as pd
import numpy as np
import json

In [2]:

# Set environment variables
#might wanna change this
os.environ["HF_HOME"] = "/home/hice1/kpereira6/scratch/ConvAI/hf_cache"
os.environ["HF_DATASETS_CACHE"] = "/home/hice1/kpereira6/scratch/ConvAI/hf_cache"
os.environ["TRANSFORMERS_CACHE"] = "/home/hice1/kpereira6/scratch/ConvAI/hf_cache"

# Verify the variables (optional)
print("HF_HOME:", os.environ.get("HF_HOME"))
print("HF_DATASETS_CACHE:", os.environ.get("HF_DATASETS_CACHE"))
print("TRANSFORMERS_CACHE:", os.environ.get("TRANSFORMERS_CACHE"))

HF_HOME: /home/hice1/kpereira6/scratch/ConvAI/hf_cache
HF_DATASETS_CACHE: /home/hice1/kpereira6/scratch/ConvAI/hf_cache
TRANSFORMERS_CACHE: /home/hice1/kpereira6/scratch/ConvAI/hf_cache


In [3]:
from torch.utils.data import Dataset
import torch
from transformers import BertTokenizer



In [4]:
import torch
import torch.nn as nn
from transformers import BertModel


In [5]:
#defining model arch

class MultiTurnToxicityModelLSTM(nn.Module):
    def __init__(self, bert_model_name='bert-base-uncased', hidden_dim=768, lstm_hidden_dim=512, num_layers=1, dropout_rate=0.3):
        super(MultiTurnToxicityModelLSTM, self).__init__()

        # Load pretrained BERT
        self.bert = BertModel.from_pretrained(bert_model_name)

        # Freeze BERT layers if you do not want to train them
        for param in self.bert.parameters():
            param.requires_grad = False

        # LSTM to model multi-turn context
        self.lstm = nn.LSTM(
            input_size=hidden_dim,
            hidden_size=lstm_hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout_rate if num_layers > 1 else 0  # Apply dropout only if more than one layer
        )

        # Fully connected layer for toxicity classification
        self.fc = nn.Linear(lstm_hidden_dim, 1)

        # Sigmoid activation for binary classification
        self.sigmoid = nn.Sigmoid()

    def forward(self, tokenized_turns, attention_masks):
        turn_embeddings = []

        # Process each turn independently using BERT
        for i in range(len(tokenized_turns)):
            outputs = self.bert(input_ids=tokenized_turns[i], attention_mask=attention_masks[i])
            cls_embedding = outputs.last_hidden_state[:, 0, :]  # CLS token embedding for each turn
            turn_embeddings.append(cls_embedding)

        # Stack the turn embeddings (batch_size, num_turns, hidden_dim)
        turn_embeddings = torch.stack(turn_embeddings, dim=0)

        # Pass the embeddings through the LSTM
        lstm_output, (hidden, _) = self.lstm(turn_embeddings)  # hidden: (num_layers, batch_size, lstm_hidden_dim)

        # Use the hidden state of the last LSTM layer for classification
        final_hidden_state = hidden[-1]  # (batch_size, lstm_hidden_dim)

        # Final classification layer to predict toxicity of the last bot turn
        logits = self.fc(final_hidden_state)  # (batch_size, 1)
        return self.sigmoid(logits)  # Output probability of toxicity

In [15]:
import torch
from transformers import BertTokenizer

def classify_toxicity(conversation, model, tokenizer, max_turns=10, max_length=128, threshold=0.5):
    """
    Classify a conversation as toxic or non-toxic.

    Args:
        conversation (list of str): A list of alternating user and bot turns.
        model (nn.Module): The trained MultiTurnToxicityModelLSTM model.
        tokenizer (BertTokenizer): The BERT tokenizer.
        max_turns (int): Maximum number of turns to consider.
        max_length (int): Maximum token length per turn.
        threshold (float): Threshold for classifying as toxic (default 0.5).

    Returns:
        int: 0 for non-toxic, 1 for toxic.
    """
    # Ensure the conversation has exactly `max_turns`
    if len(conversation) < max_turns:
        conversation = [""] * (max_turns - len(conversation)) + conversation
    else:
        conversation = conversation[-max_turns:]

    # Tokenize each turn
    tokenized_turns = []
    attention_masks = []

    for turn in conversation:
        encoded = tokenizer(turn, padding='max_length', truncation=True, max_length=max_length, return_tensors="pt")
        tokenized_turns.append(encoded['input_ids'].squeeze(0))  # shape (max_length,)
        attention_masks.append(encoded['attention_mask'].squeeze(0))  # shape (max_length,)

    # Convert lists of tensors to tensors of shape (1, num_turns, max_length)
    tokenized_turns = torch.stack(tokenized_turns).unsqueeze(0)  # Add batch dimension
    attention_masks = torch.stack(attention_masks).unsqueeze(0)  # Add batch dimension

    # Forward pass through the model
    model.eval()  # Ensure the model is in evaluation mode
    with torch.no_grad():
        prediction = model(tokenized_turns, attention_masks)

    # Classify based on threshold
    toxicity_probability = prediction.item()
    return toxicity_probability 



In [19]:
# Example usage
if __name__ == "__main__":
    # Example conversation
    example_conversation = [
             "USER: If arrested blame it on the pokemon you were trying to catch while driving",
            "BOT: Not like driving makes them go faster or anything",
            "USER: Notice the quotations.hey Trump you play Pokemon go???I wish I had time. Next question.HEADLINE!!!! TRUMP WISHES HE HAD TIME TO PLAY POKEMON. HE LOVES IT",
            "BOT: You'll be fined for using that meme!",
            "I WILL BOMB YOU"
    ]

    # Load the tokenizer and model
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    model = MultiTurnToxicityModelLSTM(bert_model_name='bert-base-uncased')

    # Load the trained model weights
    #path to checkpoint
    checkpoint_path = "/home/hice1/kpereira6/scratch/ConvAI/Eval/best_model_lstm.pth" 
    model.load_state_dict(torch.load(checkpoint_path))

    # Classify the conversation
    result = classify_toxicity(example_conversation, model, tokenizer)
    print(f"Toxicity classification: {result} (0: Non-toxic, 1: Toxic)")

  model.load_state_dict(torch.load(checkpoint_path))


Toxicity classification: 0.343756765127182 (0: Non-toxic, 1: Toxic)
