In [32]:
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader

import os
import itertools

In [None]:
LIST_LEN = 2 # [d1, d2]
SEQ_LEN = LIST_LEN * 2 + 1 # [d1, d2, SEP, o1, o2]
NO_DUPES = True # whether to use the no-dupes test dataset (i.e. d1 != d2)

N_DIGITS = 100
DIGITS = list(range(N_DIGITS)) # 100 digits from 0 to 99
PAD = N_DIGITS # special padding token
SEP = N_DIGITS + 1 # special seperator token for the model to think about the input (+1 to avoid confusion with the last digit)
VOCAB = len(DIGITS) + 2  # + the special tokens

# For backward compatibility with older versions
USE_PAD = True # whether to use the PAD token in the input sequences (or just SEP)
if not USE_PAD:
    VOCAB -= 1  # -1 for the PAD token    

D_MODEL = 16
N_HEAD = 1 # 1
N_LAYER = 3 # 2
USE_LN = False # use layer norm in model
USE_BIAS = False # use bias in model
FREEZE_WV = True # no value matrix in attn 
FREEZE_WO = True # no output matrix in attn (i.e. attn head can only copy inputs to outputs)
WEIGHT_DECAY = 0.01 # default 0.01

TRAIN_SPLIT = 0.8 # 80% train, 20% test

DEV = (
    "cuda"
    if torch.cuda.is_available()
    else ("mps" if torch.backends.mps.is_available() else "cpu")
)
device = DEV
torch.manual_seed(0)


<torch._C.Generator at 0x7e84d0214750>

In [34]:
def get_dataset():
    # Create all possible combinations of digits
    all_data = list(itertools.product(DIGITS, repeat=LIST_LEN))
    n_data = len(all_data)
    all_data = torch.tensor(all_data, dtype=torch.int64)
    if NO_DUPES:
        # Filter out combinations where d1 == d2
        all_data = all_data[all_data[:, 0] != all_data[:, 1]]
        n_data = len(all_data)

    # Create sequences of the form [d1, d2, SEP, d1, d2]
    all_targets = torch.full((n_data, SEQ_LEN), SEP)
    all_targets[:, :LIST_LEN] = all_data
    all_targets[:, LIST_LEN+1:] = all_data

    # Create input sequences of the form [d1, d2, SEP, PAD, PAD]
    all_inputs = all_targets.clone()
    all_inputs[:, LIST_LEN+1:] = PAD

    # Shuffle the dataset (inputs and targets together)
    perm = torch.randperm(n_data)
    all_inputs = all_inputs[perm]
    all_targets = all_targets[perm]

    train_ds = TensorDataset(all_inputs[:int(TRAIN_SPLIT*n_data)], all_targets[:int(TRAIN_SPLIT*n_data)])  # 80% for training
    val_ds = TensorDataset(all_inputs[int(TRAIN_SPLIT*n_data):], all_targets[int(TRAIN_SPLIT*n_data):])  # 20% for validation
        
    return train_ds, val_ds

In [35]:
DATASET_NAME = f"listlen{LIST_LEN}_digits{N_DIGITS}_{'nodupes' if NO_DUPES else 'dupes'}"
DATASET_PATH = f"data/{DATASET_NAME}.pt"

if os.path.exists(DATASET_PATH):
    raise FileExistsError(f"{DATASET_PATH} already exists. Please delete it or change the dataset name.")

train_ds, val_ds = get_dataset()

torch.save({
    'train': train_ds,
    'val': val_ds
}, DATASET_PATH)

print(f"Dataset saved to {DATASET_PATH}")
print("Train dataset size:", len(train_ds))
print("Validation dataset size:", len(val_ds))
print("Input example:", train_ds[0][0])
print("Target example:", train_ds[0][1])

Dataset saved to data/listlen2_digits100_dupes.pt
Train dataset size: 8000
Validation dataset size: 2000
Input example: tensor([ 60,  44, 101, 100, 100])
Target example: tensor([ 60,  44, 101,  60,  44])
