In [1]:
#run from the root directory
import os

os.chdir("..")

In [2]:
import torch
import numpy as np
import pickle as pkl
import pandas as pd
import torch.utils.data as data
from tqdm.notebook import trange, tqdm
from sklearn.metrics import roc_auc_score
from sklearn.metrics import accuracy_score
from config import *

In [3]:
print(torch.cuda.is_available())
if torch.cuda.is_available():
  DEVICE = torch.device("cuda")
else:
  DEVICE = torch.device("cpu")
print("Using device:", DEVICE)


True
Using device: cuda


In [4]:
DATA_DRIVE = "./embeddings/"
RESULTS_DRIVE = "./results/"

# Dataset and Dataloader

In [5]:
class ModelActivations(data.Dataset):
  """Activations dataset"""

  def __init__(self, df, select_indices, mode, layer_num=32, fusion_type=None):

    if mode == "fusion":

      all_activations = []

      for _layer in layer_num:
        temp_activations = df['activations'][_layer]
        temp_activations = [x for s, x in zip(select_indices, temp_activations) if s==1]
        temp_activations = np.asarray(temp_activations)
        all_activations.append(temp_activations[:, np.newaxis, :])
      
      all_activations = np.concatenate(all_activations, axis=1)

      if (not fusion_type) or (fusion_type == "mean"):
        self.activations = np.mean(all_activations, axis=1)
      else:
        self.activations = np.max(all_activations, axis=1)
      
      self.labels = [x for s, x in zip(select_indices, df['labels']) if s==1]

      #final array to be returned
      # activations_arr = np.array([])
      # labels_lst = []

      # for file_name in file_name_lst:
      #   curr_file_activations_arr = np.array([])
      #   with open(DATA_DRIVE+file_name, "rb") as f:
      #     inp_dict = pkl.load(f)

      #   #get the different layers
      #   layers = list(inp_dict["activations"].keys())

      #   #collect the activations of different layers and concatenate them
      #   for layer in layers:
      #     curr_layer_activation = inp_dict["activations"][layer]
      #     curr_layer_activation = np.expand_dims(curr_layer_activation, axis=1)

      #     if len(curr_file_activations_arr)==0:
      #       curr_file_activations_arr = curr_layer_activation
      #     else:
      #       curr_file_activations_arr = np.concatenate([curr_file_activations_arr, curr_layer_activation], axis=1)

      #   #fusion based on mean or max
      #   if (not fusion_type) or (fusion_type == "mean"):
      #     curr_file_activations_arr = np.mean(curr_file_activations_arr, axis=1)
      #   else:
      #     curr_file_activations_arr = np.max(curr_file_activations_arr, axis=1)

      #   #add to the global array
      #   if len(activations_arr) == 0:
      #     activations_arr = curr_file_activations_arr
      #   else:
      #     activations_arr = np.concatenate([activations_arr, curr_file_activations_arr], axis=0)

      #   #get the labels
      #   labels_lst.extend(inp_dict["labels"])


      # #convert labels to 1 or 0
      # binary_labels_lst = [index_class_mapping[x] for x in labels_lst]

      # self.activations = activations_arr
      # self.labels = np.asarray(binary_labels_lst)


    elif mode == "individual":

      temp_activations = df['activations'][layer_num]
      temp_activations = [x for s, x in zip(select_indices, temp_activations) if s==1]
      temp_activations = np.asarray(temp_activations)

      self.activations = temp_activations
      self.labels = [x for s, x in zip(select_indices, df['labels']) if s==1]
      
      # layer_arr = np.array([])
      # labels_lst = []

      # for file_name in file_name_lst:

      #   with open(DATA_DRIVE+file_name, "rb") as f:
      #     inp_dict = pkl.load(f)

      #   curr_activations = np.asarray(inp_dict["activations"][layer_num])

      #   if len(layer_arr)==0:
      #     layer_arr = curr_activations
      #   else:
      #     layer_arr = np.concatenate([layer_arr, curr_activations], axis=0)

      #   curr_labels = inp_dict["labels"]
      #   labels_lst.extend(curr_labels)

      # #convert labels to 1 or 0
      # binary_labels_lst = [index_class_mapping[x] for x in labels_lst]

      # self.activations = layer_arr
      # self.labels = np.asarray(binary_labels_lst)


  def __len__(self):
    return len(self.activations)

  def __getitem__(self, idx):
    curr_activations = self.activations[idx]
    curr_label = self.labels[idx]
    curr_activations = np.expand_dims(curr_activations, axis=0)
    curr_label = np.expand_dims(curr_label, 0)

    return torch.tensor(curr_activations, dtype=torch.float32), torch.tensor(curr_label, dtype=torch.float32)



