# Notebook for standalone experiments

### Set Up Model

In [1]:
device = "cuda" # Recommended to use L4 GPU on Google Colab

In [2]:
# %%capture
# !pip install transformer-lens==2.11.0
from transformer_lens import HookedTransformer

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
## Simple models
# model_name = "attn-only-1l"
# model_name = "gelu-1l"
# model_name = "tiny-stories-1L-21M"

## Small models
# model_name = "tiny-stories-1M"
# model_name = "tiny-stories-3M"
# model_name = "tiny-stories-8M"
# model_name = "tiny-stories-28M"
model_name = "tiny-stories-33M"
# model_name = "tiny-stories-instruct-33M"

## Large models
# model_name = "gpt2-small"
# model_name = "gpt2-medium"
# model_name = "gpt2-xl"
# model_name = "llama-7b"
# model_name = "meta-llama/Llama-3.2-1B"
# model_name = "meta-llama/Llama-3.2-1B-Instruct"
# model_name = "Qwen/Qwen2.5-0.5B"
# model_name = "Qwen/Qwen2.5-1.5B"
# model_name = "Qwen/Qwen2.5-3B"

model = HookedTransformer.from_pretrained(model_name, device=device)
model = model.eval()

`torch_dtype` is deprecated! Use `dtype` instead!


Loaded pretrained model tiny-stories-33M into HookedTransformer


### Set up Libraries

In [39]:
# %%capture
# !pip install fancy_einsum
# !pip install datasets==3.6.0

In [40]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
import pickle
import numpy as np
import matplotlib.pyplot as plt
import time
import random
from copy import deepcopy
from datasets import load_dataset, Dataset, DatasetDict
import math
import re
import itertools
from fancy_einsum import einsum
import json
import os

from inversion_optimisation.utils import DATA_PATH
from pathlib import Path
# DATA_PATH = Path("content")

In [41]:
class DotDict(dict):
    def __getattr__(self, name):
        return self.get(name)
    def __setattr__(self, name, value):
        self[name] = value
    def __delattr__(self, name):
        del self[name]

In [42]:
def get_paper_summary_stats_new(results, epochs):
    # Work out some summary stats
    stats = {}
    percent_zero_loss = 0
    percent_exact_inversion = 0
    end_epoch = []
    zero_losses_at_epoch = []

    for result in results:
        if result["found_solution"]:
            percent_zero_loss += 1
        if torch.equal(result["true_tokens"], result["pred_tokens"]):
            percent_exact_inversion += 1
        end_epoch.append(result["done_epochs"])

    for e in range(1,epochs):
        if len(zero_losses_at_epoch) == 0:
            current = 0
        else:
            current = zero_losses_at_epoch[-1]
        current += end_epoch.count(e)
        zero_losses_at_epoch.append(current)

    stats["percent_zero_loss"] = round((percent_zero_loss/len(results))*100,4)
    stats["percent_exact_inversion"] = round((percent_exact_inversion/len(results))*100,4)
    stats["zero_losses_at_epoch"] = zero_losses_at_epoch

    input_len = len(result["true_tokens"])
    success_final_epoch = [0 for _ in range(input_len)]

    for i in tqdm(range(input_len)):
        for result in results:
            final_got = False
            any_got = False
            # Get the number of inversion successes, only considering one position
            if torch.equal(result["true_tokens"][i], result["pred_tokens"][i]):
                success_final_epoch[i] += 1
                final_got = True

        # Turn tallies into a percentage
        success_final_epoch[i] = round(success_final_epoch[i]/len(results)*100,4)

    stats["success_final_epoch"] = success_final_epoch

    return stats

In [43]:
def load_dataset_tokens(target_strategy, input_len, num_targets, include_bos, random_sentence, random_start):
    name, split, ind = {
        "tinystories": ["roneneldan/TinyStories", "validation", "text"],
        "reddit": ["sentence-transformers/reddit", "train", "body"],
        "wikipedia": ["lucadiliello/english_wikipedia", "train", "maintext"]
    }[target_strategy]
    ds = load_dataset(name, split=split, streaming=True)
    loaded_true_tokens = []
    dataset_offset = (input_len-1) * num_targets
    dataset_counter = 0
    for data in ds:
        # Want to use new data for each new input length
        dataset_counter += 1
        if dataset_counter < dataset_offset:
            continue

        # Choose which sentence to take
        string = data[ind][:1000]
        if random_sentence:
            sentence_pattern = r'(?<=[.!?])\s+'
            string_list = re.split(sentence_pattern, string)
            string = random.choice(string_list)

        # Tokenise and choose which snippet of sentence to take
        tokens = model.to_tokens(string)[0]
        offset = 0 if include_bos else 1
        if random_start and (len(tokens)-input_len) >= 0:
            offset += random.randint(0, len(tokens)-input_len)
        tokens = tokens[offset:input_len+offset]

        if len(tokens) == input_len: # In case sentence is too short
            loaded_true_tokens.append(tokens)
        if len(loaded_true_tokens) >= num_targets:
            break

    if len(loaded_true_tokens) < num_targets:
        print("DIDNT LOAD NUM TARGETS DATASET")
        return None

    loaded_true_tokens = torch.stack(loaded_true_tokens)
    return loaded_true_tokens.to("cpu")

In [44]:
class CustomAdam(torch.optim.Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8):
        defaults = dict(lr=lr, betas=betas, eps=eps)
        super(CustomAdam, self).__init__(params, defaults)

    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError("Adam does not support sparse gradients")

                state = self.state[p]
                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p.data)  # First moment (m_t)
                    state['exp_avg_sq'] = torch.zeros_like(p.data)  # Second moment (v_t)

                m, v = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']
                state['step'] += 1
                t = state['step']

                m.mul_(beta1).add_(grad, alpha=1 - beta1) # m_t = β1 * m_{t-1} + (1 - β1) * g_t
                v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # v_t = β2 * v_{t-1} + (1 - β2) * g_t^2
                # m_hat = m / (1 - beta1**t) # m̂_t = m_t / (1 - β1^t)
                # v_hat = v / (1 - beta2**t) # v̂_t = v_t / (1 - β2^t)
                m_hat = m # m̂_t = m_t
                v_hat = v # v̂_t = v_t
                denom = v_hat.sqrt().add(group['eps'])
                p.data.addcdiv_(m_hat, denom, value=-group['lr']) # θ_t = θ_{t-1} - η * m̂_t / (sqrt(v̂_t) + ε)

                # m.mul_(beta1).add_(grad, alpha=1 - beta1) # m_t = β1 * m_{t-1} + (1 - β1) * g_t
                # m_hat = m # m̂_t = m_t
                # p.data.add_(m_hat.sign(), alpha=-group['lr']) # θ_t = θ_{t-1} - η * sign(m̂_t)

        return loss

