In [31]:
import datasets
from datasets import load_from_disk
import torch

from ccstools import normalize

import math
# import sys

import torch
from tqdm.auto import tqdm, trange
from statistics import mean


from ccstools import (
    normalize,
    train_test_split
)

TEST_SPLIT = 0.4

# Constants
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

DATASETS = (
    'ag_news',
    'amazon_polarity',
    'super_glue-boolq',
    'super_glue-copa',
    'dbpedia_14',
    'imdb',
    'piqa',
    'glue-qnli',
    'super_glue-rte',
    'story_cloze-2016',
)

MODELS = (
    'allenai/unifiedqa-t5-3b',
    'eleutherai/gpt-j-6b',
    'microsoft/deberta-v2-xxlarge-mnli',
    'roberta-large-mnli',
    't5-3b'
)

LAYERS = (-21, -19, -17, -13, -11, -9, -7, -5, -3, -1)


def get_representations(dataset: datasets.Dataset,
                        layer: int = -1) -> torch.Tensor:
    representations = dataset['hidden_states'].refine_names(
        'index',
        'prompt',
        'layer',
        'answer',
        'state'
    )

    # select last layer
    representations = representations.select('layer', -1)
    return normalize(representations)


MODEL = MODELS[2]
DATASET = DATASETS[1]
SUB_DATASET = None
LAYER = 'encoder'
data_dir = f'results-data/{MODEL}/{DATASET}'

print(f"""Model: {MODEL}
Dataset: {DATASET}
Loading from directory: {data_dir}...""")

dataset = load_from_disk(f'{data_dir}').with_format('torch')
# dataset = dataset.with_format('torch', device=DEVICE)

representations = get_representations(dataset)

((train_representations, test_representations),
 (train_representation_labels, test_representation_labels)) = train_test_split(representations,
                                                 dataset['label'].cpu(),
                                                 test_size=TEST_SPLIT)

((train_representations, test_representations),
 (train_representation_labels, test_representation_labels)) = ((train_representations.rename(None), test_representations.rename(None)),
                                                               (train_representation_labels.rename(None), test_representation_labels.rename(None)))

Model: microsoft/deberta-v2-xxlarge-mnli
Dataset: amazon_polarity
Loading from directory: results-data/microsoft/deberta-v2-xxlarge-mnli/amazon_polarity...


In [None]:
import pandas as pd
import torch.optim as optim
import torch.nn as nn
import numpy as np

def calculate_results(loss_function_name, train_data, test_data, train_data_labels, test_data_labels, num_initialisations=10, num_epochs = 1000, learning_rate = 0.01):
    hyper_parameter=get_hyper_parameter(loss_function_name, train_data, train_data_labels, num_epochs, learning_rate)
    #create the data pd.DataFrame variable
    data = create_data_frame(loss_function_name, hyper_parameter, train_data, test_data, train_data_labels, test_data_labels, num_initialisations, num_epochs, learning_rate)
    return (hyper_parameter, data)

