## Understanding Grokking in Deep Learning

Grokking is a term used to describe a specific learning behavior in deep learning models.  Instead of a gradual increase in performance, the model demonstrates a sudden and significant improvement in its ability to generalize to unseen data after a period of seemingly slow progress.

**How Grokking Works**

While the exact mechanisms are still being researched, grokking is often observed in models with a large number of parameters (even more than the data points available). It's believed that the model initially focuses on memorizing the training data. Then, through continued training and optimization, it transitions from memorization to understanding the underlying patterns and relationships within the data.

**Identifying Grokking in Your Model**

The provided code helps you visualize the training process and identify potential grokking. After running the code:

1. **Examine the Charts:** The code generates plots showing the model's performance on both the training data and a separate set of validation data.
2. **Look for a Sharp Increase:** Grokking is characterized by a sudden and significant jump in the validation accuracy line, often after a period where it remains relatively flat.

**What if My Model Doesn't Grok?**

If your model doesn't exhibit grokking, don't worry! It's not a guaranteed phenomenon. The research paper on grokking suggests experimenting with these parameters to potentially encourage it:

* **Dataset Size:** Try training with varying sizes of your training data. Smaller datasets might lead to faster grokking, but potentially at the cost of overall performance.
* **Weight Decay:** This regularization technique, which penalizes large weights in the model, has been shown to promote grokking. Experiment with different weight decay values in your optimizer.
* **Noise Injection:** Adding a small amount of noise during training, either to the input data or the model's gradients, can sometimes help the model escape from poor solutions and find those that generalize better.


**Papers and Github**

In this you will find grokking paper and the code of OpenAI and other person that try to replicate the code (the code made by others are much simpler to understand)
- https://paperswithcode.com/paper/grokking-generalization-beyond-overfitting-on

