In [1]:
import torch
import numpy as np
import pickle as pkl
import pandas as pd
import torch.utils.data as data
from tqdm.notebook import trange, tqdm
from sklearn.metrics import roc_auc_score
from sklearn.metrics import accuracy_score

In [2]:
print(torch.cuda.is_available())
if torch.cuda.is_available():
  DEVICE = torch.device("cuda")
else:
  DEVICE = torch.device("cpu")
print("Using device:", DEVICE)


True
Using device: cuda


In [3]:
DOMAIN_INDEX_MAPPING = {
    "Discrimination, Exclusion, Toxicity": 1,
    "Misinformation": 2,
    "HCI harms": 3,
    "Malicious Uses": 4,
    "Information Hazards": 5
}

In [4]:
DOMAIN_FILE_MAPPING = {
    1: ["toxigen.pkl", "hate_speech.pkl", "adult_content.pkl"],
    2: ["covid_fake_news.pkl", "true_false.pkl", "mis_information.pkl"],
    3: ["student_anxiety.pkl"],
    4: ["bullying.pkl"],
    5: ["do_not_answer_en.pkl"]

}

In [5]:
#1 is not safe 0 is safe

index_class_mapping = {
    0: 1, 1: 0, 2: 1, 3: 0, 4: 1, 5:1, 6:1, 7:1, 8:1, 9:1, 10:1, 11:1, 12:1, 13:1, \
    14:1, 15:1, 16:1, 17:0, 18:1, 19:0, 20:1, 21:0, 22:1, 23:0, 24: 1 ,25:1, 26: 0, \
    27: 1, 28: 0, 29:1, 30: 0, 31:0, 32: 1
}

In [6]:
DATA_DRIVE = "/content/drive/MyDrive/embeddings/"

In [7]:
THRESHOLD = 0.5

# Dataset and Dataloader

In [19]:
class ModelActivations(data.Dataset):
  """Activations dataset"""

  def __init__(self, file_name_lst, mode, layer_num=32, fusion_type=None):

    if mode == "fusion":
      #final array to be returned
      activations_arr = np.array([])
      labels_lst = []

      for file_name in file_name_lst:
        curr_file_activations_arr = np.array([])
        with open(DATA_DRIVE+file_name, "rb") as f:
          inp_dict = pkl.load(f)

        #get the different layers
        layers = list(inp_dict["activations"].keys())

        #collect the activations of different layers and concatenate them
        for layer in layers:
          curr_layer_activation = inp_dict["activations"][layer]
          curr_layer_activation = np.expand_dims(curr_layer_activation, axis=1)

          if len(curr_file_activations_arr)==0:
            curr_file_activations_arr = curr_layer_activation
          else:
            curr_file_activations_arr = np.concatenate([curr_file_activations_arr, curr_layer_activation], axis=1)

        #fusion based on mean or max
        if (not fusion_type) or (fusion_type == "mean"):
          curr_file_activations_arr = np.mean(curr_file_activations_arr, axis=1)
        else:
          curr_file_activations_arr = np.max(curr_file_activations_arr, axis=1)

        #add to the global array
        if len(activations_arr) == 0:
          activations_arr = curr_file_activations_arr
        else:
          activations_arr = np.concatenate([activations_arr, curr_file_activations_arr], axis=0)

        #get the labels
        labels_lst.extend(inp_dict["labels"])


      #convert labels to 1 or 0
      binary_labels_lst = [index_class_mapping[x] for x in labels_lst]

      self.activations = activations_arr
      self.labels = np.asarray(binary_labels_lst)


    elif mode == "individual":
      layer_arr = np.array([])
      labels_lst = []

      for file_name in file_name_lst:

        with open(DATA_DRIVE+file_name, "rb") as f:
          inp_dict = pkl.load(f)

        curr_activations = np.asarray(inp_dict["activations"][layer_num])

        if len(layer_arr)==0:
          layer_arr = curr_activations
        else:
          layer_arr = np.concatenate([layer_arr, curr_activations], axis=0)

        curr_labels = inp_dict["labels"]
        labels_lst.extend(curr_labels)

      #convert labels to 1 or 0
      binary_labels_lst = [index_class_mapping[x] for x in labels_lst]

      self.activations = layer_arr
      self.labels = np.asarray(binary_labels_lst)


  def __len__(self):
    return len(self.activations)

  def __getitem__(self, idx):
    curr_activations = self.activations[idx]
    curr_label = self.labels[idx]
    curr_activations = np.expand_dims(curr_activations, axis=0)
    curr_label = np.expand_dims(curr_label, 0)

    return torch.tensor(curr_activations, dtype=torch.float32), torch.tensor(curr_label, dtype=torch.float32)