#Need to work out what to do when we use 'Random'
def create_data_frame(loss_function_name, hyper_parameter, train_data, test_data, train_data_labels, test_data_labels, num_initialisations, num_epochs, learning_rate):
  number_of_ccs_probes = 20
  dimension_of_activation_space = train_data.size()[-1]
  if loss_function_name != 'Random':
    loss_function = get_loss_function(loss_function_name)
    loss_function = evaluate_first_variable(train_data, loss_function)
    if hyper_parameter!= None:
      loss_function = evaluate_first_variable(hyper_parameter, loss_function)
  all_weights = []
  all_losses = []
  all_test_accuracies = []
  all_train_accuracies = []
  all_cosine_similarities = []

  ccs_loss_function = get_loss_function('CCS')
  ccs_loss_function = evaluate_first_variable(train_data, ccs_loss_function)
  ccs_directions = get_multiple_ccs_directions_and_losses(ccs_loss_function, number_of_ccs_probes, dimension_of_activation_space, learning_rate, num_epochs)[0]
  ccs_directions = torch.stack(ccs_directions)

  if loss_function_name == 'Random':
    random_directions = get_random_direction(dimension_of_activation_space, num_initialisations)
    for initialisation in trange(num_initialisations):
      weights = random_directions[initialisation]
      test_accuracy = get_accuracy(weights, test_data, test_data_labels)
      train_accuracy = get_accuracy(weights, train_data, train_data_labels)
      cosine_similarity = torch.mean(torch.abs(torch.matmul(ccs_directions, weights)))
      all_weights.append(weights)
      all_losses.append(None)
      all_test_accuracies.append(test_accuracy.item())
      all_train_accuracies.append(train_accuracy.item())
      all_cosine_similarities.append(cosine_similarity.item())
    data_frame = pd.DataFrame({'weights': all_weights, 'loss': all_losses, 'train accuracy': all_train_accuracies, 'test accuracy': all_test_accuracies, 'cosine similarity': all_cosine_similarities})
    return data_frame

  for initialisation in trange(num_initialisations):
    model_instance = new_model(dimension_of_activation_space)
    model_instance.normalize()
    model_instance, loss = train(model_instance, loss_function, num_epochs, learning_rate)
    model_instance.normalize()
    weights = model_instance.weights.detach()
    weights /= torch.norm(weights)
    test_accuracy = get_accuracy(weights, test_data, test_data_labels)
    train_accuracy = get_accuracy(weights, train_data, train_data_labels)
    cosine_similarity = torch.mean(torch.abs(torch.matmul(ccs_directions, weights)))
    all_weights.append(weights)
    all_losses.append(loss.detach().item())
    all_test_accuracies.append(test_accuracy.item())
    all_train_accuracies.append(train_accuracy.item())
    all_cosine_similarities.append(cosine_similarity.item())

  data_frame = pd.DataFrame({'weights': all_weights, 'loss': all_losses, 'train accuracy': all_train_accuracies, 'test accuracy': all_test_accuracies, 'cosine similarity': all_cosine_similarities})
  return data_frame

def get_hyper_parameter(loss_function_name, data, data_labels, num_epochs, learning_rate):
    if loss_function_name in ('Random', 'Supervised', 'PCA', 'CCS') :
        return None

    number_of_points = 11
    number_of_iterations = 2
    dimension_of_activation_space = data.size()[-1]
    
    loss_function = get_loss_function(loss_function_name)
    loss_function = evaluate_first_variable(data, loss_function)
    if loss_function_name == 'MD-CCS':
        (initial_interval_lower_value, initial_interval_upper_value) = (0.9,0.999)
        ccs_loss_function = get_loss_function('CCS')
        ccs_loss_function = evaluate_first_variable(data, ccs_loss_function)
        hyper_parameter = get_hyper_parameter_for_ccs(loss_function, ccs_loss_function, number_of_points, number_of_iterations, dimension_of_activation_space, initial_interval_lower_value, initial_interval_upper_value, num_epochs, learning_rate)
    else:
        (initial_interval_lower_value, initial_interval_upper_value) = (0,0.99)
        #Need to include data and data labels...
        hyper_parameter = get_hyper_parameter_for_accuracy(loss_function, data, data_labels, number_of_points, number_of_iterations, dimension_of_activation_space, initial_interval_lower_value, initial_interval_upper_value, num_epochs, learning_rate)
    return hyper_parameter

def evaluate_first_variable(first_variable, input_function):
  def output_function(*args, **kwargs):
    output = input_function(first_variable, *args, **kwargs)
    return output
  return output_function