### Set Up Datasets (Optional)

In [None]:
# # Generate the targets and (unused) initialisations for all LOGIT-inversion experiments
for input_len in range(1,11):
    torch.manual_seed(1)
    np.random.seed(1)
    random.seed(1)

    num_targets = 1000
    tokens_list = []
    for _ in tqdm(range(num_targets)):
        tokens = torch.randint(0, len(model.tokenizer.vocab), (1, input_len)).to(device)
        tokens_list.append(tokens)
    true_tokens = torch.cat(tokens_list, dim=0).to(device)
    with open(f"/content/true_tokens_{num_targets}_{input_len}.pkl", 'wb') as file:
        pickle.dump(true_tokens, file)

    tokens_list = []
    for _ in tqdm(range(num_targets)):
        tokens = torch.randint(0, len(model.tokenizer.vocab), (1, input_len)).to(device)
        tokens_list.append(tokens)
    true_tokens = torch.cat(tokens_list, dim=0).to(device)
    with open(f"/content/initial_tokens_{num_targets}_{input_len}.pkl", 'wb') as file:
        pickle.dump(true_tokens, file)

100%|██████████| 100/100 [00:07<00:00, 14.27it/s]


In [None]:
# Generate the targets and (unused) initialisations for all TEXT-inversion experiments
for input_len in range(1,11):
    torch.manual_seed(1)
    np.random.seed(1)
    random.seed(1)

    num_targets = 1000
    with open(f"/content/true_tokens_{num_targets}_{input_len}.pkl", 'rb') as file:
        loaded_true_tokens = pickle.load(file).to("cpu")

    output_len = 25
    batch_size = 1000
    for batch in range(0, num_targets, batch_size):
        input_tokens = loaded_true_tokens[batch:batch+batch_size].to(device)
        output_tokens = model.generate(
            input_tokens,
            # min_new_tokens=output_len,
            max_new_tokens=output_len,
            do_sample=False,
            stop_at_eos=False,
            verbose=False,
            return_type="tokens",)[:,input_len:]
        if batch == 0:
            all_output_tokens = output_tokens
        else:
            all_output_tokens = torch.cat((all_output_tokens, output_tokens), dim=0)

    with open(f"/content/true_tokens_{num_targets}_{input_len}_{output_len}_greedy.pkl", 'wb') as file:
        pickle.dump(all_output_tokens, file)

In [None]:
# Filter the dataset used for evaluating privacy PII application
torch.manual_seed(1)
np.random.seed(1)
random.seed(1)

ds = load_dataset("ai4privacy/pii-masking-400k", split="train", streaming=True)
formatted_ds = {}
for data in tqdm(ds):
    # Filter out non english strings
    if data["language"] != "en":
        continue
    tokens = model.tokenizer(data["source_text"]).input_ids
    # Only keep 500 samples for each length between 15 and 24
    if len(tokens) < 15 or len(tokens) > 24:
        continue
    if len(tokens) not in formatted_ds:
        formatted_ds[len(tokens)] = []
    if len(formatted_ds[len(tokens)]) < 500:
        # Tokenise the strings and make the labels match the tokens
        tokens_decoded = []
        tokens_labels = []
        current_label = 0
        current_len = 1
        for token_id in tokens:
            decoded = model.tokenizer.decode([token_id])
            tokens_decoded.append(decoded)

            label = None
            # Check if we have passed the last label text span and should move onto the next
            if current_label < len(data["privacy_mask"]) and current_len > data["privacy_mask"][current_label]["end"]:
                current_label += 1
            # Check if we have are in the middle of the current label text span
            if current_label < len(data["privacy_mask"]) and current_len >= data["privacy_mask"][current_label]["start"]:
                    label = data["privacy_mask"][current_label]["label"]
            tokens_labels.append(label)

            current_len += len(decoded)

        new_data = {
            "source_text": data["source_text"],
            "source_text_labels": data["privacy_mask"],
            "tokens": tokens,
            "tokens_decoded": tokens_decoded,
            "tokens_labels": tokens_labels
        }
        formatted_ds[len(tokens)].append(new_data)

# # Upload to HuggingFace if want to
# dataset_dict = DatasetDict()
# for i in range(15, 25):
#     dataset_dict[f"length_{i}"] = Dataset.from_list(formatted_ds[i])
# dataset_dict.push_to_hub("AdrSkapars/pii-inversion-test-5k")

README.md:   0%|          | 0.00/15.6k [00:00<?, ?B/s]

325517it [03:50, 1412.74it/s]


In [None]:
# # Code for getting dataset onto huggingface
# from huggingface_hub import HfApi
# import os
# import yaml

# # Path to your dataset files
# dataset_dir = "pii-inversion-5k"
# username = "AdrSkapars"
# repo_name = "pii-inversion-5k"
# repo_id = f"{username}/{repo_name}"

# # Initialize Hugging Face API
# api = HfApi()

# # Create README.md with YAML configuration
# yaml_config = {
#     "configs": [
#         {
#             "config_name": "default",
#             "data_files": []
#         }
#     ]
# }

# # Add each length file as a separate split
# for length in range(15, 25):  # Range 15-24
#     file_name = f"length_{length}.jsonl"
#     file_path = os.path.join(dataset_dir, file_name)

#     if os.path.exists(file_path):
#         yaml_config["configs"][0]["data_files"].append({
#             "split": f"length_{length}",
#             "path": file_name
#         })

# # Create the README.md with YAML front matter
# readme_content = "---\n"
# readme_content += yaml.dump(yaml_config)
# readme_content += "---\n\n"
# readme_content += "# PII Inversion Dataset\n\n"
# with open(os.path.join(dataset_dir, "README.md"), "w") as f:
#     f.write(readme_content)

# # Create or update the repository
# api.create_repo(
#     repo_id=repo_id,
#     repo_type="dataset",
#     exist_ok=True
# )

# # Upload all files
# api.upload_folder(
#     folder_path=dataset_dir,
#     repo_id=repo_id,
#     repo_type="dataset"
# )