# Binary Classifier Model

In [9]:
import torch.nn as nn

class SafetyClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden_1 = nn.Linear(4096, 1024)
        self.relu = nn.ReLU()
        self.hidden_2 = nn.Linear(1024, 256)
        self.hidden_3 = nn.Linear(256, 32)
        self.output = nn.Linear(32, 1)

        self.sigmoid = nn.Sigmoid()


    def forward(self, batch):
      batch = [x.to(DEVICE) for x in batch]
      x, labels = batch
      x = self.relu(self.hidden_1(x))
      x = self.relu(self.hidden_2(x))
      x = self.relu(self.hidden_3(x))
      x = self.sigmoid(self.output(x))
      return x, labels

In [10]:
def validate (model, data_loader, criterion):
  bce_loss = nn.BCELoss()

  total_count = 0

  true_lst = []
  predictions_lst = []

  with tqdm(data_loader, unit="batch", total=len(data_loader)) as batch_iterator:
    model.eval()
    val_loss = 0.0
    for i, batch_data in enumerate(batch_iterator, start=1):

        output, target = model.forward(batch_data)

        output = output.flatten()
        target = target.flatten()

        #loss
        loss = bce_loss(output, target)
        val_loss += loss.item()

        #convert predictions to list
        true_lst.extend(target.tolist())
        predictions_lst.extend(output.tolist())


        #total count
        total_count += len(output)

        batch_iterator.set_postfix(mean_loss=val_loss / i, current_loss=loss.item(), total_loss = val_loss)

  #accuracy
  rounded_truth = [x>=THRESHOLD for x in true_lst]
  rounded_preds = [x>=THRESHOLD for x in predictions_lst]
  accuracy = accuracy_score(rounded_truth, rounded_preds)

  #roc auc score
  roc_auc = roc_auc_score(true_lst, predictions_lst)

  return val_loss, accuracy, roc_auc

In [11]:
def training(model, train_dataloader, val_dataloader, num_epochs, criterion, optimizer, file_path=None):
  val_loss_lst = []
  acc_lst = []
  train_loss_lst = []
  roc_auc_lst = []

  bce_loss = nn.BCELoss()

  for epoch in trange(num_epochs, desc="training", unit="epoch"):

    with tqdm(train_dataloader, desc="epoch {}".format(epoch + 1), unit="batch", total=len(train_dataloader)) as batch_iterator:
        model.train()
        total_loss = 0.0
        running_loss = 0.0
        for i, batch_data in enumerate(batch_iterator, start=1):
            optimizer.zero_grad()

            output, target = model(batch_data)
            output = torch.squeeze(output,dim=2)

            loss = criterion(output, target)
            total_loss += loss.item()
            running_loss += bce_loss(output, target).item()

            loss.backward()
            optimizer.step()

            batch_iterator.set_postfix(mean_loss=total_loss / i, current_loss=loss.item(), total_loss=total_loss)



        train_loss_lst.append(total_loss)


    print("Validation Set")
    val_loss, accuracy, roc_score = validate(model, val_dataloader, criterion)
    val_loss_lst.append(val_loss)

    acc_lst.append(accuracy)
    print(f"Accuracy: {accuracy}")

    roc_auc_lst.append(roc_score)
    print(f"AUC-ROC score: {roc_score}")



    if file_path is not None:
      torch.save(model.state_dict(), file_path)

  return model, val_loss_lst, acc_lst, roc_auc_lst

# Runner

In [40]:
domain_loss_mapping = {}
domain_acc_mapping = {}
domain_roc_auc_mapping = {}

