In [None]:
import json
import os
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, DistilBertModel
from typing import List
from dataclasses import dataclass
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from google.colab import drive


In [None]:
@dataclass
class HateSpeechExample:
  text: str
  label: int
  rules : list
  exemplars: str

  @staticmethod
  def from_list(sample):
    text, label = sample[0], sample[1]

    rules = []
    exemplars = ""
    if label == "normal":
      label = 0
    else:
      label = 1
    return HateSpeechExample(text, label, rules, exemplars)

In [None]:
import random
class HateSpeechDataset(Dataset):
  tokenizer = None

  def __init__(self, raw_data_list, tokenizer, rules, exemplars):
    HateSpeechDataset.tokenizer = tokenizer
    self.data = [HateSpeechExample.from_list(sample) for sample in raw_data_list]
    self.rules = rules
    self.exemplars = exemplars

  def __len__(self):
    return len(self.data)

  def find_rules(self, idx):
    text = self.data[idx].text
    applied_rules = []
    exemplars=""
    for rule in self.rules:
      if text.find(rule)!=-1:
        applied_rules.append(rule)
        rule_exemplars = " ".join(self.exemplars[rule])
        exemplars += " " + rule_exemplars

      if len(applied_rules) == 0:
        samples = random.sample(list(self.exemplars.values()),10)[0]
        samples = [x for item in samples for x in item]
        exemplars=" ".join(samples)

    return applied_rules, exemplars

  def __getitem__(self, idx):
    self.data[idx].rules , self.data[idx].exemplars= self.find_rules(idx)
    return self.data[idx]

  def __iter__(self):
    return iter(self.data)

  @staticmethod
  def collate_fn(samples: List[HateSpeechExample]):
    # get the encoding of each thing
    # get the labels of each thing
    texts = [sample.text for sample in samples]
    labels = [sample.label for sample in samples]
    exemplars = [sample.exemplars for sample in samples]

    text_encoding = HateSpeechDataset.tokenizer(texts,
                                   padding='max_length',
                                   max_length=512,
                                   truncation=True,
                                   return_tensors="pt")
    exemplar_encoding = HateSpeechDataset.tokenizer(exemplars,
                                   padding='max_length',
                                   max_length=512,
                                   truncation=True,
                                   return_tensors="pt")

    return {'text_encoding': text_encoding, 'exemplar_encoding': exemplar_encoding, 'labels' : torch.tensor(labels, dtype=torch.long)}

In [None]:
all_rules_path = ["./CAD/cad_rules.json", "./hatexplain/hatexplain_rules.json", "./jigsaw/hate_abuse_sample.json"]

In [None]:
def get_dataset(file, tokenizer,  rules, exemplars):
  with open(file, 'r') as f:
    data = json.load(f)
  return HateSpeechDataset(data, tokenizer,  rules, exemplars)

def get_rules(rules_path = ['hatexplain_rules.json'], exemplar_path = 'rule_to_exemplar.json'):
  rules = []
  for rule_set in rules_path:
    with open(rule_set, 'r') as f:
      new_rules = json.load(f)
    rules += new_rules


  with open(exemplar_path, 'r') as f:
    exemplars = json.load(f)

  return rules, exemplars

In [None]:
def initialize_datasets(tokenizer,  rules, exemplars):
  # return a dictionary of train, test, validation datasets
  datasets = {}
  data_names = ['test', 'train', 'val']
  for data_name in data_names:
    datasets[data_name] = get_dataset(f'{data_name}.json', tokenizer, rules, exemplars)
  return datasets

In [None]:
tokenizer = AutoTokenizer.from_pretrained("Geotrend/distilbert-base-zh-cased")

rules, exemplars = get_rules(['hatexplain_rules.json'],'rule_to_exemplar.json')
datasets = initialize_datasets(tokenizer, rules, exemplars)



### similarity calculation

In [None]:
def sim(x_e, x_t):
  cos_sim = nn.CosineSimilarity(dim=1, eps=1e-6)
  return cos_sim(x_e, x_t)

def get_loss(correct_labels, distance, margin=1e-8):
  correct_labels_float = correct_labels.to(dtype=torch.float32)
  margin_distance = torch.max(margin - distance, torch.zeros_like(distance))
  return torch.mean(0.5*(correct_labels_float * torch.square(distance) + (1-correct_labels_float)*torch.square(margin_distance)))

In [None]:
# what now?
# create a training loop for model
import torch.nn as nn
from torch.optim import Optimizer
from tqdm import tqdm

def train_one_epoch(exemplar_encoder: nn.Module, text_encoder: nn.Module, dataloader: DataLoader, loss, exemplar_optimizer: Optimizer, text_optimizer: Optimizer, epoch: int, k = 0.5):
    """
    Train the model for one epoch.
    :param model: A pre-trained model loaded from transformers. (e.g., RobertaForSequenceClassification https://huggingface.co/docs/transformers/v4.37.0/en/model_doc/roberta#transformers.RobertaForSequenceClassification)
    :param dataloader: A train set dataloader for HateXplain rules and texts.
    :param optimizer: An instance of Pytorch optimizer. (e.g., AdamW https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html)
    :param epoch: An integer denoting current epoch.
    Trains model for one epoch.
    """
    exemplar_encoder.train()
    text_encoder.train()

    with tqdm(dataloader, desc=f"Train Ep {epoch}", total=len(dataloader)) as tq:
        for batch in tq:
            text_encoding = batch['text_encoding'].to(text_encoder.device)

            label_encoding = batch['labels'].to(text_encoder.device)
            exemplar_encoding = batch['exemplar_encoding'].to(exemplar_encoder.device)

            exemplar_output = exemplar_encoder(**exemplar_encoding).last_hidden_state
            text_output = text_encoder(**text_encoding).last_hidden_state

            t_pooled_output = torch.mean(text_output, dim=1)
            e_pooled_output = torch.mean(exemplar_output, dim=1)

            similarity = sim(e_pooled_output, t_pooled_output)

            loss = get_loss(label_encoding, similarity)

            predicted_labels = torch.where(similarity>=k, 1, 0)

            exemplar_optimizer.zero_grad()
            loss.backward(retain_graph=True )
            exemplar_optimizer.step()


            text_optimizer.zero_grad()
            loss.backward()
            text_optimizer.step()

            tq.set_postfix({"loss": loss.detach().item()}) # for printing better-looking progress bar

