Importing cellular automata & optimization classes, and other stuff

In [4]:
import os
import sys
import shutil


sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(''))))

from lamm_automata.blender import Lattice, clear_initial
from lamm_automata.ruleset import conway, seeds
from lamm_automata.genetic import Optimizer, RulesetMutator, ArbitraryRulesetMutator
from lamm_automata.genetic.mutation import ICMutation, SRTMutation
from lamm_automata.objectives import surface_to_vol

import numpy as np
import torch
import pandas as pd
import random
import datetime
import decimal
from copy import deepcopy

Import transformer modules

In [1]:
from x_transformers import Decoder, TransformerWrapper
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper

More imports (mostly PyTorch stuff)

In [2]:
from torch.utils.data import Dataset, IterableDataset, DataLoader, get_worker_info
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F

import time
import pickle
import json

Open all experiment data as Pandas DataFrames

In [5]:
list_df = []
for filename in os.listdir(os.path.join(os.getcwd(), 'data')):
    df = pd.read_hdf(f'data/{filename}')
    list_df.append(df)
    # print data to make sure they are correct
    print(f'EXPERIMENT DATA FROM {filename}:')
    # with pd.option_context('display.max_rows', None, 'display.max_columns', None):
    #     # print(df)

EXPERIMENT DATA FROM test_experiment_100ITERS_32GRID_40.h5:
EXPERIMENT DATA FROM test_experiment_100ITERS_32GRID_41.h5:
EXPERIMENT DATA FROM test_experiment_100ITERS_32GRID_6.h5:
EXPERIMENT DATA FROM test_experiment_100ITERS_32GRID_12.h5:
EXPERIMENT DATA FROM test_experiment_100ITERS_32GRID_9.h5:
EXPERIMENT DATA FROM test_experiment_100ITERS_32GRID_14.h5:
EXPERIMENT DATA FROM test_experiment_100ITERS_32GRID_17.h5:
EXPERIMENT DATA FROM test_experiment_100ITERS_32GRID_24.h5:
EXPERIMENT DATA FROM test_experiment_100ITERS_32GRID_49.h5:
EXPERIMENT DATA FROM test_experiment_100ITERS_32GRID_36.h5:
EXPERIMENT DATA FROM test_experiment_100ITERS_32GRID_39.h5:
EXPERIMENT DATA FROM test_experiment_100ITERS_32GRID_31.h5:
EXPERIMENT DATA FROM test_experiment_100ITERS_32GRID_32.h5:
EXPERIMENT DATA FROM test_experiment_100ITERS_32GRID_44.h5:
EXPERIMENT DATA FROM test_experiment_100ITERS_32GRID_45.h5:
EXPERIMENT DATA FROM test_experiment_100ITERS_32GRID_4.h5:
EXPERIMENT DATA FROM test_experiment_100ITE

Import tokenizer modules

In [6]:
from tokenizers import Tokenizer, models, pre_tokenizers, trainers, processors
from transformers import PreTrainedTokenizerFast
from TokenizerChanger import TokenizerChanger
from torchtune.generation import sample

  from .autonotebook import tqdm as notebook_tqdm


Load tokenizer

In [7]:
# model we want to open and epoch number
model_name = "test_model_04_12_2029"

model_dir = os.path.join("models", model_name)

delimiter_tokens = ["[BMB]","[ICM]","[SRTM]","[EMB]"]
special_tokens = ["[UNK]", "[PAD]", "[BOS]", "[EOS]"]
# perf_num_tokens = ["[P]","#"]
perf_num_tokens = ["[BP]", "[PPOS]", "[PNEG]", "[PEPOS]", "[PENEG]", "[EP]"]

# load tokenizer
tokenizer = PreTrainedTokenizerFast(tokenizer_file = os.path.join(model_dir,"tokenizer.json"))

# special tokens
encoder_special_tokens_dict = {"additional_special_tokens": delimiter_tokens+special_tokens+perf_num_tokens}
tokenizer.add_special_tokens(encoder_special_tokens_dict)
tokenizer.add_special_tokens({'pad_token': '[PAD]', 'bos_token': '[BOS]', 'unk_token': '[UNK]', 'eos_token': '[EOS]'})

