<a href="https://colab.research.google.com/github/keenanpepper/self-ablating-transformers/blob/main/Self_Ablating_Transformer_on_TinyStories.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install datasets
!pip install tiktoken

Collecting datasets
  Downloading datasets-2.21.0-py3-none-any.whl.metadata (21 kB)
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.3 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Downloading datasets-2.21.0-py3-none-any.whl (527 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m527.3/527.3 kB[0m [31m35.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m11.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl (39.9 MB)
[

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import datasets
from datasets import load_dataset
import tiktoken
import os
import numpy as np
from tqdm.notebook import tqdm
import math
import time

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
class GPTNeoConfig:
    def __init__(self, vocab_size=50257, hidden_size=64, mlp_hidden_size=None, num_layers=8,
                 num_heads=16, max_position_embeddings=2048, window_size=256, attention_layers=None,
                 loss_coeff_base=1.0, loss_coeff_ablated=0.1,
                 loss_coeff_attention_density=0.1, loss_coeff_neuron_density=0.1):
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size  # refers to residual stream width. not number of neurons in MLP, that is 4x more
        if mlp_hidden_size == None:
            self.mlp_hidden_size = 4 * self.hidden_size
        else:
            self.mlp_hidden_size = mlp_hidden_size
        self.num_layers = num_layers    # number of complete transformer blocks
        self.num_heads = num_heads
        self.max_position_embeddings = max_position_embeddings
        self.window_size = window_size  # only used in local attention layers - if all attention is global then this is unused
        if attention_layers == None:
#            self.attention_layers = ["global", "local"] * (num_layers // 2)
            self.attention_layers = ["global"] * num_layers
        else:
            self.attention_layers = attention_layers
        # The following are the 4 coefficients for the 4 different components of the loss function
        # * L_base: cross-entropy loss for unablated first forward pass
        # * L_ablated: cross-entropy loss for ablated second forward pass
        # * L_attention_density: density penalty for attention ablation mask
        # * L_neuron_density: density pentaly for neuron ablation mask
        self.loss_coeff_base = loss_coeff_base
        self.loss_coeff_ablated = loss_coeff_ablated
        self.loss_coeff_attention_density = loss_coeff_attention_density
        self.loss_coeff_neuron_density = loss_coeff_neuron_density

In [4]:
class Attention(nn.Module):
    def __init__(self, config, layer_id):
        super().__init__()
        self.is_local = (config.attention_layers[layer_id] == "local")
        self.num_heads = config.num_heads
        self.hidden_size = config.hidden_size
        self.head_dim = config.hidden_size // config.num_heads

        self.attention = nn.ModuleDict(dict(
            k_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False),
            v_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False),
            q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False),
            out_proj = nn.Linear(config.hidden_size, config.hidden_size)
        ))

    def forward(self, x_ablated, x_clean, ablation_mask=None):
        """
        Perform multi-head attention with a causal mask and also (optionally) an ablation mask

        x_ablated is the residual stream that's actually undergoing this computation for the first time.

        x_clean (which should have the same shape) is what x was at this same layer on the first, clean,
        forward pass thru the model. It is needed because when one sequence position of x_ablated is
        attending to other sequence positions, it needs to be attending to the clean versions of those
        (since the versions that are there in x_ablated are actually using DIFFERENT ablation masks,
        and they shouldn't get mixed up!).

        x_ablated and x_clean both have shape (batch_size, seq_len, hidden_size)

        ablation_mask is a per-sequence-position ablation mask to apply to the attention outputs.
        For example, let's say the ablation_mask at position 1 is [0,0,0,0,1,1,1,1,1,1,1,1...]
        and the ablation_mask at position 2 is [1,1,1,1,0,0,0,0,1,1,1,1,1...]
        (for this example let's say the head dimension is 4)
        That means the first attention head will be fully ablated at position 1, and the second
        attention head will be fully ablated at position 2.
        But if position 2 is attending back to position 1 then it will see the residual stream values from
        the non-ablated first pass, i.e. from x_clean.

        NOTE for later (TODO?): Currently this code is going to do the easiest, simplest thing, which means
        that for a token attending to *itself*, it still gets the clean version of it, even tho logically
        we might want it to get the ablated version since it's the same sequence position.
        """
        batch_size, seq_len, _ = x_ablated.shape

        assert x_clean.shape == x_ablated.shape
        assert x_clean.device == x_ablated.device

        q = self.attention.q_proj(x_ablated).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.attention.k_proj(x_clean).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.attention.v_proj(x_clean).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # Intentionally don't divide by sqrt(self.num_heads) here, since the reference implementation
        # for which the pretrained weights were trained doesn't have that
        scores = torch.matmul(q, k.transpose(-1, -2))

        if self.is_local:
            # Local attention
            local_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=x_clean.device)
            local_mask = torch.triu(local_mask, diagonal=1) | torch.tril(local_mask, diagonal=-config.window_size)
            scores = scores.masked_fill(local_mask.unsqueeze(0).unsqueeze(0), float('-inf'))
        else:
            # Global attention
            causal_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool, device=x_clean.device), diagonal=1)
            scores = scores.masked_fill(causal_mask.unsqueeze(0).unsqueeze(0), float('-inf'))

        attn = F.softmax(scores, dim=-1)
        context = torch.matmul(attn, v)

        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_size)

        # Apply ablation to this "context" value which is the concatenated results of the num_heads heads,
        # right *before* applying out_proj which mixes them all together into the full resisual stream.
        # This is the latest point at which ablation can be applied such that it's possible to, e.g.
        # fully ablate a single head while leaving the rest untouched.
        if ablation_mask is not None:
            assert context.shape == ablation_mask.shape, f"context has shape {context.shape} while ablation mask has shape {ablation_mask.shape}"
            context = context * ablation_mask

        return self.attention.out_proj(context)