# from datasets import Dataset, DatasetDict
# from huggingface_hub import HfApi, HfFolder
# import os
# import json

# # Folder with your JSONL files
# data_dir = "pii-inversion-5k"

# # Prepare a dataset dictionary with custom splits
# dataset_dict = DatasetDict()

# for i in range(15, 25):
#     file_path = os.path.join(data_dir, f"length_{i}.jsonl")
#     with open(file_path, 'r') as f:
#         data = [json.loads(line) for line in f]
#     dataset_dict[f"length_{i}"] = Dataset.from_list(data)

# # Push to Hugging Face hub
# dataset_dict.push_to_hub("AdrSkapars/pii-inversion-5k")

### SODA Logits Inversion

#### Original

In [29]:
def onehot_search(cfg):
    # Get the targets used for all experiments based on dataset
    if cfg.target_strategy == "random":
        with open(DATA_PATH / f"true_tokens_{cfg.num_targets}_{cfg.input_len}.pkl", 'rb') as file:
            loaded_true_tokens = pickle.load(file).to("cpu")
    elif cfg.target_strategy == "privacy":
        # Privacy dataset only allows num_targets == 500 currently
        privacy_ds = load_dataset("AdrSkapars/pii-inversion-test-5k", split=f"length_{cfg.input_len}")
        loaded_true_tokens = torch.cat([torch.tensor(item["tokens"]).to(torch.int64).unsqueeze(0) for item in privacy_ds], dim=0).to("cpu")
    else:
        loaded_true_tokens = load_dataset_tokens(cfg.target_strategy, cfg.input_len, cfg.num_targets, include_bos=False, random_sentence=True, random_start=False)

    # Get the initialisation based on strategy
    if cfg.init_strategy == "loaded":
        with open(DATA_PATH / f"initial_tokens_{cfg.num_targets}_{cfg.input_len}.pkl", 'rb') as file:
            initialisation_tokens = pickle.load(file).to(device)
        initialisation_embeds = F.one_hot(initialisation_tokens, num_classes=model.embed.W_E.size(0)).to(model.embed.W_E.dtype).to("cpu")
    elif cfg.init_strategy == "normal":
        normal_embed = torch.empty((cfg.num_targets, cfg.input_len, model.embed.W_E.size(0)))
        _ = nn.init.normal_(normal_embed, std=0.05)
        initialisation_embeds = normal_embed.to("cpu")
    elif cfg.init_strategy == "zeros":
        initialisation_embeds = torch.zeros((cfg.num_targets, cfg.input_len, model.embed.W_E.size(0))).to("cpu")

    # Initialise state variables
    state_path = DATA_PATH / f'{cfg.save_folder}/checkpoint_{cfg.input_len}_{cfg.num_targets}_{cfg.max_epochs}.pt'
    if os.path.exists(state_path):
        print("LOADING STATE")
        state = torch.load(state_path, weights_only=False)
    else:
        print("INITIALISING STATE")
        state = DotDict({
            "results" : [],
            "batch_results" : [],
            "true_logits" : torch.Tensor([]).to(device),
            "optimizers" : [],
            "loaded_i" : 0,
            "epoch" : 0,
            "num_remain_items" : cfg.num_targets,
            "num_success_items" : 0,
            "elapsed_time" : 0,
            "checkpoint_elapsed_time" : 0,
        })

    while state.num_remain_items != 0 or len(state.batch_results) != 0:
        start_time = time.time()

        # Checkpoint current progress if hour has passed
        # if state.elapsed_time - state.checkpoint_elapsed_time > (3600 * 3):
        if state.elapsed_time - state.checkpoint_elapsed_time > (3600 * 6):
            print("\nSAVING STATE")
            state.checkpoint_elapsed_time = state.elapsed_time
            torch.save(state, state_path)

        # Print progress
        state.epoch += 1
        if state.epoch % 100 == 0:
            print(f"({state.num_success_items}/{cfg.num_targets})({cfg.num_targets-state.num_remain_items}/{cfg.num_targets}){state.epoch}", end=", ")

        with torch.no_grad():
            # Add new items to batch if have space and have more items to do
            if (cfg.max_batch_size - len(state.batch_results)) > 0 and state.num_remain_items != 0:
                num_new_items = min((cfg.max_batch_size - len(state.batch_results)), state.num_remain_items)
                state.num_remain_items -= num_new_items

                # Initialise new target and add to end (batched)
                true_tokens = loaded_true_tokens[state.loaded_i:state.loaded_i+num_new_items].to(device)
                new_true_logits = model(true_tokens).detach()[:,-1,:]
                state.true_logits = torch.cat((state.true_logits, new_true_logits))

                for i in range(num_new_items):
                    # Initialise new results tracking and add to end
                    state.batch_results.append({
                        "true_tokens": true_tokens[i].to("cpu"),
                        "pred_tokens": None,
                        "found_solution": False,
                        "done_epochs": 0,
                    })

                    # Initialise new prediction and add to end, one optimiser per sequence
                    new_pred_embed = initialisation_embeds[state.loaded_i+i:state.loaded_i+i+1].to(device)
                    for j in range(cfg.input_len):
                        new_pred_embed_pos = new_pred_embed[:,j:j+1]
                        new_pred_embed_pos.requires_grad = True
                        if j == 0:
                            if cfg.bias_correction:
                                state.optimizers.append(torch.optim.Adam([new_pred_embed_pos], lr=cfg.learn_rate, betas=cfg.betas))
                            else:
                                state.optimizers.append(CustomAdam([new_pred_embed_pos], lr=cfg.learn_rate, betas=cfg.betas))
                        else:
                            state.optimizers[-1].param_groups[0]['params'].append(new_pred_embed_pos)

                state.loaded_i += num_new_items

        # Do one epoch of optimisation on batch
        for optimizer in state.optimizers:
            optimizer.zero_grad()
        pred_embed_pre = torch.cat([torch.cat([param for param in optimizer.param_groups[0]['params']], dim=1)
                                    for optimizer in state.optimizers], dim=0).to(device)
        pred_one_hot = torch.softmax(pred_embed_pre / cfg.temp, dim=-1)
        pred_embed = (pred_one_hot @ model.embed.W_E)
        if "gpt" not in cfg.model_name and "tiny" not in cfg.model_name:
            pred_embed_full = pred_embed
        else:
            pred_embed_full = pred_embed + model.pos_embed(pred_embed[:,:,0].detach())
        pred_logits = model(pred_embed_full, start_at_layer=0)
        loss = torch.nn.HuberLoss()(state.true_logits.detach(), pred_logits[:,-1,:])

        if cfg.reg_weight is not None:
            # Compute regularisation penalty
            if state.epoch >= 0:
                # # Size of input penalty
                # reg_penalty = (pred_one_hot).pow(2).sum(dim=-1).sqrt() * -1
                # reg_penalty = reg_penalty.mean(dim=-1).mean(dim=-1)

                # Fluency penalty
                reg_penalty = pred_logits[:,:-1,:].softmax(dim=-1).log().gather(2, pred_one_hot[:,1:,:].argmax(dim=-1).unsqueeze(-1)).squeeze(-1) * -1
                reg_penalty = reg_penalty.mean(dim=-1).mean(dim=-1)

                loss = loss + (cfg.reg_weight * reg_penalty)

        loss.backward()
        for optimizer in state.optimizers:
            optimizer.step()

        with torch.no_grad():
            # Add decay to embeddings
            for i in range(len(state.optimizers)):
                for j in range(len(state.optimizers[i].param_groups[0]['params'])):
                    state.optimizers[i].param_groups[0]['params'][j].mul_(cfg.decay_rate)

            # Intervene if sequence not found yet
            for i in range(len(state.batch_results)):
                targets_epoch = (state.batch_results[i]["done_epochs"]+1)
                # Reset optimiser state
                if targets_epoch % cfg.reset_epoch == 0:
                    for j in range(cfg.input_len):
                        del state.optimizers[i].state[state.optimizers[i].param_groups[0]['params'][j]]

                # Reinitialise sequence
                if targets_epoch % cfg.reinit_epoch == 0:
                    for j in range(cfg.input_len):
                        state.optimizers[i].param_groups[0]['params'][j].normal_(std=0.1)

            # Assume largest one-hot token is the true one
            pred_tokens = torch.argmax(pred_one_hot, dim=-1)

            # Update history of tokens over epochs
            disc_pred_logits = model(pred_tokens)[:,-1,:]
            for i in range(len(state.batch_results)-1,-1,-1):
                state.batch_results[i]["done_epochs"] += 1

                # Remove item if have found a solution or reached final epoch
                threshold = 1e-4 if "tiny" in cfg.model_name else 1e-3
                have_inverted = torch.allclose(state.true_logits[i], disc_pred_logits[i], atol=threshold, rtol=threshold)
                if have_inverted:
                    state.batch_results[i]["found_solution"] = True
                    state.num_success_items += 1
                if have_inverted or (cfg.max_epochs is not None and state.batch_results[i]["done_epochs"] >= cfg.max_epochs):
                    state.batch_results[i]["pred_tokens"] = pred_tokens[i].to("cpu")
                    del state.optimizers[i]
                    state.true_logits = torch.cat((state.true_logits[:i], state.true_logits[i+1:]))
                    state.results.append(state.batch_results.pop(i))

            state.elapsed_time += time.time() - start_time

    return state.results, round(state.elapsed_time, 3)

