<a href="https://colab.research.google.com/github/JohnKim8121/CoTLengthGeneralizationExperiment1/blob/main/Experiment1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
# Define the vocabulary and encoding/decoding
import string
import torch # Import torch here for tensor conversion

# Base alphabet tokens
alphabet = list(string.ascii_uppercase)  # ['A','B',...,'Z']
# Special tokens
T1_TOKEN = "[F1]"   # denotes ROT13
T2_TOKEN = "[F2]"   # denotes POS1
THINK_TOKEN = "<think>"
ANSWER_TOKEN = "<answer>"
endoftext_token = "<|endoftext|>"


# Construct vocabulary list
vocab = alphabet + [T1_TOKEN, T2_TOKEN, THINK_TOKEN, ANSWER_TOKEN,endoftext_token]
vocab_size = len(vocab)
token_to_id = {tok: i for i, tok in enumerate(vocab)}
id_to_token = {i: tok for tok, i in token_to_id.items()}

def encode_sequence(seq_tokens):
    """Encode a sequence of token strings (letters or special markers) to list of ids."""
    return [token_to_id[token] for token in seq_tokens]

def decode_sequence(id_list):
    """Decode a list of token ids back to token strings."""
    # Convert tensor elements to integers before using as dictionary keys
    return [id_to_token[int(i)] for i in id_list]

# Transformation functions
def apply_rot(sequence, n=13):
    """Apply ROT-n to a sequence of letters (list of chars)."""
    result = []
    for ch in sequence:
        if ch not in token_to_id or ch not in alphabet:
            raise ValueError(f"Unexpected token in sequence: {ch}")
        # shift letter by n
        new_idx = (ord(ch) - ord('A') + n) % 26
        result.append(chr(ord('A') + new_idx))
    return result

def apply_pos(sequence, n=1):
    """Apply cyclic position shift (left rotate by n) to a sequence of letters."""
    l = len(sequence)
    # left rotation by n: element at index i moves to index i-n (mod l) in the result
    return [sequence[(i + n) % l] for i in range(l)]

Experiment 1 is training len 3,4,5 to test len 1 and 6.

In [3]:
import random

# Function to generate a random sequence of given length
def random_sequence(length):
    return [random.choice(alphabet) for _ in range(length)]

# Generate training examples
train_examples3 = []
train_examples4 = []
train_examples5 = []
train_lengths = [3, 4, 5]  # in-distribution sequence lengths
for length in train_lengths:
    # For each length, generate a number of examples (you can adjust count for real training)
    for _ in range(10000):  # e.g., 1000 samples per length for illustration
        seq = random_sequence(length)
        # Randomly decide one-step or two-step transformation
        if random.random() < 0.5:
            # Single-step: choose F1 or F2 randomly
            if random.random() < 0.5:
                # ROT13 single-step
                prompt_tokens = seq + [T1_TOKEN, ANSWER_TOKEN]
                result_seq = apply_rot(seq, n=13)
            else:
                # POS1 single-step
                prompt_tokens = seq + [T2_TOKEN, ANSWER_TOKEN]
                result_seq = apply_pos(seq, n=1)
            output_tokens = result_seq  # final result only
        else:
            # Two-step: randomly choose combination of two transforms (allow repeats)
            # First transformation:
            first_is_rot = random.random() < 0.5
            if first_is_rot:
                interm_seq = apply_rot(seq, n=13)
                first_token = T1_TOKEN
            else:
                interm_seq = apply_pos(seq, n=1)
                first_token = T2_TOKEN
            # Second transformation:
            second_is_rot = random.random() < 0.5
            if second_is_rot:
                final_seq = apply_rot(interm_seq, n=13)
                second_token = T1_TOKEN
            else:
                final_seq = apply_pos(interm_seq, n=1)
                second_token = T2_TOKEN
            # Prompt includes both operations then <think>
            prompt_tokens = seq + [first_token, second_token, THINK_TOKEN]
            # Output includes intermediate result, second op token, <answer>, then final result
            output_tokens = interm_seq + [second_token, ANSWER_TOKEN] + final_seq
        # Encode to token ids
        input_ids = encode_sequence(prompt_tokens)
        output_ids = encode_sequence(output_tokens)
        if length == 3:
          train_examples3.append((input_ids, output_ids))
        elif length == 4:
          train_examples4.append((input_ids, output_ids))
        elif length == 5:
          train_examples5.append((input_ids, output_ids))