In [5]:
class NewGELUActivation(nn.Module):
    """
    Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
    the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
    """

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))

In [6]:
class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.hidden_size, config.mlp_hidden_size)
        self.c_proj = nn.Linear(config.mlp_hidden_size, config.hidden_size)
#        self.act = nn.GELU()
        self.act = NewGELUActivation()

    def forward(self, x, ablation_mask=None):
        """
        Now, here we have a choice of whether to apply the ablation_mask either before or after the activation.
        If we were using ReLU it actually wouldn't matter, because ReLU has the property that
        ReLU(c*x) = c*ReLU(x)
        for all positive reals c.
        But since we're using GeLU here it does matter a bit.
        I'm not sure which makes more sense... for now let's just pick "after the GeLU" and see what happens.

        Note that unlike the Attention layer, this layer does not need access to the clean activations because
        it does not mix information between different sequence positons. The Attention layer is the only
        layer that does that.
        """
        activations = self.act(self.c_fc(x))

        if ablation_mask is not None:
            activations = activations * ablation_mask

        return self.c_proj(activations)

In [7]:
class GPTNeoBlock(nn.Module):
    def __init__(self, config, layer_id):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.hidden_size, eps=1e-5)
        self.attn = Attention(config, layer_id)
        self.ln_2 = nn.LayerNorm(config.hidden_size, eps=1e-5)
        self.mlp = MLP(config)

    def forward(self, x_ablated, x_clean, attention_ablations=None, neuron_ablations=None):
        x_ablated = x_ablated + self.attn(self.ln_1(x_ablated), self.ln_2(x_clean), attention_ablations)
        x_ablated = x_ablated + self.mlp(self.ln_2(x_ablated), neuron_ablations)
        return x_ablated

