In [1]:
import torch
import numpy as np
import pickle as pkl
import pandas as pd
import torch.utils.data as data
from tqdm.notebook import trange, tqdm
from sklearn.metrics import roc_auc_score
from sklearn.metrics import accuracy_score
from config import *

In [3]:
print(torch.cuda.is_available())
if torch.cuda.is_available():
  DEVICE = torch.device("cuda")
else:
  DEVICE = torch.device("cpu")
print("Using device:", DEVICE)


True
Using device: cuda


In [4]:
DATA_DRIVE = "./BERTembeddings/"
RESULTS_DRIVE = "./results/"

# Dataset and Dataloader

In [5]:
class ModelActivations(data.Dataset):
  """Activations dataset"""

  def __init__(self, file_name_lst):

    layer_arr = np.array([])
    labels_lst = []

    for file_name in file_name_lst:

      with open(DATA_DRIVE+file_name, "rb") as f:
        inp_dict = pkl.load(f)

      curr_activations = np.asarray(inp_dict["activations"])

      if len(layer_arr)==0:
        layer_arr = curr_activations
      else:
        layer_arr = np.concatenate([layer_arr, curr_activations], axis=0)

      curr_labels = inp_dict["labels"]
      labels_lst.extend(curr_labels)

    #convert labels to 1 or 0
    binary_labels_lst = [index_class_mapping[x] for x in labels_lst]

    self.activations = layer_arr
    self.labels = np.asarray(binary_labels_lst)


  def __len__(self):
    return len(self.activations)

  def __getitem__(self, idx):
    curr_activations = self.activations[idx]
    curr_label = self.labels[idx]
    curr_activations = np.expand_dims(curr_activations, axis=0)
    curr_label = np.expand_dims(curr_label, 0)

    return torch.tensor(curr_activations, dtype=torch.float32), torch.tensor(curr_label, dtype=torch.float32)



# Binary Classifier Model

In [11]:
import torch.nn as nn

class SafetyClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        # self.hidden_1 = nn.Linear(4096, 1024)
        # self.relu = nn.ReLU()
        # self.hidden_2 = nn.Linear(1024, 256)
        # self.hidden_3 = nn.Linear(256, 32)
        # self.output = nn.Linear(32, 1)

        self.output = nn.Linear(768, 1)
        self.sigmoid = nn.Sigmoid()


    def forward(self, batch):
      batch = [x.to(DEVICE) for x in batch]
      x, labels = batch
    #   x = self.relu(self.hidden_1(x))
    #   x = self.relu(self.hidden_2(x))
    #   x = self.relu(self.hidden_3(x))
      x = self.sigmoid(self.output(x))
      return x, labels

In [7]:
def validate (model, data_loader, criterion):
  bce_loss = nn.BCELoss()

  total_count = 0

  true_lst = []
  predictions_lst = []

  model.eval()
  with torch.no_grad():
    with tqdm(data_loader, unit="batch", total=len(data_loader)) as batch_iterator:
      val_loss = 0.0
      for i, batch_data in enumerate(batch_iterator, start=1):

          output, target = model.forward(batch_data)

          output = output.flatten()
          target = target.flatten()

          #loss
          loss = bce_loss(output, target)
          val_loss += loss.item()

          #convert predictions to list
          true_lst.extend(target.tolist())
          predictions_lst.extend(output.tolist())


          #total count
          total_count += len(output)

          batch_iterator.set_postfix(mean_loss=val_loss / i, current_loss=loss.item(), total_loss = val_loss)

  # #accuracy
  # rounded_truth = [x>=THRESHOLD for x in true_lst]
  # rounded_preds = [x>=THRESHOLD for x in predictions_lst]
  # accuracy = accuracy_score(rounded_truth, rounded_preds)

  # #roc auc score
  # roc_auc = roc_auc_score(true_lst, predictions_lst)


  return {'Validation Loss': val_loss, 'prediction': predictions_lst, 'label': true_lst}

