In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import datetime
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset, random_split

In [3]:
repo_path = '/content/drive/MyDrive/odeformer' # folder where odeformer is stored
script_path = '/content/drive/MyDrive/aisc' # folder containing the script generate_samples.py
activations_path = '/content/drive/MyDrive/aisc/activations' # where you want to save activations
logs_path = '/content/drive/MyDrive/aisc/logs'
probes_path = '/content/drive/MyDrive/aisc/probes'
%cd {script_path}

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

/content/drive/MyDrive/aisc


# LRProbe and ActivationsDataset classes

In [4]:
# Probe class

class LRProbe(torch.nn.Module):
    def __init__(self, d_in=512): # Default decoder layer activation dimension
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(d_in, 1, bias=False),
            torch.nn.Sigmoid()
        )

    def forward(self, x):
        return self.net(x).squeeze(-1) # I just copied this. I don't know if .squeeze is necessary?

    def predict(self, acts):
        with torch.no_grad():
            return self(acts)

    @property
    def direction(self):
        return self.net[0].weight.data[0]

In [5]:
# PyTorch wrapper for activations dataset

class ActivationsDataset(Dataset):
  def __init__(self, activations_path, feature_label, layer_idx, module='ffn'):
    self.act_paths = [os.path.join(activations_path, f) for f in os.listdir(activations_path)]
    self.feature_label = feature_label
    self.layer_idx = layer_idx
    self.module = module

  def __len__(self):
    return len(self.act_paths)

  def __getitem__(self, idx):
    act_path = self.act_paths[idx]
    # TODO: need to change from torch load to pickle
    activation = torch.load(act_path)
    layer_name = self.get_layer_name(self.layer_idx)
    if 'encoder' in layer_name:
      act_data = activation['encoder'][layer_name]
    else:
      act_data = activation['decoder'][layer_name]
    # TODO: will need to update the below functionality when the activations
    #       script is changed to collect only activations for the final token
    act_data = act_data[-1, :, :].flatten()
    act_label = torch.tensor(activation['feature_dict'][self.feature_label], dtype=torch.float)
    return act_data, act_label

  def get_layer_name(self, idx):
    '''
    Helper function to return the correct name of a layer in the ODEFormer given
    its index
    '''
    layers = [f'encoder_{self.module}_{num}' for num in range(4)] + [f'decoder_{self.module}_{num}' for num in range(12)]
    layer_name = layers[idx]
    if -16 <= idx < 16:
      return layer_name
    else:
      raise ValueError("Layer index should be in -16 to 15")

# Helper functions

In [6]:
# Dataset helper functions

def split_dataset(dataset, lengths=[0.8, 0.0, 0.2], seed=None):
  '''
  Split into training, validation, and testing datasets
  Default is to have no validation dataset (i.e. empty) and randomized splitting
  Seed can be set for deterministic splitting
  '''
  generator = torch.Generator().manual_seed(seed)
  return random_split(dataset, lengths, generator)

def get_d_in(dataset):
  '''
  Return the input dimension a probe requires for a given dataset of activations
  '''
  d_in = dataset[0][0].shape[0]
  return d_in

In [7]:
# Probe training and evaluation functions

def eval_probe(probe, dataloader):
  '''
  Evaluate a given probe on a specified dataset (via its corresponding dataloader)
  '''
  with torch.no_grad():
    total_loss = 0
    correct = 0
    total_preds = 0
    criterion = nn.BCELoss()

    probe.eval()

    for acts, labels in dataloader:
      outputs = probe(acts)
      preds = outputs.round()
      loss = criterion(preds, labels)
      total_loss += loss.item()
      correct += (preds == labels).float().sum()
      total_preds += len(labels)

    accuracy = (correct / total_preds).item()
    avg_loss = total_loss / total_preds
    return avg_loss, accuracy

