<a href="https://colab.research.google.com/github/JohnKim8121/CoTLengthGeneralizationExperiment1/blob/main/Experiment1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [19]:
# Define the vocabulary and encoding/decoding
import string

# Base alphabet tokens
alphabet = list(string.ascii_uppercase)  # ['A','B',...,'Z']
# Special tokens
T1_TOKEN = "[F1]"   # denotes ROT13
T2_TOKEN = "[F2]"   # denotes POS1
THINK_TOKEN = "<think>"
ANSWER_TOKEN = "<answer>"

# Construct vocabulary list
vocab = alphabet + [T1_TOKEN, T2_TOKEN, THINK_TOKEN, ANSWER_TOKEN]
vocab_size = len(vocab)
token_to_id = {tok: i for i, tok in enumerate(vocab)}
id_to_token = {i: tok for tok, i in token_to_id.items()}

def encode_sequence(seq_tokens):
    """Encode a sequence of token strings (letters or special markers) to list of ids."""
    return [token_to_id[token] for token in seq_tokens]

def decode_sequence(id_list):
    """Decode a list of token ids back to token strings."""
    return [id_to_token[i] for i in id_list]

# Transformation functions
def apply_rot(sequence, n=13):
    """Apply ROT-n to a sequence of letters (list of chars)."""
    result = []
    for ch in sequence:
        if ch not in token_to_id or ch not in alphabet:
            raise ValueError(f"Unexpected token in sequence: {ch}")
        # shift letter by n
        new_idx = (ord(ch) - ord('A') + n) % 26
        result.append(chr(ord('A') + new_idx))
    return result

def apply_pos(sequence, n=1):
    """Apply cyclic position shift (left rotate by n) to a sequence of letters."""
    l = len(sequence)
    # left rotation by n: element at index i moves to index i-n (mod l) in the result
    return [sequence[(i + n) % l] for i in range(l)]


Experiment 1 is training len 3,4,5 to test len 1 and 6.

In [20]:
import random

# Function to generate a random sequence of given length
def random_sequence(length):
    return [random.choice(alphabet) for _ in range(length)]

# Generate training examples
train_examples = []
train_lengths = [3, 4, 5]  # in-distribution sequence lengths
for length in train_lengths:
    # For each length, generate a number of examples (you can adjust count for real training)
    for _ in range(1000):  # e.g., 1000 samples per length for illustration
        seq = random_sequence(length)
        # Randomly decide one-step or two-step transformation
        if random.random() < 0.5:
            # Single-step: choose F1 or F2 randomly
            if random.random() < 0.5:
                # ROT13 single-step
                prompt_tokens = seq + [T1_TOKEN, ANSWER_TOKEN]
                result_seq = apply_rot(seq, n=13)
            else:
                # POS1 single-step
                prompt_tokens = seq + [T2_TOKEN, ANSWER_TOKEN]
                result_seq = apply_pos(seq, n=1)
            output_tokens = result_seq  # final result only
        else:
            # Two-step: randomly choose combination of two transforms (allow repeats)
            # First transformation:
            first_is_rot = random.random() < 0.5
            if first_is_rot:
                interm_seq = apply_rot(seq, n=13)
                first_token = T1_TOKEN
            else:
                interm_seq = apply_pos(seq, n=1)
                first_token = T2_TOKEN
            # Second transformation:
            second_is_rot = random.random() < 0.5
            if second_is_rot:
                final_seq = apply_rot(interm_seq, n=13)
                second_token = T1_TOKEN
            else:
                final_seq = apply_pos(interm_seq, n=1)
                second_token = T2_TOKEN
            # Prompt includes both operations then <think>
            prompt_tokens = seq + [first_token, second_token, THINK_TOKEN]
            # Output includes intermediate result, second op token, <answer>, then final result
            output_tokens = interm_seq + [second_token, ANSWER_TOKEN] + final_seq
        # Encode to token ids
        input_ids = encode_sequence(prompt_tokens)
        output_ids = encode_sequence(output_tokens)
        train_examples.append((input_ids, output_ids))