def get_loss_function(loss_function_name):
    
    def ccs_loss(data, model):
        probabilities = model(data).squeeze()
        consistency_loss = (probabilities.sum(dim=2) - 1) ** 2
        confidence_loss = probabilities.min(dim=2).values ** 2
        loss = torch.mean(consistency_loss + confidence_loss)
        return loss

    def supervised_loss(data, data_labels, model):
        probabilities = model(data)
        loss = torch.mean(torch.square(probabilities - data_labels))
        return loss
    
    def pca_loss(data, model):
        displacement_data = data[:,:,1,:]-data[:,:,0,:]
        displacement = torch.mean(torch.square(torch.matmul(displacement_data, model.weights/torch.norm(model.weights))))
        return displacement
    
    def ma_loss(data, hyper_parameter, model):
        displacement_data = data[:,:,1,:]-data[:,:,0,:]
        mean = torch.mean(torch.abs(torch.matmul(displacement_data, model.weights/torch.norm(model.weights))))
        standard_deviation = torch.sqrt(torch.mean(torch.square(torch.abs(torch.matmul(displacement_data, model.weights/torch.norm(model.weights))) - mean)))
        loss = hyper_parameter * standard_deviation - (1-hyper_parameter) * mean
        return loss

    def smr_loss(data, hyper_parameter, model):
        displacement_data = data[:,:,1,:]-data[:,:,0,:]
        mean = torch.mean(torch.abs(torch.matmul(displacement_data, model.weights/torch.norm(model.weights))))
        square_mean_root = torch.square(torch.mean(torch.sqrt(torch.abs(torch.matmul(displacement_data, model.weights/torch.norm(model.weights))))))
        standard_deviation = torch.sqrt(torch.mean(torch.square(torch.abs(torch.matmul(displacement_data, model.weights/torch.norm(model.weights))) - mean)))
        loss = hyper_parameter * standard_deviation - (1-hyper_parameter) * square_mean_root
        return loss
    
    def md_loss(data, hyper_parameter, model):
        averaged_data = torch.mean(data, dim=2)
        displacement_data = data[:,:,1,:]-data[:,:,0,:]
        mean = torch.mean(torch.square(torch.matmul(averaged_data, model.weights/torch.norm(model.weights))))
        displacement = torch.mean(torch.square(torch.matmul(displacement_data, model.weights/torch.norm(model.weights))))
        loss = hyper_parameter * mean - (1-hyper_parameter) * displacement
        return loss

    if loss_function_name=="Supervised":
        return supervised_loss
    elif loss_function_name=="PCA":
        return pca_loss
    elif loss_function_name=="CCS":
        return ccs_loss
    elif loss_function_name in ("MD-CCS", "MD-Acc"):
        return md_loss
    elif loss_function_name=="SMR":
        return smr_loss
    elif loss_function_name=="MA":
        return ma_loss
    else:
        raise Exception(f'The loss function \'{loss_function_name}\' does not exist.')


class new_probe(nn.Module):
    def __init__(self, number_of_dimensions):
        super(new_model, self).__init__()
        self.weights = nn.Parameter(torch.normal(0,1,size=[number_of_dimensions]))  # Initialize weights randomly
        self.bias = nn.Parameter(torch.normal(0,1))

    def forward(self, x):
        return torch.nn.functional.sigmoid(torch.dot(self.weights, x)+self.bias)

    def normalize(self):
        self.weights.data = self.weights.div(torch.norm(self.weights))


class new_model(nn.Module):
    def __init__(self, number_of_dimensions):
        super(new_model, self).__init__()
        self.weights = nn.Parameter(torch.normal(0,1,size=[number_of_dimensions]))  # Initialize weights randomly

    #forward pass not needed so set it to the identity map
    def forward(self, x):
        return x

    def normalize(self):
        self.weights.data = self.weights.div(torch.norm(self.weights))

def train(model, loss_function, epochs, learning_rate):
    seed=np.random.randint(100)
    torch.manual_seed(seed)
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
    for epoch in range(epochs):
      loss = loss_function(model)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
    return (model, loss)

def get_random_direction(number_of_dimensions: int, number_of_directions):
  all_directions = []
  for vector in range(number_of_directions):
    direction = torch.normal(0,1,size=[number_of_dimensions])
    direction /= torch.norm(direction)
    all_directions.append(direction)
  return all_directions

def get_multiple_ccs_directions_and_losses(ccs_loss_function, num_initialisations, dimension_of_activation_space, learning_rate, epochs):
  all_weights = []
  all_losses = []
  for initialisation in range(num_initialisations):
    seed=np.random.randint(100)
    torch.manual_seed(seed)
    probe = new_probe(dimension_of_activation_space)
    (probe, loss) = train(probe, ccs_loss_function, epochs, learning_rate)
    weights=probe[0].weights.detach().squeeze()
    weights /= torch.norm(weights)
    all_weights.append(weights)
    all_losses.append(loss.detach().item())
  return all_weights, all_losses