tokenizer.padding_side = "right"

In [8]:
def str_to_tensor(dat_str):
    return torch.Tensor(tokenizer(dat_str, padding="longest", return_tensors="pt")["input_ids"])

def tensor_decode(tensor):
    if(tensor.dim()==2):
        return [tokenizer.decode(seq) for seq in tensor]
    return tokenizer.decode(tensor)

Encode dataset

In [11]:
def decimal_to_tokens(val):
    formatted_val = "{:.2e}".format(val)
    digits, exp = formatted_val.split("e")

    digits = digits.replace(".","")
    digits = digits.replace("+","")
    digits = digits.replace("-","")

    exp = int(exp)
    return f"{'[PPOS]' if val > 0 else '[PNEG]'} {' '.join(digits)} {'[PEPOS]' if exp >= 0 else '[PENEG]'} {abs(exp)}"

data_strings = []
initial_ics = []
initial_srts = []
# performance_metric_vals = []
# store performance metric values separately to put them in the input directly, they shouldn't be tokenized
for data_frame in list_df:
    # final goal: [BOS][BMB][ICM][pos of mutation in IC (2 tokens)][SRTM][pos of mutation in SRT (3 tokens)][P][performance metric after the mutation][EMB] ... [EOS]
    data_str = '[BOS] '
    # perf_met = []
    for index, row in data_frame.iterrows():
        # print(row)
        if index == 0:
            # initial conditions
            grid_sz = row["ic_cell_pos"]
            initial_ic = row["ic_state_old"]
            initial_srt = row["srt_state_old"]
            initial_ics.append(initial_ic)
            initial_srts.append(initial_srt)
            # print(f'grid_sz: {grid_sz}, initial_ic: {initial_ic}, initial_srt: {initial_srt}')
        else:
            data_str += '[BMB] '
            # columns: ic_cell_pos, ic_state_old, ic_state_new, srt_cell_pos, srt_state_old, srt_state_new, objective
            ic_mut_strings = []
            srt_mut_strings = []
            # only include mutation positions since otherwise data will be too big
            if type(row["ic_cell_pos"]) != int and len(row["ic_cell_pos"].shape) >= 1:
                if len(row["ic_cell_pos"].shape) == 2:
                    # batch update
                    for i in range(row["ic_cell_pos"].shape[0]):
                        ic_mut_strings.append(f'[ICM] {row["ic_cell_pos"][i][0]} {row["ic_cell_pos"][i][1]} ')
                else:
                    ic_mut_strings.append(f'[ICM] {row["ic_cell_pos"][0]} {row["ic_cell_pos"][1]} ')
            
            if type(row["srt_cell_pos"]) != int and len(row["srt_cell_pos"].shape) >= 1:
                if len(row["srt_cell_pos"].shape) == 2:
                    # batch update
                    for i in range(row["srt_cell_pos"].shape[0]):
                        ic_mut_strings.append(f'[SRTM] {row["srt_cell_pos"][i][0]} {row["srt_cell_pos"][i][1]} {row["srt_cell_pos"][i][2]} ')
                else:
                    ic_mut_strings.append(f'[SRTM] {row["srt_cell_pos"][0]} {row["srt_cell_pos"][1]} {row["srt_cell_pos"][2]} ')
            data_str += ''.join(ic_mut_strings)
            data_str += ''.join(srt_mut_strings)
            data_str += f'[BP] {decimal_to_tokens(row["objective"])} [EP] '
            data_str += '[EMB] '
    data_str += '[EOS]'

    data_strings.append(data_str)

torch.set_printoptions(profile="full")
encoded = str_to_tensor(data_strings)
# print(encoded)

Load the model

In [12]:
epoch_num = 1
print(f"Loading model {model_name} in epoch number {epoch_num}")

# load model hyperparameters
with open(os.path.join(model_dir,'model_hyperparams.json')) as f:
    d = json.load(f)
    print(f"model hyperparams: {d}")
    num_tokens = d["num_tokens"]
    max_seq_len = d["max_seq_len"]
    DIM = d["dim"]
    DEPTH = d["depth"]
    HEADS = d["heads"]
    ATTN_FLASH = d["attn_flash"]
    ROTARY_POS_EMB = d["rotary_pos_emb"]