In [41]:
for domain in list(DOMAIN_INDEX_MAPPING.keys()):
  test_domain = domain
  if test_domain != "Information Hazards":
    print(f"Test domain is {domain}")
    train_file_lst = []

    #get the test files list
    test_domain_index = DOMAIN_INDEX_MAPPING[domain]
    test_file_lst = DOMAIN_FILE_MAPPING[test_domain_index]

    #get the train files list
    for key in list(DOMAIN_FILE_MAPPING.keys()):
      if key != test_domain_index:
        train_file_lst.extend(DOMAIN_FILE_MAPPING[key])


    train_dataset = ModelActivations(file_name_lst=train_file_lst, mode="individual", layer_num=32)
    train_dataloader = torch.utils.data.DataLoader(
          train_dataset,
          batch_size=32,
          shuffle=True
    )

    val_dataset = ModelActivations(file_name_lst=test_file_lst, mode="individual", layer_num=32)
    val_dataloader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=16,
        shuffle=True
    )


    model = SafetyClassifier().to(device=DEVICE)
    criterion = torch.nn.BCELoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005)

    model, val_loss_lst, acc_lst, roc_lst = training(model, train_dataloader, val_dataloader, 10, criterion, optimizer)

    domain_loss_mapping[test_domain] = min(val_loss_lst)
    domain_acc_mapping[test_domain] = max(acc_lst)
    domain_roc_auc_mapping[test_domain] = max(roc_lst)


Test domain is Discrimination, Exclusion, Toxicity


training:   0%|          | 0/10 [00:00<?, ?epoch/s]

epoch 1:   0%|          | 0/186 [00:00<?, ?batch/s]

Validation Set


  0%|          | 0/179 [00:00<?, ?batch/s]

Accuracy: 0.8044928044928045
AUC-ROC score: 0.8833242049687908


epoch 2:   0%|          | 0/186 [00:00<?, ?batch/s]

Validation Set


  0%|          | 0/179 [00:00<?, ?batch/s]

Accuracy: 0.623025623025623
AUC-ROC score: 0.7162598376631387


epoch 3:   0%|          | 0/186 [00:00<?, ?batch/s]

Validation Set


  0%|          | 0/179 [00:00<?, ?batch/s]

Accuracy: 0.7083187083187084
AUC-ROC score: 0.7908193323958256


epoch 4:   0%|          | 0/186 [00:00<?, ?batch/s]

Validation Set


  0%|          | 0/179 [00:00<?, ?batch/s]

Accuracy: 0.5552825552825553
AUC-ROC score: 0.5710487750721635


epoch 5:   0%|          | 0/186 [00:00<?, ?batch/s]

Validation Set


  0%|          | 0/179 [00:00<?, ?batch/s]

Accuracy: 0.5412425412425412
AUC-ROC score: 0.5230967359928946


epoch 6:   0%|          | 0/186 [00:00<?, ?batch/s]

Validation Set


  0%|          | 0/179 [00:00<?, ?batch/s]

Accuracy: 0.5587925587925588
AUC-ROC score: 0.5280931586608443


epoch 7:   0%|          | 0/186 [00:00<?, ?batch/s]

Validation Set


  0%|          | 0/179 [00:00<?, ?batch/s]

Accuracy: 0.6767286767286768
AUC-ROC score: 0.695816741913996


epoch 8:   0%|          | 0/186 [00:00<?, ?batch/s]

Validation Set


  0%|          | 0/179 [00:00<?, ?batch/s]

Accuracy: 0.555984555984556
AUC-ROC score: 0.5973998963807269


epoch 9:   0%|          | 0/186 [00:00<?, ?batch/s]

Validation Set


  0%|          | 0/179 [00:00<?, ?batch/s]

Accuracy: 0.5356265356265356
AUC-ROC score: 0.5222914662127156


epoch 10:   0%|          | 0/186 [00:00<?, ?batch/s]

Validation Set


  0%|          | 0/179 [00:00<?, ?batch/s]

Accuracy: 0.729027729027729
AUC-ROC score: 0.7862156761157575
Test domain is Misinformation


training:   0%|          | 0/10 [00:00<?, ?epoch/s]

epoch 1:   0%|          | 0/181 [00:00<?, ?batch/s]

Validation Set


  0%|          | 0/188 [00:00<?, ?batch/s]

Accuracy: 0.55
AUC-ROC score: 0.6637499437406392


