In [1]:

import torch
from torch.jit import script, trace
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import numpy as np
import math
import pickle

import pandas as pd

from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import tqdm
import nltk
from transformers import get_linear_schedule_with_warmup



In [2]:
print(torch.cuda.is_available())
if torch.cuda.is_available():
  device = torch.device("cuda")
else:
  device = torch.device("cpu")
print("Using device:", device)


True
Using device: cuda


In [3]:
class SafetyDataset(Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        return self.df.iloc[idx]

In [4]:
!pip install transformers



In [4]:
def transformer_collate_fn(batch, tokenizer):

  sentences, labels = [], []

  for data in batch:
    sentences.append(data['text'])
    labels.append(data['binary_label'])

  tokenizer_output = tokenizer(sentences, padding=True, truncation=True, max_length=512, return_tensors="pt")
  labels = F.one_hot(torch.tensor(labels).type(torch.LongTensor), num_classes=2).type(torch.FloatTensor)

  return tokenizer_output, labels

# Finetuning

In [5]:
BATCH_SIZE = 16
LR = 5e-5
WEIGHT_DECAY = 0
N_EPOCHS = 1
CLIP = 1.0

In [12]:
def evaluate(model,
                 dataloader,
                 device):

    model.eval()
    bce_loss = nn.BCELoss()

    epoch_loss = 0
    with torch.no_grad():
      for i, batch in enumerate(dataloader):

          input, labels = batch[0], batch[1]

          input = input.to(device)
          labels = labels.to(device)

          output = model(**input, labels=labels)
          loss = output.loss
          epoch_loss += loss.item()

    return epoch_loss/len(dataloader)



In [13]:
def evaluate_acc(model,
                 dataloader,
                 device):

    model.eval()

    epoch_loss = 0
    with torch.no_grad():

      total_correct = 0
      total = 0

      complete_output = []
      complete_true = []

      for i, batch in enumerate(dataloader):

          input, labels = batch[0], batch[1]

          input = input.to(device)
          labels = labels.to(device)

          output = model(**input, labels=labels)

          preds = output.logits
          output_class = torch.argmax(preds, dim=1)
          true_class = torch.argmax(labels, dim=1)

          complete_output.extend(output_class.cpu().numpy())
          complete_true.extend(true_class.cpu().numpy())

          total_correct += torch.sum(torch.where(output_class == true_class.to(device), 1, 0))
          total += labels.size()[0]

    print(f"Total Correct = {total_correct}")
    return total_correct / total

In [14]:
from tqdm.notebook import trange, tqdm
def train(model, dataloader, opt, device, clip: float, scheduler = None):

  model.train()
  bce_loss = nn.BCELoss()

  epoch_loss = 0.0
  progress_bar = tqdm(dataloader, desc=f"Epoch {epoch + 1}", leave=True)
  for batch in progress_bar:
      input, labels = batch[0], batch[1]

      input = input.to(device)
      labels = labels.to(device)

      opt.zero_grad()

      output = model(**input, labels=labels)
      loss = output.loss


      loss.backward()
      torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

      opt.step()

      if scheduler is not None:
        scheduler.step()

      epoch_loss += loss.item()
      progress_bar.set_postfix({'loss': loss.item()})

  return epoch_loss / len(dataloader)

In [9]:
files = ["/content/discrimination.csv", "/content/hci_harms_df.csv", "/content/malicious_activity.csv", "/content/misinfo.csv"]

In [10]:
domains = ["discrimination", "hci", "malicious_activity", "misinfo"]

In [None]:
from torch.utils.data import random_split
from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification
from functools import partial
import time

bert_model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(bert_model_name)

for i,file in enumerate(files):
  print(f"Domain: {domains[i]}")
  bert_model = AutoModelForSequenceClassification.from_pretrained(bert_model_name, num_labels = 2, output_hidden_states = True)
  optimizer = optim.Adam(bert_model.parameters(), lr=LR)

  bert_model = bert_model.to(device)

  for params in bert_model.base_model.parameters():
      params.requires_grad = True


  dataset = SafetyDataset(pd.read_csv(file))

  total_size = len(dataset)
  train_size = int(0.8 * total_size)  # 80% for training
  val_size = int(0.1 * total_size)    # 10% for validation
  test_size = total_size - train_size - val_size


  best_val_loss = float('inf')

  train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

  train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=BATCH_SIZE, \
                                collate_fn=partial(transformer_collate_fn, tokenizer=tokenizer))
  val_dataloader = DataLoader(val_dataset, shuffle=True, batch_size=BATCH_SIZE, \
                                collate_fn=partial(transformer_collate_fn, tokenizer=tokenizer))
  test_dataloader = DataLoader(test_dataset, shuffle=True, batch_size=BATCH_SIZE, \
                                collate_fn=partial(transformer_collate_fn, tokenizer=tokenizer))

  scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=10, num_training_steps=N_EPOCHS*len(train_dataloader))


  for epoch in range(N_EPOCHS):

    train_loss = train(bert_model, train_dataloader, optimizer, device, CLIP, scheduler)
    train_acc = evaluate_acc(bert_model, train_dataloader, device)
    print(f"Train accuracy: {train_acc}, train_loss: {train_loss}")

    valid_loss = evaluate(bert_model, val_dataloader, device)
    valid_acc = evaluate_acc(bert_model, val_dataloader, device)
    print(f"Valid accuracy: {valid_acc}, Valid_loss: {valid_loss}")

    if valid_loss<best_val_loss:
      best_val_loss = valid_loss
      torch.save(bert_model.state_dict(), f"{domains[i]}_bert.pt")

  bert_model.load_state_dict(f"{domains[i]}_bert.pt")
  test_loss = evaluate(bert_model, test_dataloader, device)
  test_acc = evaluate_acc(bert_model, test_dataloader, device)