#### Same as PIA

In [None]:
def onehot_search(cfg):
    # Get the targets used for all experiments based on dataset
    if cfg.target_strategy == "random":
        with open(DATA_PATH / f"true_tokens_{cfg.num_targets}_{cfg.input_len}.pkl", 'rb') as file:
            loaded_true_tokens = pickle.load(file).to("cpu")
    elif cfg.target_strategy == "privacy":
        # Privacy dataset only allows num_targets == 500 currently
        privacy_ds = load_dataset("AdrSkapars/pii-inversion-test-5k", split=f"length_{cfg.input_len}")
        loaded_true_tokens = torch.cat([torch.tensor(item["tokens"]).to(torch.int64).unsqueeze(0) for item in privacy_ds], dim=0).to("cpu")
    else:
        loaded_true_tokens = load_dataset_tokens(cfg.target_strategy, cfg.input_len, cfg.num_targets, include_bos=False, random_sentence=True, random_start=False)

    # Get the initialisation based on strategy
    if cfg.init_strategy == "loaded":
        with open(DATA_PATH / f"initial_tokens_{cfg.num_targets}_{cfg.input_len}.pkl", 'rb') as file:
            initialisation_tokens = pickle.load(file).to(device)
        initialisation_embeds = F.one_hot(initialisation_tokens, num_classes=model.embed.W_E.size(0)).to(model.embed.W_E.dtype).to("cpu")
    elif cfg.init_strategy == "normal":
        normal_embed = torch.empty((cfg.num_targets, cfg.input_len, model.embed.W_E.size(0)))
        _ = nn.init.normal_(normal_embed, std=0.05)
        initialisation_embeds = normal_embed.to("cpu")
    elif cfg.init_strategy == "zeros":
        initialisation_embeds = torch.zeros((cfg.num_targets, cfg.input_len, model.embed.W_E.size(0))).to("cpu")

    # Initialise state variables
    state_path = DATA_PATH / f'{cfg.save_folder}/checkpoint_{cfg.input_len}_{cfg.num_targets}_{cfg.max_epochs}.pt'
    if os.path.exists(state_path):
        print("LOADING STATE")
        state = torch.load(state_path, weights_only=False)
    else:
        print("INITIALISING STATE")
        state = DotDict({
            "results" : [],
            "batch_results" : [],
            "true_logits" : torch.Tensor([]).to(device),
            "optimizers" : [],
            "loaded_i" : 0,
            "epoch" : 0,
            "num_remain_items" : cfg.num_targets,
            "num_success_items" : 0,
            "elapsed_time" : 0,
            "checkpoint_elapsed_time" : 0,
        })

    while state.num_remain_items != 0 or len(state.batch_results) != 0:
        start_time = time.time()

        # Checkpoint current progress if hour has passed
        # if state.elapsed_time - state.checkpoint_elapsed_time > (3600 * 3):
        if state.elapsed_time - state.checkpoint_elapsed_time > (3600 * 6):
            print("\nSAVING STATE")
            state.checkpoint_elapsed_time = state.elapsed_time
            torch.save(state, state_path)

        # Print progress
        state.epoch += 1
        if state.epoch % 100 == 0:
            print(f"({state.num_success_items}/{cfg.num_targets})({cfg.num_targets-state.num_remain_items}/{cfg.num_targets}){state.epoch}", end=", ")

        with torch.no_grad():
            # Add new items to batch if have space and have more items to do
            if (cfg.max_batch_size - len(state.batch_results)) > 0 and state.num_remain_items != 0:
                num_new_items = min((cfg.max_batch_size - len(state.batch_results)), state.num_remain_items)
                state.num_remain_items -= num_new_items

                # Initialise new target and add to end (batched)
                true_tokens = loaded_true_tokens[state.loaded_i:state.loaded_i+num_new_items].to(device)
                new_true_logits = model(true_tokens).detach()[:,-1,:]
                state.true_logits = torch.cat((state.true_logits, new_true_logits))

                for i in range(num_new_items):
                    # Initialise new results tracking and add to end
                    state.batch_results.append({
                        "true_tokens": true_tokens[i].to("cpu"),
                        "pred_tokens": None,
                        "found_solution": False,
                        "done_epochs": 0,
                    })

                    # Initialise new prediction and add to end, one optimiser per sequence
                    new_pred_embed = initialisation_embeds[state.loaded_i+i:state.loaded_i+i+1].to(device)
                    for j in range(cfg.input_len):
                        new_pred_embed_pos = new_pred_embed[:,j:j+1]
                        new_pred_embed_pos.requires_grad = True
                        if j == 0:
                            if cfg.bias_correction:
                                state.optimizers.append(torch.optim.Adam([new_pred_embed_pos], lr=cfg.learn_rate, betas=cfg.betas))
                            else:
                                state.optimizers.append(CustomAdam([new_pred_embed_pos], lr=cfg.learn_rate, betas=cfg.betas))
                        else:
                            state.optimizers[-1].param_groups[0]['params'].append(new_pred_embed_pos)

                state.loaded_i += num_new_items

        # Do one epoch of optimisation on batch
        for optimizer in state.optimizers:
            optimizer.zero_grad()
        pred_embed_pre = torch.cat([torch.cat([param for param in optimizer.param_groups[0]['params']], dim=1)
                                    for optimizer in state.optimizers], dim=0).to(device)
        pred_one_hot = torch.softmax(pred_embed_pre / cfg.temp, dim=-1)
        pred_embed = (pred_one_hot @ model.embed.W_E)
        if "gpt" not in cfg.model_name and "tiny" not in cfg.model_name:
            pred_embed_full = pred_embed
        else:
            pred_embed_full = pred_embed + model.pos_embed(pred_embed[:,:,0].detach())
        pred_logits = model(pred_embed_full, start_at_layer=0)
        loss = torch.nn.HuberLoss()(state.true_logits.detach(), pred_logits[:,-1,:])

        if cfg.reg_weight is not None:
            # Compute regularisation penalty
            if state.epoch >= 0:
                # # Size of input penalty
                # reg_penalty = (pred_one_hot).pow(2).sum(dim=-1).sqrt() * -1
                # reg_penalty = reg_penalty.mean(dim=-1).mean(dim=-1)

                # Fluency penalty
                reg_penalty = pred_logits[:,:-1,:].softmax(dim=-1).log().gather(2, pred_one_hot[:,1:,:].argmax(dim=-1).unsqueeze(-1)).squeeze(-1) * -1
                reg_penalty = reg_penalty.mean(dim=-1).mean(dim=-1)

                loss = loss + (cfg.reg_weight * reg_penalty)

        loss.backward()
        for optimizer in state.optimizers:
            optimizer.step()

        with torch.no_grad():
            # Add decay to embeddings
            for i in range(len(state.optimizers)):
                for j in range(len(state.optimizers[i].param_groups[0]['params'])):
                    state.optimizers[i].param_groups[0]['params'][j].mul_(cfg.decay_rate)

            # Intervene if sequence not found yet
            for i in range(len(state.batch_results)):
                targets_epoch = (state.batch_results[i]["done_epochs"]+1)
                # Reset optimiser state
                if targets_epoch % cfg.reset_epoch == 0:
                    for j in range(cfg.input_len):
                        del state.optimizers[i].state[state.optimizers[i].param_groups[0]['params'][j]]

                # Reinitialise sequence
                if targets_epoch % cfg.reinit_epoch == 0:
                    for j in range(cfg.input_len):
                        state.optimizers[i].param_groups[0]['params'][j].normal_(std=0.1)

            # Assume largest one-hot token is the true one
            pred_tokens = torch.argmax(pred_one_hot, dim=-1)

            # Update history of tokens over epochs
            disc_pred_logits = model(pred_tokens)[:,-1,:]
            for i in range(len(state.batch_results)-1,-1,-1):
                state.batch_results[i]["done_epochs"] += 1

                # Check if have found a solution
                threshold = 1e-4 if "tiny" in cfg.model_name else 1e-3
                have_inverted = torch.allclose(state.true_logits[i], disc_pred_logits[i], atol=threshold, rtol=threshold)
                if have_inverted:
                    state.batch_results[i]["found_solution"] = True
                    state.num_success_items += 1
                
                
                
                ###########################################################
                
                if (not have_inverted) and (cfg.max_epochs is not None and state.batch_results[i]["done_epochs"] >= cfg.max_epochs):
                    # Have reached the end of search without solving, so do a final adaptive discretisation search
                    batch_size, seq_len, d_model = pred_embed.shape
                    
                    # Find top-k nearest neighbors based on embedding distance
                    distances = torch.cdist(pred_embed[i], model.embed.W_E, p=2)  # Shape: [seq_len, vocab_size]
                    top_k_indices = torch.topk(distances, k=cfg.top_k, dim=-1, largest=False).indices  # Shape: [seq_len, k]
                    
                    # # Find top-k nearest neighbors based on one hot distance
                    # top_k_indices = torch.topk(pred_one_hot[i], k=cfg.top_k, dim=-1).indices
                    
                    # Initialize pred_tokens for this example with closest neighbor
                    pred_tokens_i = pred_tokens[i].clone()  # Shape: [seq_len]
                    true_logits_i = state.true_logits[i].unsqueeze(0)  # Shape: [1, vocab_size]
                    
                    # Iterate through each position in the sequence
                    for pos in range(seq_len):
                        best_loss = float('inf')
                        best_token = pred_tokens_i[pos].item()
                        best_logits = None
                        
                        # Process k substitutions in batches
                        for k_start in range(0, cfg.top_k, cfg.max_batch_size):
                            k_end = min(k_start + cfg.max_batch_size, cfg.top_k)
                            k_batch_size = k_end - k_start
                            
                            # Create candidate sequences for this batch of k values
                            # Shape: [k_batch_size, seq_len]
                            candidate_tokens_batch = pred_tokens_i.unsqueeze(0).expand(k_batch_size, -1).clone()
                            candidate_tokens_batch[:, pos] = top_k_indices[pos, k_start:k_end]
                            
                            # Calculate losses for all candidates in this batch
                            candidate_logits = model(candidate_tokens_batch)  # Shape: [k_batch_size, seq_len, vocab_size]
                            candidate_losses = torch.nn.HuberLoss(reduction='none')(
                                true_logits_i.expand(k_batch_size, -1), 
                                candidate_logits[:, -1, :]
                            ).mean(dim=-1)  # Shape: [k_batch_size]
                            
                            # Find best in this batch
                            min_loss_idx = torch.argmin(candidate_losses)
                            min_loss = candidate_losses[min_loss_idx].item()
                            min_loss_logits = candidate_logits[min_loss_idx, -1, :]
                            
                            if min_loss < best_loss:
                                best_loss = min_loss
                                best_token = candidate_tokens_batch[min_loss_idx, pos].item()
                                best_logits = min_loss_logits.clone()
                        
                        # Update pred_tokens with best choice at this position
                        pred_tokens_i[pos] = best_token
                    pred_tokens[i] = pred_tokens_i
                    
                    # Repeat check if have found a solution
                    have_inverted = torch.allclose(state.true_logits[i], best_logits, atol=threshold, rtol=threshold)
                    if have_inverted:
                        state.batch_results[i]["found_solution"] = True
                        state.num_success_items += 1
                
                ###########################################################
                
                
                
                # Remove item if have found a solution or reached final epoch
                if have_inverted or (cfg.max_epochs is not None and state.batch_results[i]["done_epochs"] >= cfg.max_epochs):
                    state.batch_results[i]["pred_tokens"] = pred_tokens[i].to("cpu")
                    del state.optimizers[i]
                    state.true_logits = torch.cat((state.true_logits[:i], state.true_logits[i+1:]))
                    state.results.append(state.batch_results.pop(i))

            state.elapsed_time += time.time() - start_time

    return state.results, round(state.elapsed_time, 3)