with open(os.path.join(model_dir,'model_training_params.json')) as f:
    d = json.load(f)
    print(f"model training params: {d}")

model = TransformerWrapper(
    num_tokens = num_tokens,
    max_seq_len = max_seq_len,
    attn_layers = Decoder(
        dim = DIM,
        depth = DEPTH,
        heads = HEADS,
        attn_flash = ATTN_FLASH,
        rotary_pos_emb = ROTARY_POS_EMB,
    ), 
)
# wrap the transformer into an autoregressor
model = AutoregressiveWrapper(model)

# load the model
model.load_state_dict(torch.load(os.path.join(model_dir, f"weights_{epoch_num}.pt"), weights_only=True))
if torch.cuda.is_available():
    model.cuda()
model.eval()

Loading model test_model_04_12_2029 in epoch number 1
model hyperparams: {'model_name': 'test_model', 'num_tokens': 93, 'max_seq_len': 100, 'dim': 50, 'depth': 6, 'limit_seq_len': 100, 'heads': 4, 'rotary_pos_emb': True, 'attn_flash': True, 'masking': False, 'mask_prob': 0.15}
model training params: {'epochs': 1, 'batch_size': 1, 'lr': 0.001, 'workers': 0, 'drop_last': False}


AutoregressiveWrapper(
  (net): TransformerWrapper(
    (token_emb): TokenEmbedding(
      (emb): Embedding(93, 50)
    )
    (post_emb_norm): Identity()
    (emb_dropout): Dropout(p=0.0, inplace=False)
    (project_emb): Identity()
    (attn_layers): Decoder(
      (layers): ModuleList(
        (0): ModuleList(
          (0): ModuleList(
            (0): LayerNorm(
              (ln): LayerNorm((50,), eps=1e-05, elementwise_affine=False)
            )
            (1-2): 2 x None
          )
          (1): Attention(
            (to_q): Linear(in_features=50, out_features=256, bias=False)
            (to_k): Linear(in_features=50, out_features=256, bias=False)
            (to_v): Linear(in_features=50, out_features=256, bias=False)
            (split_q_heads): Rearrange('b n (h d) -> b h n d', h=4)
            (split_k_heads): Rearrange('b n (h d) -> b h n d', d=64)
            (split_v_heads): Rearrange('b n (h d) -> b h n d', d=64)
            (merge_heads): Rearrange('b h n d -> b n

Transformer-Powered Genetic Algorithm Performance Evaluation & Comparison with Dataset Sequences on Randomly Generated Cellular Automata

In [21]:
NUM_TEST_EXPERIMENTS = 10
NUM_TEST_SEQUENCES = 10
GRID_SIZE = 32
RULESET_MUTATOR_CLASS = ArbitraryRulesetMutator
RULE_SET = [conway(), seeds()]
OPT_FUNC = surface_to_vol
SRT_NUM_MUTATE = 10
IC_NUM_MUTATE = 10
RULESET_MUTE_PROB = 2/3

def export(lattice: Lattice, iteration: int, export_dir: str):
    stlfile = os.path.join(export_dir, "genetic%.5d.stl" % iteration)
    lattice.clear_lattice()
    lattice.update_selected(optim.state.generate())
    lattice.export_stl(stlfile)
    with open(os.path.join(export_dir, "objective.txt"), "a", encoding="utf-8") as w:
        w.write(f"{iteration}: {optim.objvalue}\n")


# Measure the difference between the resulting performance metrics when applying the predicted mutations and running the automaton

performance_logs = []

for num_test in range(NUM_TEST_EXPERIMENTS):
    # clear_initial()
    # lattice = Lattice(dim=(GRID_SIZE, GRID_SIZE, GRID_SIZE))

    ruleset_mutator_init = RULESET_MUTATOR_CLASS(rules=RULE_SET, grid_size=GRID_SIZE, mutate_p=1/(GRID_SIZE**2) * (SRT_NUM_MUTATE+IC_NUM_MUTATE), rule_mutate_p=RULESET_MUTE_PROB)

    optim_init = Optimizer(mutator=ruleset_mutator_init, objective=lambda grid: OPT_FUNC(grid))

    model_perf_improvs = []
    model_perf_sum = 0
    model_mut_batch_num = 0
    data_perf_improvs = []
    data_perf_sum = 0
    data_mut_batch_num = 0
    
    # test the mutation prediction model
    for repeat_num in range(NUM_TEST_SEQUENCES):
        optim = deepcopy(optim_init)
        
        curr_level = "seq" # seq, mutb, icm, srtm
        curr_output = torch.full((1,max_seq_len), int(str_to_tensor("[PAD]")[0]))
        curr_idx = 1
        in_idx = 0
        curr_output[0][0] = str_to_tensor("[BOS]")[0]
        ic_mutations, srt_mutations = [], []
        curr_mutation = None
        it = 0

        while curr_idx < max_seq_len:
            with torch.no_grad():
                next_output = model.net(curr_output.cuda())
                # print(model.net.num_tokens)
                # print(next_output.shape)
                logits = next_output[0][curr_idx]
                # print(logits)
            valid = False
            added_perf = False
            while valid == False:
                token = int(sample(logits)[0])
                # print(token)
                if curr_level == "seq":
                    if token == int(str_to_tensor("[BMB]")[0]):
                        # print('STARTING MUTATION BATCH')
                        curr_level = "mutb"
                        valid = True
                elif curr_level == "mutb":
                    if token == int(str_to_tensor("[EMB]")[0]):
                        # perform mutations predicted & add the resulting performance to the sequence
                        # print(f"MUTATION BATCH {it}: IC: {ic_mutations}, SRT: {srt_mutations}")
                        old_objval = optim.objvalue
                        accepted, new_state = optim.step_muts(ic_mutations, srt_mutations)
                        if accepted:
                            # print("Got a better state!", optim.objvalue)
                            model_perf_sum += optim.objvalue - old_objval
                            model_perf_improvs.append(optim.objvalue - old_objval)
                            # export(it + 1, lattice, export_dir)
                        perf_tokens = str_to_tensor(f'[BP] {decimal_to_tokens(optim.objvalue)} [EP] ')[0]
                        added_perf = True
                        it += 1
                        model_mut_batch_num += 1
                        # end mutation batch
                        curr_level = "seq"
                        valid = True
                        ic_mutations, srt_mutations = [], []
                    elif token == int(str_to_tensor("[ICM]")[0]):
                        # got ic mutation
                        curr_mutation = [0, 0]
                        curr_level = "icm"
                        valid = True
                    elif token == int(str_to_tensor("[SRTM]")[0]):
                        # got srt mutation
                        curr_mutation = [0, 0, 0]
                        curr_level = "srtm"
                        valid = True
                elif curr_level == "icm":
                    decoded = tensor_decode(torch.tensor(token, dtype=torch.int8))
                    if decoded.isdigit() and int(decoded) < GRID_SIZE:
                        # got ic dimension
                        curr_mutation[in_idx] = int(decoded)
                        in_idx += 1
                        valid = True
                        if in_idx == 2:
                            in_idx = 0
                            curr_level = "mutb"
                            # print(f"ICM: {curr_mutation}")
                            ic_mutations.append(curr_mutation)
                            curr_mutation = None

                elif curr_level == "srtm":
                    decoded = tensor_decode(torch.tensor(token, dtype=torch.int8))
                    if decoded.isdigit() and (in_idx > 0 or int(decoded) < GRID_SIZE-1):
                        # got srt dimension
                        curr_mutation[in_idx] = int(decoded)
                        in_idx += 1
                        valid = True
                        if in_idx == 3:
                            in_idx = 0
                            curr_level = "mutb"
                            # print(f"SRTM: {curr_mutation}")
                            srt_mutations.append(curr_mutation)
                            curr_mutation = None
            curr_output[0][curr_idx] = token
            curr_idx += 1
            if added_perf and (curr_idx + perf_tokens.shape[0] <= 100):
                for i in range(perf_tokens.shape[0]):
                    curr_output[0][curr_idx] = perf_tokens[i]
                    curr_idx += 1
            
    # test the dataset predictions
    for repeat_num in random.sample(range(len(encoded)), NUM_TEST_SEQUENCES):
        optim = deepcopy(optim_init)

        curr_level = "seq" # seq, mutb, icm, srtm, perf
        curr_idx = 1
        in_idx = 0
        ic_mutations, srt_mutations = [], []
        curr_mutation = None
        it = 0

        while curr_idx < len(encoded[repeat_num]):
            valid = False
            token = encoded[repeat_num][curr_idx]
            # print(tensor_decode(torch.tensor(token, dtype=torch.int8)))
            if curr_level == "seq":
                if token == int(str_to_tensor("[BMB]")[0]):
                    # print('STARTING MUTATION BATCH')
                    curr_level = "mutb"
                    valid = True
                elif token == int(str_to_tensor("[EOS]")[0]):
                    break
            elif curr_level == "mutb":
                if token == int(str_to_tensor("[EMB]")[0]):
                    # perform mutations predicted & add the resulting performance to the sequence
                    # print(f"MUTATION BATCH {it}: IC: {ic_mutations}, SRT: {srt_mutations}")
                    old_objval = optim.objvalue
                    accepted, new_state = optim.step_muts(ic_mutations, srt_mutations)
                    if accepted:
                        # print("Got a better state!", optim.objvalue)
                        data_perf_sum += optim.objvalue - old_objval
                        data_perf_improvs.append(optim.objvalue - old_objval)
                        # export(it + 1, lattice, export_dir)
                    it += 1
                    data_mut_batch_num += 1
                    # end mutation batch
                    curr_level = "seq"
                    valid = True
                    ic_mutations, srt_mutations = [], []
                elif token == int(str_to_tensor("[ICM]")[0]):
                    # got ic mutation
                    # print("GOT IC MUTATION")
                    curr_mutation = [0, 0]
                    curr_level = "icm"
                    valid = True
                elif token == int(str_to_tensor("[SRTM]")[0]):
                    # got srt mutation
                    # print("GOT SRT MUTATION")
                    curr_mutation = [0, 0, 0]
                    curr_level = "srtm"
                    valid = True
                elif token == int(str_to_tensor("[BP]")[0]):
                    # print("GOT PERF")
                    curr_level = "perf"
                    valid = True
            elif curr_level == "icm":
                decoded = tensor_decode(torch.tensor(token, dtype=torch.int8))
                if decoded.isdigit() and int(decoded) < GRID_SIZE:
                    # print("GOT IC DIM")
                    # got ic dimension
                    curr_mutation[in_idx] = int(decoded)
                    in_idx += 1
                    valid = True
                    if in_idx == 2:
                        in_idx = 0
                        curr_level = "mutb"
                        # print(f"ICM: {curr_mutation}")
                        ic_mutations.append(curr_mutation)
                        curr_mutation = None

            elif curr_level == "srtm":
                decoded = tensor_decode(torch.tensor(token, dtype=torch.int8))
                if decoded.isdigit() and (in_idx > 0 or int(decoded) < GRID_SIZE-1):
                    # print("GOT SRT DIM")
                    # got srt dimension
                    curr_mutation[in_idx] = int(decoded)
                    in_idx += 1
                    valid = True
                    if in_idx == 3:
                        in_idx = 0
                        curr_level = "mutb"
                        # print(f"SRTM: {curr_mutation}")
                        srt_mutations.append(curr_mutation)
                        curr_mutation = None
            
            elif curr_level == "perf":
                if token == int(str_to_tensor("[EP]")[0]):
                    # print("ENDED PERF")
                    curr_level = "mutb"
                valid = True
            assert valid, "NOT VALID DATASET SEQUENCE"
            curr_idx += 1
    
    performance_logs.append({"model_perf_improvs": model_perf_improvs, "avg_model_improv": float(model_perf_sum)/model_mut_batch_num if model_mut_batch_num != 0 else 0,
                             "data_perf_improvs": data_perf_improvs, "avg_data_improv": float(data_perf_sum)/data_mut_batch_num if data_mut_batch_num != 0 else 0})
    
print(f'PERFORMANCE COMPARISON: {performance_logs}')
with open(os.path.join(model_dir,'model_evaluation.json'), 'w', encoding='utf-8') as f:
    json.dump(performance_logs, f, ensure_ascii=False, indent=4)

  decoded = tensor_decode(torch.tensor(token, dtype=torch.int8))
  decoded = tensor_decode(torch.tensor(token, dtype=torch.int8))


PERFORMANCE COMPARISON: [{'model_perf_improvs': [-0.044068912901405355, -0.007574957801577575], 'avg_model_improv': -0.017214623567660976, 'data_perf_improvs': [-0.023274473284094377, -0.007173798239308304, -0.0318346823092206, -0.02760949071494423, -0.012920683534611399, -0.004234297515270491, -0.01398144397686174, -0.07412339637322063, -0.06690008805748526, -0.041100105801689324, -0.00475752413531616, -0.03044045037737053, -0.045496229340044714, -0.002796427288607184, -0.00855516128332301, -0.01591174864868261, -0.0028485949240408814, -0.012236114673535248, -0.04084392779582657, -0.01783995405248895, -0.006627017184237971, -0.0017586451684348248, -0.07026247771835958, -0.0017851857767672286, -0.043866141219711, -0.012609093043772113, -0.0021991302037065452, -0.03467433066428249, -0.008965340028307622, -0.034461397628873236, -0.011533697376996166, -0.02374921800110208, -0.07248760921504704, -0.014405228758169741, -0.005781424234284849, -0.05852704613621995, -0.05799978097793357, -0.00

Comparison of Transformer Prediction Performance with Dataset Sequences on their respective Initial Conditions (Ground Truth)

In [24]:
NUM_TEST_EXPERIMENTS = 10
NUM_TEST_SEQUENCES = 10
GRID_SIZE = 32
RULESET_MUTATOR_CLASS = ArbitraryRulesetMutator
RULE_SET = [conway(), seeds()]
OPT_FUNC = surface_to_vol
SRT_NUM_MUTATE = 10
IC_NUM_MUTATE = 10
RULESET_MUTE_PROB = 2/3

# Measure the difference between the resulting performance metrics when applying the predicted mutations and running the automaton

performance_logs = []

for num_test in random.sample(range(len(encoded)), NUM_TEST_SEQUENCES):
    # clear_initial()
    # lattice = Lattice(dim=(GRID_SIZE, GRID_SIZE, GRID_SIZE))

    ruleset_mutator_init = RULESET_MUTATOR_CLASS(rules=RULE_SET, grid_size=GRID_SIZE, mutate_p=1/(GRID_SIZE**2) * (SRT_NUM_MUTATE+IC_NUM_MUTATE), rule_mutate_p=RULESET_MUTE_PROB)
    
    optim_init = Optimizer(mutator=ruleset_mutator_init, objective=lambda grid: OPT_FUNC(grid))

    optim_init.state.initial = initial_ics[num_test]
    optim_init.state.rules = initial_srts[num_test]

    model_perf_improvs = []
    model_perf_sum = 0
    model_mut_batch_num = 0
    data_perf_improvs = []
    data_perf_sum = 0
    data_mut_batch_num = 0
    
    # test the mutation prediction model
    for repeat_num in range(NUM_TEST_EXPERIMENTS):
        optim = deepcopy(optim_init)
        
        curr_level = "seq" # seq, mutb, icm, srtm
        curr_output = torch.full((1,max_seq_len), int(str_to_tensor("[PAD]")[0]))
        curr_idx = 1
        in_idx = 0
        curr_output[0][0] = str_to_tensor("[BOS]")[0]
        ic_mutations, srt_mutations = [], []
        curr_mutation = None
        it = 0

        while curr_idx < max_seq_len:
            with torch.no_grad():
                next_output = model.net(curr_output.cuda())
                # print(model.net.num_tokens)
                # print(next_output.shape)
                logits = next_output[0][curr_idx]
                # print(logits)
            valid = False
            added_perf = False
            while valid == False:
                token = int(sample(logits)[0])
                # print(token)
                if curr_level == "seq":
                    if token == int(str_to_tensor("[BMB]")[0]):
                        # print('STARTING MUTATION BATCH')
                        curr_level = "mutb"
                        valid = True
                elif curr_level == "mutb":
                    if token == int(str_to_tensor("[EMB]")[0]):
                        # perform mutations predicted & add the resulting performance to the sequence
                        # print(f"MUTATION BATCH {it}: IC: {ic_mutations}, SRT: {srt_mutations}")
                        old_objval = optim.objvalue
                        accepted, new_state = optim.step_muts(ic_mutations, srt_mutations)
                        if accepted:
                            # print("Got a better state!", optim.objvalue)
                            model_perf_sum += optim.objvalue - old_objval
                            model_perf_improvs.append(optim.objvalue - old_objval)
                            # export(it + 1, lattice, export_dir)
                        perf_tokens = str_to_tensor(f'[BP] {decimal_to_tokens(optim.objvalue)} [EP] ')[0]
                        added_perf = True
                        it += 1
                        model_mut_batch_num += 1
                        # end mutation batch
                        curr_level = "seq"
                        valid = True
                        ic_mutations, srt_mutations = [], []
                    elif token == int(str_to_tensor("[ICM]")[0]):
                        # got ic mutation
                        curr_mutation = [0, 0]
                        curr_level = "icm"
                        valid = True
                    elif token == int(str_to_tensor("[SRTM]")[0]):
                        # got srt mutation
                        curr_mutation = [0, 0, 0]
                        curr_level = "srtm"
                        valid = True
                elif curr_level == "icm":
                    decoded = tensor_decode(torch.tensor(token, dtype=torch.int8))
                    if decoded.isdigit() and int(decoded) < GRID_SIZE:
                        # got ic dimension
                        curr_mutation[in_idx] = int(decoded)
                        in_idx += 1
                        valid = True
                        if in_idx == 2:
                            in_idx = 0
                            curr_level = "mutb"
                            # print(f"ICM: {curr_mutation}")
                            ic_mutations.append(curr_mutation)
                            curr_mutation = None

                elif curr_level == "srtm":
                    decoded = tensor_decode(torch.tensor(token, dtype=torch.int8))
                    if decoded.isdigit() and (in_idx > 0 or int(decoded) < GRID_SIZE-1):
                        # got srt dimension
                        curr_mutation[in_idx] = int(decoded)
                        in_idx += 1
                        valid = True
                        if in_idx == 3:
                            in_idx = 0
                            curr_level = "mutb"
                            # print(f"SRTM: {curr_mutation}")
                            srt_mutations.append(curr_mutation)
                            curr_mutation = None
            curr_output[0][curr_idx] = token
            curr_idx += 1
            if added_perf and (curr_idx + perf_tokens.shape[0] <= 100):
                for i in range(perf_tokens.shape[0]):
                    curr_output[0][curr_idx] = perf_tokens[i]
                    curr_idx += 1
            
    # test the dataset predictions
    optim = deepcopy(optim_init)

    curr_level = "seq" # seq, mutb, icm, srtm, perf
    curr_idx = 1
    in_idx = 0
    ic_mutations, srt_mutations = [], []
    curr_mutation = None
    it = 0

    while curr_idx < len(encoded[num_test]):
        valid = False
        token = encoded[num_test][curr_idx]
        # print(tensor_decode(torch.tensor(token, dtype=torch.int8)))
        if curr_level == "seq":
            if token == int(str_to_tensor("[BMB]")[0]):
                # print('STARTING MUTATION BATCH')
                curr_level = "mutb"
                valid = True
            elif token == int(str_to_tensor("[EOS]")[0]):
                break
        elif curr_level == "mutb":
            if token == int(str_to_tensor("[EMB]")[0]):
                # perform mutations predicted & add the resulting performance to the sequence
                # print(f"MUTATION BATCH {it}: IC: {ic_mutations}, SRT: {srt_mutations}")
                old_objval = optim.objvalue
                accepted, new_state = optim.step_muts(ic_mutations, srt_mutations)
                # print(f"old obj: {old_objval}, new obj: {optim.objvalue}")
                if accepted:
                    # print("Got a better state!", optim.objvalue)
                    data_perf_sum += optim.objvalue - old_objval
                    data_perf_improvs.append(optim.objvalue - old_objval)
                    # export(it + 1, lattice, export_dir)
                it += 1
                data_mut_batch_num += 1
                # end mutation batch
                curr_level = "seq"
                valid = True
                ic_mutations, srt_mutations = [], []
            elif token == int(str_to_tensor("[ICM]")[0]):
                # got ic mutation
                # print("GOT IC MUTATION")
                curr_mutation = [0, 0]
                curr_level = "icm"
                valid = True
            elif token == int(str_to_tensor("[SRTM]")[0]):
                # got srt mutation
                # print("GOT SRT MUTATION")
                curr_mutation = [0, 0, 0]
                curr_level = "srtm"
                valid = True
            elif token == int(str_to_tensor("[BP]")[0]):
                # print("GOT PERF")
                curr_level = "perf"
                valid = True
        elif curr_level == "icm":
            decoded = tensor_decode(torch.tensor(token, dtype=torch.int8))
            if decoded.isdigit() and int(decoded) < GRID_SIZE:
                # print("GOT IC DIM")
                # got ic dimension
                curr_mutation[in_idx] = int(decoded)
                in_idx += 1
                valid = True
                if in_idx == 2:
                    in_idx = 0
                    curr_level = "mutb"
                    # print(f"ICM: {curr_mutation}")
                    ic_mutations.append(curr_mutation)
                    curr_mutation = None

        elif curr_level == "srtm":
            decoded = tensor_decode(torch.tensor(token, dtype=torch.int8))
            if decoded.isdigit() and (in_idx > 0 or int(decoded) < GRID_SIZE-1):
                # print("GOT SRT DIM")
                # got srt dimension
                curr_mutation[in_idx] = int(decoded)
                in_idx += 1
                valid = True
                if in_idx == 3:
                    in_idx = 0
                    curr_level = "mutb"
                    # print(f"SRTM: {curr_mutation}")
                    srt_mutations.append(curr_mutation)
                    curr_mutation = None
        
        elif curr_level == "perf":
            if token == int(str_to_tensor("[EP]")[0]):
                # print("ENDED PERF")
                curr_level = "mutb"
            valid = True
        assert valid, "NOT VALID DATASET SEQUENCE"
        curr_idx += 1
    
    performance_logs.append({"model_perf_improvs": model_perf_improvs, "avg_model_improv": float(model_perf_sum)/model_mut_batch_num if model_mut_batch_num != 0 else 0,
                             "data_perf_improvs": data_perf_improvs, "avg_data_improv": float(data_perf_sum)/data_mut_batch_num if data_mut_batch_num != 0 else 0})
    
print(f'PERFORMANCE COMPARISON: {performance_logs}')
with open(os.path.join(model_dir,'model_comparison_ground_truth.json'), 'w', encoding='utf-8') as f:
    json.dump(performance_logs, f, ensure_ascii=False, indent=4)

  decoded = tensor_decode(torch.tensor(token, dtype=torch.int8))
  decoded = tensor_decode(torch.tensor(token, dtype=torch.int8))


PERFORMANCE COMPARISON: [{'model_perf_improvs': [], 'avg_model_improv': 0.0, 'data_perf_improvs': [-0.009709132184854141, -0.027187078029807132, -0.01698703576466265, -0.01977577601934044, -0.0057668332343991935], 'avg_data_improv': -0.0007942585523306356}, {'model_perf_improvs': [-0.058499224260887495, -0.03221583141691298], 'avg_model_improv': -0.030238351892600157, 'data_perf_improvs': [-0.030642111941680206, -0.013590172932373967, -0.010283084896371975, -0.015728665641890238], 'avg_data_improv': -0.0007024403541231639}, {'model_perf_improvs': [], 'avg_model_improv': 0, 'data_perf_improvs': [-0.09795350108646605, -0.003244516229476524, -0.008993270224394223, -0.03082432757234166, -0.004169022046354165, -0.008820796777938789, -0.004346576711884431, -0.005795740649326753, -0.0039641393393736735], 'avg_data_improv': -0.0016811189063755628}, {'model_perf_improvs': [-0.055430474076423764, -0.05461823118713216, -0.062418944993305026], 'avg_model_improv': -0.04311691256421524, 'data_perf_i