In [8]:
class GPTNeo(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.hidden_size),
            wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size),
            h = nn.ModuleList([GPTNeoBlock(config, i) for i in range(config.num_layers)]),
            ln_f = nn.LayerNorm(config.hidden_size, eps=1e-5)
        ))
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.attention_ablations_head = nn.Linear(config.hidden_size, math.prod(self.get_attention_ablations_shape(1, 1)))
        self.neuron_ablations_head = nn.Linear(config.hidden_size, math.prod(self.get_neuron_ablations_shape(1, 1)))

        # tie weights
        self.transformer.wte.weight = self.lm_head.weight

    def get_attention_ablations_shape(self, batch_size, block_size):
        return torch.Size([batch_size, block_size, self.config.num_layers, self.config.hidden_size])

    def get_neuron_ablations_shape(self, batch_size, block_size):
        return torch.Size([batch_size, block_size, self.config.num_layers, self.config.mlp_hidden_size])

    def forward(self, input_ids, targets=None, attention_ablations=None, neuron_ablations=None):
        """
        If "targets" is supplied, also outputs loss (TODO is this base loss or complete self-ablating loss?)

        input_ids should have shape (batch_size, block_size)

        attention_ablations should have shape (batch_size, block_size, num_layers, hidden_size)

        neuron_ablations should have shape (batch_size, block_size, num_layers, mlp_hidden_size)
        """
        second_pass = attention_ablations != None or neuron_ablations != None

        device = input_ids.device
        b, t = input_ids.shape
        pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0)

        tok_emb = self.transformer.wte(input_ids)
        pos_emb = self.transformer.wpe(pos)

        x_clean = tok_emb + pos_emb
        if second_pass:
            x_ablated = x_clean

        for i, block in enumerate(self.transformer.h):
            x_clean = block(x_clean, x_clean)
            if second_pass:
                x_ablated = block(x_ablated, x_clean, attention_ablations[:,:,i,:], neuron_ablations[:,:,i,:])

        x = x_ablated if second_pass else x_clean

        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)

        L_base = None
        if targets is not None:
            # Reshape logits to (batch_size * sequence_length, vocab_size)
            logits_view = logits.view(-1, logits.size(-1))
            # Reshape targets to (batch_size * sequence_length,)
            targets_view = targets.view(-1)
            # Compute cross entropy loss
            L_base = F.cross_entropy(logits_view, targets_view)

        if second_pass:
            return {"logits": logits,
                    "L_base": L_base}
        else:
            # if it's the first pass we need to compute these, then do the second pass
            output_attention_ablations = self.attention_ablations_head(x)
            output_attention_ablations = torch.sigmoid(output_attention_ablations)
            output_attention_ablations = output_attention_ablations.reshape(self.get_attention_ablations_shape(b, t))
            output_neuron_ablations = self.neuron_ablations_head(x)
            output_neuron_ablations = torch.sigmoid(output_neuron_ablations)
            output_neuron_ablations = output_neuron_ablations.reshape(self.get_neuron_ablations_shape(b, t))
            second_pass_output = self.forward(input_ids,
                                              targets,
                                              output_attention_ablations,
                                              output_neuron_ablations)
            L_total = L_ablated = L_attention_density = L_neuron_density = None
            if targets is not None:
                L_ablated = second_pass_output["L_base"]
                L_attention_density = output_attention_ablations.mean()
                L_neuron_density = output_neuron_ablations.mean()
                L_total = sum([self.config.loss_coeff_base * L_base,
                               self.config.loss_coeff_ablated * L_ablated,
                               self.config.loss_coeff_attention_density * L_attention_density,
                               self.config.loss_coeff_neuron_density * L_neuron_density])
            return {"logits": logits,
                    "L_base": L_base,
                    "L_ablated": L_ablated,
                    "loss": L_total,
                    "attention_ablations": output_attention_ablations,
                    "neuron_ablations": output_neuron_ablations,
                    "attention_ablation_mask_density": L_attention_density,
                    "neuron_ablation_mask_density": L_neuron_density}

    def generate(self, input_ids, max_new_tokens, temperature=1.0):
        self.eval()
        device = next(self.parameters()).device

        # Ensure input_ids is a tensor on the correct device
        x = torch.tensor(input_ids, dtype=torch.long, device=device).unsqueeze(0) if isinstance(input_ids, list) else input_ids.to(device)

        for _ in range(max_new_tokens):
            # Crop input if it's getting too long
            x_crop = x[:, -self.config.max_position_embeddings:]

            # Forward pass
            logits = self(x_crop)["logits"]

            # Focus on the last token's predictions
            logits = logits[:, -1, :] / temperature

            # Apply softmax to convert logits to probabilities
            probs = F.softmax(logits, dim=-1)

            # Sample from the distribution
            next_token = torch.multinomial(probs, num_samples=1)

            # Append the new token to the sequence
            x = torch.cat((x, next_token), dim=1)

        return x[0].tolist()  # Convert to list of token IDs

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig

ts1m = AutoModelForCausalLM.from_pretrained('roneneldan/TinyStories-1M')

ts3m = AutoModelForCausalLM.from_pretrained('roneneldan/TinyStories-3M')

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [None]:
config = GPTNeoConfig(hidden_size=128)
model = GPTNeo(config)
model.load_state_dict(ts3m.state_dict())

RuntimeError: Error(s) in loading state_dict for GPTNeo:
	Missing key(s) in state_dict: "attention_ablations_head.weight", "attention_ablations_head.bias", "neuron_ablations_head.weight", "neuron_ablations_head.bias". 

In [None]:
# Forward pass without targets
input_ids = torch.randint(0, config.vocab_size, (1, 100))
ret = model(input_ids)
print("Logits shape:", ret["logits"].shape)  # Should be (1, 100, 50257)
print("Loss:", ret["loss"])  # Should be None
print([(key, value.shape if value != None else None) for key, value in ret.items()])

# Forward pass with targets
targets = torch.randint(0, config.vocab_size, (1, 100))
ret = model(input_ids, targets)
print("Logits shape:", ret["logits"].shape)  # Should be (1, 100, 50257)
print("Loss:", ret["loss"])  # Should be a tensor with a single value
print([(key, value.shape if value != None else None) for key, value in ret.items()])

Logits shape: torch.Size([1, 100, 50257])
Loss: None
[('logits', torch.Size([1, 100, 50257])), ('L_base', None), ('L_ablated', None), ('loss', None), ('attention_ablations', torch.Size([1, 100, 8, 128])), ('neuron_ablations', torch.Size([1, 100, 8, 512])), ('attention_ablation_mask_density', None), ('neuron_ablation_mask_density', None)]
Logits shape: torch.Size([1, 100, 50257])
Loss: tensor(17.7614, grad_fn=<AddBackward0>)
[('logits', torch.Size([1, 100, 50257])), ('L_base', torch.Size([])), ('L_ablated', torch.Size([])), ('loss', torch.Size([])), ('attention_ablations', torch.Size([1, 100, 8, 128])), ('neuron_ablations', torch.Size([1, 100, 8, 512])), ('attention_ablation_mask_density', torch.Size([])), ('neuron_ablation_mask_density', torch.Size([]))]


In [9]:
ds = datasets.load_dataset("tdooms/TinyStories")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [10]:
def prepare_data(dataset_name="tdooms/TinyStories", split="train", separator_token="<|endoftext|>", output_file="train.bin"):
    # Initialize tokenizer
    enc = tiktoken.get_encoding("gpt2")
    separator_token_id = enc.encode_single_token(separator_token)

    def process(example):
        ids = enc.encode_ordinary(example['text'])  # encode_ordinary ignores any special tokens
        ids.append(separator_token_id)  # Add separator token at the end of each example
        return {'ids': ids, 'len': len(ids)}

    # Load and process the dataset
    if not os.path.exists(output_file):
        print(f"Processing {dataset_name} dataset...")
        ds = load_dataset(dataset_name, split=split)

        tokenized = ds.map(
            process,
            remove_columns=['text'],
            desc="Tokenizing the dataset",
            num_proc=8,
        )

        # Concatenate all ids into one large file
        arr_len = np.sum(tokenized['len'], dtype=np.uint64)
        dtype = np.uint16  # Can use uint16 since enc.max_token_value == 50256 is < 2**16
        arr = np.memmap(output_file, dtype=dtype, mode='w+', shape=(arr_len,))

        total_batches = 1024
        idx = 0

        for batch_idx in tqdm(range(total_batches), desc=f'Writing {output_file}'):
            # Batch together samples for faster write
            batch = tokenized.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy')
            arr_batch = np.concatenate(batch['ids'])
            # Write into mmap
            arr[idx : idx + len(arr_batch)] = arr_batch
            idx += len(arr_batch)

        arr.flush()
        print(f"Dataset processed and saved to {output_file}")
    else:
        print(f"{output_file} already exists. Skipping processing.")

    # Load the processed data
    data = np.memmap(output_file, dtype=np.uint16, mode='r')
    return data, enc