In [65]:
def outside_in_left(seq_len):
    left, right = 0, seq_len - 1
    while left <= right:
        yield left
        if left != right:
            yield right
        left += 1
        right -= 1

In [66]:
def outside_in_right(seq_len):
    left, right = 0, seq_len - 1
    while left <= right:
        yield right
        if left != right:
            yield left
        left += 1
        right -= 1

#### All possible combinations

In [78]:
def onehot_search(cfg):
    # Get the targets used for all experiments based on dataset
    if cfg.target_strategy == "random":
        with open(DATA_PATH / f"true_tokens_{cfg.num_targets}_{cfg.input_len}.pkl", 'rb') as file:
            loaded_true_tokens = pickle.load(file).to("cpu")
    elif cfg.target_strategy == "privacy":
        # Privacy dataset only allows num_targets == 500 currently
        privacy_ds = load_dataset("AdrSkapars/pii-inversion-test-5k", split=f"length_{cfg.input_len}")
        loaded_true_tokens = torch.cat([torch.tensor(item["tokens"]).to(torch.int64).unsqueeze(0) for item in privacy_ds], dim=0).to("cpu")
    else:
        loaded_true_tokens = load_dataset_tokens(cfg.target_strategy, cfg.input_len, cfg.num_targets, include_bos=False, random_sentence=True, random_start=False)

    # Get the initialisation based on strategy
    if cfg.init_strategy == "loaded":
        with open(DATA_PATH / f"initial_tokens_{cfg.num_targets}_{cfg.input_len}.pkl", 'rb') as file:
            initialisation_tokens = pickle.load(file).to(device)
        initialisation_embeds = F.one_hot(initialisation_tokens, num_classes=model.embed.W_E.size(0)).to(model.embed.W_E.dtype).to("cpu")
    elif cfg.init_strategy == "normal":
        normal_embed = torch.empty((cfg.num_targets, cfg.input_len, model.embed.W_E.size(0)))
        _ = nn.init.normal_(normal_embed, std=0.05)
        initialisation_embeds = normal_embed.to("cpu")
    elif cfg.init_strategy == "zeros":
        initialisation_embeds = torch.zeros((cfg.num_targets, cfg.input_len, model.embed.W_E.size(0))).to("cpu")

    # Initialise state variables
    state_path = DATA_PATH / f'{cfg.save_folder}/checkpoint_{cfg.input_len}_{cfg.num_targets}_{cfg.max_epochs}.pt'
    if os.path.exists(state_path):
        print("LOADING STATE")
        state = torch.load(state_path, weights_only=False)
    else:
        print("INITIALISING STATE")
        state = DotDict({
            "results" : [],
            "batch_results" : [],
            "true_logits" : torch.Tensor([]).to(device),
            "optimizers" : [],
            "loaded_i" : 0,
            "epoch" : 0,
            "num_remain_items" : cfg.num_targets,
            "num_success_items" : 0,
            "elapsed_time" : 0,
            "checkpoint_elapsed_time" : 0,
        })

    while state.num_remain_items != 0 or len(state.batch_results) != 0:
        start_time = time.time()

        # Checkpoint current progress if hour has passed
        # if state.elapsed_time - state.checkpoint_elapsed_time > (3600 * 3):
        if state.elapsed_time - state.checkpoint_elapsed_time > (3600 * 6):
            print("\nSAVING STATE")
            state.checkpoint_elapsed_time = state.elapsed_time
            torch.save(state, state_path)

        # Print progress
        state.epoch += 1
        if state.epoch % 100 == 0:
            print(f"({state.num_success_items}/{cfg.num_targets})({cfg.num_targets-state.num_remain_items}/{cfg.num_targets}){state.epoch}", end=", ")

        with torch.no_grad():
            # Add new items to batch if have space and have more items to do
            if (cfg.max_batch_size - len(state.batch_results)) > 0 and state.num_remain_items != 0:
                num_new_items = min((cfg.max_batch_size - len(state.batch_results)), state.num_remain_items)
                state.num_remain_items -= num_new_items

                # Initialise new target and add to end (batched)
                true_tokens = loaded_true_tokens[state.loaded_i:state.loaded_i+num_new_items].to(device)
                new_true_logits = model(true_tokens).detach()[:,-1,:]
                state.true_logits = torch.cat((state.true_logits, new_true_logits))

                for i in range(num_new_items):
                    # Initialise new results tracking and add to end
                    state.batch_results.append({
                        "true_tokens": true_tokens[i].to("cpu"),
                        "pred_tokens": None,
                        "found_solution": False,
                        "done_epochs": 0,
                    })

                    # Initialise new prediction and add to end, one optimiser per sequence
                    new_pred_embed = initialisation_embeds[state.loaded_i+i:state.loaded_i+i+1].to(device)
                    for j in range(cfg.input_len):
                        new_pred_embed_pos = new_pred_embed[:,j:j+1]
                        new_pred_embed_pos.requires_grad = True
                        if j == 0:
                            if cfg.bias_correction:
                                state.optimizers.append(torch.optim.Adam([new_pred_embed_pos], lr=cfg.learn_rate, betas=cfg.betas))
                            else:
                                state.optimizers.append(CustomAdam([new_pred_embed_pos], lr=cfg.learn_rate, betas=cfg.betas))
                        else:
                            state.optimizers[-1].param_groups[0]['params'].append(new_pred_embed_pos)

                state.loaded_i += num_new_items

        # Do one epoch of optimisation on batch
        for optimizer in state.optimizers:
            optimizer.zero_grad()
        pred_embed_pre = torch.cat([torch.cat([param for param in optimizer.param_groups[0]['params']], dim=1)
                                    for optimizer in state.optimizers], dim=0).to(device)
        pred_one_hot = torch.softmax(pred_embed_pre / cfg.temp, dim=-1)
        pred_embed = (pred_one_hot @ model.embed.W_E)
        if "gpt" not in cfg.model_name and "tiny" not in cfg.model_name:
            pred_embed_full = pred_embed
        else:
            pred_embed_full = pred_embed + model.pos_embed(pred_embed[:,:,0].detach())
        pred_logits = model(pred_embed_full, start_at_layer=0)
        loss = torch.nn.HuberLoss()(state.true_logits.detach(), pred_logits[:,-1,:])

        if cfg.reg_weight is not None:
            # Compute regularisation penalty
            if state.epoch >= 0:
                # # Size of input penalty
                # reg_penalty = (pred_one_hot).pow(2).sum(dim=-1).sqrt() * -1
                # reg_penalty = reg_penalty.mean(dim=-1).mean(dim=-1)

                # Fluency penalty
                reg_penalty = pred_logits[:,:-1,:].softmax(dim=-1).log().gather(2, pred_one_hot[:,1:,:].argmax(dim=-1).unsqueeze(-1)).squeeze(-1) * -1
                reg_penalty = reg_penalty.mean(dim=-1).mean(dim=-1)

                loss = loss + (cfg.reg_weight * reg_penalty)

        loss.backward()
        for optimizer in state.optimizers:
            optimizer.step()

        with torch.no_grad():
            # Add decay to embeddings
            for i in range(len(state.optimizers)):
                for j in range(len(state.optimizers[i].param_groups[0]['params'])):
                    state.optimizers[i].param_groups[0]['params'][j].mul_(cfg.decay_rate)

            # Intervene if sequence not found yet
            for i in range(len(state.batch_results)):
                targets_epoch = (state.batch_results[i]["done_epochs"]+1)
                # Reset optimiser state
                if targets_epoch % cfg.reset_epoch == 0:
                    for j in range(cfg.input_len):
                        del state.optimizers[i].state[state.optimizers[i].param_groups[0]['params'][j]]

                # Reinitialise sequence
                if targets_epoch % cfg.reinit_epoch == 0:
                    for j in range(cfg.input_len):
                        state.optimizers[i].param_groups[0]['params'][j].normal_(std=0.1)

            # Assume largest one-hot token is the true one
            pred_tokens = torch.argmax(pred_one_hot, dim=-1)

            # Update history of tokens over epochs
            disc_pred_logits = model(pred_tokens)[:,-1,:]
            for i in range(len(state.batch_results)-1,-1,-1):
                state.batch_results[i]["done_epochs"] += 1

                # Check if have found a solution
                threshold = 1e-4 if "tiny" in cfg.model_name else 1e-3
                have_inverted = torch.allclose(state.true_logits[i], disc_pred_logits[i], atol=threshold, rtol=threshold)
                if have_inverted:
                    state.batch_results[i]["found_solution"] = True
                    state.num_success_items += 1
                
                
                
                ###########################################################
                
                if (not have_inverted) and (cfg.max_epochs is not None and state.batch_results[i]["done_epochs"] >= cfg.max_epochs):
                    # Have reached the end of search without solving, so do a final adaptive discretisation search
                    batch_size, seq_len, d_model = pred_embed.shape
                    
                    # Find top-k nearest neighbors based on one hot distance
                    top_k_indices = torch.topk(pred_one_hot[i], k=cfg.top_k, dim=-1).indices
                    
                    # # Find top-k nearest neighbors based on embedding distance
                    # distances = torch.cdist(pred_embed[i], model.embed.W_E, p=2)  # Shape: [seq_len, vocab_size]
                    # top_k_indices = torch.topk(distances, k=cfg.top_k, dim=-1, largest=False).indices  # Shape: [seq_len, k]
                    
                    # Initialize with current tokens
                    pred_tokens_i = pred_tokens[i].clone()  # Shape: [seq_len]
                    true_logits_i = state.true_logits[i].unsqueeze(0)  # Shape: [1, vocab_size]
                    
                    # Generate all k^seq_len combinations using itertools.product
                    import itertools
                    
                    # Create list of choices for each position: [choices_pos0, choices_pos1, ...]
                    position_choices = [top_k_indices[pos].tolist() for pos in range(seq_len)]
                    
                    # Generate all combinations
                    all_combinations = list(itertools.product(*position_choices))
                    total_combinations = len(all_combinations)  # This is k^seq_len
                    
                    best_loss = float('inf')
                    best_tokens = pred_tokens_i.clone()
                    best_logits = None
                    
                    # Process all combinations in batches
                    for comb_start in range(0, total_combinations, cfg.max_batch_size):
                        comb_end = min(comb_start + cfg.max_batch_size, total_combinations)
                        batch_size_actual = comb_end - comb_start
                        
                        # Create batch of candidate sequences
                        # Shape: [batch_size_actual, seq_len]
                        candidate_tokens_batch = torch.tensor(
                            all_combinations[comb_start:comb_end], 
                            dtype=torch.long, 
                            device=pred_tokens.device
                        )
                        
                        # Calculate losses for all candidates in this batch
                        candidate_logits = model(candidate_tokens_batch)  # Shape: [batch_size_actual, seq_len, vocab_size]
                        candidate_losses = torch.nn.HuberLoss(reduction='none')(
                            true_logits_i.expand(batch_size_actual, -1), 
                            candidate_logits[:, -1, :]
                        ).mean(dim=-1)  # Shape: [batch_size_actual]
                        
                        # Find best in this batch
                        min_loss_idx = torch.argmin(candidate_losses)
                        min_loss = candidate_losses[min_loss_idx].item()
                        min_loss_logits = candidate_logits[min_loss_idx, -1, :]
                        
                        if min_loss < best_loss:
                            best_loss = min_loss
                            best_tokens = candidate_tokens_batch[min_loss_idx].clone()
                            best_logits = min_loss_logits.clone()
                    
                    # Update with best combination found
                    pred_tokens[i] = best_tokens
                    
                    # Repeat check if have found a solution
                    have_inverted = torch.allclose(state.true_logits[i], best_logits, atol=threshold, rtol=threshold)
                    if have_inverted:
                        state.batch_results[i]["found_solution"] = True
                        state.num_success_items += 1
                
                ###########################################################
                
                
                
                # Remove item if have found a solution or reached final epoch
                if have_inverted or (cfg.max_epochs is not None and state.batch_results[i]["done_epochs"] >= cfg.max_epochs):
                    state.batch_results[i]["pred_tokens"] = pred_tokens[i].to("cpu")
                    del state.optimizers[i]
                    state.true_logits = torch.cat((state.true_logits[:i], state.true_logits[i+1:]))
                    state.results.append(state.batch_results.pop(i))

            state.elapsed_time += time.time() - start_time

    return state.results, round(state.elapsed_time, 3)