def get_hyper_parameter_for_ccs(loss_function, ccs_loss_function, number_of_points, number_of_iterations, dimension_of_activation_space, initial_interval_lower_value, initial_interval_upper_value, num_epochs, learning_rate):
  number_of_initialisations = 3
  model_instance = new_model(dimension_of_activation_space)
  model_instance.normalize()
  current_interval_lower_value = initial_interval_lower_value
  current_interval_upper_value = initial_interval_upper_value
  best_ccs_cosine_similarity = 0
  best_hyper_parameter = initial_interval_lower_value
  number_of_ccs_probes = 20

  ccs_directions = get_multiple_ccs_directions_and_losses(ccs_loss_function, number_of_ccs_probes, dimension_of_activation_space, learning_rate, num_epochs)[0]
  ccs_directions = torch.stack(ccs_directions)

  optimizer = optim.AdamW(model_instance.parameters(), lr=learning_rate)

  for iteration in trange(number_of_iterations):
    interval_width = current_interval_upper_value - current_interval_lower_value
    for point in trange(number_of_points):
      current_hyper_parameter = current_interval_lower_value + point * (interval_width / (number_of_points - 1))
      current_loss_function = evaluate_first_variable(current_hyper_parameter,loss_function)
      all_current_cosine_similarities = []
      for initialisation in range(number_of_initialisations):
        model_instance = train(model_instance, current_loss_function, num_epochs, learning_rate)[0]
        model_instance.normalize()
        current_ccs_cosine_similarity = torch.mean(torch.abs(torch.matmul(ccs_directions, model_instance.weights.detach())))
        all_current_cosine_similarities.append(current_ccs_cosine_similarity.item())
      average_current_cosine_similarity = mean(all_current_cosine_similarities)
      if best_ccs_cosine_similarity < average_current_cosine_similarity:
        best_ccs_cosine_similarity = average_current_cosine_similarity
        best_hyper_parameter = current_hyper_parameter
    interval_lower_value = max(current_interval_lower_value, best_hyper_parameter - interval_width/(number_of_points - 1))
    interval_upper_value = min(current_interval_upper_value, best_hyper_parameter + interval_width/(number_of_points - 1))
    current_interval_lower_value = interval_lower_value
    current_interval_upper_value = interval_upper_value
  return best_hyper_parameter

def get_hyper_parameter_for_accuracy(loss_function, data, data_labels, number_of_points, number_of_iterations, dimension_of_activation_space, initial_interval_lower_value, initial_interval_upper_value, num_epochs, learning_rate):
  number_of_initialisations = 3
  model_instance = new_model(dimension_of_activation_space)
  model_instance.normalize()
  current_interval_lower_value = initial_interval_lower_value
  current_interval_upper_value = initial_interval_upper_value
  best_accuracy = 0
  best_hyper_parameter = initial_interval_lower_value

  optimizer = optim.AdamW(model_instance.parameters(), lr=learning_rate)

  for iteration in trange(number_of_iterations):
    interval_width = current_interval_upper_value - current_interval_lower_value
    for point in trange(number_of_points):
      current_hyper_parameter = current_interval_lower_value + point * (interval_width / (number_of_points - 1))
      current_loss_function = evaluate_first_variable(current_hyper_parameter,loss_function)
      all_current_accuracies = []
      for initialisation in range(number_of_initialisations):
        model_instance = train(model_instance, current_loss_function, num_epochs, learning_rate)[0]
        model_instance.normalize()
        current_accuracy = get_accuracy(model_instance.weights.detach(), data, data_labels)
        all_current_accuracies.append(current_accuracy)
      average_current_accuracy = mean(all_current_accuracies)
      if best_accuracy < average_current_accuracy:
        best_accuracy = average_current_accuracy
        best_hyper_parameter = current_hyper_parameter
    
    interval_lower_value = max(current_interval_lower_value, best_hyper_parameter - interval_width/(number_of_points - 1))
    interval_upper_value = min(current_interval_upper_value, best_hyper_parameter + interval_width/(number_of_points - 1))
    current_interval_lower_value = interval_lower_value
    current_interval_upper_value = interval_upper_value
  return best_hyper_parameter