In [11]:
data, tokenizer = prepare_data(separator_token="<|endoftext|>")
print(f"Processed data shape: {data.shape}")
print(f"First 10 tokens: {data[:10]}")
print(f"Last 10 tokens: {data[-10:]}")

# Decode a small portion to verify
sample = data[:10000]
decoded = tokenizer.decode(sample.tolist())
print("\nSample decoded text:")
print(decoded[:5000] + "...")  # Print first 500 characters

train.bin already exists. Skipping processing.
Processed data shape: (396967656,)
First 10 tokens: [ 8888    11 19919   373   845  6568    13   679   373  1016]
Last 10 tokens: [  262 25103 12023   290  8359   257 12625  8073    13 50256]

Sample decoded text:
Today, Tommy was very excited. He was going flying with his mom and dad to a special place. There was a big green flag waving in the wind when they arrived. Tommy was so happy to see it. 

He hopped out of the car and ran up to the flag. He couldn't believe how big it was! He wanted to reach out and touch it. His mom said he could and Tommy smiled.

He waved his little arms, trying to make the flag move. But the wind was too strong. His dad said, "Let me help you, Tommy." He reached out and grabbed the green flag and waved it back and forth. 

Now it was Tommy's turn! He clapped his hands and laughed as he tried to move the flag. He waved it for a long time until he was too tired to do it anymore. The big green flag was so exciti

In [12]:
prepare_data(split="validation", output_file="validation.bin")

validation.bin already exists. Skipping processing.


(memmap([ 7454,  2402,   257, ..., 20567,    13, 50256], dtype=uint16),
 <Encoding 'gpt2'>)

In [13]:
class TrainingConfig:
    train_file = "train.bin"
    val_file = "validation.bin"
    block_size = 256
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_batches = 10000
    batch_size = 64
    learning_rate = 4e-3
    weight_decay = 0.0
    max_grad_norm = 1.0
    save_path = "best_model.pt"
    eval_iters = 100
    log_interval = 1000

In [14]:
class BatchGenerator:
    def __init__(self, data_file, block_size, batch_size, device):
        self.data_file = data_file
        self.block_size = block_size
        self.batch_size = batch_size
        self.device = device
        self.device_type = 'cuda' if 'cuda' in device.type else 'cpu'

    def get_batch(self, shifted=True):
        # We recreate np.memmap every batch to avoid a memory leak, as per
        # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
        data = np.memmap(self.data_file, dtype=np.uint16, mode='r')

        # Generate random starting indices
        ix = torch.randint(len(data) - self.block_size, (self.batch_size,))

        shift = 1 if shifted else 0
        # Create input and target tensors
        x = torch.stack([torch.from_numpy((data[i:i+self.block_size]).astype(np.int64)) for i in ix])
        y = torch.stack([torch.from_numpy((data[i+shift:i+shift+self.block_size]).astype(np.int64)) for i in ix])

        if self.device_type == 'cuda':
            # Pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
            x, y = x.pin_memory().to(self.device, non_blocking=True), y.pin_memory().to(self.device, non_blocking=True)
        else:
            x, y = x.to(self.device), y.to(self.device)

        return x, y

In [15]:
class LossEstimator:
    def __init__(self, model, train_batch_gen, val_batch_gen, eval_iters):
        self.model = model
        self.train_batch_gen = train_batch_gen
        self.val_batch_gen = val_batch_gen
        self.eval_iters = eval_iters

    def estimate_loss(self):
        stat_names = ["loss", "L_base", "L_ablated", "attention_ablation_mask_density", "neuron_ablation_mask_density"]
        out = {}
        for stat in stat_names:
            out[stat] = {}
        self.model.eval()
        with torch.inference_mode():
            for split, batch_gen in [('train', self.train_batch_gen), ('val', self.val_batch_gen)]:
                stats = {}
                for stat in stat_names:
                    stats[stat] = torch.zeros(self.eval_iters)
                for k in tqdm(range(self.eval_iters)):
                    X, Y = batch_gen.get_batch()
                    ret = self.model(X, Y)
                    for stat in stat_names:
                        stats[stat][k] = ret[stat].item()
                for stat in stat_names:
                    out[stat][split] = stats[stat].mean()
        self.model.train()
        return out

    def estimate_loss_pretrainedformat(self):
        out = {}
        self.model.eval()
        with torch.inference_mode():
            for split, batch_gen in [('train', self.train_batch_gen), ('val', self.val_batch_gen)]:
                losses = torch.zeros(self.eval_iters)
                for k in tqdm(range(self.eval_iters)):
                    X, Y = batch_gen.get_batch(shifted=False)
                    ret = self.model(X, labels=Y)
                    losses[k] = ret["loss"].item()
                out[split] = losses.mean()
        self.model.train()
        return out