# Generate evaluation examples for length 1 and 6 (unseen lengths)
test_examples_len1 = []
test_examples_len6 = []
for _ in range(200):  # generate some test examples
    seq1 = random_sequence(1)
    seq6 = random_sequence(6)
    # We'll test on single and double ops for these lengths as well
    # Single op for length1
    res1 = apply_rot(seq1, 13)
    prompt1 = seq1 + [T1_TOKEN, ANSWER_TOKEN]
    out1 = res1
    test_examples_len1.append((encode_sequence(prompt1), encode_sequence(out1)))
    # Two ops for length6
    interm6 = apply_rot(seq6, 13)
    final6 = apply_pos(interm6, 1)
    prompt6 = seq6 + [T1_TOKEN, T2_TOKEN, THINK_TOKEN]  # e.g., first ROT13 then POS1
    out6 = interm6 + [T2_TOKEN, ANSWER_TOKEN] + final6
    test_examples_len6.append((encode_sequence(prompt6), encode_sequence(out6)))


In [21]:
print("Example training sample (decoded):")
ex_in, ex_out = random.choice(train_examples)
print("Prompt:", " ".join(decode_sequence(ex_in)))
print("Target:", " ".join(decode_sequence(ex_out)))


Example training sample (decoded):
Prompt: K O J T [F2] <answer>
Target: O J T K


In [40]:
i = 99
print((train_examples[i][0]), (train_examples[i][1]))
print(decode_sequence(train_examples[i][0]), decode_sequence(train_examples[i][1]))

[4, 11, 25, 26, 27, 28] [17, 24, 12, 27, 29, 24, 12, 17]
['E', 'L', 'Z', '[F1]', '[F2]', '<think>'] ['R', 'Y', 'M', '[F2]', '<answer>', 'Y', 'M', 'R']


In [22]:
!pip install flax optax




In [25]:
from flax import linen as nn
import jax
import jax.numpy as jnp
from functools import partial

# Model hyperparameters
d_model = 32        # embedding/hidden size
num_heads = 4       # number of attention heads
num_layers = 4      # number of transformer decoder layers
d_ff = 4 * d_model  # feed-forward hidden dim (e.g. 128)
dropout_rate = 0.0  # (for simplicity, we disable dropout in this example)

class TransformerDecoderBlock(nn.Module):
    """Single transformer decoder block with causal self-attention."""
    @nn.compact
    def __call__(self, x, train=True, attn_mask=None):
        # Layer norm
        norm1 = nn.LayerNorm()(x)
        # Self-attention (causal)
        attn = nn.SelfAttention(num_heads=num_heads,
                                qkv_features=d_model,
                                use_bias=True,
                                broadcast_dropout=False,
                                deterministic=not train)(norm1, mask=attn_mask)
        x = x + attn  # residual connection
        # Second layer norm and feed-forward MLP
        norm2 = nn.LayerNorm()(x)
        ff = nn.Dense(d_ff)(norm2)
        ff = nn.gelu(ff)
        ff = nn.Dense(d_model)(ff)
        x = x + ff  # residual connection
        return x

class GPTDecoder(nn.Module):
    """Small GPT-like decoder model."""

    max_len: int  # new argument
    @nn.compact
    def __call__(self, input_ids, train=True):
        # input_ids: [batch, seq_length]
        batch, seq_len = input_ids.shape
        # Token embeddings
        embed_init = nn.initializers.normal(stddev=0.02)
        token_embed = nn.Embed(num_embeddings=vocab_size, features=d_model,
                               embedding_init=embed_init)
        x = token_embed(input_ids)  # shape [batch, seq_len, d_model]
        # Positional embeddings (learned)
        pos_embed = self.param("pos_embedding", nn.initializers.normal(stddev=0.01),
                               (1, self.max_len, d_model))
        x = x + pos_embed[:, :seq_len, :]
        # Apply multiple Transformer decoder blocks
        # Causal mask: allow attention only to current and past positions
        # mask shape: [batch, 1, seq_len, seq_len]
        if attn_mask := True:  # (will define mask below outside Module for JIT)
            pass
        attn_mask = None  # placeholder, see below on how we pass mask
        for _ in range(num_layers):
            x = TransformerDecoderBlock()(x, train=train, attn_mask=attn_mask)
        # Final layer norm
        x = nn.LayerNorm()(x)
        # Output projection
        logits = nn.Dense(vocab_size, use_bias=False)(x)
        return logits