Here is a blog (it's in french because i speak french lol but use google translate)
- https://scienceetonnante.substack.com/p/grokking-les-modeles-dia-sont-ils


**Conclusion**

Grokking is not a magic way to train a model so the code is not that special but you can change the datasize, the complexity of the data and the learning algo to try helping it to better understand.
***The more you dataset is complex the more you need to let the model train for a long long long time like more then 10000 step maybe more.***

In [None]:
!pip install einops

Collecting einops
  Downloading einops-0.8.0-py3-none-any.whl (43 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/43.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.2/43.2 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.8.0


In [None]:
import json
import os
import torch
from torch.utils.data import DataLoader, TensorDataset
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from math import ceil
from einops import rearrange, repeat

# --- Data Loading and Preprocessing ---

def load_data(directory):
    """Loads JSON data from a specified directory.

    Args:
        directory (str): The path to the directory containing JSON files.

    Returns:
        list: A list of dictionaries, each representing data from a JSON file.
    """
    data = []
    for filename in os.listdir(directory):
        if filename.endswith('.json'):
            with open(os.path.join(directory, filename), 'r') as f:
                data.append(json.load(f))
    return data

def extract_pairs(data):
    """Extracts input-output pairs from the loaded data.

    Args:
        data (list): The list of dictionaries containing data from JSON files.

    Returns:
        tuple: Two lists - inputs and outputs - extracted from the data.
    """
    inputs, outputs = [], []
    for task in data:
        for pair in task['train']:  # Assumes 'train' key holds the pairs
            inputs.append(pair['input'])
            outputs.append(pair['output'])
    return inputs, outputs

def linearize_grid(grid):
    """Linearizes a 2D grid into a 1D sequence.

    Args:
        grid (list of lists): The 2D grid to linearize.

    Returns:
        list: The linearized 1D sequence.
    """
    return [cell for row in grid for cell in row]

def pad_sequence(sequence, max_length, pad_value=0):
    """Pads a sequence to a specified maximum length.

    Args:
        sequence (list): The sequence to pad.
        max_length (int): The desired maximum length.
        pad_value (int, optional): The value used for padding. Defaults to 0.

    Returns:
        list: The padded sequence.
    """
    return sequence + [pad_value] * (max_length - len(sequence))

def tokenize_sequence(sequence):
    """Tokenizes a sequence by converting elements to integers.

    Args:
        sequence (list): The sequence to tokenize.

    Returns:
        list: The tokenized sequence.
    """
    return [int(element) for element in sequence]

def create_dataloader(inputs, outputs, batch_size):
    """Creates a PyTorch DataLoader from input-output tensors.

    Args:
        inputs (list): List of input sequences.
        outputs (list): List of output sequences.
        batch_size (int): The batch size for the DataLoader.

    Returns:
        torch.utils.data.DataLoader: The created DataLoader.
    """
    input_tensors = torch.tensor(inputs, dtype=torch.long)
    output_tensors = torch.tensor(outputs, dtype=torch.long)
    dataset = TensorDataset(input_tensors, output_tensors)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

# --- Model Definition (Transformer) ---

class DecoderBlock(torch.nn.Module):
    def __init__(self, dim_model: int, n_heads: int):
        super().__init__()

        self.self_attn = nn.MultiheadAttention(dim_model, n_heads)
        self.self_attn_norm = nn.LayerNorm(dim_model)
        self.ffn = nn.Sequential(
            nn.Linear(dim_model, dim_model * 4),
            nn.GELU(),
            nn.Linear(dim_model * 4, dim_model)
        )
        self.ffn_norm = nn.LayerNorm(dim_model)

    def forward(self, x: Tensor):
        attn_mask = torch.full(
            (len(x), len(x)), -float("Inf"), device=x.device, dtype=x.dtype
        )
        attn_mask = torch.triu(attn_mask, diagonal=1)

        a1, _ = self.self_attn(x, x, x, attn_mask=attn_mask)
        a1 = self.self_attn_norm(x + a1)
        a2 = self.ffn(a1)
        a2 = self.ffn_norm(a1 + a2)

        return a2

class Transformer(torch.nn.Module):
    def __init__(self, num_layers: int, dim_model: int, num_heads: int, num_tokens: int, seq_len: int):
        super().__init__()

        self.token_embeddings = nn.Embedding(num_tokens, dim_model)
        self.position_embeddings = nn.Embedding(seq_len, dim_model)
        self.model = nn.Sequential(
            *[DecoderBlock(dim_model, num_heads) for _ in range(num_layers)],
            nn.LayerNorm(dim_model),
            nn.Linear(dim_model, num_tokens)
        )

    def forward(self, inputs: Tensor):
        batch_size, context_len = inputs.shape

        token_embedding = self.token_embeddings(inputs)

        positions = repeat(torch.arange(context_len, device=inputs.device), "p -> b p", b = batch_size)
        position_embedding = self.position_embeddings(positions)

        embedding = token_embedding + position_embedding

        embedding = rearrange(embedding, 'b s d -> s b d')

        return self.model(embedding)

# --- Training and Evaluation Functions ---

def train(model, train_loader, optimizer, scheduler, device, num_steps, accumulation_steps, loss_history, accuracy_history):
    """Trains the model for a specified number of steps.

    Args:
        model: The Transformer model to train.
        train_loader: DataLoader for the training data.
        optimizer: The optimizer used for training.
        scheduler: The learning rate scheduler.
        device: The device to train on (CPU or GPU).
        num_steps (int): Total training steps.
        accumulation_steps (int): Number of steps to accumulate gradients before updating.
        loss_history (list): List to store training loss history.
        accuracy_history (list): List to store training accuracy history.
    """
    model.train()
    criterion = nn.CrossEntropyLoss()
    optimizer.zero_grad()

    for step, batch in enumerate(tqdm(train_loader, total=len(train_loader), leave=False)):
        batch = tuple(t.to(device) for t in batch)
        inputs, labels = batch
        labels = labels.view(-1)  # Reshape labels to 1D

        output = model(inputs)[-1,:,:]  # Get output from the last decoder layer
        loss = criterion(output, labels) / accumulation_steps
        acc = (torch.argmax(output, dim=1) == labels).sum().item() / len(labels)
        loss.backward()

        if (step + 1) % accumulation_steps == 0:
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

        # --- More Frequent Logging ---
        if (step + 1) % 10 == 0:  # Log every 10 steps (adjust as needed)
            loss_history.append(loss.item() * accumulation_steps)
            accuracy_history.append(acc)

        # --- Early Stopping (Optional) ---
        # if len(loss_history) > 200 and loss_history[-1] < 0.05 and np.mean(loss_history[-100:]) - np.mean(loss_history[-200:-100]) < 1e-4:
        #     print("Early stopping triggered.")
        #     break

        if step >= num_steps:
            break

def evaluate(model, val_loader, device):
    """Evaluates the model on the validation set.

    Args:
        model: The Transformer model to evaluate.
        val_loader: DataLoader for the validation data.
        device: The device to evaluate on (CPU or GPU).

    Returns:
        tuple: Average validation loss and validation accuracy.
    """
    model.eval()
    criterion = nn.CrossEntropyLoss()

    correct = 0
    total_loss = 0.

    with torch.no_grad():
        for batch in val_loader:
            batch = tuple(t.to(device) for t in batch)
            inputs, labels = batch

            # Reshape labels to 1D tensor
            labels = labels.view(-1)

            output = model(inputs)[-1,:,:]
            loss = criterion(output, labels)
            total_loss += loss.item() * len(labels)
            correct += (torch.argmax(output, dim=1) == labels).sum().item()

    acc = correct / len(val_loader.dataset)
    avg_loss = total_loss / len(val_loader.dataset)

    print(f"Validation Loss: {avg_loss:.4f}, Validation Accuracy: {acc:.4f}")
    return avg_loss, acc

# --- Plotting with Enhanced Analysis ---

def plot_metrics(loss_history, accuracy_history, val_loss_history, val_accuracy_history):
    """Plots training and validation metrics with potential grokking highlights.

    Args:
        loss_history (list): Training loss history.
        accuracy_history (list): Training accuracy history.
        val_loss_history (list): Validation loss history.
        val_accuracy_history (list): Validation accuracy history.
    """
    plt.figure(figsize=(14, 6))

    # --- Loss Plot ---
    plt.subplot(1, 2, 1)
    plt.plot(loss_history, label='Training Loss', alpha=0.7)
    plt.plot(val_loss_history, label='Validation Loss', alpha=0.7)
    plt.xlabel('Steps')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training and Validation Loss')

    # --- Highlight Potential Grokking Region (Loss) ---
    if len(val_loss_history) > 100: # Adjust threshold as needed
        min_val_loss_idx = np.argmin(val_loss_history)
        if min_val_loss_idx > 50: # Check if the minimum is not too early
            plt.axvspan(min_val_loss_idx - 50, min_val_loss_idx + 50, color='lightblue', alpha=0.5)

    # --- Accuracy Plot ---
    plt.subplot(1, 2, 2)
    plt.plot(accuracy_history, label='Training Accuracy', alpha=0.7)
    plt.plot(val_accuracy_history, label='Validation Accuracy', alpha=0.7)
    plt.xlabel('Steps')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.title('Training and Validation Accuracy')

    # --- Highlight Potential Grokking Region (Accuracy) ---
    if len(val_accuracy_history) > 100: # Adjust threshold as needed
        max_val_acc_idx = np.argmax(val_accuracy_history)
        if max_val_acc_idx > 50: # Check if the maximum is not too early
            plt.axvspan(max_val_acc_idx - 50, max_val_acc_idx + 50, color='lightgreen', alpha=0.5)

    plt.tight_layout()
    plt.show()


In [None]:
# --- Main Function ---

def main():
    # --- Configurations ---
    num_layers = 3       # Number of decoder layers in the Transformer
    dim_model = 64      # Model dimensionality
    num_heads = 4       # Number of attention heads
    num_tokens = 11      # Size of the vocabulary (number of distinct tokens)
    batch_size = 8       # Batch size for training
    learning_rate = 1e-3 # Learning rate
    weight_decay = 1e-2  # Weight decay for regularization
    num_steps = 1000     # Total training steps
    accumulation_steps = 4 # Gradient accumulation steps
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # --- Data Loading ---
    train_data = load_data('train')  # Load training data
    eval_data = load_data('evaluation')  # Load evaluation data
    train_inputs, train_outputs = extract_pairs(train_data)
    eval_inputs, eval_outputs = extract_pairs(eval_data)
    linearized_train_inputs = [linearize_grid(input_grid) for input_grid in train_inputs]
    linearized_train_outputs = [linearize_grid(output_grid) for output_grid in train_outputs]
    linearized_eval_inputs = [linearize_grid(input_grid) for input_grid in eval_inputs]
    linearized_eval_outputs = [linearize_grid(output_grid) for output_grid in eval_outputs]
    max_length = max(
        max(len(seq) for seq in linearized_train_inputs + linearized_train_outputs),
        max(len(seq) for seq in linearized_eval_inputs + linearized_eval_outputs)
    )
    padded_train_inputs = [pad_sequence(seq, max_length) for seq in linearized_train_inputs]
    padded_train_outputs = [pad_sequence(seq, max_length) for seq in linearized_train_outputs]
    padded_eval_inputs = [pad_sequence(seq, max_length) for seq in linearized_eval_inputs]
    padded_eval_outputs = [pad_sequence(seq, max_length) for seq in linearized_eval_outputs]
    tokenized_train_inputs = [tokenize_sequence(input_seq) for input_seq in padded_train_inputs]
    tokenized_train_outputs = [tokenize_sequence(output_seq) for output_seq in padded_train_outputs]
    tokenized_eval_inputs = [tokenize_sequence(input_seq) for input_seq in padded_eval_inputs]
    tokenized_eval_outputs = [tokenize_sequence(output_seq) for output_seq in padded_eval_outputs]
    train_loader = create_dataloader(tokenized_train_inputs, tokenized_train_outputs, batch_size)
    eval_loader = create_dataloader(tokenized_eval_inputs, tokenized_eval_outputs, batch_size)

    # --- Model, Optimizer, Scheduler Initialization ---
    model = Transformer(
        num_layers=num_layers,
        dim_model=dim_model,
        num_heads=num_heads,
        num_tokens=num_tokens,
        seq_len=max_length
    ).to(device)

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=learning_rate,
        betas=(0.9, 0.98),
        weight_decay=weight_decay
    )
    scheduler = torch.optim.lr_scheduler.LinearLR(
        optimizer, start_factor=0.1, total_iters=9
    )

    num_epochs = ceil(num_steps / len(train_loader))
    loss_history = []
    accuracy_history = []
    val_loss_history = []
    val_accuracy_history = []

    for epoch in range(num_epochs):
        print(f"Epoch [{epoch+1}/{num_epochs}]")
        train(model, train_loader, optimizer, scheduler, device, num_steps, accumulation_steps, loss_history, accuracy_history)
        val_loss, val_acc = evaluate(model, eval_loader, device)
        val_loss_history.append(val_loss)
        val_accuracy_history.append(val_acc)

    plot_metrics(loss_history, accuracy_history, val_loss_history, val_accuracy_history)

if __name__ == "__main__":
    main()