In [16]:
def train_gptneo(model, config):
    train_batch_gen = BatchGenerator(config.train_file, config.block_size, config.batch_size, config.device)
    val_batch_gen = BatchGenerator(config.val_file, config.block_size, config.batch_size, config.device)
    loss_estimator = LossEstimator(model, train_batch_gen, val_batch_gen, config.eval_iters)

    model.to(device)

    optimizer = optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.num_batches)

    best_val_loss = float('inf')

    for iteration in tqdm(range(config.num_batches)):
        model.train()

        # Get batch
        x, y = train_batch_gen.get_batch()

        # Forward pass
        loss = model(x, targets=y)["loss"]

        # Backward pass
        optimizer.zero_grad()
        loss.backward()

        if config.max_grad_norm:
            nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)

        optimizer.step()
        scheduler.step()

        # Logging
        if (iteration + 1) % config.log_interval == 0:
            stats = loss_estimator.estimate_loss()
            print(f"Iteration {iteration}: train loss {stats['loss']['train']:.4f}, val loss {stats['loss']['val']:.4f}")
            print(f"train L_base {stats['L_base']['train']:.4f}, val L_base {stats['L_base']['val']:.4f}")
            print(f"train L_ablated {stats['L_ablated']['train']:.4f}, val L_ablated {stats['L_ablated']['val']:.4f}")
            print(f"train attention ablation mask density {stats['attention_ablation_mask_density']['train']:.4f}, val {stats['attention_ablation_mask_density']['val']:.4f}")
            print(f"train neuron ablation mask density {stats['neuron_ablation_mask_density']['train']:.4f}, val {stats['neuron_ablation_mask_density']['val']:.4f}")
            print(f"The current learning rate: {optimizer.param_groups[0]['lr']:.4f}")

            # Save best model
            if stats['loss']['val'] < best_val_loss:
                best_val_loss = stats['loss']['val']
                torch.save(model.state_dict(), config.save_path)
                print(f"New best model saved to {config.save_path}")

    print("Training completed!")

In [None]:
config = GPTNeoConfig(loss_coeff_attention_density=0.1, loss_coeff_neuron_density=0.1)
model = GPTNeo(config)

training_config = TrainingConfig()

train_gptneo(model, training_config)

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Iteration 999: train loss 4.5526, val loss 4.5570
train L_base 4.1304, val L_base 4.1344
train L_ablated 4.1973, val L_ablated 4.2010
train attention ablation mask density 0.0086, val 0.0086
train neuron ablation mask density 0.0160, val 0.0160
The current learning rate: 0.0039
New best model saved to best_model.pt


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Iteration 1999: train loss 3.8857, val loss 3.8936
train L_base 3.5223, val L_base 3.5296
train L_ablated 3.6010, val L_ablated 3.6072
train attention ablation mask density 0.0128, val 0.0128
train neuron ablation mask density 0.0201, val 0.0201
The current learning rate: 0.0036
New best model saved to best_model.pt


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Iteration 2999: train loss 3.3846, val loss 3.3957
train L_base 3.0667, val L_base 3.0768
train L_ablated 3.1409, val L_ablated 3.1512
train attention ablation mask density 0.0135, val 0.0135
train neuron ablation mask density 0.0245, val 0.0245
The current learning rate: 0.0032
New best model saved to best_model.pt