#### Experiment

In [None]:
# # Experiment parameters
# cfg = DotDict({
#     "learn_rate" : 0.065,
#     "decay_rate" : 0.9,
#     "betas" : (0.9,0.995),
#     "temp" : 0.05,
#     "reset_epoch" : 50,
#     "reinit_epoch" : 1500,
#     "reg_weight" : None,#9e-3,
#     "bias_correction" : False,
#     "target_strategy" : {
#         0: "random",
#         1: "tinystories",
#         2: "reddit",
#         3: "wikipedia",
#         4: "privacy",
#     }[0],
#     "init_strategy" : {
#         0: "loaded",
#         1: "normal",
#         2: "zeros",
#     }[2],
#     "save_folder": "OneHot_TinyStories33M_privacy",
#     "model_name": model_name,
# })

In [74]:
# Experiment parameters
cfg = DotDict({
    "learn_rate" : 0.065,
    "decay_rate" : 0.9,
    "betas" : (0.9,0.995),
    "temp" : 0.05,
    "reset_epoch" : 50,
    "reinit_epoch" : 1500,
    "reg_weight" : None,#9e-3,
    "bias_correction" : False,
    "target_strategy" : {
        0: "random",
        1: "tinystories",
        2: "reddit",
        3: "wikipedia",
        4: "privacy",
    }[0],
    "init_strategy" : {
        0: "loaded",
        1: "normal",
        2: "zeros",
    }[2],
    "save_folder": "OneHot_TinyStories33M_privacy",
    "model_name": model_name,
})