# Generate evaluation examples for length 1 and 6 (unseen lengths)
test_examples_len1 = []
test_examples_len6 = []
for _ in range(200):  # generate some test examples
    seq1 = random_sequence(1)
    seq6 = random_sequence(6)
    # We'll test on single and double ops for these lengths as well
    # Single op for length1
    res1 = apply_rot(seq1, 13)
    prompt1 = seq1 + [T1_TOKEN, ANSWER_TOKEN]
    out1 = res1
    test_examples_len1.append((encode_sequence(prompt1), encode_sequence(out1)))
    # Two ops for length6
    interm6 = apply_rot(seq6, 13)
    final6 = apply_pos(interm6, 1)
    prompt6 = seq6 + [T1_TOKEN, T2_TOKEN, THINK_TOKEN]  # e.g., first ROT13 then POS1
    out6 = interm6 + [T2_TOKEN, ANSWER_TOKEN] + final6
    test_examples_len6.append((encode_sequence(prompt6), encode_sequence(out6)))


In [4]:
print("Example training sample (decoded):")
ex_in, ex_out = random.choice(train_examples3)
print("Prompt:", " ".join(decode_sequence(ex_in)))
print("Target:", " ".join(decode_sequence(ex_out)))

Example training sample (decoded):
Prompt: Q N G [F2] [F2] <think>
Target: N G Q [F2] <answer> G Q N


In [5]:
i = 109
print((train_examples4[i][0]), (train_examples4[i][1]))
print(decode_sequence(train_examples4[i][0]), decode_sequence(train_examples4[i][1]))

[16, 9, 15, 10, 27, 27, 28] [9, 15, 10, 16, 27, 29, 15, 10, 16, 9]
['Q', 'J', 'P', 'K', '[F2]', '[F2]', '<think>'] ['J', 'P', 'K', 'Q', '[F2]', '<answer>', 'P', 'K', 'Q', 'J']