In [12]:
def training(model, train_dataloader, val_dataloader, num_epochs, criterion, optimizer, file_name=None):
  # val_loss_lst = []
  # acc_lst = []
  train_loss_lst = []
  # roc_auc_lst = []
  # max_roc_score = 0
  best_val_loss = None
  epochs_without_improvement = 0
  patience = 3

  bce_loss = nn.BCELoss()
  model.train()
  
  for epoch in trange(num_epochs, desc="training", unit="epoch"):

    with tqdm(train_dataloader, desc="epoch {}".format(epoch + 1), unit="batch", total=len(train_dataloader)) as batch_iterator:
        total_loss = 0.0
        running_loss = 0.0
        for i, batch_data in enumerate(batch_iterator, start=1):
            optimizer.zero_grad()

            output, target = model(batch_data)
            output = torch.squeeze(output,dim=2)

            loss = criterion(output, target)
            total_loss += loss.item()
            running_loss += bce_loss(output, target).item()

            loss.backward()
            optimizer.step()

            batch_iterator.set_postfix(mean_loss=total_loss / i, current_loss=loss.item(), total_loss=total_loss)



        train_loss_lst.append(total_loss)

    val_loss = validate(model, val_dataloader, criterion)['Validation Loss']

    if best_val_loss is None or val_loss < best_val_loss:
       best_val_loss = val_loss
       epochs_without_improvement = 0
    else:
       epochs_without_improvement += 1
    
    if(epochs_without_improvement >= patience):
       break


  print("Completed Training...")
  # val_loss, accuracy, roc_score, predictions_lst, true_lst = validate(model, val_dataloader, criterion)
  # val_loss_lst.append(val_loss)

  # acc_lst.append(accuracy)
  # print(f"Accuracy: {accuracy}")

  # roc_auc_lst.append(roc_score)
  # print(f"AUC-ROC score: {roc_score}")

  # if roc_score>max_roc_score:
  #   max_roc_score = roc_score

  #   new_file_name = f"{RESULTS_DRIVE}/true_{file_name}"
  #   with open(new_file_name, "wb+") as f:
  #     pkl.dump(true_lst, f)

  #   new_file_name = f"{RESULTS_DRIVE}/pred_{file_name}"
  #   with open(new_file_name, "wb+") as f:
  #     pkl.dump(predictions_lst, f)





    # if file_path is not None:
    #   torch.save(model.state_dict(), file_path)

  # return model, val_loss_lst, acc_lst, roc_auc_lst
  return model

# Runner

In [13]:
filename = "bert"
config_results = {}

for domain in list(DOMAIN_INDEX_MAPPING.keys()):
  test_domain = domain
  if test_domain != "Information Hazards":
    print(f"Test domain is {domain}")
    train_file_lst = []

    #get the test files list
    test_domain_index = DOMAIN_INDEX_MAPPING[domain]
    test_file_lst = DOMAIN_FILE_MAPPING[test_domain_index]

    #get the train files list
    for key in list(DOMAIN_FILE_MAPPING.keys()):
      if key != test_domain_index:
        train_file_lst.extend(DOMAIN_FILE_MAPPING[key])

    dataset = ModelActivations(file_name_lst=train_file_lst)
    n = len(dataset)
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [int(0.9 * n), n - int(0.9 * n)])

    train_dataloader = torch.utils.data.DataLoader(
          train_dataset,
          batch_size=32,
          shuffle=True
    )

    val_dataloader = torch.utils.data.DataLoader(
          val_dataset,
          batch_size=32,
          shuffle=True
    )

    test_dataset = ModelActivations(file_name_lst=test_file_lst)
    test_dataloader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=16,
        shuffle=True
    )


    model = SafetyClassifier().to(device=DEVICE)
    criterion = torch.nn.BCELoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005)

    # model, val_loss_lst, acc_lst, roc_lst = training(model, train_dataloader, val_dataloader, \
    #                                                  10, criterion, optimizer, "fusion_max.pkl")

    # domain_loss_mapping[test_domain] = min(val_loss_lst)
    # domain_acc_mapping[test_domain] = max(acc_lst)
    # domain_roc_auc_mapping[test_domain] = max(roc_lst)

    model = training(model, train_dataloader, val_dataloader, 10, criterion, optimizer)
    
    results = validate(model, test_dataloader, criterion)
    
    config_results[test_domain] = results

pkl.dump(config_results, open(f"{RESULTS_DRIVE}/{filename}.pkl", "wb"))



Test domain is Discrimination, Exclusion, Toxicity


training:   0%|          | 0/10 [00:00<?, ?epoch/s]

epoch 1:   0%|          | 0/1332 [00:00<?, ?batch/s]

  0%|          | 0/148 [00:00<?, ?batch/s]

epoch 2:   0%|          | 0/1332 [00:00<?, ?batch/s]

  0%|          | 0/148 [00:00<?, ?batch/s]

epoch 3:   0%|          | 0/1332 [00:00<?, ?batch/s]

  0%|          | 0/148 [00:00<?, ?batch/s]

epoch 4:   0%|          | 0/1332 [00:00<?, ?batch/s]

  0%|          | 0/148 [00:00<?, ?batch/s]

epoch 5:   0%|          | 0/1332 [00:00<?, ?batch/s]

  0%|          | 0/148 [00:00<?, ?batch/s]

epoch 6:   0%|          | 0/1332 [00:00<?, ?batch/s]

  0%|          | 0/148 [00:00<?, ?batch/s]

epoch 7:   0%|          | 0/1332 [00:00<?, ?batch/s]

  0%|          | 0/148 [00:00<?, ?batch/s]

epoch 8:   0%|          | 0/1332 [00:00<?, ?batch/s]

  0%|          | 0/148 [00:00<?, ?batch/s]

epoch 9:   0%|          | 0/1332 [00:00<?, ?batch/s]

  0%|          | 0/148 [00:00<?, ?batch/s]

epoch 10:   0%|          | 0/1332 [00:00<?, ?batch/s]

  0%|          | 0/148 [00:00<?, ?batch/s]

Completed Training...


  0%|          | 0/1304 [00:00<?, ?batch/s]