In [79]:
# Set parameters for dataset size
cfg.input_len = 7
cfg.top_k = 6
cfg.num_targets = 100
cfg.max_epochs = 1000
cfg.max_batch_size = {
    "tiny-stories-33M" : [1000,1000,1000,1000,1000,1000,1000,900,750,650,None, None, None, None, 500, 500, 475, 450, 410, 380, 375, 350, 325, 300],
    "gpt2-small": [1000,1000,1000,1000,1000,950,850,700,650,600],
    "Qwen/Qwen2.5-0.5B": [1000,700,500,400,300,240,220,200,180,100],
}[cfg.model_name][cfg.input_len-1]

# # Set random seeds
# torch.manual_seed(0)
# np.random.seed(0)
# random.seed(0)

# # Generate the targets and initialisations used for all experiments
# tokens_list = []
# for _ in tqdm(range(cfg.num_targets)):
#     tokens = torch.randint(0, len(model.tokenizer.vocab), (1, cfg.input_len)).to(device)
#     tokens_list.append(tokens)
# true_tokens = torch.cat(tokens_list, dim=0).to(device)
# with open(DATA_PATH / f'true_tokens_{cfg.num_targets}_{cfg.input_len}.pkl', 'wb') as file:
#     pickle.dump(true_tokens, file)