In [None]:
def evaluate(exemplar_encoder: nn.Module, text_encoder: nn.Module, dataloader: DataLoader, loss, k = 0.5) -> float:
    """
    Evaluate model on the dataloader and compute the accuracy.
    :param model: A language model loaded from transformers. (e.g., RobertaForSequenceClassification https://huggingface.co/docs/transformers/v4.37.0/en/model_doc/roberta#transformers.RobertaForSequenceClassification)
    :param dataloader: A validation / test set dataloader for SST2Dataset
    :return: A floating number representing the accuracy of model in the given dataset.
    """
    exemplar_encoder.eval()
    text_encoder.eval()

    all_predictions = []
    all_labels = []
    all_loss = []
    with tqdm(dataloader, desc=f"Eval", total=len(dataloader)) as tq:
        for batch in tq:
            with torch.no_grad():
                text_encoding = batch['text_encoding'].to(text_encoder.device)
                label_encoding = batch['labels'].to(text_encoder.device)
                exemplar_encoding = batch['exemplar_encoding'].to(exemplar_encoder.device)
                exemplar_output = exemplar_encoder(**exemplar_encoding).last_hidden_state
                text_output = text_encoder(**text_encoding).last_hidden_state

                t_pooled_output = torch.mean(text_output, dim=1)
                e_pooled_output = torch.mean(exemplar_output, dim=1)

                similarity = sim(e_pooled_output, t_pooled_output)
                predicted_labels = torch.where(similarity>=k, 1, 0).clone().detach()
                loss = get_loss(label_encoding, similarity)

                all_predictions += predicted_labels
                all_labels += label_encoding
                all_loss.append(loss.detach.item())

    all_predictions = torch.Tensor(all_predictions)
    all_labels = torch.Tensor(all_labels)
    accuracy = compute_accuracy(all_predictions, all_labels)
    final_loss = all_loss.mean()
    print(f"Accuracy: {accuracy}")
    return {
        "accuracy": accuracy,
        "loss": final_loss,
    }


def compute_accuracy(predictions: torch.Tensor, labels: torch.Tensor) -> float:
    """
    Given two tensors predictions and labels, compute the accuracy.
    :param predictions: torch.Tensor of size (N,)
    :param labels: torch.Tensor of size (N,)
    :return: A floating number representing the accuracy
    """
    assert predictions.size(-1) == labels.size(-1)

    accuracy = torch.mean(1.0 * (predictions == labels))
    return accuracy.item()

In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel

torch.manual_seed(64)

def main():
    # hyper-parameters (we provide initial set of values here, but you can modify them.)
    batch_size = 16
    learning_rate = 5e-5
    num_epochs = 10
    model_name = "Geotrend/distilbert-base-zh-cased"

    tokenizer = AutoTokenizer.from_pretrained(model_name)



    text_encoder = AutoModel.from_pretrained(model_name).cuda()
    exemplar_encoder = AutoModel.from_pretrained(model_name).cuda()

    text_optimizer = torch.optim.AdamW(params=text_encoder.parameters(), lr=learning_rate, eps=1e-8)
    exemplar_optimizer = torch.optim.AdamW(params=exemplar_encoder.parameters(), lr=learning_rate, eps=1e-8)

    rules, exemplars = get_rules(['hatexplain_rules.json'],'rule_to_exemplar.json')

    datasets = initialize_datasets(tokenizer, rules, exemplars)

    train_dataloader = DataLoader(datasets['train'],
                                   batch_size=batch_size,
                                   shuffle=True,
                                   collate_fn=HateSpeechDataset.collate_fn,
                                   num_workers=2)

    validation_dataloader = DataLoader(datasets['val'],
                                   batch_size=batch_size,
                                   shuffle=False,
                                   collate_fn=HateSpeechDataset.collate_fn,
                                   num_workers=2)

    train_acc_history, val_acc_history = [], []
    train_loss_history, val_loss_history = [], []

    best_acc = 0.0
    for epoch in range(1, num_epochs + 1):
        train_one_epoch(exemplar_encoder, text_encoder, train_dataloader, get_loss , exemplar_optimizer, text_optimizer, epoch, k = 0.5)
        train_acc = evaluate(exemplar_encoder, text_encoder, train_dataloader, get_loss, k = 0.5)
        valid_acc = evaluate(exemplar_encoder, text_encoder, validation_dataloader,get_loss, k = 0.5)

        train_acc_history.append(train_acc['accuracy'])
        val_acc_history.append(valid_acc['accuracy'])
        train_loss_history.append(train_acc['loss'])
        val_loss_history.append(valid_acc['loss'])

        if valid_acc > best_acc:
          torch.save(exemplar_encoder, './checkpoints/best_exemplar_encoder.pth')
          torch.save(text_encoder, './checkpoints/best_text_encoder.pth')
          best_acc = valid_acc

    return train_acc_history, val_acc_history

In [None]:
if not os.path.exists('./checkpoints'):
    # If the directory does not exist, create it
    os.makedirs('./checkpoints')
    print("Directory created")
else:
    print("Directory already exists")

In [None]:
torch.autograd.set_detect_anomaly(True)
main()