epoch 2:   0%|          | 0/181 [00:00<?, ?batch/s]

Validation Set


  0%|          | 0/188 [00:00<?, ?batch/s]

Accuracy: 0.5276666666666666
AUC-ROC score: 0.6691339088647373


epoch 3:   0%|          | 0/181 [00:00<?, ?batch/s]

Validation Set


  0%|          | 0/188 [00:00<?, ?batch/s]

Accuracy: 0.5266666666666666
AUC-ROC score: 0.6613137463119482


epoch 4:   0%|          | 0/181 [00:00<?, ?batch/s]

Validation Set


  0%|          | 0/188 [00:00<?, ?batch/s]

Accuracy: 0.544
AUC-ROC score: 0.6508492823979277


epoch 5:   0%|          | 0/181 [00:00<?, ?batch/s]

Validation Set


  0%|          | 0/188 [00:00<?, ?batch/s]

Accuracy: 0.57
AUC-ROC score: 0.6530288034558621


epoch 6:   0%|          | 0/181 [00:00<?, ?batch/s]

Validation Set


  0%|          | 0/188 [00:00<?, ?batch/s]

Accuracy: 0.5393333333333333
AUC-ROC score: 0.6498653563100276


epoch 7:   0%|          | 0/181 [00:00<?, ?batch/s]

Validation Set


  0%|          | 0/188 [00:00<?, ?batch/s]

Accuracy: 0.529
AUC-ROC score: 0.6474991210171158


epoch 8:   0%|          | 0/181 [00:00<?, ?batch/s]

Validation Set


  0%|          | 0/188 [00:00<?, ?batch/s]

Accuracy: 0.5406666666666666
AUC-ROC score: 0.6067481822099212


epoch 9:   0%|          | 0/181 [00:00<?, ?batch/s]

Validation Set


  0%|          | 0/188 [00:00<?, ?batch/s]

Accuracy: 0.5446666666666666
AUC-ROC score: 0.6128348880227052


epoch 10:   0%|          | 0/181 [00:00<?, ?batch/s]

Validation Set


  0%|          | 0/188 [00:00<?, ?batch/s]

Accuracy: 0.542
AUC-ROC score: 0.579924164609765
Test domain is HCI harms


training:   0%|          | 0/10 [00:00<?, ?epoch/s]

epoch 1:   0%|          | 0/244 [00:00<?, ?batch/s]

Validation Set


  0%|          | 0/63 [00:00<?, ?batch/s]

Accuracy: 0.743
AUC-ROC score: 0.517519743819438


epoch 2:   0%|          | 0/244 [00:00<?, ?batch/s]

Validation Set


  0%|          | 0/63 [00:00<?, ?batch/s]

Accuracy: 0.812
AUC-ROC score: 0.502332190405585


epoch 3:   0%|          | 0/244 [00:00<?, ?batch/s]

Validation Set


  0%|          | 0/63 [00:00<?, ?batch/s]

Accuracy: 0.795
AUC-ROC score: 0.5157899072272162


epoch 4:   0%|          | 0/244 [00:00<?, ?batch/s]

Validation Set


  0%|          | 0/63 [00:00<?, ?batch/s]

Accuracy: 0.713
AUC-ROC score: 0.48274796898650113


epoch 5:   0%|          | 0/244 [00:00<?, ?batch/s]

Validation Set


  0%|          | 0/63 [00:00<?, ?batch/s]

Accuracy: 0.776
AUC-ROC score: 0.4899247315149456


epoch 6:   0%|          | 0/244 [00:00<?, ?batch/s]

Validation Set


  0%|          | 0/63 [00:00<?, ?batch/s]

Accuracy: 0.497
AUC-ROC score: 0.4622782359785418


epoch 7:   0%|          | 0/244 [00:00<?, ?batch/s]

Validation Set


  0%|          | 0/63 [00:00<?, ?batch/s]

Accuracy: 0.786
AUC-ROC score: 0.531760005766122


epoch 8:   0%|          | 0/244 [00:00<?, ?batch/s]

Validation Set


  0%|          | 0/63 [00:00<?, ?batch/s]

Accuracy: 0.802
AUC-ROC score: 0.5307509344206592