GPT2----------

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Define the PyTorch model
class GPTDecoderTorch(nn.Module):
    def __init__(self, vocab_size, d_model=32, num_heads=4, num_layers=4, d_ff=128):
        super().__init__()
        self.d_model = d_model
        # Token embedding and positional embedding
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Parameter(torch.zeros(1, 100, d_model))  # max position 100 for example
        # Transformer decoder layers
        self.layers = nn.ModuleList([
            nn.TransformerDecoderLayer(d_model, num_heads, dim_feedforward=d_ff, dropout=0.0)
            for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(d_model)
        self.out_proj = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, x, attn_mask=None):
        """
        x: Tensor of shape (batch, seq_len) of token ids.
        attn_mask: Causal mask of shape (seq_len, seq_len) if provided.
        """
        batch, seq_len = x.shape
        # Input embeddings
        tok_embeddings = self.token_emb(x)  # (batch, seq_len, d_model)
        tok_embeddings = tok_embeddings + self.pos_emb[:, :seq_len, :]
        # We need to transpose to shape (seq_len, batch, d_model) for PyTorch Transformer
        hs = tok_embeddings.transpose(0, 1)  # (seq_len, batch, d_model)
        # Pass through each decoder layer (as we are not using an encoder, we treat it as decoder-only)
        for layer in self.layers:
            hs = layer(hs, hs, tgt_mask=attn_mask)  # decoder layer with no encoder (so using tgt as both)
        hs = self.norm(hs)
        logits = self.out_proj(hs)  # (seq_len, batch, vocab_size)
        return logits.transpose(0, 1)  # return to (batch, seq_len, vocab_size)


####Check this part if it is correctly done!!

In [7]:
train_portion = int(len(train_examples3) * 0.85)  # 85% for training
test_portion = int(len(train_examples3) * 0.1)    # 10% for testing
val_portion = len(train_examples3) - train_portion - test_portion  # Remaining 5% for validation

train_data3 = train_examples3[:train_portion]
test_data3 = train_examples3[train_portion:train_portion + test_portion]
val_data3 = train_examples3[train_portion + test_portion:]

train_data4 = train_examples4[:train_portion]
test_data4 = train_examples4[train_portion:train_portion + test_portion]
val_data4 = train_examples4[train_portion + test_portion:]

train_data5 = train_examples5[:train_portion]
test_data5 = train_examples5[train_portion:train_portion + test_portion]
val_data5 = train_examples5[train_portion + test_portion:]

train_data = train_data3 + train_data4 + train_data5
test_data = test_data3 + test_data4 + test_data5
val_data = val_data3 + val_data4 + val_data5

print("Training set length:", len(train_data))
print("Validation set length:", len(val_data))
print("Test set length:", len(test_data))

Training set length: 25500
Validation set length: 1500
Test set length: 3000


Model and optimizer.

In [8]:
model_torch = GPTDecoderTorch(vocab_size, d_model=32, num_heads=4, num_layers=4, d_ff=4*32)
optimizer = torch.optim.Adam(model_torch.parameters(), lr=1e-3)

In [9]:
inputs = torch.tensor([[23, 9, 25, 7, 27, 29],[0, 3, 24, 4, 26, 29]])
targets = torch.tensor([[9, 25, 7, 27, 29, 9],[3, 24, 4, 26, 29, 13]])

In [10]:
with torch.no_grad():
    logits = model_torch(inputs)
probas = torch.softmax(logits, dim=-1)
print(probas.shape)

torch.Size([2, 6, 31])


In [11]:
token_ids = torch.argmax(probas, dim=-1, keepdim=True)
print("Token IDs:\n", token_ids)

print(f"Targets batch 1: {decode_sequence(targets[0])}")
print(f"Outputs batch 1: {decode_sequence(token_ids[0].flatten())}")

Token IDs:
 tensor([[[13],
         [28],
         [13],
         [18],
         [25],
         [18]],

        [[19],
         [ 8],
         [ 2],
         [19],
         [20],
         [18]]])
Targets batch 1: ['J', 'Z', 'H', '[F2]', '<answer>', 'J']
Outputs batch 1: ['N', '<think>', 'N', 'S', 'Z', 'S']


In [12]:
# Logits have shape (batch_size, num_tokens, vocab_size)
print("Logits shape:", logits.shape)
# Targets have shape (batch_size, num_tokens)
print("Targets shape:", targets.shape)

logits_flat = logits.flatten(0, 1)
targets_flat = targets.flatten()

print("Flattened logits:", logits_flat.shape)
print("Flattened targets:", targets_flat.shape)

loss = torch.nn.functional.cross_entropy(logits_flat, targets_flat)
print(loss)

Logits shape: torch.Size([2, 6, 31])
Targets shape: torch.Size([2, 6])
Flattened logits: torch.Size([12, 31])
Flattened targets: torch.Size([12])
tensor(3.6199)


This training is training Decoder Only

In [13]:
#train_data

def collate(batch, pad_token_id=31, ignore_index=-100):  # batch: list of (prompt_ids, target_ids)
    xs, ys, ms = [], [], []
    for prompt, target in batch:
        seq = prompt + target
        x   = seq[:-1]                 # inputs
        y   = seq[1:]                  # labels
        mask = [0]*(len(prompt)-1) + [1]*len(target)  # no loss on prompt
        xs.append(torch.tensor(x)); ys.append(torch.tensor(y)); ms.append(torch.tensor(mask))
    L = max(len(x) for x in xs)

    # Pad inputs with a different value for diagnosis, but track padding with the original pad_token_id
    X_padded_diag = []
    padding_mask = []
    for t in xs:
        padded_t = torch.cat([t, torch.full((L-len(t),), 0, dtype=t.dtype)]) # Pad with 0 for diagnosis
        mask_t = torch.cat([torch.full((len(t),), False, dtype=torch.bool), torch.full((L-len(t),), True, dtype=torch.bool)]) # Mask based on original padding intent
        X_padded_diag.append(padded_t)
        padding_mask.append(mask_t)

    X = torch.stack(X_padded_diag)
    padding_mask = torch.stack(padding_mask)


    # Pad labels with ignore_index
    Y = torch.stack([torch.cat([t, torch.full((L-len(t),), ignore_index, dtype=t.dtype)]) for t in ys])

    # Pad loss mask with 0
    M = torch.stack([torch.cat([t, torch.full((L-len(t),), 0, dtype=t.dtype)]) for t in ms])


    return {"input_ids": X, "labels": Y, "loss_mask": M, "padding_mask": padding_mask}

Add Ending token at the end.

In [17]:
input= collate([train_data[99]]) # Wrap in a list to simulate a batch
print(input)
collated_data = collate(train_data[:99])
input_ids = collated_data['input_ids']
labels = collated_data['labels']

loss_mask = collated_data['loss_mask']

{'input_ids': tensor([[16, 20, 18, 26, 29,  3,  7]]), 'labels': tensor([[20, 18, 26, 29,  3,  7,  5]]), 'loss_mask': tensor([[0, 0, 0, 0, 1, 1, 1]]), 'padding_mask': tensor([[False, False, False, False, False, False, False]])}


In [19]:
logits = model_torch(input_ids)                 # (B, L, V)
logits2D = logits.reshape(-1, logits.size(-1))  # (B*L, V)
targets1D = labels.reshape(-1)                  # (B*L,)
mask1D    = loss_mask.reshape(-1).float()       # (B*L,)

loss_all = torch.nn.functional.cross_entropy(logits2D, targets1D, reduction="none")
loss = (loss_all * mask1D).sum() / mask1D.sum()

loss

tensor(3.5932, grad_fn=<DivBackward0>)

In [23]:
def calc_loss(input_ids,labels,loss_mask):
    logits = model_torch(input_ids)                 # (B, L, V)
    logits2D = logits.reshape(-1, logits.size(-1))  # (B*L, V)
    targets1D = labels.reshape(-1)                  # (B*L,)
    mask1D    = loss_mask.reshape(-1).float()       # (B*L,)

    loss_all = torch.nn.functional.cross_entropy(logits2D, targets1D, reduction="none")
    loss = (loss_all * mask1D).sum() / mask1D.sum()
    return loss

This is training loop(fix)

In [20]:
# Utility: create causal attn mask for PyTorch (size [seq_len, seq_len])
def generate_causal_mask(seq_len):
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)  # 1s above diagonal
    mask = mask.masked_fill(mask == 1, float('-inf'))  # convert to -inf where mask is 1 (to block)
    return mask  # PyTorch uses -inf for masked positions