def get_accuracy(direction, data, data_labels):
  direction /= torch.norm(direction)

  displaced_data = data[:,:,1,:] - data[:,:,0,:]

  #calculate accuracy
  truth_components=torch.matmul(displaced_data, direction.transpose(0,-1)).squeeze()
  truth_components[truth_components>0]=1
  truth_components[truth_components<0]=0
  accuracy=torch.max(torch.sum(truth_components==data_labels),torch.sum((1-truth_components)==data_labels))/len(truth_components)

  return accuracy.item()

#create an imutable data class
class experimental_data():
  def __init__(self, dataset: str, sub_dataset: str, model: str, layer: str, loss_function_name: str, hyper_parameter: float, data: pd.DataFrame):
      self.dataset = dataset
      self.sub_dataset = sub_dataset
      self.model = model
      self.layer = layer
      self.loss_function_name = loss_function_name
      self.hyper_parameter = hyper_parameter
      self.data = data

In [None]:
all_loss_functions=('Random', 'Supervised', 'PCA', 'CCS', 'MD-CCS', 'MD-Acc', 'SMR', 'MA')
experimental_data_for_all_functions = {}
for loss_function_name in all_loss_functions:
    (hyper_parameter, data) = calculate_results(loss_function_name, train_representations, test_representations, train_representation_labels, test_representation_labels)
    experimental_data_for_all_functions[loss_function_name]=experimental_data(DATASET, SUB_DATASET,MODEL, LAYER, loss_function_name, hyper_parameter, data) 


########################
#Save the data to a file
########################
if SUB_DATASET == None:
    sub_dataset = 'None'
else:
    sub_dataset = SUB_DATASET

base_directory = 'g3-nandi/src/optimization/' + DATASET + '_' + sub_dataset + '/' + MODEL + '_' + LAYER
os.makedirs(os.path.dirname(base_directory), exist_ok=True)
for loss_function_name in all_loss_functions:
    directory = base_directory + '/' + loss_function_name +'.pkl'
    os.makedirs(os.path.dirname(directory), exist_ok=True)
    with open(directory, 'wb') as f:
        pickle.dump(experimental_data_for_all_functions[loss_function_name], f)
    f.close()

########################
#Print the data
########################
print("SUMMARY DATA FOR INITIALISATIONS THAT MINIMISE LOSS")
for loss_function_name in all_loss_functions:
    print()
    print()
    print()
    print(f'************************ {loss_function_name} ************************')
    data_frame = experimental_data_for_all_functions[loss_function_name].data
    hyper_parameter = experimental_data_for_all_functions[loss_function_name].hyper_parameter
    print(f'Hyper parameter: {hyper_parameter}')
    print(data_frame[data_frame.loss == data_frame.loss.min()][['loss', 'train accuracy', 'test accuracy', 'cosine similarity']])

print()
print()
print()
print()
print()
print()
print()
print()
print()
print()
print()
print("SUMMARY OF ALL DATA")
for loss_function_name in all_loss_functions:
    print()
    print()
    print()
    print(f'************************ {loss_function_name} ************************')
    data_frame = experimental_data_for_all_functions[loss_function_name].data
    hyper_parameter = experimental_data_for_all_functions[loss_function_name].hyper_parameter
    print(f'Hyper parameter: {hyper_parameter}')
    print(data_frame[['loss', 'train accuracy', 'test accuracy', 'cosine similarity']])

In [None]:
import matplotlib.pyplot as plt

#Plot the figure of CCS direction, PCA direction and decision boundary.
#Use CCS lowest loss and PCA lowest loss
#Orthogonalise the PCA vector
#projected data > - bias / weightnorm