# Binary Classifier Model

In [6]:
import torch.nn as nn

class SafetyClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        # self.hidden_1 = nn.Linear(4096, 256)
        # self.hidden_1 = nn.Linear(1000, 128)
        self.relu = nn.ReLU()
        # self.dropout = nn.Dropout(0.50)
        # self.hidden_2 = nn.Linear(128, 128)
        # self.hidden_3 = nn.Linear(128, 32)
        self.output = nn.Linear(4096, 1)


        # self.output = nn.Linear(4096, 1)
        self.sigmoid = nn.Sigmoid()


    def forward(self, batch):
      batch = [x.to(DEVICE) for x in batch]
      x, labels = batch
      # x = self.relu(self.hidden_1(x))
      # x = self.dropout(x)
    #   x = self.relu(self.hidden_2(x))
    #   x = self.dropout(x)
    #   x = self.relu(self.hidden_3(x))
    #   x = self.dropout(x)
      x = self.sigmoid(self.output(x))
      return x, labels

In [7]:
def validate (model, data_loader, criterion):
  bce_loss = nn.BCELoss()

  total_count = 0

  true_lst = []
  predictions_lst = []

  model.eval()
  with torch.no_grad():
    with tqdm(data_loader, unit="batch", total=len(data_loader)) as batch_iterator:
      val_loss = 0.0
      for i, batch_data in enumerate(batch_iterator, start=1):

          output, target = model.forward(batch_data)

          output = output.flatten()
          target = target.flatten()

          #loss
          loss = bce_loss(output, target)
          val_loss += loss.item()

          #convert predictions to list
          true_lst.extend(target.tolist())
          predictions_lst.extend(output.tolist())


          #total count
          total_count += len(output)

          batch_iterator.set_postfix(mean_loss=val_loss / i, current_loss=loss.item(), total_loss = val_loss)

  # #accuracy
  # rounded_truth = [x>=THRESHOLD for x in true_lst]
  # rounded_preds = [x>=THRESHOLD for x in predictions_lst]
  # accuracy = accuracy_score(rounded_truth, rounded_preds)

  # #roc auc score
  # roc_auc = roc_auc_score(true_lst, predictions_lst)


  return {'Validation Loss': val_loss, 'prediction': predictions_lst, 'label': true_lst}

In [8]:
def training(model, train_dataloader, val_dataloader, num_epochs, criterion, optimizer, file_name=None):
  # val_loss_lst = []
  # acc_lst = []
  train_loss_lst = []
  # roc_auc_lst = []
  # max_roc_score = 0
  best_val_loss = None
  epochs_without_improvement = 0
  patience = 3

  bce_loss = nn.BCELoss()
  model.train()
  
  for epoch in trange(num_epochs, desc="training", unit="epoch"):

    with tqdm(train_dataloader, desc="epoch {}".format(epoch + 1), unit="batch", total=len(train_dataloader)) as batch_iterator:
        total_loss = 0.0
        running_loss = 0.0
        for i, batch_data in enumerate(batch_iterator, start=1):
            optimizer.zero_grad()

            output, target = model(batch_data)
            output = torch.squeeze(output,dim=2)

            loss = criterion(output, target)
            total_loss += loss.item()
            running_loss += bce_loss(output, target).item()

            loss.backward()
            optimizer.step()

            batch_iterator.set_postfix(mean_loss=total_loss / i, current_loss=loss.item(), total_loss=total_loss)



        train_loss_lst.append(total_loss)

    val_loss = validate(model, val_dataloader, criterion)['Validation Loss']
    wandb.log({"Training Loss": total_loss/len(train_dataloader.dataset), "Validation Loss": val_loss/len(val_dataloader.dataset)})

    if best_val_loss is None or val_loss < best_val_loss:
       best_val_loss = val_loss
       epochs_without_improvement = 0
    else:
       epochs_without_improvement += 1
    
    if(epochs_without_improvement >= patience):
       break


  print("Completed Training...")
  # val_loss, accuracy, roc_score, predictions_lst, true_lst = validate(model, val_dataloader, criterion)
  # val_loss_lst.append(val_loss)

  # acc_lst.append(accuracy)
  # print(f"Accuracy: {accuracy}")

  # roc_auc_lst.append(roc_score)
  # print(f"AUC-ROC score: {roc_score}")

  # if roc_score>max_roc_score:
  #   max_roc_score = roc_score

  #   new_file_name = f"{RESULTS_DRIVE}/true_{file_name}"
  #   with open(new_file_name, "wb+") as f:
  #     pkl.dump(true_lst, f)

  #   new_file_name = f"{RESULTS_DRIVE}/pred_{file_name}"
  #   with open(new_file_name, "wb+") as f:
  #     pkl.dump(predictions_lst, f)





    # if file_path is not None:
    #   torch.save(model.state_dict(), file_path)

  # return model, val_loss_lst, acc_lst, roc_auc_lst
  return model

