# Setup (just run)

In [141]:
# Colab-specific setup

# !git clone https://github.com/AISC-Steering-LLMs/Steering-LLMs
# !pwd
# repo_path = '/content/repository/'


In [142]:
# Imports

%pip install kaleido

import os
import pandas as pd
import main
from omegaconf import DictConfig, OmegaConf
import yaml
from hydra import initialize
from hydra.core.global_hydra import GlobalHydra
from hydra import compose
import ipywidgets as widgets
from IPython.display import display
from ipywidgets import widgets, Layout, Box, VBox, Label

from data_handler import DataHandler
from data_analyser import DataAnalyzer
from model_handler import ModelHandler

from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sklearn.cluster import FeatureAgglomeration

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Note: you may need to restart the kernel to use updated packages.


In [130]:
# Initialize Hydra for configuration management
GlobalHydra.instance().clear()  # Clear any previous Hydra instance
initialize(config_path=".", job_name="experiment")




The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1



hydra.initialize()

## Helper Functions

In [131]:
# Constants for styles
widget_layout = Layout(width='100%')
description_style = {'description_width': 'initial'}
BASE_PATH = '../data/inputs/' 


def load_yaml_config(file_path):
    """Load a YAML configuration file."""
    with open(file_path, 'r') as file:
        return yaml.safe_load(file)
    
def create_section_heading(title):
    """Create a section heading for the form."""
    return widgets.Label(value=title, layout=widgets.Layout(height='45px', align_items='center', justify_content='center'))
    
def create_form(config):
    """Create an interactive form for updating the configuration file."""
    form_items = []
    # Define the options for the dropdown
    model_options = ['gpt2-small', 'gpt2-medium', 'gpt2-large', 'gpt2-XL', 'llama']

    # Data Section
    form_items.append(create_section_heading('Data Settings'))
    base_path_display = widgets.Text(
        value=BASE_PATH,
        description='Input files located at:',
        disabled=True,
        layout=widget_layout,
        style=description_style
    )
    form_items.append(base_path_display)
    form_items.append(widgets.Text(
        value=config.get('prompts_sheet', ''), 
        description='prompts_sheet',
        placeholder='Enter the filename of the dataset including its extension (e.g., .csv or .xlsx).',
        layout=widget_layout,
        style=description_style
    ))
    
    # Model Section
    form_items.append(create_section_heading('Model Configuration'))
    form_items.append(widgets.Dropdown(
        options=model_options,
        value=config.get('model_name', 'gpt2-small'),
        description='model_name',
        tooltip='Select the language model you want to use.',
        layout=widget_layout,
        style=description_style
    ))

    # Execution Section
    form_items.append(create_section_heading('Execution Settings'))
    form_items.append(widgets.Checkbox(
        value=config.get('use_gpu', False),
        description='use_gpu',
        tooltip='Check this box to use GPU for computation if available.',
    ))
    form_items.append(widgets.Checkbox(
        value=config.get('write_cache', False),
        description='write_cache',
        tooltip='Check this box to write intermediate computations to disk.',
    ))
    form_items.append(widgets.Textarea(
        value=config.get('experiment_notes', ''),
        placeholder='Enter any notes for the experiment here.',
        description='experiment_notes',
        layout=widget_layout,
        style=description_style
    ))

    return widgets.VBox(form_items)
    


def update_config_and_save(btn, form):
    """Update the configuration file with values from the form."""
    updated_config = {}
    for widget in form.children:
        # Skip over any non-input widgets like Labels or disabled Text widgets
        if isinstance(widget, widgets.Label) or (isinstance(widget, widgets.Text) and widget.disabled):
            continue
        if isinstance(widget, widgets.Text) and widget.description == 'prompts_sheet':
            # Concatenate the base path with the provided filename
            updated_config[widget.description] = widget.value
        elif isinstance(widget, (widgets.Text, widgets.Textarea, widgets.Dropdown, widgets.Checkbox)):
            # Make sure to capture only the widgets that should contribute to the configuration
            updated_config[widget.description] = widget.value
    
    # Use the filtered updated_config to write to the yaml file
    with open('config.yaml', 'w') as file:
        yaml.safe_dump(updated_config, file)
    print("Configuration updated and saved.")