In [45]:
config = GPTNeoConfig(loss_coeff_attention_density=0, loss_coeff_neuron_density=0)
model = GPTNeo(config)
model.load_state_dict(torch.load("best_model.pt"))
model.to(device)

GPTNeo(
  (transformer): ModuleDict(
    (wte): Embedding(50257, 64)
    (wpe): Embedding(2048, 64)
    (h): ModuleList(
      (0-7): 8 x GPTNeoBlock(
        (ln_1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (attention): ModuleDict(
            (k_proj): Linear(in_features=64, out_features=64, bias=False)
            (v_proj): Linear(in_features=64, out_features=64, bias=False)
            (q_proj): Linear(in_features=64, out_features=64, bias=False)
            (out_proj): Linear(in_features=64, out_features=64, bias=True)
          )
        )
        (ln_2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (c_fc): Linear(in_features=64, out_features=256, bias=True)
          (c_proj): Linear(in_features=256, out_features=64, bias=True)
          (act): NewGELUActivation()
        )
      )
    )
    (ln_f): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=64, ou

In [46]:
batch = BatchGenerator("train.bin", 256, 32, device).get_batch()

In [47]:
ret = model(batch[0])

In [48]:
ret["attention_ablations"].std()

tensor(0.4076, device='cuda:0', grad_fn=<StdBackward0>)

In [49]:
import plotly.express as px

px.histogram(ret["attention_ablations"][:10,:10].flatten().detach().cpu())

In [50]:
px.imshow(ret["neuron_ablations"][0,10,:,:64].detach().cpu())

In [41]:
px.imshow(ret["neuron_ablations"][1,10,:,:64].detach().cpu())

In [None]:
ts1m.to(device)
LossEstimator(ts1m, BatchGenerator("train.bin", 256, 32, device), BatchGenerator("validation.bin", 256, 32, device), 100).estimate_loss_pretrainedformat()

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

{'train': tensor(2.3472), 'val': tensor(2.3606)}

In [51]:
enc = tiktoken.get_encoding("gpt2")

In [61]:
enc.decode(model.generate(enc.encode("""Sara and Ben are playing in the snow. They make a big snowman with a hat and a scarf. They are happy and laugh.
But then a big dog comes. The dog is angry and barks. He runs to the snowman and bites his hat. Sara and Ben are scared and
cry. ”Go away, dog! Leave our snowman alone!” Sara shouts. But the dog does not listen. He bites the scarf and the snowman’s
nose. He shakes his head and makes the snowman fall.
Sara and"""), 100, temperature=0.01))

'Sara and Ben are playing in the snow. They make a big snowman with a hat and a scarf. They are happy and laugh.\nBut then a big dog comes. The dog is angry and barks. He runs to the snowman and bites his hat. Sara and Ben are scared and\ncry. ”Go away, dog! Leave our snowman alone!” Sara shouts. But the dog does not listen. He bites the scarf and the snowman’s\nnose. He shakes his head and makes the snowman fall.\nSara and Ben are sad. They are sorry. They are sorry. They are not angry. They are not hurt. They are sorry. They are not hurt. They are not hurt. They are not hurt. They are friends. They are not hurt. They are not hurt. They are not hurt. They are not hurt. They are friends.<|endoftext|>Once upon a time, there was a little girl named Lily. She loved to play outside in the park. One day, she saw a big tree'

In [62]:
print(_)

Sara and Ben are playing in the snow. They make a big snowman with a hat and a scarf. They are happy and laugh.
But then a big dog comes. The dog is angry and barks. He runs to the snowman and bites his hat. Sara and Ben are scared and
cry. ”Go away, dog! Leave our snowman alone!” Sara shouts. But the dog does not listen. He bites the scarf and the snowman’s
nose. He shakes his head and makes the snowman fall.
Sara and Ben are sad. They are sorry. They are sorry. They are not angry. They are not hurt. They are sorry. They are not hurt. They are not hurt. They are not hurt. They are friends. They are not hurt. They are not hurt. They are not hurt. They are not hurt. They are friends.<|endoftext|>Once upon a time, there was a little girl named Lily. She loved to play outside in the park. One day, she saw a big tree