# Runner

In [9]:
# LAYERS = [1, 4, 8, 12, 16, 20, 24, 28, 32]
DOMAINS = ["Discrimination, Exclusion, Toxicity", "HCI harms", "Malicious Uses", "Misinformation"]
# DOMAINS = ["Misinformation"]
LAYERS = [32]

In [10]:
import wandb
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33msaiprasath2107[0m ([33msafety-awareness[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [11]:
df = pkl.load(open("./embeddings/Llama2Embeddings.pkl", "rb"))

# from sklearn.decomposition import PCA

# for l in df['activations'].keys():
#   df['activations'][l] = PCA(n_components=1000).fit_transform(df['activations'][l])

In [12]:
combinations = []
for _mode in ["individual", "fusion"]:
  if _mode=="fusion":
    for _type in ["max", "mean"]:
      print()
      # combinations.append([_mode, _type, None])
  else:
    for l in LAYERS:
      combinations.append([_mode, None, l])

params = {}

N = len(df['labels'])
params["df"] = df

print("Starting to learn...")

for args in combinations:
  params["mode"], params["fusion_type"], params["layer_num"] = args
  filename = '_'.join([str(params[key]) for key in ["mode", "fusion_type", "layer_num"] if params[key] is not None])
  print(f"Current Config: {filename}")
  config_results = {}

  for i, domain in enumerate(DOMAINS):
    print(f"\tTest domain is {domain}")
    select_indices = np.zeros((N,)) 
    select_indices[i*N//len(DOMAINS):(i+1)*N//len(DOMAINS)] = 1 # Test Domain
    params["select_indices"] = select_indices

    wandb.init(
    # set the wandb project where this run will be logged
    project="Llama2-Embeddings-Safety-Classification",
    name=filename + "_" + domain,
    # track hyperparameters and run metadata
    config={
    "domain": domain,
    "mode": params["mode"],
    "layer": params["layer_num"],
    }
    )
  
  # config_results = {}
  # for domain in list(DOMAIN_INDEX_MAPPING.keys()):
  #   test_domain = domain
  #     print(f"Test domain is {domain}")
  #     train_file_lst = []

  #     #get the test files list
  #     test_domain_index = DOMAIN_INDEX_MAPPING[domain]
  #     test_file_lst = DOMAIN_FILE_MAPPING[test_domain_index]

  #     #get the train files list
  #     for key in list(DOMAIN_FILE_MAPPING.keys()):
  #       if key != test_domain_index:
  #         train_file_lst.extend(DOMAIN_FILE_MAPPING[key])

    dataset = ModelActivations(**params)
    
    n = len(dataset)
    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [int(0.7 * n), int(0.1 * n), n - int(0.7 * n) - int(0.1 * n)])

    train_dataloader = torch.utils.data.DataLoader(
          train_dataset,
          batch_size=32,
          shuffle=True
    )

    val_dataloader = torch.utils.data.DataLoader(
          val_dataset,
          batch_size=32,
          shuffle=True
    )

    test_dataloader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=16,
        shuffle=True
    )


    model = SafetyClassifier().to(device=DEVICE)
    criterion = torch.nn.BCELoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005)

  # model, val_loss_lst, acc_lst, roc_lst = training(model, train_dataloader, val_dataloader, \
  #                                                  10, criterion, optimizer, "fusion_max.pkl")

  # domain_loss_mapping[test_domain] = min(val_loss_lst)
  # domain_acc_mapping[test_domain] = max(acc_lst)
  # domain_roc_auc_mapping[test_domain] = max(roc_lst)

    model = training(model, train_dataloader, val_dataloader, 50, criterion, optimizer, "fusion_max.pkl")

    _model_name = filename + "_" + domain
    torch.save(model, f"./network_weights/{_model_name}.pt")
  
    results = validate(model, test_dataloader, criterion)
  
    config_results[domain] = results
  pkl.dump(config_results, open(f"{RESULTS_DRIVE}/{filename}.pkl", "wb"))

wandb.finish()





Starting to learn...
Current Config: individual_32
	Test domain is Discrimination, Exclusion, Toxicity


training:   0%|          | 0/50 [00:00<?, ?epoch/s]

epoch 1:   0%|          | 0/438 [00:00<?, ?batch/s]

  0%|          | 0/63 [00:00<?, ?batch/s]

epoch 2:   0%|          | 0/438 [00:00<?, ?batch/s]

  0%|          | 0/63 [00:00<?, ?batch/s]

epoch 3:   0%|          | 0/438 [00:00<?, ?batch/s]

  0%|          | 0/63 [00:00<?, ?batch/s]

epoch 4:   0%|          | 0/438 [00:00<?, ?batch/s]

  0%|          | 0/63 [00:00<?, ?batch/s]

epoch 5:   0%|          | 0/438 [00:00<?, ?batch/s]

  0%|          | 0/63 [00:00<?, ?batch/s]

epoch 6:   0%|          | 0/438 [00:00<?, ?batch/s]

  0%|          | 0/63 [00:00<?, ?batch/s]

epoch 7:   0%|          | 0/438 [00:00<?, ?batch/s]

  0%|          | 0/63 [00:00<?, ?batch/s]

epoch 8:   0%|          | 0/438 [00:00<?, ?batch/s]

  0%|          | 0/63 [00:00<?, ?batch/s]

epoch 9:   0%|          | 0/438 [00:00<?, ?batch/s]

  0%|          | 0/63 [00:00<?, ?batch/s]

epoch 10:   0%|          | 0/438 [00:00<?, ?batch/s]

  0%|          | 0/63 [00:00<?, ?batch/s]

epoch 11:   0%|          | 0/438 [00:00<?, ?batch/s]

  0%|          | 0/63 [00:00<?, ?batch/s]

epoch 12:   0%|          | 0/438 [00:00<?, ?batch/s]

  0%|          | 0/63 [00:00<?, ?batch/s]

epoch 13:   0%|          | 0/438 [00:00<?, ?batch/s]

  0%|          | 0/63 [00:00<?, ?batch/s]

epoch 14:   0%|          | 0/438 [00:00<?, ?batch/s]

  0%|          | 0/63 [00:00<?, ?batch/s]

epoch 15:   0%|          | 0/438 [00:00<?, ?batch/s]

  0%|          | 0/63 [00:00<?, ?batch/s]

Completed Training...


  0%|          | 0/250 [00:00<?, ?batch/s]

	Test domain is HCI harms


VBox(children=(Label(value='0.005 MB of 0.005 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))



0,1
Training Loss,█▅▄▄▃▃▂▂▃▃▂▂▂▁▁
Validation Loss,▂▁▁▃▁▁▄▁█▁▁▁▄▃▁

0,1
Training Loss,0.00624
Validation Loss,0.00598


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112421833806568, max=1.0…

training:   0%|          | 0/50 [00:00<?, ?epoch/s]

epoch 1:   0%|          | 0/438 [00:00<?, ?batch/s]

  0%|          | 0/63 [00:00<?, ?batch/s]

epoch 2:   0%|          | 0/438 [00:00<?, ?batch/s]

  0%|          | 0/63 [00:00<?, ?batch/s]

epoch 3:   0%|          | 0/438 [00:00<?, ?batch/s]

  0%|          | 0/63 [00:00<?, ?batch/s]

epoch 4:   0%|          | 0/438 [00:00<?, ?batch/s]

  0%|          | 0/63 [00:00<?, ?batch/s]

epoch 5:   0%|          | 0/438 [00:00<?, ?batch/s]

  0%|          | 0/63 [00:00<?, ?batch/s]

Completed Training...


  0%|          | 0/250 [00:00<?, ?batch/s]

	Test domain is Malicious Uses




VBox(children=(Label(value='0.005 MB of 0.005 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Training Loss,█▄▂▂▁
Validation Loss,█▁▁▅▄

0,1
Training Loss,0.00412
Validation Loss,0.00551


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112483890934123, max=1.0…

training:   0%|          | 0/50 [00:00<?, ?epoch/s]

epoch 1:   0%|          | 0/438 [00:00<?, ?batch/s]

  0%|          | 0/63 [00:00<?, ?batch/s]

epoch 2:   0%|          | 0/438 [00:00<?, ?batch/s]

  0%|          | 0/63 [00:00<?, ?batch/s]

epoch 3:   0%|          | 0/438 [00:00<?, ?batch/s]

  0%|          | 0/63 [00:00<?, ?batch/s]

epoch 4:   0%|          | 0/438 [00:00<?, ?batch/s]

  0%|          | 0/63 [00:00<?, ?batch/s]

epoch 5:   0%|          | 0/438 [00:00<?, ?batch/s]

  0%|          | 0/63 [00:00<?, ?batch/s]

epoch 6:   0%|          | 0/438 [00:00<?, ?batch/s]

  0%|          | 0/63 [00:00<?, ?batch/s]

epoch 7:   0%|          | 0/438 [00:00<?, ?batch/s]

  0%|          | 0/63 [00:00<?, ?batch/s]

epoch 8:   0%|          | 0/438 [00:00<?, ?batch/s]

  0%|          | 0/63 [00:00<?, ?batch/s]

epoch 9:   0%|          | 0/438 [00:00<?, ?batch/s]

  0%|          | 0/63 [00:00<?, ?batch/s]

epoch 10:   0%|          | 0/438 [00:00<?, ?batch/s]

  0%|          | 0/63 [00:00<?, ?batch/s]

epoch 11:   0%|          | 0/438 [00:00<?, ?batch/s]

  0%|          | 0/63 [00:00<?, ?batch/s]

epoch 12:   0%|          | 0/438 [00:00<?, ?batch/s]

  0%|          | 0/63 [00:00<?, ?batch/s]

epoch 13:   0%|          | 0/438 [00:00<?, ?batch/s]

  0%|          | 0/63 [00:00<?, ?batch/s]

epoch 14:   0%|          | 0/438 [00:00<?, ?batch/s]

  0%|          | 0/63 [00:00<?, ?batch/s]

epoch 15:   0%|          | 0/438 [00:00<?, ?batch/s]

  0%|          | 0/63 [00:00<?, ?batch/s]

epoch 16:   0%|          | 0/438 [00:00<?, ?batch/s]

  0%|          | 0/63 [00:00<?, ?batch/s]

epoch 17:   0%|          | 0/438 [00:00<?, ?batch/s]

  0%|          | 0/63 [00:00<?, ?batch/s]

Completed Training...


  0%|          | 0/250 [00:00<?, ?batch/s]

	Test domain is Misinformation




VBox(children=(Label(value='0.005 MB of 0.005 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Training Loss,█▅▅▄▄▃▃▂▂▂▂▂▁▂▁▁▁
Validation Loss,█▃▃▃▂▂▂▂▂▂▅▂▁▁▃▂▂

0,1
Training Loss,0.01144
Validation Loss,0.01293


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112481120249464, max=1.0…

training:   0%|          | 0/50 [00:00<?, ?epoch/s]

epoch 1:   0%|          | 0/438 [00:00<?, ?batch/s]

  0%|          | 0/63 [00:00<?, ?batch/s]

epoch 2:   0%|          | 0/438 [00:00<?, ?batch/s]

  0%|          | 0/63 [00:00<?, ?batch/s]

epoch 3:   0%|          | 0/438 [00:00<?, ?batch/s]

  0%|          | 0/63 [00:00<?, ?batch/s]

epoch 4:   0%|          | 0/438 [00:00<?, ?batch/s]

  0%|          | 0/63 [00:00<?, ?batch/s]

epoch 5:   0%|          | 0/438 [00:00<?, ?batch/s]

  0%|          | 0/63 [00:00<?, ?batch/s]

epoch 6:   0%|          | 0/438 [00:00<?, ?batch/s]

  0%|          | 0/63 [00:00<?, ?batch/s]

epoch 7:   0%|          | 0/438 [00:00<?, ?batch/s]

  0%|          | 0/63 [00:00<?, ?batch/s]

Completed Training...


  0%|          | 0/250 [00:00<?, ?batch/s]



VBox(children=(Label(value='0.005 MB of 0.005 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Training Loss,█▄▃▄▁▁▂
Validation Loss,▆▄▂▁▁█▃

0,1
Training Loss,0.00742
Validation Loss,0.00762


In [16]:
test_dataloader.dataset.dataset.activations

array([[ 2.223  ,  1.845  ,  1.453  , ...,  0.5728 , -1.078  ,  0.971  ],
       [ 2.2    ,  1.879  ,  1.582  , ...,  0.2615 , -0.99   ,  1.351  ],
       [ 2.227  , -0.04068,  0.8096 , ...,  1.103  ,  0.0806 ,  2.049  ],
       ...,
       [-0.1405 , -4.555  , -1.     , ...,  1.455  ,  1.308  , -2.016  ],
       [ 2.773  ,  0.5347 ,  1.567  , ...,  0.676  , -1.034  ,  1.955  ],
       [ 2.57   ,  1.297  ,  1.446  , ...,  0.516  , -0.5327 ,  1.639  ]],
      dtype=float16)

In [50]:
combinations = []
for _mode in ["individual", "fusion"]:
  if _mode=="fusion":
    for _type in ["max", "mean"]:
      print()
      # combinations.append([_mode, _type, None])
  else:
    for l in LAYERS:
      combinations.append([_mode, None, l])

params = {}

N = len(df['labels'])
params["df"] = df

print("Starting to learn...")

for args in combinations:
  params["mode"], params["fusion_type"], params["layer_num"] = args
  filename = '_'.join([str(params[key]) for key in ["mode", "fusion_type", "layer_num"] if params[key] is not None])
  print(f"Current Config: {filename}")
  config_results = {}

  for i, domain in enumerate(DOMAINS):
    print(f"\tTest domain is {domain}")
    select_indices = np.ones((N,)) 
    # select_indices[60000:80000] = 0
    select_indices[i*N//len(DOMAINS):(i+1)*N//len(DOMAINS)] = 0 # Test Domain
    params["select_indices"] = select_indices

    wandb.init(
    # set the wandb project where this run will be logged
    project="Llama2-Embeddings-Safety-Classification",
    name=filename + "_" + domain,
    # track hyperparameters and run metadata
    config={
    "domain": domain,
    "mode": params["mode"],
    "layer": params["layer_num"],
    }
    )
  
  # config_results = {}
  # for domain in list(DOMAIN_INDEX_MAPPING.keys()):
  #   test_domain = domain
  #     print(f"Test domain is {domain}")
  #     train_file_lst = []

  #     #get the test files list
  #     test_domain_index = DOMAIN_INDEX_MAPPING[domain]
  #     test_file_lst = DOMAIN_FILE_MAPPING[test_domain_index]

  #     #get the train files list
  #     for key in list(DOMAIN_FILE_MAPPING.keys()):
  #       if key != test_domain_index:
  #         train_file_lst.extend(DOMAIN_FILE_MAPPING[key])

    dataset = ModelActivations(**params)
    
    n = len(dataset)
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [int(0.9 * n), n - int(0.9 * n)])

    train_dataloader = torch.utils.data.DataLoader(
          train_dataset,
          batch_size=32,
          shuffle=True
    )

    val_dataloader = torch.utils.data.DataLoader(
          val_dataset,
          batch_size=32,
          shuffle=True
    )

    params["select_indices"] = 1 - select_indices
    test_dataset = ModelActivations(**params)
    test_dataloader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=16,
        shuffle=True
    )


    model = SafetyClassifier().to(device=DEVICE)
    criterion = torch.nn.BCELoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005)

  # model, val_loss_lst, acc_lst, roc_lst = training(model, train_dataloader, val_dataloader, \
  #                                                  10, criterion, optimizer, "fusion_max.pkl")

  # domain_loss_mapping[test_domain] = min(val_loss_lst)
  # domain_acc_mapping[test_domain] = max(acc_lst)
  # domain_roc_auc_mapping[test_domain] = max(roc_lst)

    model = training(model, train_dataloader, val_dataloader, 50, criterion, optimizer, "fusion_max.pkl")

    _model_name = filename + "_" + domain
    torch.save(model, f"./network_weights/{_model_name}.pt")
  
    results = validate(model, test_dataloader, criterion)
  
    config_results[domain] = results
  pkl.dump(config_results, open(f"{RESULTS_DRIVE}/{filename}.pkl", "wb"))

wandb.finish()





Starting to learn...
Current Config: individual_32
	Test domain is Misinformation




VBox(children=(Label(value='0.005 MB of 0.005 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112594578622115, max=1.0…

training:   0%|          | 0/50 [00:00<?, ?epoch/s]

epoch 1:   0%|          | 0/1688 [00:00<?, ?batch/s]

  0%|          | 0/188 [00:00<?, ?batch/s]

epoch 2:   0%|          | 0/1688 [00:00<?, ?batch/s]

  0%|          | 0/188 [00:00<?, ?batch/s]

epoch 3:   0%|          | 0/1688 [00:00<?, ?batch/s]

  0%|          | 0/188 [00:00<?, ?batch/s]

epoch 4:   0%|          | 0/1688 [00:00<?, ?batch/s]

  0%|          | 0/188 [00:00<?, ?batch/s]

epoch 5:   0%|          | 0/1688 [00:00<?, ?batch/s]

  0%|          | 0/188 [00:00<?, ?batch/s]

epoch 6:   0%|          | 0/1688 [00:00<?, ?batch/s]

  0%|          | 0/188 [00:00<?, ?batch/s]

epoch 7:   0%|          | 0/1688 [00:00<?, ?batch/s]

  0%|          | 0/188 [00:00<?, ?batch/s]

epoch 8:   0%|          | 0/1688 [00:00<?, ?batch/s]

  0%|          | 0/188 [00:00<?, ?batch/s]

epoch 9:   0%|          | 0/1688 [00:00<?, ?batch/s]

  0%|          | 0/188 [00:00<?, ?batch/s]

epoch 10:   0%|          | 0/1688 [00:00<?, ?batch/s]

  0%|          | 0/188 [00:00<?, ?batch/s]

epoch 11:   0%|          | 0/1688 [00:00<?, ?batch/s]

  0%|          | 0/188 [00:00<?, ?batch/s]

epoch 12:   0%|          | 0/1688 [00:00<?, ?batch/s]

  0%|          | 0/188 [00:00<?, ?batch/s]

Completed Training...


  0%|          | 0/1250 [00:00<?, ?batch/s]



VBox(children=(Label(value='0.005 MB of 0.005 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Training Loss,█▅▄▃▃▂▂▂▁▁▁▁
Validation Loss,▇▅▄▂▂▁▂█▁▂▆▅

0,1
Training Loss,0.00778
Validation Loss,0.01055


In [51]:
results = validate(model, train_dataloader, criterion)

  0%|          | 0/1688 [00:00<?, ?batch/s]

In [43]:
import random
import numpy as np
import pickle as pkl
from collections import defaultdict
from sklearn.metrics import (
    precision_recall_curve,
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    roc_auc_score
)
import matplotlib.pyplot as plt
from config import *

def compute(labels, preds):
    return { 'accuracy': accuracy_score(labels, preds),
            'precision': precision_score(labels, preds),
            'recall': recall_score(labels, preds),
            'f1_score': f1_score(labels, preds)
    }

def findMean(data):
    _result = {}
    for key in data.keys():
        val = data[key]
        val = np.array(val).astype(np.float32)
        mean = np.round(np.mean(val), 3)
        std = np.round(np.std(val), 3)
        _result[key] = str(mean) + u" \u00B1 " + str(std)
    return _result


def calculate_accuracy(labels, predictions, threshold):
    binary_predictions = (predictions >= threshold).astype(int)
    correct_predictions = (binary_predictions == labels).sum()
    return correct_predictions / len(labels)

def evaluate(data):
    labels = data['label']
    preds = data['prediction']

    cresult = defaultdict(list)

    for seed in range(10):
        random.seed(seed)
        zlist = list(zip(labels, preds))
        random.shuffle(zlist)

        val_size = int(0.1 * len(zlist))
        val_zlist = zlist[0:val_size]
        test_zlist = zlist[val_size:]

        vlabels, vpreds = zip(*val_zlist)
        tlabels, tpreds = zip(*test_zlist)


        # precision, recall, threshold = precision_recall_curve(vlabels, vpreds)
        # f1_scores = 2*recall*precision/(recall+precision + 1e-8)

        # Sweep over thresholds
        thresholds = np.linspace(0, 1, 101)
        best_accuracy = 0
        best_threshold = 0

        for threshold in thresholds:
            accuracy = calculate_accuracy(vlabels, vpreds, threshold)
            if accuracy > best_accuracy:
                best_accuracy = accuracy
                best_threshold = threshold

        # cutoff = threshold[np.argmax(f1_scores)]
        cutoff = best_threshold

        rpreds = 1.0 * np.array([x > cutoff for x in tpreds])
        _result = compute(tlabels, rpreds)
        _result['auc_roc_score'] = roc_auc_score(tlabels, tpreds)
        _result['cutoff'] = cutoff

        for key in _result.keys():
            cresult[key].append(_result[key])
    
    return findMean(cresult)



In [52]:
evaluate(results)

{'accuracy': '0.887 ± 0.001',
 'precision': '0.89 ± 0.016',
 'recall': '0.885 ± 0.021',
 'f1_score': '0.887 ± 0.003',
 'auc_roc_score': '0.962 ± 0.0',
 'cutoff': '0.412 ± 0.053'}

In [None]:
df = pkl.load(open("./results/individual_32.pkl", "rb"))
df

{'Misinformation': {'Validation Loss': 2393.71874576807,
  'prediction': [0.014812757261097431,
   0.03246169909834862,
   0.44484058022499084,
   0.1269844025373459,
   0.003813639050349593,
   0.9940662384033203,
   0.010563718155026436,
   0.04833226650953293,
   5.38545634753973e-08,
   0.0006271650199778378,
   0.9569999575614929,
   0.9738706350326538,
   0.053610656410455704,
   0.011302978731691837,
   0.14068418741226196,
   0.00048387027345597744,
   0.003904585959389806,
   0.045110125094652176,
   0.6172489523887634,
   0.04471366107463837,
   0.008781461976468563,
   0.9487672448158264,
   0.03480622544884682,
   0.1660684496164322,
   0.06696891784667969,
   0.025283198803663254,
   0.016793668270111084,
   0.0011630572844296694,
   0.029789511114358902,
   0.042526375502347946,
   0.15057367086410522,
   0.9999716281890869,
   0.0023225038312375546,
   0.07853653281927109,
   0.011940022930502892,
   0.0014197810087352991,
   0.0030036375392228365,
   0.00225499691441655

### Indomain Classifier

In [7]:
DOMAINS = ["Discrimination, Exclusion, Toxicity", "HCI harms", "Malicious Uses", "Misinformation"]

In [10]:
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

data = pkl.load(open("./embeddings/Llama2Embeddings.pkl", "rb"))
N = 80000
for i in range(4):
   domain_data = data['activations'][32][i*N//4:(i+1)*N//4]
   labels = data['labels'][i*N//4:(i+1)*N//4]

   X_train, X_test, y_train, y_test = train_test_split(domain_data, labels, test_size=0.2, random_state=42)

   clf = LogisticRegression(max_iter=1000 ,random_state=10).fit(X_train, y_train)
   preds = clf.predict(X_test)
   print(f"Accuracy for {DOMAINS[i]} is {accuracy_score(y_test, preds)}")




STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


Accuracy for Discrimination, Exclusion, Toxicity is 0.87925


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


Accuracy for HCI harms is 0.952


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


Accuracy for Malicious Uses is 0.806
Accuracy for Misinformation is 0.90025


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