# Experiment Setup

## Load existing configuration and edit if needed

In [132]:
# Load configuration and create interactive form
config = load_yaml_config('config.yaml')
form = create_form(config)
display(form)

# Create a button to save the configuration, pass the form to the event handler
save_button = widgets.Button(
        description="Save Configuration",
        button_style='success',  # Use 'success' styling for a green color
        tooltip='Click to save the configuration',
        layout=widgets.Layout(width='auto', margin='10px 0'))
save_button.on_click(lambda btn: update_config_and_save(btn, form))
display(save_button)

VBox(children=(Label(value='Data Settings', layout=Layout(align_items='center', height='45px', justify_content…

Button(button_style='success', description='Save Configuration', layout=Layout(margin='10px 0', width='auto'),…

Configuration updated and saved.
Configuration updated and saved.


In [139]:
# Compose the final configuration from Hydra
cfg = compose(config_name="config")

# Instantiate classes DataHandler and ModelHandler
data_handler = DataHandler("../data")
model_handler = ModelHandler(cfg)

# Load inputs and create output directories
prompts_dict = data_handler.csv_to_dictionary(cfg.prompts_sheet)
experiment_base_dir, images_dir, metrics_dir = data_handler.create_output_directories()

# Save configurations and prompts
data_handler.write_experiment_parameters(cfg, prompts_dict, experiment_base_dir)


Loaded pretrained model gpt2-small into HookedTransformer


## Model Initialization and Data Processing

In [140]:
# Initialize the model and populate the data
activations_cache = data_handler.populate_data(prompts_dict)

# Compute activations and add hidden states
model_handler.compute_activations(activations_cache)

Computing activations:   0%|          | 0/160 [00:00<?, ?it/s]

Computing activations: 100%|██████████| 160/160 [00:12<00:00, 12.88it/s]


## Visualization

In [137]:
data_analyzer = DataAnalyzer(images_dir, metrics_dir, 42)

# Create checkboxes for each visualization type
tsne_checkbox = widgets.Checkbox(value=False, description='t-SNE Plot')
pca_checkbox = widgets.Checkbox(value=False, description='PCA Plot')
fa_checkbox = widgets.Checkbox(value=False, description='Feature Agglomeration')
raster_checkbox = widgets.Checkbox(value=False, description='Raster Plot')
random_proj_checkbox = widgets.Checkbox(value=False, description='Random Projections Analysis')
probe_hidden_states_checkbox = widgets.Checkbox(value=False, description='Probe Hidden States')
classifier_battery_checkbox = widgets.Checkbox(value=False, description='Classifier Battery (needs t-SNE)')

# Event handler for the button click
def on_run_button_clicked(b):
    # Clear previous output
    # clear_output(wait=True)
    
    if tsne_checkbox.value:
        # Run t-SNE plot
        tsne_model = TSNE(n_components=2, random_state=42)
        tsne_embedded_data_dict, tsne_labels, tsne_prompts = data_analyzer.plot_embeddings(activations_cache, tsne_model)
    
    if pca_checkbox.value:
        # Run PCA plot
        pca_model = PCA(n_components=2, random_state=42)
        pca_embedded_data_dict, pca_labels, pca_prompts = data_analyzer.plot_embeddings(activations_cache, pca_model)
    
    if fa_checkbox.value:
        # Run Feature Agglomeration
        fa_model = FeatureAgglomeration(n_clusters=2)
        fa_embedded_data_dict, fa_labels, fa_prompts = data_analyzer.plot_embeddings(activations_cache, fa_model)
    
    if raster_checkbox.value:
        # Run Raster plot
        data_analyzer.raster_plot(activations_cache)
    
    if random_proj_checkbox.value:
        # Run Random Projections Analysis
        data_analyzer.random_projections_analysis(activations_cache)
    
    if probe_hidden_states_checkbox.value:
        # Run Probe Hidden States
        data_analyzer.probe_hidden_states(activations_cache)
    
    if classifier_battery_checkbox.value:
        # Run Classifier Battery
        data_analyzer.classifier_battery(tsne_embedded_data_dict, tsne_labels, tsne_prompts, 0.2)
    
    print("Selected visualizations have been executed.")

In [None]:
# Create a button to run the selected visualizations
run_button = widgets.Button(description="Run Visualizations")
# Assign the event handler to the button
run_button.on_click(on_run_button_clicked)

# Display the checkboxes and the button
checkboxes = widgets.VBox([tsne_checkbox, pca_checkbox, fa_checkbox, raster_checkbox,
                           random_proj_checkbox, probe_hidden_states_checkbox, classifier_battery_checkbox, run_button])
display(checkboxes)

VBox(children=(Checkbox(value=False, description='t-SNE Plot'), Checkbox(value=False, description='PCA Plot'),…

TSNE: 100%|██████████| 12/12 [00:06<00:00,  1.75it/s]
PCA: 100%|██████████| 12/12 [00:02<00:00,  4.41it/s]
FeatureAgglomeration: 100%|██████████| 12/12 [00:02<00:00,  5.74it/s]
Computing logistic_regression:   0%|          | 0/12 [00:00<?, ?it/s]


ValueError: 
Image export using the "kaleido" engine requires the kaleido package,
which can be installed using pip:
    $ pip install -U kaleido


In [9]:
data_analyzer = DataAnalyzer(images_dir, metrics_dir, 42)

# Get various representations for each layer
# and plot them
tsne_model = TSNE(n_components=2, random_state=42)
tsne_embedded_data_dict, tsne_labels, tsne_prompts = data_analyzer.plot_embeddings(activations_cache, tsne_model)
pca_model = PCA(n_components=2, random_state=42)
pca_embedded_data_dict, pca_labels, pca_prompts = data_analyzer.plot_embeddings(activations_cache, pca_model)
fa_model = FeatureAgglomeration(n_clusters=2)
fa_embedded_data_dict, fa_labels, fa_prompts = data_analyzer.plot_embeddings(activations_cache, fa_model)

# Further analysis
data_analyzer.raster_plot(activations_cache)
data_analyzer.random_projections_analysis(activations_cache)
data_analyzer.probe_hidden_states(activations_cache)

# See if the representations can be used to classify the ethical area
# Why are we actually doing this? Hypothesis - better seperation of ethical areas
# Leads to better steering vectors. This actually needs to be tested.
# Only done with the t-SNE representation but could be done with others (PCA, heirarchical clustering, etc.)
data_analyzer.classifier_battery(tsne_embedded_data_dict, tsne_labels, tsne_prompts, 0.2)

TSNE: 100%|██████████| 12/12 [00:07<00:00,  1.70it/s]
PCA: 100%|██████████| 12/12 [00:02<00:00,  4.41it/s]
FeatureAgglomeration: 100%|██████████| 12/12 [00:02<00:00,  5.50it/s]
Computing Raster Plots: 100%|██████████| 12/12 [01:22<00:00,  6.85s/it]
Random projections analysis: 100%|██████████| 12/12 [00:00<00:00, 1242.39it/s]
Probing hidden states: 100%|██████████| 12/12 [00:02<00:00,  4.01it/s]
Computing logistic_regression:   0%|          | 0/12 [00:00<?, ?it/s]


ValueError: 
Image export using the "kaleido" engine requires the kaleido package,
which can be installed using pip:
    $ pip install -U kaleido


## Save Results

In [13]:
# Save the activations cache if required by the configuration
if cfg.write_cache:
    model_handler.save_activations_cache(activations_cache, experiment_base_dir)