def train_probe(probe, train_dataloader, val_dataloader=None, \
                lr=0.01, epochs=20, device='cpu', \
                logs_path='/content/drive/MyDrive/aisc/logs', write_log=False): # TODO: determine if default hyperparameters are good
  '''
  Train an instantiated probe using specified training and validation data
  '''
  # Use Adam optimizer for now; TODO: determine if other optimizers might be better
  opt = optim.Adam(probe.parameters(), lr=lr)
  criterion = nn.BCELoss()

  # Open log files to write to if desired
  # Include the current time of the experiment in filename to avoid collisions
  today_str = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M')
  if write_log:
    train_f = open(os.path.join(logs_path, f'{today_str}_train_acc_per_epoch.txt'), 'w')
    if val_dataloader is not None:
      val_f = open(os.path.join(logs_path, f'{today_str}_val_acc_per_epoch.txt'), 'w')

  # Main training loop
  for epoch in tqdm(range(epochs), desc='Training LR Probe'):
    probe.train()

    total_loss = 0
    correct_preds = 0
    total_preds = 0

    for train_acts, train_labels in train_dataloader:
      # Calculate batch loss
      opt.zero_grad()
      outputs = probe(train_acts)
      loss = criterion(outputs, train_labels)
      loss.backward()
      opt.step()
      total_loss += loss.item()
      total_preds += len(train_labels)

      # Calculate correct batch predictions
      preds = outputs.round()
      correct_preds += (preds == train_labels).float().sum()

    # Calculate epoch stats
    accuracy = (correct_preds / total_preds).item()
    avg_loss = total_loss / total_preds

    # Write to specified log file
    if write_log:
      train_f.write(f'Epoch {epoch+1}: Loss {avg_loss}, Accuracy {accuracy}\n')
    # print(f' Epoch {epoch+1}: Loss {avg_loss}, Accuracy {accuracy.item()}\n')

    # TODO: maybe implement early stopping? Need to test on larger dataset
    # Run evaluation on validation set
    if val_dataloader is not None:
        avg_val_loss, val_accuracy = eval_probe(probe, val_dataloader)
        if write_log:
          val_f.write(f'Epoch {epoch+1} (Validation): Loss {avg_val_loss}, Accuracy {val_accuracy.item()}\n')
        # print(f' Epoch {epoch+1} (Validation): Loss {avg_val_loss}, Accuracy {val_accuracy.item()}\n')

  print(f'\nEpoch {epoch+1} (Final): Loss {avg_loss}, Accuracy {accuracy}')

  # TODO: return also train and val accuracy arrays for easy plotting?
  return probe

In [8]:
# Probe saving and loading functionality

def save_probe_to_path(probe, probe_path):
  '''
  Save a probe's state dictionary to a specified path
  (saving only the state dictionary is suggested by PyTorch)
  '''
  torch.save(probe.state_dict(), probe_path)
  print(f'Saved state dictionary to {probe_path}')

def load_probe_from_path(probe_path, d_in=512):
  '''
  Returns a probe ready for evaluation loaded from the given path, with specified input dimension
  '''
  probe = LRProbe(d_in=d_in)
  probe.load_state_dict(torch.load(probe_path, weights_only=True))
  probe.eval()
  return probe

In [10]:
# MWE and features testing

# Test dataset, dataloaders, and splitting
target_feature = 'trig'
target_layer_idx = 15
full_dataset = ActivationsDataset(activations_path=activations_path, feature_label=target_feature, layer_idx=target_layer_idx)
train_dataset, val_dataset, test_dataset = split_dataset(full_dataset, lengths=[0.7, 0.1, 0.2], seed=42)
train_dataloader = DataLoader(train_dataset)
val_dataloader = DataLoader(val_dataset)
test_dataloader = DataLoader(test_dataset)

# Test helper function and probe
d_in = get_d_in(full_dataset)
test_probe = LRProbe(d_in)

# Test training loop
train_probe(test_probe, train_dataloader, val_dataloader=val_dataloader)

# Test evaluation function on test set
test_loss, test_acc = eval_probe(test_probe, test_dataloader)
print(f'Test Set: Loss {test_loss}, Accuracy {test_acc}')

# Test extracting saved probe direction
print(f'Probe direction dim: {test_probe.direction.shape}')

# Test saving and loading probe
probe_name = f'test_probe_{target_feature}_{target_layer_idx}.pt'
test_probe_path = os.path.join(probes_path, probe_name)
save_probe_to_path(test_probe, test_probe_path)
test_probe_copy = load_probe_from_path(test_probe_path, d_in=d_in)

  activation = torch.load(act_path)
Training LR Probe: 100%|██████████| 20/20 [00:02<00:00,  8.04it/s]


Epoch 20 (Final): Loss 6.269478007275049e-05, Accuracy 1.0
Test Set: Loss 50.0, Accuracy 0.5
Probe direction dim: torch.Size([512])
Saved state dictionary to /content/drive/MyDrive/aisc/probes/test_probe_trig_15.pt



