This is a self-contained notebook for training Eleuther's Pythia 70m on https://huggingface.co/datasets/iamtarun/code_instructions_120k_alpaca using IA3 acting on a single position.

In [99]:
# prompt: connect to my drive

from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [100]:
save_path = '/content/drive/MyDrive/Research/Stadie/IA3_Results'

In [101]:
!uv venv
!source .venv/bin/activate

Using CPython 3.11.12 interpreter at: [36m/usr/bin/python3[39m
Creating virtual environment at: [36m.venv[39m
Activate with: [32msource .venv/bin/activate[39m


In [102]:
!uv pip install datasets trl randomname

[2mUsing Python 3.11.12 environment at: /usr[0m
[2mAudited [1m3 packages[0m [2min 86ms[0m[0m


In [103]:
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import IA3Config, get_peft_model
import randomname



In [104]:
import wandb
import torch
import numpy as np
from transformers import TrainerCallback

class PEFTParameterHistogramCallback(TrainerCallback):
    """
    A callback that tracks actual parameter values for parameter-efficient fine-tuning methods.
    Only tracks parameters that have requires_grad=True.
    """

    def __init__(self, peft_param_prefix=None, max_individual_values=50):
        """
        Args:
            peft_param_prefix (list, optional): List of parameter name prefixes to track.
                If None, will try to auto-detect PEFT parameters using common prefixes.
            max_individual_values (int): Maximum number of individual values to track per parameter.
        """
        self.peft_param_prefix = peft_param_prefix or ["lora", "adapter", "prefix", "prompt", "ia3"]
        self.max_individual_values = max_individual_values

    def _is_peft_param(self, param_name):
        """Check if parameter is a PEFT parameter based on naming."""
        return any(prefix in param_name.lower() for prefix in self.peft_param_prefix)

    def on_log(self, args, state, control, model=None, logs=None, **kwargs):
        """Log actual parameter values when logging occurs."""
        if not model or not state.is_world_process_zero or not wandb.run:
            return

        # Dictionary to store parameter values
        param_values = {}

        # Collect all trainable PEFT parameter values
        for name, param in model.named_parameters():
            # Only track parameters that:
            # 1. Are PEFT parameters (based on name)
            # 2. Have requires_grad=True (are being trained)
            if self._is_peft_param(name) and param.requires_grad:
                # Get the actual parameter values as numpy array
                param_data = param.data.detach().cpu().numpy().flatten()

                # Store param shape information
                original_shape = list(param.shape)
                param_values[f"peft_params/{name}/shape"] = str(original_shape)

                # Use histogram instead of raw values list
                param_values[f"peft_params/{name}/histogram"] = wandb.Histogram(param_data)

                # For small parameters, also log individual values for better tracking
                if len(param_data) <= self.max_individual_values:
                    for i, value in enumerate(param_data):
                        param_values[f"peft_params/{name}/value_{i}"] = float(value)

                # Log summary statistics as well
                param_values[f"peft_params/{name}/mean"] = float(np.mean(param_data))
                param_values[f"peft_params/{name}/std"] = float(np.std(param_data))
                param_values[f"peft_params/{name}/min"] = float(np.min(param_data))
                param_values[f"peft_params/{name}/max"] = float(np.max(param_data))

                # Log gradient information if available
                if param.grad is not None:
                    grad_data = param.grad.detach().cpu().numpy().flatten()
                    param_values[f"peft_grads/{name}/histogram"] = wandb.Histogram(grad_data)

                    # Log individual gradient values for small parameters
                    if len(grad_data) <= self.max_individual_values:
                        for i, value in enumerate(grad_data):
                            param_values[f"peft_grads/{name}/value_{i}"] = float(value)

                    # Log gradient summary statistics
                    param_values[f"peft_grads/{name}/mean"] = float(np.mean(grad_data))
                    param_values[f"peft_grads/{name}/std"] = float(np.std(grad_data))

        # Log the parameter values to wandb
        wandb.log(param_values, step=state.global_step)

        # Log the total number of trainable parameters
        if state.global_step == 0 or not hasattr(self, 'logged_trainable_params'):
            trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
            wandb.log({"trainable_parameters": trainable_params}, step=state.global_step)
            self.logged_trainable_params = True

In [105]:
import wandb
import torch
import numpy as np
from transformers import TrainerCallback

class PEFTParameterTrackingCallback(TrainerCallback):
    """
    A callback that tracks individual parameter values over time for PEFT methods.
    Optimized for tracking small parameter vectors where each element is important.
    """

    def __init__(self, peft_param_prefix=None):
        """
        Args:
            peft_param_prefix (list, optional): List of parameter name prefixes to track.
                If None, will use common PEFT parameter prefixes.
        """
        self.peft_param_prefix = peft_param_prefix or ["lora", "adapter", "prefix", "prompt", "ia3"]
        # Keep track of parameters we've seen to maintain consistent tracking
        self.tracked_params = {}

    def _is_peft_param(self, param_name):
        """Check if parameter is a PEFT parameter based on naming."""
        return any(prefix in param_name.lower() for prefix in self.peft_param_prefix)

    def on_train_begin(self, args, state, control, model=None, **kwargs):
        """Initialize parameter tracking at the beginning of training."""
        if not model or not state.is_world_process_zero:
            return

        # Find trainable PEFT parameters
        for name, param in model.named_parameters():
            if self._is_peft_param(name) and param.requires_grad:
                param_data = param.data.detach().cpu().numpy().flatten()

                # Store initial parameter data for reference
                self.tracked_params[name] = {
                    'shape': param.shape,
                    'size': param.numel(),
                    'indices': list(range(len(param_data)))
                }

                # Log initial parameter values
                param_dict = {}
                for i, value in enumerate(param_data):
                    param_dict[f"param/{name}/{i}"] = float(value)

                wandb.log(param_dict, step=0)

    def on_log(self, args, state, control, model=None, logs=None, **kwargs):
        """Log individual parameter values at logging steps."""
        if not model or not state.is_world_process_zero or not wandb.run:
            return

        # Dictionary to log current parameter values
        param_dict = {}

        # Track trainable PEFT parameters that we identified at the beginning
        for name, param in model.named_parameters():
            if name in self.tracked_params and param.requires_grad:
                param_data = param.data.detach().cpu().numpy().flatten()

                # Log each individual parameter value
                for i, value in enumerate(param_data):
                    param_dict[f"param/{name}/{i}"] = float(value)

                # If parameter has a gradient, track that too
                if param.grad is not None:
                    grad_data = param.grad.detach().cpu().numpy().flatten()
                    for i, value in enumerate(grad_data):
                        param_dict[f"grad/{name}/{i}"] = float(value)

        # Log all the values
        wandb.log(param_dict, step=state.global_step)

In [106]:

dataset = load_dataset("iamtarun/code_instructions_120k_alpaca", split="train")


In [107]:
def preprocess_function(example):
  example['prompt'] = example['instruction'] + "\ninput:\n" + example['input']
  example['completion'] = example['output']
  return example

In [108]:
# make the dataset a prompt-completion dataset https://huggingface.co/docs/trl/en/dataset_formats
dataset = dataset.map(preprocess_function)

In [109]:
dataset = dataset.select_columns(['prompt', 'completion'])

In [110]:
model_name = "EleutherAI/pythia-70m-deduped"

In [111]:
dataset[0]

{'prompt': 'Create a function to calculate the sum of a sequence of integers.\ninput:\n[1, 2, 3, 4, 5]',
 'completion': '# Python code\ndef sum_sequence(sequence):\n  sum = 0\n  for num in sequence:\n    sum += num\n  return sum'}

In [112]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

In [113]:
model

GPTNeoXForCausalLM(
  (gpt_neox): GPTNeoXModel(
    (embed_in): Embedding(50304, 512)
    (emb_dropout): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0-5): 6 x GPTNeoXLayer(
        (input_layernorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (post_attention_layernorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (post_attention_dropout): Dropout(p=0.0, inplace=False)
        (post_mlp_dropout): Dropout(p=0.0, inplace=False)
        (attention): GPTNeoXAttention(
          (query_key_value): Linear(in_features=512, out_features=1536, bias=True)
          (dense): Linear(in_features=512, out_features=512, bias=True)
        )
        (mlp): GPTNeoXMLP(
          (dense_h_to_4h): Linear(in_features=512, out_features=2048, bias=True)
          (dense_4h_to_h): Linear(in_features=2048, out_features=512, bias=True)
          (act): GELUActivation()
        )
      )
    )
    (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise

In [114]:
from peft import IA3Config, get_peft_model

# Create IA3 configuration with precise targeting for ONLY the 4th MLP layer (index 3)
config = IA3Config(
    task_type="CAUSAL_LM",
    # Use a list of exact module names to target
    target_modules=["gpt_neox.layers.3.mlp.dense_h_to_4h"],
    # Similarly exact name for feedforward module
    feedforward_modules=["gpt_neox.layers.3.mlp.dense_h_to_4h"],
    init_ia3_weights=True
)

# Apply the configuration to your model
peft_model = get_peft_model(model, config)

In [115]:
# Print all modules with IA3 adapters
for name, module in peft_model.named_modules():
    if "ia3_" in name:
          print(name)

base_model.model.gpt_neox.layers.3.mlp.dense_h_to_4h.ia3_l


In [116]:
# Print all modules with IA3 adapters
for name, module in peft_model.named_parameters():
    if "ia3_" in name:
          print(name)

base_model.model.gpt_neox.layers.3.mlp.dense_h_to_4h.ia3_l.default


In [117]:
tracking_callback = PEFTParameterTrackingCallback()
histogram_callback = PEFTParameterHistogramCallback()

In [118]:
wandb.init(project="IA3_visualization")


0,1
param/base_model.model.gpt_neox.layers.3.mlp.dense_h_to_4h.ia3_l.default/0,▂▂▂▂▂▁▁▂▄▄▄▄▄▅▅▅▅▆▆▆████████████████████
param/base_model.model.gpt_neox.layers.3.mlp.dense_h_to_4h.ia3_l.default/1,█▇▇▇▇▇▆▆▆▆▅▄▄▄▄▃▃▃▃▃▃▃▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁
param/base_model.model.gpt_neox.layers.3.mlp.dense_h_to_4h.ia3_l.default/10,▄▅▅▆▅▇████▇▇▇▇▆▄▄▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁
param/base_model.model.gpt_neox.layers.3.mlp.dense_h_to_4h.ia3_l.default/100,▅▅▅▅▅█▇▆▆▆▅▄▄▄▃▄▄▄▄▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▂
param/base_model.model.gpt_neox.layers.3.mlp.dense_h_to_4h.ia3_l.default/101,▆▆▇▆▇▄▂▅▃▃█▇▆▅▅▆▆▆▅▂▁▁▁▁▁▁▁▄▄▄▅▅▅▅▅▅▅▅▅▅
param/base_model.model.gpt_neox.layers.3.mlp.dense_h_to_4h.ia3_l.default/102,▇▇████▇▆▄▂▁▁▁▁▂▂▂▃▃▃▃▃▄▄▄▄▄▃▃▂▃▃▃▃▃▄▄▄▄▄
param/base_model.model.gpt_neox.layers.3.mlp.dense_h_to_4h.ia3_l.default/103,█▇▇▅▅▅█▇▆▆▅▅▅▅▅▆▆▆▆▆▅▄▃▂▁▁▁▁▁▂▁▁▂▁▁▁▁▁▁▁
param/base_model.model.gpt_neox.layers.3.mlp.dense_h_to_4h.ia3_l.default/104,▇▇▇▇▇▇██▆▄▂▂▁▁▁▂▂▂▂▄▅▅▅▅▅▅▅▅▅▅▅▆▆▆▆▆▆▆▆▅
param/base_model.model.gpt_neox.layers.3.mlp.dense_h_to_4h.ia3_l.default/105,▄▅▆▆▆▆▃▁▁▂▂▂▂▂▂▃▃▃▃▃▅▆▆▆▇▇▇██▇▇▇▇▇▇█████
param/base_model.model.gpt_neox.layers.3.mlp.dense_h_to_4h.ia3_l.default/106,█▇▆▇▆▁▂▅███▇▆▆▅▃▃▂▂▂▄▄▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅

0,1
param/base_model.model.gpt_neox.layers.3.mlp.dense_h_to_4h.ia3_l.default/0,2.50564
param/base_model.model.gpt_neox.layers.3.mlp.dense_h_to_4h.ia3_l.default/1,-1.63776
param/base_model.model.gpt_neox.layers.3.mlp.dense_h_to_4h.ia3_l.default/10,0.29747
param/base_model.model.gpt_neox.layers.3.mlp.dense_h_to_4h.ia3_l.default/100,0.39007
param/base_model.model.gpt_neox.layers.3.mlp.dense_h_to_4h.ia3_l.default/101,0.98364
param/base_model.model.gpt_neox.layers.3.mlp.dense_h_to_4h.ia3_l.default/102,0.48715
param/base_model.model.gpt_neox.layers.3.mlp.dense_h_to_4h.ia3_l.default/103,-0.04163
param/base_model.model.gpt_neox.layers.3.mlp.dense_h_to_4h.ia3_l.default/104,0.75127
param/base_model.model.gpt_neox.layers.3.mlp.dense_h_to_4h.ia3_l.default/105,1.50106
param/base_model.model.gpt_neox.layers.3.mlp.dense_h_to_4h.ia3_l.default/106,0.60538


In [119]:
import time
import randomname

wandb.finish()
# Get current timestamp in the desired format
timestamp = time.strftime("%Y-%m-%d_%H:%M:%S")

# Combine random name with timestamp
run_name = randomname.get_name() + "_" + timestamp


In [120]:
lets_overfit:bool = False
if lets_overfit:
  batch_size=64
  small_dataset = dataset.select(range(batch_size))
train_dataset = small_dataset if lets_overfit else dataset

In [121]:
# prompt: print the shapes and names of all peft_model parameters which require gradients

for name, param in peft_model.named_parameters():
    if param.requires_grad:
        print(f"Name: {name}, Shape: {param.shape}")


Name: base_model.model.gpt_neox.layers.3.mlp.dense_h_to_4h.ia3_l.default, Shape: torch.Size([1, 512])


In [122]:

training_args = SFTConfig(
    max_length=512,
    output_dir=save_path + "/" + run_name,
    run_name=run_name,
    per_device_train_batch_size=64,
    logging_steps=50,
    learning_rate=5e-3,
    max_steps=5000

)
trainer = SFTTrainer(
    peft_model,
    train_dataset=train_dataset,
    args=training_args,
    callbacks=[tracking_callback, histogram_callback]
)
trainer.train()

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


Step,Training Loss
50,11.3614


Step,Training Loss
50,11.3614
100,11.4212
150,11.5191
200,11.5104
250,11.3036
300,11.4231
350,11.6226
400,11.6307
450,11.1859
500,10.9697


TrainOutput(global_step=5000, training_loss=6.940906524658203, metrics={'train_runtime': 1972.5611, 'train_samples_per_second': 162.226, 'train_steps_per_second': 2.535, 'total_flos': 4.269295697315328e+16, 'train_loss': 6.940906524658203})

In [123]:
dataset

Dataset({
    features: ['prompt', 'completion'],
    num_rows: 121959
})