Maybe add txt genaration function too

In [42]:
collated_data = collate(train_data)

input_ids = collated_data['input_ids']
labels = collated_data['labels']
loss_mask = collated_data['loss_mask']

torch.manual_seed(123)

# Training loop
model_torch.train()
num_epochs = 5
batch_size = 64
num_samples = input_ids.size(0)
for epoch in range(1, num_epochs+1):
    # Shuffle indices
    indices = torch.randperm(num_samples)
    for i in range(0, num_samples, batch_size):
        idx = indices[i:i+batch_size]
        batch_in = input_ids[idx]
        batch_lbl = labels[idx]
        batch_mask = loss_mask[idx]
        seq_len = batch_in.size(1)

        # Causal mask for this sequence length
        attn_mask = generate_causal_mask(seq_len)
        logits = model_torch(batch_in, attn_mask=attn_mask.to(batch_in.device))

        # Compute loss
        loss = calc_loss(input_ids,labels,loss_mask)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch} done (last batch loss = {loss.item():.4f}).")

KeyboardInterrupt: 

In [None]:
torch.save(model_torch.state_dict(), "model.pth")

Test the accuracy!!

In [None]:
model_torch.eval()

def greedy_decode_torch(prompt_ids, max_len):
    generated = prompt_ids.clone()  # start with the prompt
    for _ in range(max_len):
        seq_len = generated.size(1)
        attn_mask = generate_causal_mask(seq_len).to(generated.device)
        with torch.no_grad():
            logits = model_torch(generated, attn_mask=attn_mask)
        next_token = int(torch.argmax(logits[0, -1]))
        # Append next token
        next_tok_t = torch.tensor([[next_token]], dtype=torch.long)
        generated = torch.cat([generated, next_tok_t], dim=1)
        if next_token == token_to_id[ANSWER_TOKEN]:
            # Continue until after answer token to get final answer,
            # stopping criteria could also be sequence length or a special end token if defined.
            continue
    # Return generated part after the prompt
    return generated[0, prompt_ids.size(1):].tolist()

