In [8]:
import torch
import torch.nn as nn
import math
import numpy as np
from torch.utils.data import Dataset, DataLoader
from collections import defaultdict, OrderedDict
import pandas as pd
from typing import Union, List

class MathTokenizer:
    def __init__(self, base: int = 10):
        self.base = base
        self.special_tokens = ['[PAD]', '[SOS]', '[EOS]', '[UNK]']
        self.pad_token, self.sos_token, self.eos_token, self.unk_token = self.special_tokens

        
        self.digits = [str(i) for i in range(base)]

        # Vocabulary: special tokens + signs + digit symbols
        self.vocab = self.special_tokens + ['+', '-'] + self.digits
        self.token2id = {tok: idx for idx, tok in enumerate(self.vocab)}
        self.id2token = {idx: tok for tok, idx in self.token2id.items()}

    def _int_to_base(self, n: int) -> List[str]:
        """
        Convert a non-negative integer to its digit list in the current base.
        Returns a list of digit symbols (strings).
        """
        if n == 0:
            return [self.digits[0]]
        digits: List[str] = []
        while n > 0:
            digits.append(self.digits[n % self.base])
            n //= self.base
        return list(reversed(digits))

    def encode(self, sequence: Union[str, List[str]]) -> List[int]:
        """
        Encode a sequence (either a string of single-character tokens or a list of token strings)
        into token IDs, adding SOS and EOS.
        """
        if isinstance(sequence, str):
            # legacy: split string into single-character tokens
            tokens = [self.sos_token] + list(sequence) + [self.eos_token]
        else:
            # sequence is already a list of token strings
            tokens = [self.sos_token] + sequence + [self.eos_token]
        return [self.token2id.get(tok, self.token2id[self.unk_token]) for tok in tokens]

    def decode(self, ids: List[int]) -> List[str]:
        """
        Decode a list of token IDs to the sequence of token strings,
        stripping out special tokens.
        """
        tokens = [self.id2token.get(i, self.unk_token) for i in ids]
        # Remove special tokens
        return [tok for tok in tokens if tok not in (self.sos_token, self.eos_token, self.pad_token)]
    
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        position_enc = np.array(
        [
            [pos / np.power(10000, 2 * (j // 2) / d_model) for j in range(d_model)]
            for pos in range(max_len)
        ]
    )
        position_enc = torch.Tensor(position_enc)
        pe = torch.zeros_like(position_enc)
        pe[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
        pe[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        return x + self.pe[:x.size(1)]

class LCMTransformer(nn.Module):
    def __init__(self, tokenizer, d_model=128, nhead=8, num_layers=3, max_length=512, dropout=0.1):
        super().__init__()
        self.tokenizer = tokenizer
        self.vocab_size = len(tokenizer.vocab)
        
        self.embedding = nn.Embedding(self.vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, max_length)
        
        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            batch_first=True,
            dropout=dropout
        )
        
        self.fc_out = nn.Linear(d_model, self.vocab_size)
        self.pad_id = tokenizer.token2id[tokenizer.pad_token]

    def forward(self, src, tgt):
        # src: (batch_size, src_len)
        # tgt: (batch_size, tgt_len)
        
        # Embedding + positional encoding
        src_emb = self.pos_encoder(self.embedding(src))
        tgt_emb = self.pos_encoder(self.embedding(tgt))
        
        # Create masks
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(1)).to(src.device)
        
        # Key padding masks must be 2D (batch_size, seq_len)
        src_key_padding_mask = (src == self.pad_id)
        tgt_key_padding_mask = (tgt == self.pad_id)
        
        # Transformer forward
        output = self.transformer(
            src=src_emb,
            tgt=tgt_emb,
            tgt_mask=tgt_mask,
            src_key_padding_mask=src_key_padding_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=src_key_padding_mask
        )
        
        return self.fc_out(output)

    def predict(self, a: int, b: int, max_length: int = 20) -> int | None:
        """Predict lcm for integers a and b, in whatever base this tokenizer uses."""
        # 1) Build the input token list properly, not as one big string
        digits_a = self.tokenizer._int_to_base(abs(a))
        digits_b = self.tokenizer._int_to_base(abs(b))
        src_tokens = ['+'] + digits_a + ['+'] + digits_b

        # 2) Encode & send through the model
        src_ids = self.tokenizer.encode(src_tokens)
        src = torch.tensor(src_ids, device=next(self.parameters()).device).unsqueeze(0)

        sos_id = self.tokenizer.token2id[self.tokenizer.sos_token]
        eos_id = self.tokenizer.token2id[self.tokenizer.eos_token]

        tgt_ids = [sos_id]
        for _ in range(max_length):
            tgt = torch.tensor(tgt_ids, device=src.device).unsqueeze(0)
            with torch.no_grad():
                logits = self(src, tgt)
            next_id = logits.argmax(-1)[0, -1].item()
            tgt_ids.append(next_id)
            if next_id == eos_id:
                break

        # 3) Decode the token IDs back to digit‐strings
        pred_tokens = self.tokenizer.decode(tgt_ids[1:])  # skip the SOS

        # 4) Convert the list of digit‐strings into an integer
        try:
            value = 0
            for tok in pred_tokens:
                # skip any stray '+' signs
                if tok == '+':
                    continue
                digit = self.tokenizer.digits.index(tok)
                value = value * self.tokenizer.base + digit
            return value
        except Exception:
            return None


class LCMDataset(Dataset):
    def __init__(self, max_num=100, num_samples=100000, seed=42, test=False, base=10):
        self.rng = np.random.RandomState(seed)
        self.sample_generated = {i:0 for i in range (1,max_num+1)}
        self.tokenizer = MathTokenizer(base)
        self.data = []
        
        for _ in range(num_samples):
            if test:
                lcm = self.rng.randint(1, 100000)
                divisors = [d for d in range(1,lcm+1) if lcm % d == 0]
                while True:
                    a, b = np.random.choice(divisors, 2)
                    if math.lcm(a,b) == lcm:
                        break
            else:        
                a, b = self.rng.randint(1, np.sqrt(max_num), size=2)
                lcm = math.lcm(a,b)
                
                
            self.sample_generated[lcm] += 1
            sign_a = ['+']
            sign_b = ['+']
            
            a = self.tokenizer._int_to_base(abs(a))
            b = self.tokenizer._int_to_base(abs(b))
            lcm = self.tokenizer._int_to_base(abs(lcm))
            
            src = sign_a + a + sign_b + b
            tgt = ["+"] + lcm
            
            self.data.append((
                self.tokenizer.encode(src),
                self.tokenizer.encode(tgt)
            ))
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return (
            torch.tensor(self.data[idx][0]),
            torch.tensor(self.data[idx][1])
        )

def collate_fn(batch):
    src_batch, tgt_batch = zip(*batch)
    src = nn.utils.rnn.pad_sequence(src_batch, padding_value=tokenizer.token2id[tokenizer.pad_token], batch_first=True)
    tgt = nn.utils.rnn.pad_sequence(tgt_batch, padding_value=tokenizer.token2id[tokenizer.pad_token], batch_first=True)
    return src, tgt

def base_to_int(s, b):
    num = 0
    for digit in s:
        num = num * b + int(digit)
    return num

def compute_accuracy(model, dataloader, device, max_int, base=10):
    model.eval()
    perfect_sequences = 0
    total_sequences = 0
    lcm_correct = defaultdict(int)
    lcm_total = defaultdict(int)
    with torch.no_grad():
        for src, tgt in dataloader:
            src, tgt = src.to(device), tgt.to(device)
            tgt_input = tgt[:, :-1]
            tgt_output = tgt[:, 1:]
            
            # Get model predictions
            output = model(src, tgt_input)
            preds = output.argmax(-1)
                  
            # 1. Move all sequences to CPU at once
            tgt_np = tgt_output.cpu().numpy()
            preds_np = preds.cpu().numpy()
            pad_id = model.pad_id
            
            # 2. Process all sequences in batch
            for seq_idx in range(tgt_output.shape[0]):
                # Get non-pad tokens
                target_tokens = tgt_np[seq_idx][1:]
                pred_tokens = preds_np[seq_idx][1:]
                
                # Skip empty sequences
                if len(target_tokens) == 0:
                    continue
                
                # Batch decode using tokenizer (more efficient than one-by-one)
                try:
                    correct_lcm = tokenizer.decode(target_tokens)
                    predicted_lcm = tokenizer.decode(pred_tokens)
                    
                    if base!=10:
                        correct_lcm = base_to_int(correct_lcm, base)
                        predicted_lcm = base_to_int(predicted_lcm, base)
                            
                    lcm_total[correct_lcm] += 1
                    if predicted_lcm == correct_lcm:
                        lcm_correct[correct_lcm] += 1
                except (ValueError, AttributeError):
                    continue
    
    # Sort LCM values in ascending order
    sorted_lcms = sorted(lcm_total.keys())
    sorted_per_lcm = OrderedDict((k, lcm_correct[k]/lcm_total[k]) for k in sorted_lcms)
    sorted_correct = OrderedDict((k, lcm_correct[k]) for k in sorted_lcms)
    sorted_total = OrderedDict((k, lcm_total[k]) for k in sorted_lcms)
    
    return {
        'overall_accuracy': sum(lcm_correct.values()) / max(sum(lcm_total.values()), 1),
        'per_lcm_accuracy': sorted_per_lcm,
        'correct_counts': sorted_correct,
        'total_counts': sorted_total,
        'sorted_lcms': sorted_lcms  # List of LCM values in order
    }



In [9]:
# Initialize DataFrame to store results
per_lcm_df = pd.DataFrame()

# Initialize components
validate_step = 1
layers = 2
heads = 4
hidden_dimension = 256
length = 512
lr = 10e-5
batch = 128
max_int = 1000000
sample_size = 10000
dropout = 0
max_epoch = 10
base = 30
tokenizer = MathTokenizer(base)
seed = None

model = LCMTransformer(tokenizer, d_model=hidden_dimension, nhead=heads, num_layers=layers, max_length=length, dropout=dropout)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


# Training loop
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.token2id[tokenizer.pad_token])

counter = 0

# Collect infos on number of lcm generated
samples_generated = defaultdict(int)


for epoch in range(max_epoch):
    model.train()
    total_loss = 0
    train_dataset = LCMDataset(max_num=max_int, num_samples=3*sample_size, seed=seed, base=base)
    train_dataloader = DataLoader(train_dataset, batch_size=batch, collate_fn=collate_fn, shuffle=True)
    total_sequences = 0
    perfect_sequences = 0
    
    for key in train_dataset.sample_generated:
        samples_generated[key] += train_dataset.sample_generated[key] 

    for src, tgt in train_dataloader:
        src, tgt = src.to(device), tgt.to(device)
        
        # Prepare target input/output
        tgt_input = tgt[:, :-1]
        tgt_output = tgt[:, 1:]
        
        optimizer.zero_grad()
        output = model(src, tgt_input)
        
        # Get prediction
        preds = output.argmax(-1)
        
        # Mask out padding tokens
        mask = (tgt_output != tokenizer.token2id[tokenizer.pad_token])
            
        loss = criterion(
            output.reshape(-1, output.size(-1)),
            tgt_output.reshape(-1)
        )
        
        # Compute accuracy
    
        with torch.no_grad():
            mask = (tgt_output != model.pad_id)
            seq_match = (preds == tgt_output) | ~mask
            perfect_sequences += seq_match.all(dim=1).sum().item()
            total_sequences += tgt.size(0)
        
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    counter += 1
    
    print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_dataloader)}, Accuracy: {perfect_sequences / max(total_sequences, 1)}")
    if counter % validate_step == 0:
        # Accuracy validation
        validation_dataset = LCMDataset(max_num=max_int, num_samples=sample_size, seed=seed, test=True, base=base)
        validation_dataloader = DataLoader(validation_dataset, batch_size=128, collate_fn=collate_fn)
        results = compute_accuracy(model, validation_dataloader, "cuda", max_int=max_int, base=base)
        # Convert per_lcm_accuracy to a DataFrame row
        row = pd.DataFrame({
            'step': counter,
            **results['per_lcm_accuracy']  # Flattens LCMs into columns
        }, index=[0])
        
        # Append to the main DataFrame
        per_lcm_df = pd.concat([per_lcm_df, row], ignore_index=True)
        
        # Optional: Save to CSV periodically
        if counter % (validate_step * 10) == 0:  # Save every 10 validations
            per_lcm_df.to_csv("per_lcm_accuracy_history.csv", index=False)


  pe[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
  pe[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))


Epoch 1, Loss: 2.185579536823516, Accuracy: 0.0005


  output = torch._nested_tensor_from_mask(


Epoch 2, Loss: 1.9656426931949371, Accuracy: 0.0005




Epoch 3, Loss: 1.8331449184011905, Accuracy: 0.0004




Epoch 4, Loss: 1.7238990717745841, Accuracy: 0.0008666666666666666




Epoch 5, Loss: 1.5867805034556288, Accuracy: 0.0011666666666666668




Epoch 6, Loss: 1.4931402307875612, Accuracy: 0.0021333333333333334




Epoch 7, Loss: 1.440169089905759, Accuracy: 0.0025666666666666667




Epoch 8, Loss: 1.407294309900162, Accuracy: 0.0028666666666666667




Epoch 9, Loss: 1.3718386944304122, Accuracy: 0.003766666666666667




Epoch 10, Loss: 1.3522793384308511, Accuracy: 0.0046