Test domain is Misinformation


training:   0%|          | 0/10 [00:00<?, ?epoch/s]

epoch 1:   0%|          | 0/1175 [00:00<?, ?batch/s]

  0%|          | 0/131 [00:00<?, ?batch/s]

epoch 2:   0%|          | 0/1175 [00:00<?, ?batch/s]

  0%|          | 0/131 [00:00<?, ?batch/s]

epoch 3:   0%|          | 0/1175 [00:00<?, ?batch/s]

  0%|          | 0/131 [00:00<?, ?batch/s]

epoch 4:   0%|          | 0/1175 [00:00<?, ?batch/s]

  0%|          | 0/131 [00:00<?, ?batch/s]

epoch 5:   0%|          | 0/1175 [00:00<?, ?batch/s]

  0%|          | 0/131 [00:00<?, ?batch/s]

epoch 6:   0%|          | 0/1175 [00:00<?, ?batch/s]

  0%|          | 0/131 [00:00<?, ?batch/s]

epoch 7:   0%|          | 0/1175 [00:00<?, ?batch/s]

  0%|          | 0/131 [00:00<?, ?batch/s]

epoch 8:   0%|          | 0/1175 [00:00<?, ?batch/s]

  0%|          | 0/131 [00:00<?, ?batch/s]

epoch 9:   0%|          | 0/1175 [00:00<?, ?batch/s]

  0%|          | 0/131 [00:00<?, ?batch/s]

epoch 10:   0%|          | 0/1175 [00:00<?, ?batch/s]

  0%|          | 0/131 [00:00<?, ?batch/s]

Completed Training...


  0%|          | 0/1652 [00:00<?, ?batch/s]

Test domain is HCI harms


training:   0%|          | 0/10 [00:00<?, ?epoch/s]

epoch 1:   0%|          | 0/1722 [00:00<?, ?batch/s]

  0%|          | 0/192 [00:00<?, ?batch/s]

epoch 2:   0%|          | 0/1722 [00:00<?, ?batch/s]

  0%|          | 0/192 [00:00<?, ?batch/s]

epoch 3:   0%|          | 0/1722 [00:00<?, ?batch/s]

  0%|          | 0/192 [00:00<?, ?batch/s]

epoch 4:   0%|          | 0/1722 [00:00<?, ?batch/s]

  0%|          | 0/192 [00:00<?, ?batch/s]

epoch 5:   0%|          | 0/1722 [00:00<?, ?batch/s]

  0%|          | 0/192 [00:00<?, ?batch/s]

epoch 6:   0%|          | 0/1722 [00:00<?, ?batch/s]

  0%|          | 0/192 [00:00<?, ?batch/s]

epoch 7:   0%|          | 0/1722 [00:00<?, ?batch/s]

  0%|          | 0/192 [00:00<?, ?batch/s]

epoch 8:   0%|          | 0/1722 [00:00<?, ?batch/s]

  0%|          | 0/192 [00:00<?, ?batch/s]

epoch 9:   0%|          | 0/1722 [00:00<?, ?batch/s]

  0%|          | 0/192 [00:00<?, ?batch/s]

epoch 10:   0%|          | 0/1722 [00:00<?, ?batch/s]

  0%|          | 0/192 [00:00<?, ?batch/s]

Completed Training...


  0%|          | 0/437 [00:00<?, ?batch/s]

Test domain is Malicious Uses


training:   0%|          | 0/10 [00:00<?, ?epoch/s]

epoch 1:   0%|          | 0/1553 [00:00<?, ?batch/s]

  0%|          | 0/173 [00:00<?, ?batch/s]

epoch 2:   0%|          | 0/1553 [00:00<?, ?batch/s]

  0%|          | 0/173 [00:00<?, ?batch/s]

epoch 3:   0%|          | 0/1553 [00:00<?, ?batch/s]

  0%|          | 0/173 [00:00<?, ?batch/s]

epoch 4:   0%|          | 0/1553 [00:00<?, ?batch/s]

  0%|          | 0/173 [00:00<?, ?batch/s]

epoch 5:   0%|          | 0/1553 [00:00<?, ?batch/s]

  0%|          | 0/173 [00:00<?, ?batch/s]

epoch 6:   0%|          | 0/1553 [00:00<?, ?batch/s]

  0%|          | 0/173 [00:00<?, ?batch/s]

epoch 7:   0%|          | 0/1553 [00:00<?, ?batch/s]

  0%|          | 0/173 [00:00<?, ?batch/s]

epoch 8:   0%|          | 0/1553 [00:00<?, ?batch/s]

  0%|          | 0/173 [00:00<?, ?batch/s]

epoch 9:   0%|          | 0/1553 [00:00<?, ?batch/s]

  0%|          | 0/173 [00:00<?, ?batch/s]

epoch 10:   0%|          | 0/1553 [00:00<?, ?batch/s]

  0%|          | 0/173 [00:00<?, ?batch/s]

Completed Training...


  0%|          | 0/813 [00:00<?, ?batch/s]