In [1]:
import torch
import numpy as np
import pickle as pkl
import pandas as pd
import torch.utils.data as data
from tqdm.notebook import trange, tqdm

In [2]:
print(torch.cuda.is_available())
if torch.cuda.is_available():
  DEVICE = torch.device("cuda")
else:
  DEVICE = torch.device("cpu")
print("Using device:", DEVICE)


True
Using device: cuda


# Dataset and Dataloader

In [3]:
class ModelActivations(data.Dataset):
  """Activations dataset"""

  def __init__(self, file_name, mode, layer_num=32):

    if mode == "fuse":
      pass
    elif mode == "individual":

      with open(file_name, "rb") as f:
        inp_dict = pkl.load(f)

      layer_activations = np.asarray(inp_dict["activations"][layer_num])

      self.activations = layer_activations
      self.labels = np.asarray(inp_dict["labels"])

  def __len__(self):
    return len(self.activations)

  def __getitem__(self, idx):
    curr_activations = self.activations[idx]
    curr_label = self.labels[idx]
    curr_activations = np.expand_dims(curr_activations, axis=0)
    curr_label = np.expand_dims(curr_label, 0)

    return torch.tensor(curr_activations, dtype=torch.float32), torch.tensor(curr_label, dtype=torch.float32)



# Binary Classifier Model

In [4]:
import torch.nn as nn

class SafetyClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden_1 = nn.Linear(4096, 1024)
        self.relu = nn.ReLU()
        self.hidden_2 = nn.Linear(1024, 256)
        self.hidden_3 = nn.Linear(256, 32)
        self.output = nn.Linear(32, 1)

        self.sigmoid = nn.Sigmoid()


    def forward(self, batch):
      batch = [x.to(DEVICE) for x in batch]
      x, labels = batch
      x = self.relu(self.hidden_1(x))
      x = self.relu(self.hidden_2(x))
      x = self.relu(self.hidden_3(x))
      x = self.sigmoid(self.output(x))
      return x, labels

In [5]:
def validate (model, data_loader, criterion):
  bce_loss = nn.BCELoss()
  with tqdm(data_loader, unit="batch", total=len(data_loader)) as batch_iterator:
    model.eval()
    val_loss = 0.0
    for i, batch_data in enumerate(batch_iterator, start=1):

        output, target = model.forward(batch_data)
        output = output.flatten()
        target = target.flatten()

        loss = bce_loss(output, target)
        val_loss += loss.item()

        batch_iterator.set_postfix(mean_loss=val_loss / i, current_loss=loss.item(), total_loss = val_loss)

  return val_loss

In [6]:
def training(model, train_dataloader, val_dataloader, num_epochs, criterion, optimizer, file_path=None):
  val_loss_lst = []
  train_loss_lst = []
  bce_loss = nn.BCELoss()

  for epoch in trange(num_epochs, desc="training", unit="epoch"):

    with tqdm(train_dataloader, desc="epoch {}".format(epoch + 1), unit="batch", total=len(train_dataloader)) as batch_iterator:
        model.train()
        total_loss = 0.0
        running_loss = 0.0
        for i, batch_data in enumerate(batch_iterator, start=1):
            optimizer.zero_grad()

            output, target = model(batch_data)
            output = torch.squeeze(output,dim=2)

            loss = criterion(output, target)
            total_loss += loss.item()
            running_loss += bce_loss(output, target).item()

            loss.backward()
            optimizer.step()

            batch_iterator.set_postfix(mean_loss=total_loss / i, current_loss=loss.item(), total_loss=total_loss)

            if(i%200 == 0):
              print(f"Running Train Loss: {running_loss/200}")
              running_loss = 0.0

        train_loss_lst.append(total_loss)


    print("Validation Set")
    val_loss = validate(model, val_dataloader, criterion)
    val_loss_lst.append(val_loss)

    if file_path is not None:
      torch.save(model.state_dict(), file_path)
  return model

## Runner

In [7]:
model = SafetyClassifier().to(device=DEVICE)

In [9]:
dataset = ModelActivations("/content/do_not_answer_en.pkl", mode="individual", layer_num=16)

In [10]:
train_dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=32,
            shuffle=True
        )

In [11]:
val_dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=32,
            shuffle=True
        )

In [None]:
criterion = torch.nn.BCELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005)

model = training(model, train_dataloader, val_dataloader, 2, criterion, optimizer)