# Create causal mask function
def causal_mask(seq_len):
    # Mask shape [1, 1, seq_len, seq_len] with True allowed positions
    # We want to mask out future => allow indices i >= j
    mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=jnp.bool_), k=0)
    # Add batch and head dimensions: shape (batch, heads, seq, seq)
    return mask.reshape(1, 1, seq_len, seq_len)

In [52]:
concat = train_examples[1][0]+train_examples[1][1]
concat

[3, 9, 15, 26, 29, 16, 22, 2]

In [53]:
concat[:-1]

[3, 9, 15, 26, 29, 16, 22]

In [57]:
concat[1:]
for i in range(10):
  print(train_inputs[i],train_labels[i])

[ 4  1  7 26 27 28 17 14 20 27 29 14 20  0  0  0  0  0  0] [ 1  7 26 27 28 17 14 20 27 29 14 20 17  0  0  0  0  0  0]
[ 3  9 15 26 29 16 22  0  0  0  0  0  0  0  0  0  0  0  0] [ 9 15 26 29 16 22  2  0  0  0  0  0  0  0  0  0  0  0  0]
[11 14 16 27 29 14 16  0  0  0  0  0  0  0  0  0  0  0  0] [14 16 27 29 14 16 11  0  0  0  0  0  0  0  0  0  0  0  0]
[20  7 17 26 29  7 20  0  0  0  0  0  0  0  0  0  0  0  0] [ 7 17 26 29  7 20  4  0  0  0  0  0  0  0  0  0  0  0  0]
[ 8 23 12 26 27 28 21 10 25 27 29 10 25  0  0  0  0  0  0] [23 12 26 27 28 21 10 25 27 29 10 25 21  0  0  0  0  0  0]
[ 9  6  8 26 29 22 19  0  0  0  0  0  0  0  0  0  0  0  0] [ 6  8 26 29 22 19 21  0  0  0  0  0  0  0  0  0  0  0  0]
[21 10  9 26 27 28  8 23 22 27 29 23 22  0  0  0  0  0  0] [10  9 26 27 28  8 23 22 27 29 23 22  8  0  0  0  0  0  0]
[ 9 22  0 26 27 28 22  9 13 27 29  9 13  0  0  0  0  0  0] [22  0 26 27 28 22  9 13 27 29  9 13 22  0  0  0  0  0  0]
[19 21 20 26 26 28  6  8  7 26 29 19 21  0  0  0  0  0  

####Check this part if it is correctly done!!

In [26]:
import optax

# Prepare training data as concatenated sequences and loss masks
train_inputs = []
train_labels = []
loss_masks = []  # 1 for tokens where loss should be applied, 0 for prompt tokens

for input_ids, output_ids in train_examples:
    concat = input_ids + output_ids  # concatenate prompt + target
    train_inputs.append(concat[:-1])   # model input: all except last token (next-token pred)
    train_labels.append(concat[1:])    # expected outputs: all except first token
    # Mask: 0 for prompt part labels, 1 for output part labels
    prompt_len = len(input_ids)
    seq_len = len(concat) - 1
    mask = [0]* (prompt_len - 1) + [1]* (len(concat) - prompt_len)
    # (We put 0 for prompt tokens *except* the last prompt token because once we reach <think>/<answer>,
    # the next token to predict is the first token of reasoning, which we do want to learn.
    # The above scheme masks all prompt indices except we allow starting from the first output token.)
    loss_masks.append(mask)

# Pad sequences to same length for batching
max_len = max(len(seq) for seq in train_inputs)
def pad(seq, length, pad_id=0):
    return seq + [pad_id] * (length - len(seq))
train_inputs = [pad(seq, max_len) for seq in train_inputs]
train_labels = [pad(seq, max_len) for seq in train_labels]
loss_masks = [pad(mask, max_len, pad_id=0) for mask in loss_masks]

train_inputs = jnp.array(train_inputs)
train_labels = jnp.array(train_labels)
loss_masks = jnp.array(loss_masks)

# Initialize model and optimizer
model = GPTDecoder(max_len=30)
params = model.init(jax.random.PRNGKey(0), train_inputs[:1], train=True)["params"]
optimizer = optax.adam(learning_rate=1e-3)
opt_state = optimizer.init(params)

# Define loss function
def compute_loss(params, batch_inputs, batch_labels, batch_mask):
    logits = model.apply({"params": params}, batch_inputs, train=True)
    # Compute cross-entropy
    num_classes = logits.shape[-1]
    # Flatten for convenience:
    logits_flat = logits.reshape(-1, num_classes)
    labels_flat = batch_labels.reshape(-1,)
    mask_flat = batch_mask.reshape(-1,)
    # Optax cross_entropy_with_integer_labels expects unnormalized logits
    loss = optax.softmax_cross_entropy_with_integer_labels(logits_flat, labels_flat)
    # Apply mask to zero-out loss for prompt tokens
    loss = loss * mask_flat
    # Average loss
    return loss.sum() / mask_flat.sum()

# Training loop (simple form)
num_epochs = 5
batch_size = 64
num_samples = train_inputs.shape[0]

for epoch in range(1, num_epochs+1):
    # Shuffle training data
    permutation = jax.random.permutation(jax.random.PRNGKey(epoch), num_samples)
    shuffled_inputs = train_inputs[permutation]
    shuffled_labels = train_labels[permutation]
    shuffled_masks = loss_masks[permutation]
    # Batch iteration
    for i in range(0, num_samples, batch_size):
        batch_in = shuffled_inputs[i:i+batch_size]
        batch_lbl = shuffled_labels[i:i+batch_size]
        batch_m = shuffled_masks[i:i+batch_size]
        # Compute gradients
        loss_val, grads = jax.value_and_grad(compute_loss)(params, batch_in, batch_lbl, batch_m)
        # Update params
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
    print(f"Epoch {epoch} done.")


Epoch 1 done.
Epoch 2 done.
Epoch 3 done.
Epoch 4 done.
Epoch 5 done.


In [27]:
def greedy_decode(prompt_ids, max_len):
    """Generate output from model given prompt (without including the prompt termination token)."""
    # Start with prompt as context
    generated = list(prompt_ids)
    # We assume prompt already contains <think> or <answer> token at end to signal start of gen.
    for _ in range(max_len):
        inputs = jnp.array([generated])  # shape (1, cur_len)
        # Create causal mask for the current sequence length
        current_seq_len = inputs.shape[1]
        # The causal mask is created inside the model now, no need to pass it here.
        # causal_mask = create_causal_mask(current_seq_len) if current_seq_len > 1 else None

        logits = model.apply({"params": params}, inputs, train=False)  # get logits
        next_id = int(jnp.argmax(logits[0, -1]))  # pick highest probability token
        generated.append(next_id)
        # Stop if <answer> was just produced and we have some output after it (meaning chain ended)
        if next_id == token_to_id[ANSWER_TOKEN] and len(generated) > len(prompt_ids) + 1:
            break
    return generated[len(prompt_ids):]  # return the generated portion (excluding prompt)

# Evaluate exact match on test examples
for test_set, name in [(test_examples_len1, "Length-1"), (test_examples_len6, "Length-6")]:
    correct = 0
    total = len(test_set)
    for inp_ids, true_out_ids in test_set:
        # decode until we produce as many tokens as true_out (or a bit more)
        gen_ids = greedy_decode(inp_ids, max_len=len(true_out_ids)+5)
        # Compare with true output
        # Note: need to stop at the same length as true output
        gen_ids = gen_ids[:len(true_out_ids)]
        if gen_ids == true_out_ids:
            correct += 1
    print(f"{name} exact match accuracy: {100 * correct/total:.2f}%")

Length-1 exact match accuracy: 8.00%
Length-6 exact match accuracy: 0.00%


In [42]:
decode_sequence(test_set[1][0])

['R', 'X', 'B', 'A', 'U', 'Q', '[F1]', '[F2]', '<think>']

In [45]:
print(decode_sequence(test_set[1][1]))
print(test_set[1])

['E', 'K', 'O', 'N', 'H', 'D', '[F2]', '<answer>', 'K', 'O', 'N', 'H', 'D', 'E']
([17, 23, 1, 0, 20, 16, 26, 27, 28], [4, 10, 14, 13, 7, 3, 27, 29, 10, 14, 13, 7, 3, 4])


In [47]:
greedy_decode(test_set[1][0],max_len=100)

[16, 16, 17, 16, 26, 29]

In [55]:
b=99
print(decode_sequence(greedy_decode(test_set[b][0],max_len=100)))
print(decode_sequence(test_set[b][1]))

['X', 'H', 'L', 'A', '[F1]', '<answer>']
['T', 'C', 'V', 'B', 'U', 'K', '[F2]', '<answer>', 'C', 'V', 'B', 'U', 'K', 'T']


Next is testing len 3, 4 and 5

In [29]:
test_examples_len3 = []
test_examples_len4 = []
test_examples_len5 = []
for _ in range(200):  # generate some test examples
    seq3 = random_sequence(3)
    seq4 = random_sequence(4)
    seq5 = random_sequence(5)

    # Two ops for each length
    interm3 = apply_rot(seq3, 13)
    final3 = apply_pos(interm3, 1)
    prompt3 = seq3 + [T1_TOKEN, T2_TOKEN, THINK_TOKEN]  # e.g., first ROT13 then POS1
    out3 = interm3 + [T2_TOKEN, ANSWER_TOKEN] + final3
    test_examples_len3.append((encode_sequence(prompt3), encode_sequence(out3)))

    interm4 = apply_rot(seq4, 13)
    final4 = apply_pos(interm4, 1)
    prompt4 = seq4 + [T1_TOKEN, T2_TOKEN, THINK_TOKEN]  # e.g., first ROT13 then POS1
    out4 = interm4 + [T2_TOKEN, ANSWER_TOKEN] + final4
    test_examples_len4.append((encode_sequence(prompt4), encode_sequence(out4)))

    interm5 = apply_rot(seq5, 13)
    final5 = apply_pos(interm5, 1)
    prompt5 = seq5 + [T1_TOKEN, T2_TOKEN, THINK_TOKEN]  # e.g., first ROT13 then POS1
    out5 = interm5 + [T2_TOKEN, ANSWER_TOKEN] + final5
    test_examples_len5.append((encode_sequence(prompt5), encode_sequence(out5)))

In [31]:
# Evaluate exact match on test examples
for test_set, name in [(test_examples_len3, "Length-3"), (test_examples_len4, "Length-4"), (test_examples_len6, "Length-5")]:
    correct = 0
    total = len(test_set)
    for inp_ids, true_out_ids in test_set:
        # decode until we produce as many tokens as true_out (or a bit more)
        gen_ids = greedy_decode(inp_ids, max_len=len(true_out_ids)+5)
        # Compare with true output
        # Note: need to stop at the same length as true output
        gen_ids = gen_ids[:len(true_out_ids)]
        if gen_ids == true_out_ids:
            correct += 1
    print(f"{name} exact match accuracy: {100 * correct/total:.2f}%")

Length-3 exact match accuracy: 0.00%
Length-4 exact match accuracy: 0.00%
Length-5 exact match accuracy: 0.00%