epoch 9:   0%|          | 0/244 [00:00<?, ?batch/s]

Validation Set


  0%|          | 0/63 [00:00<?, ?batch/s]

Accuracy: 0.763
AUC-ROC score: 0.5593035348387031


epoch 10:   0%|          | 0/244 [00:00<?, ?batch/s]

Validation Set


  0%|          | 0/63 [00:00<?, ?batch/s]

Accuracy: 0.631
AUC-ROC score: 0.502713166321729
Test domain is Malicious Uses


training:   0%|          | 0/10 [00:00<?, ?epoch/s]

epoch 1:   0%|          | 0/244 [00:00<?, ?batch/s]

Validation Set


  0%|          | 0/63 [00:00<?, ?batch/s]

Accuracy: 0.794
AUC-ROC score: 0.8814086360991615


epoch 2:   0%|          | 0/244 [00:00<?, ?batch/s]

Validation Set


  0%|          | 0/63 [00:00<?, ?batch/s]

Accuracy: 0.857
AUC-ROC score: 0.8368355287631511


epoch 3:   0%|          | 0/244 [00:00<?, ?batch/s]

Validation Set


  0%|          | 0/63 [00:00<?, ?batch/s]

Accuracy: 0.868
AUC-ROC score: 0.8321903108576006


epoch 4:   0%|          | 0/244 [00:00<?, ?batch/s]

Validation Set


  0%|          | 0/63 [00:00<?, ?batch/s]

Accuracy: 0.657
AUC-ROC score: 0.8622840736748729


epoch 5:   0%|          | 0/244 [00:00<?, ?batch/s]

Validation Set


  0%|          | 0/63 [00:00<?, ?batch/s]

Accuracy: 0.739
AUC-ROC score: 0.7567412890241226


epoch 6:   0%|          | 0/244 [00:00<?, ?batch/s]

Validation Set


  0%|          | 0/63 [00:00<?, ?batch/s]

Accuracy: 0.618
AUC-ROC score: 0.659678173198905


epoch 7:   0%|          | 0/244 [00:00<?, ?batch/s]

Validation Set


  0%|          | 0/63 [00:00<?, ?batch/s]

Accuracy: 0.786
AUC-ROC score: 0.8148017436259407


epoch 8:   0%|          | 0/244 [00:00<?, ?batch/s]

Validation Set


  0%|          | 0/63 [00:00<?, ?batch/s]

Accuracy: 0.62
AUC-ROC score: 0.7081620389358921


epoch 9:   0%|          | 0/244 [00:00<?, ?batch/s]

Validation Set


  0%|          | 0/63 [00:00<?, ?batch/s]

Accuracy: 0.848
AUC-ROC score: 0.7979091750207461


epoch 10:   0%|          | 0/244 [00:00<?, ?batch/s]

Validation Set


  0%|          | 0/63 [00:00<?, ?batch/s]

Accuracy: 0.614
AUC-ROC score: 0.6646667747689314


In [37]:
domain_loss_mapping

{'Discrimination, Exclusion, Toxicity': 99.5377899594605,
 'Misinformation': 155.75302943587303,
 'HCI harms': 31.303983494639397,
 'Malicious Uses': 20.43748101592064}

In [42]:
domain_acc_mapping

{'Discrimination, Exclusion, Toxicity': 0.8044928044928045,
 'Misinformation': 0.57,
 'HCI harms': 0.812,
 'Malicious Uses': 0.868}

In [43]:
domain_roc_auc_mapping

{'Discrimination, Exclusion, Toxicity': 0.8833242049687908,
 'Misinformation': 0.6691339088647373,
 'HCI harms': 0.5593035348387031,
 'Malicious Uses': 0.8814086360991615}

In [None]:
with open("/content/drive/MyDrive/embeddings/adult_content.pkl", "rb") as f:
  d = pkl.load(f)

In [None]:
d['activations'].keys()

dict_keys([16, 20, 24, 28, 32])

In [None]:
import numpy as np

In [None]:
arr = np.random.random((935, 5, 4096))

In [None]:
arr.shape

(935, 5, 4096)

In [None]:
new_arr = np.max(arr, axis=1)

In [None]:
new_arr.shape

(935, 4096)