# Set random seeds
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

results, elapsed_time = onehot_search(cfg)
print()
stats = get_paper_summary_stats_new(results, epochs=cfg.max_epochs)
stats["elapsed_time"] = elapsed_time
stats["experiment_params"] = cfg
print("input_len", cfg.input_len, "\tpercent_zero_loss", stats["percent_zero_loss"], "\tpercent_exact_inversion", stats["percent_exact_inversion"], "\telapsed_time", stats["elapsed_time"])

# with open(DATA_PATH / f'{cfg.save_folder}/stats_{cfg.input_len}_{cfg.num_targets}_{cfg.max_epochs}.json', 'w') as f:
#     json.dump(stats, f)
# with open(DATA_PATH / f'{cfg.save_folder}/results_{cfg.input_len}_{cfg.num_targets}_{cfg.max_epochs}.pkl', 'wb') as f:
#     pickle.dump(results, f)

INITIALISING STATE
(1/100)(100/100)100, (2/100)(100/100)200, (4/100)(100/100)300, (9/100)(100/100)400, (14/100)(100/100)500, (21/100)(100/100)600, (22/100)(100/100)700, (23/100)(100/100)800, (26/100)(100/100)900, (27/100)(100/100)1000, 


100%|██████████| 7/7 [00:00<00:00, 7671.84it/s]

input_len 7 	percent_zero_loss 27.0 	percent_exact_inversion 27.0 	elapsed_time 877.604