def plot_data_distribution(train_data, test_data, test_labels, num_epochs = 1000, learning_rate = 0.01):
  dimension_of_activation_space = data.size()[-1]

  ccs_loss_function = get_loss_function('CCS')
  ccs_loss_function = evaluate_first_variable(train_data, ccs_loss_function)
  ccs_probe = new_probe(dimension_of_activation_space)
  ccs_probe = train(ccs_probe, ccs_loss_function, num_epochs, learning_rate)[0]
  ccs_weights = ccs_probe.weights.detach()
  ccs_bias = ccs_probe.bias.detach()/torch.norm(ccs_weights)
  ccs_weights/=torch.norm(ccs_weights)

  y_data = torch.matmul(test_data, ccs_weights)
  yes_true_y_data = y_data[:,:,1][test_labels]
  yes_false_y_data = y_data[:,:,1][1-test_labels]
  no_true_y_data = y_data[:,:,0][test_labels]
  no_false_y_data = y_data[:,:,0][1-test_labels]
  true_y_data = torch.flatten(torch.cat([yes_true_y_data, no_true_y_data], dim=0))
  false_y_data = torch.flatten(torch.cat([yes_false_y_data, no_false_y_data], dim=0))

  pca_loss_function = get_loss_function('PCA')
  pca_loss_function = evaluate_first_variable(train_data, pca_loss_function)
  pca_model = new_model(dimension_of_activation_space)
  pca_model = train(pca_model, pca_loss_function, num_epochs, learning_rate)[0]
  pca_weights = pca_model.weights.detach()
  #Orthogonalise the pca direction with the ccs direction
  pca_weights -= torch.dot(pca_weights, ccs_weights) * ccs_weights
  pca_weights/=torch.norm(pca_weights)

  x_data = torch.matmul(test_data, pca_weights)
  yes_true_x_data = x_data[:,:,1][test_labels]
  yes_false_x_data = x_data[:,:,1][1-test_labels]
  no_true_x_data = x_data[:,:,0][test_labels]
  no_false_x_data = x_data[:,:,0][1-test_labels]
  true_x_data = torch.flatten(torch.cat([yes_true_x_data, no_true_x_data], dim=0))
  false_x_data = torch.flatten(torch.cat([yes_false_x_data, no_false_x_data], dim=0))

  plt.figure(figsize=(8, 6))
  plt.scatter(true_x_data, true_y_data, s=5, color='red', label='True')
  plt.scatter(false_x_data, false_y_data, s=5, color='blue', label='False')
  plt.axhline(y = - ccs_bias, color = 'black', linestyle = '-')
  plt.xlabel('PCA component 1')
  plt.ylabel('CCS direction')
  plt.title("Dimensionality reduction with first PCA component\n and CCS direction")
  plt.legend()
  plt.show()
  

plot_data_distribution(train_representations, train_representation_labels)

In [None]:
#histogram of ccs outputs

#trains a probe on the test data, and then generates the histogram on the test set
def create_prober_histogram(train_data, test_data, num_epochs = 1000, learning_rate = 0.01):
  dimension_of_activation_space = train_data.size()[-1]
  ccs_loss_function = get_loss_function('CCS')
  ccs_loss_function = evaluate_first_variable(train_data, ccs_loss_function)
  ccs_probe = new_probe(dimension_of_activation_space)
  ccs_probe = train(ccs_probe, ccs_loss_function, num_epochs, learning_rate)[0]

  #calculate difference in prober outputs over test set
  test_probabilities = ccs_probe(test_data)
  yes_probabilities = test_probabilities[:, :, 1]
  no_probabilities = test_probabilities[:, :, 0]

  yes_truth_mask = yes_probabilities > no_probabilities
  no_truth_mask = no_probabilities > yes_probabilities

  true_probabilities = torch.cat([yes_probabilities[yes_truth_mask], no_probabilities[no_truth_mask]],dim=0)
  false_probabilities = torch.cat([yes_probabilities[no_truth_mask], no_probabilities[yes_truth_mask]],dim=0)

  # Set the number of bins for the histogram
  num_bins = 30

  # Plot the histogram
  plt.hist([true_probabilities.detach().numpy(), false_probabilities.detach().numpy()], bins=num_bins, stacked=True, color=['blue', 'orange'], label=['Larger CCS output', 'Smaller CCS output'])
  plt.title('Histogram of CCS prober outputs')
  plt.xlabel('Value')
  plt.xlim([0,1])
  plt.ylabel('Frequency')
  plt.legend()
  plt.show()

create_prober_histogram(train_representations, test_representations)