# Evaluate accuracy on test sets
for test_set, name in [(test_examples_len1, "Length-1"), (test_examples_len6, "Length-6")]:
    correct = 0
    total = len(test_set)
    for inp_ids, true_out_ids in test_set:
        inp_t = torch.tensor([inp_ids], dtype=torch.long)
        gen_ids = greedy_decode_torch(inp_t, max_len=len(true_out_ids)+5)
        gen_ids = gen_ids[:len(true_out_ids)]
        if gen_ids == true_out_ids:
            correct += 1
    print(f"{name} exact match accuracy: {100 * correct/total:.2f}%")

In [None]:
decode_sequence(test_set[1][0])

In [None]:
print(decode_sequence(test_set[1][1]))
print(test_set[1])

In [None]:
greedy_decode(test_set[1][0],max_len=100)

In [None]:
b=99
print(decode_sequence(greedy_decode(test_set[b][0],max_len=100)))
print(decode_sequence(test_set[b][1]))

Next is testing len 3, 4 and 5

In [None]:
test_examples_len3 = []
test_examples_len4 = []
test_examples_len5 = []
for _ in range(200):  # generate some test examples
    seq3 = random_sequence(3)
    seq4 = random_sequence(4)
    seq5 = random_sequence(5)

    # Two ops for each length
    interm3 = apply_rot(seq3, 13)
    final3 = apply_pos(interm3, 1)
    prompt3 = seq3 + [T1_TOKEN, T2_TOKEN, THINK_TOKEN]  # e.g., first ROT13 then POS1
    out3 = interm3 + [T2_TOKEN, ANSWER_TOKEN] + final3
    test_examples_len3.append((encode_sequence(prompt3), encode_sequence(out3)))

    interm4 = apply_rot(seq4, 13)
    final4 = apply_pos(interm4, 1)
    prompt4 = seq4 + [T1_TOKEN, T2_TOKEN, THINK_TOKEN]  # e.g., first ROT13 then POS1
    out4 = interm4 + [T2_TOKEN, ANSWER_TOKEN] + final4
    test_examples_len4.append((encode_sequence(prompt4), encode_sequence(out4)))

    interm5 = apply_rot(seq5, 13)
    final5 = apply_pos(interm5, 1)
    prompt5 = seq5 + [T1_TOKEN, T2_TOKEN, THINK_TOKEN]  # e.g., first ROT13 then POS1
    out5 = interm5 + [T2_TOKEN, ANSWER_TOKEN] + final5
    test_examples_len5.append((encode_sequence(prompt5), encode_sequence(out5)))

In [None]:
# Evaluate exact match on test examples
for test_set, name in [(test_examples_len3, "Length-3"), (test_examples_len4, "Length-4"), (test_examples_len6, "Length-5")]:
    correct = 0
    total = len(test_set)
    for inp_ids, true_out_ids in test_set:
        # decode until we produce as many tokens as true_out (or a bit more)
        gen_ids = greedy_decode(inp_ids, max_len=len(true_out_ids)+5)
        # Compare with true output
        # Note: need to stop at the same length as true output
        gen_ids = gen_ids[:len(true_out_ids)]
        if gen_ids == true_out_ids:
            correct += 1
    print(f"{name} exact match accuracy: {100 * correct/total:.2f}%")

In [None]:
concat = train_examples[1][0]+train_examples[1][1]
concat

concat[:-1]
for i in range(10):
  print(train_inputs[i],train_labels[i])

for i in range(10):
  print(decode_sequence(train_inputs[i]),decode_sequence(train_labels[i]))