In [11]:
import pandas as pd

df_muslim = pd.read_csv("data/toxicbias_train.csv")
df_muslim = df_muslim[df_muslim['rationale'] == 'prejudice against muslims']

In [42]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from src.Helpers import *
import torch

DEVICE = torch.device(
    "cuda") if torch.cuda.is_available() else torch.device("cpu")

import pandas as pd
import torch
import time


def standardise_results(results):
    label_mapping = {'contradiction': results['contradiction'],
                     'neutral': results['neutral'],
                     'entailment': results['entailment']}
    return label_mapping


def convert_probabilities(probabilities, label_mapping):
    probabilities_list = probabilities.tolist()[0]  # Convert the tensor to a list and extract the first (and only) batch
    return {name: round(float(pred) * 100, 1) for pred, name in zip(probabilities_list, label_mapping)}



def get_random_samples(csv_filename, num_samples):
    df = pd.read_csv(csv_filename)
    random_samples = df.sample(n=num_samples)
    return random_samples


def benchmark_test(premise, hypothesis, model, num_runs):
    # Setup device
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    results = []
    for _ in range(num_runs):
        start_time = time.time()
        _, _ = model(premise, hypothesis)
        execution_time = time.time() - start_time
        results.append(execution_time)

    avg_execution_time = sum(results) / len(results)
    return device, avg_execution_time


In [43]:
premises = df_muslim.sample(10)['comment_text'].values
premises = [str(p) for p in premises]

hypotheses = [
  'The use of violence is encouraged or desired against a religion.',
  'A religion should be discriminated against.',
  'Members of a religion should be excluded and marginalized.',
  'Members of a religion should remain in their respective countries.',
  'Members of a religion are compared to animals.',
  'A religion or members of a religion are compared to devils or demons.',
  'The integration of a religion should be avoided in a country or society.',
  'A religion or members of a religion are stupid.',
  'A religion is inferior when compared to another religion.',
  'Religious practices of a particular faith should be restricted or banned in public spaces.',
  'Religious attire or symbols should be prohibited.',
  'A religion promotes violence or intolerance.',
  'Members of a religion are a threat.',
  'Members of a religion are not trustworthy or loyal.'
 ]


In [51]:
def bart_nli_batched(premises, hypotheses, batch_size=32):
    model_name = 'facebook/bart-large-mnli'
    model = AutoModelForSequenceClassification.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # Convert scalar inputs to lists
    if not isinstance(premises, (list, tuple)):
        premises = [premises]
    if not isinstance(hypotheses, (list, tuple)):
        hypotheses = [hypotheses]

    num_premises = len(premises)
    num_hypotheses = len(hypotheses)
    results = []

    for i in range(num_premises):
        for j in range(num_hypotheses):
            # Tokenize the input pair
            inputs = tokenizer(premises[i], hypotheses[j], return_tensors='pt', padding=True, truncation=True)

            # Run the input through the model
            logits = model(**inputs.to(DEVICE)).logits

            # Get probabilities and labels for the input
            probabilities = torch.softmax(logits, dim=1)
            label_mapping = ['contradiction', 'neutral', 'entailment']
            label = label_mapping[probabilities.argmax(dim=1)]

            probabilities = convert_probabilities(probabilities, label_mapping)

            results.append(label,standardise_results(probabilities))

    return results

    


In [50]:

results = bart_nli_batched(premises, hypotheses)
print(results)

[['entailment', {'contradiction': 16.5, 'neutral': 23.1, 'entailment': 60.3}], ['contradiction', {'contradiction': 51.4, 'neutral': 23.9, 'entailment': 24.7}], ['entailment', {'contradiction': 4.4, 'neutral': 30.2, 'entailment': 65.3}], ['contradiction', {'contradiction': 92.4, 'neutral': 6.3, 'entailment': 1.3}], ['entailment', {'contradiction': 1.0, 'neutral': 9.1, 'entailment': 89.9}], ['entailment', {'contradiction': 20.9, 'neutral': 28.0, 'entailment': 51.1}], ['neutral', {'contradiction': 7.6, 'neutral': 56.7, 'entailment': 35.7}], ['neutral', {'contradiction': 40.8, 'neutral': 47.6, 'entailment': 11.5}], ['entailment', {'contradiction': 27.4, 'neutral': 29.8, 'entailment': 42.8}], ['entailment', {'contradiction': 0.2, 'neutral': 6.2, 'entailment': 93.6}], ['neutral', {'contradiction': 0.3, 'neutral': 91.1, 'entailment': 8.6}], ['entailment', {'contradiction': 0.5, 'neutral': 21.9, 'entailment': 77.6}], ['entailment', {'contradiction': 0.5, 'neutral': 18.9, 'entailment': 80.6}], 