# How sharp is Occam's razor? Toy models for complexity via grokking


By Louis Jaburi

## 1. Introduction

Why do neural networks generalize? While a full explanation is still missing, an important principle seems to be Occam's razor. Many versions of this principle exist, but for this document we will summarize it as
> *solutions that are simpler, will be preferred.*

What remains elusive, is the notion of *simple* or equivalently the notion of complexity. Several competing notions have been proposed, for example:

1.   Kolmogorov complexity,
2.   Notions from statistical learning theory, such as VC dimension or Rademacher complexity.

But they all come with their respective downsides: Kolmogorov complexity requires fixing a UTM and in general is incomputable. More concrete notions as mentioned in 2. have failed empirical verification, see [Zhang et al.](https://arxiv.org/pdf/1611.03530.pdf)

As there is a lot of uncertainty about what theoretical notion of complexity is empirically valid, in this document we propose the opposite approach: Starting from empirical results, we want to gather guiding intuition towards a more theoretical notion that could explain why DNNs prefer to learn certain solutions over others.

To do this, we propose a toy model for studying generalization behaviour. More precisely, in our set up multiple hypotheses can be explained by the data. A useful theory of complexity should predict which hypothesis is likely to be learned.




## 2. Method

We first start with the high-level picture. Consider supervised learning tasks, where we are trying to learn two (different) functions $f_1, f_2: X\to Y$. This will yield two datasets $D_1, D_2\subset X\times Y$ and we can consider the intersection $D_{12}=D_1\cap D_2$. We can ask:

> If we train a model on the dataset $D_{12}$, which function will it learn?

By Occam's razor, the function that is simpler, should be the function that should be learned during training (or is more likely to be learned).

<img src="http://cogeometry.com/assets/Frame3.jpg" width =500>


More generally, we could start with a family of functions $\{f_i\}_{i\in I}:X\to Y$ and consider $D_I=\cap_{i\in I} D_i$. Is there now a hierachy determining the learning behaviour?

To get reasonable results, we should put some requirements:


1.   The functions $f_1, f_2$ should capture some structure. Using random maps $f_1,f_2: X\to Y$ is unlikely to result in interesting or informative generalization behaviour,
2.   At least initially, the training data should be free of noise to reduce other potential biases. Although later on this poses an interesting question in itself: (How) does noise prefer certain solutions?
3. The training data set should be sufficiently big.

We propose grokking, where MLPs learn simple algorithmic tasks, as a good toy model. Not only are the two above conditions fulfilled, but additionally we observe that the model already employs some ranking among two options: memorization and generalization. Furthermore, we manage to find a set up where the training data set $D_{12}$ can make up as much as $75\%$ of the test set $X$, while still fitting both tasks $f_1$ and $f_2$. Also, we can make $X$ arbitrary large.

In an ideal case, we would have a notion such as $c(Memorization)>c(f_2)> c(f_1)$, where $c$ indicates the complexity. We should expect then that $f_1$ will be preferred over time, or at least more likely to be learned.

### Possible measurements for "complexity"


What kind of measurements can we do to observe complexity? For example:

1.   Training time until grokking is completed. Do more complex tasks require longer training time?
2.   The learning coefficient as described in [Lau et al.](https://arxiv.org/pdf/2308.12108.pdf) Using the widely applicable Bayesian information criterion (WBIC), their approach explores the local geometry of the loss landscape.

  Informed by considerations based on the free energy formula in singular learning theory (see [*loc. cit.*]((https://arxiv.org/pdf/2308.12108.pdf) p.5) , the learning coefficient should approximate a notion of "effective parameters" that the MLP is using. Less effective parameters used, means the model is simpler.
3.  Circuit size. In particular [Varma et al.](https://arxiv.org/pdf/2309.02390.pdf) suggest that, "*...weight decay prefers circuits [...] that require less parameter norm to produce a given logit value*".





Running the experiments, that we describe below, serves as a way to empirically verify the above hypotheses. The reader is invited to skip to the Section 4 to run the experiments themselves and to Section 5 & 6 for an overview of the current results.

## 3. Description of the experiments


We first describe the algorithmic tasks $f_1, f_2$ that are grokked. Grokking was first observed in [Power et al.](https://arxiv.org/pdf/2201.02177.pdf) when performing several binary operations (such as addition) in the finite group ℤ/nℤ: The model quickly realizes perfect accuracy on the training data, but fails to generalize. After some time though it also achieves perfect accuracy on the whole test data.

A deconstruction of the learned algorithm in case of the modular addition via Fourier transforms was given in [Nanda et al.](https://arxiv.org/pdf/2301.05217.pdf) A more general approach for finite groups using representations theory was given in [Chughtai et al.](https://arxiv.org/abs/2302.03025) An interesting distinction within the learned algorithms, the Clock vs Pizza algorithm, was given in [Zhong et al.](https://arxiv.org/pdf/2306.17844.pdf)


In our case, we study different group operations on a set $M$ of size $100$, i.e. $$f_1,f_2: M\times M\to M$$ are given by $f_i(a,b)=a+_i b$. Thus $|X|=|M|^2 =10000$, $|Y|=100$, and the size of the training data set will be $|D_{12}|=7500$.

**The following paragraph constructs the two groups and may be safely skipped.**

Consider more generally $M=\{1,...,2\cdot N\}$ (in the above $N=50$). We can endow it with two group structures:

1.   The commutative group given by $\mathbb{Z}/N\times \mathbb{Z}/2$, which is defined by $(a,b)+_1(c,d)=(a+b,c+d)$,
2.   The [semidirect product](https://en.wikipedia.org/wiki/Semidirect_product) product $\mathbb{Z}/N \rtimes \mathbb{Z}/2$, which is defined by $(a,0)+_2(b,c)=(a+b,c)$ and $(a,1)+_2(b,c)=(a+(N+1)\cdot b, 1+c)$.

One can verify that $+_1$ and $+_2$ always agree on $75\%$ of $|X|$ (this will be expanded and explained more in future).

We trained on two different kinds of MLPs. The first one is the one used in [Chughtai et al.](https://arxiv.org/abs/2302.03025), the second one is the one used in [Investigating the learning coefficient of modular addition: hackathon project](https://www.lesswrong.com/posts/4v3hMuKfsGatLXPgt/investigating-the-learning-coefficient-of-modular-addition). In both cases there are two embedding matrices which are not tied.

Be default, we used Adam optimizer with weight decay $0.0002$. In both cases there are ~16-17k parameters.

In this colab we will use the latter model.

<img src="https://pbs.twimg.com/media/FpCHLTvaAAEnsGy?format=png&name=4096x4096" width =500>

<img src="https://res.cloudinary.com/lesswrong-2-0/image/upload/f_auto,q_auto/v1/mirroredImages/AhnoNzES5qiTnnavt/odmceomuhkcwtp4bshet" width=400>


Left: Image taken from [Chughtai et al.](https://arxiv.org/abs/2302.03025)

Right: Image taken from [Investigating the learning coefficient of modular addition: hackathon project](https://www.lesswrong.com/posts/4v3hMuKfsGatLXPgt/investigating-the-learning-coefficient-of-modular-addition)

## Code

You can run all cells in this section and proceed to the next section.



(Available also at https://github.com/LouisYRYJ/Finite-groups)

In [None]:
%pip install wandb torch tqdm einops

Collecting wandb
  Downloading wandb-0.16.3-py3-none-any.whl (2.2 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/2.2 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.5/2.2 MB[0m [31m14.5 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m37.0 MB/s[0m eta [36m0:00:00[0m
Collecting einops
  Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m7.0 MB/s[0m eta [36m0:00:00[0m
Collecting GitPython!=3.1.29,>=1.0.0 (from wandb)
  Downloading GitPython-3.1.42-py3-none-any.whl (195 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m195.4/195.4 kB[0m [31m28.8 MB/s[0m eta [36m0:00:00[0m
Collecting sentry-sdk>=1.0.0 (from wandb)
  Downloading sentry_sdk-1.40.6-py2.py3-none-any.whl (258 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
import torch as t
import os
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from tqdm import tqdm
import random
import wandb
import copy
from einops import rearrange
from dataclasses import dataclass

from datetime import datetime
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
device = t.device("cuda" if t.cuda.is_available() else "cpu")


In [None]:
def twisted_group(group, automorphism=lambda x: x):
    """Constructs semidirect product of groups with Z/2Z using the given automorphism"""
    group_cardinality = group.size(dim=0)
    new_cardinality = group_cardinality * 2
    new_group = t.zeros((new_cardinality, new_cardinality), dtype=t.int64)

    for i in range(new_cardinality):
        for j in range(new_cardinality):
            if i < group_cardinality and j < group_cardinality:
                new_group[i, j] = group[i, j]

            if i < group_cardinality and j >= group_cardinality:
                new_group[i, j] = group[i, j - group_cardinality] + group_cardinality

            if i >= group_cardinality and j < group_cardinality:
                new_group[i, j] = (
                    group[i - group_cardinality, automorphism(j) % group_cardinality]
                    + group_cardinality
                )

            if i >= group_cardinality and j >= group_cardinality:
                new_group[i, j] = group[
                    i - group_cardinality,
                    automorphism(j - group_cardinality) % group_cardinality,
                ]

    return new_group


def cyclic(params):
    cyclic_group = t.zeros((params.N_1, params.N_1), dtype=t.int64)
    for i in range(params.N_1):
        for j in range(params.N_1):
            cyclic_group[i, j] = (i + j) % params.N_1
    return cyclic_group


class GroupData(Dataset):
    def __init__(self, params):
        self.group1 = twisted_group(cyclic(params))
        self.group2 = twisted_group(cyclic(params), lambda x: (params.N_1 // 2 + 1) * x)
        self.group1_list = [
            (i, j, self.group1[i, j].item())
            for i in range(self.group1.size(0))
            for j in range(self.group1.size(1))
        ]

        self.group2_list = [
            (i, j, self.group2[i, j].item())
            for i in range(self.group2.size(0))
            for j in range(self.group2.size(1))
        ]

        self.group1_only = [
            item for item in self.group1_list if item not in self.group2_list
        ]
        self.group2_only = [
            item for item in self.group2_list if item not in self.group1_list
        ]

        if (params.data_group1 == True) and (params.data_group2 == False):
            self.train_data = self.group1_list
        elif (params.data_group2 == True) and (params.data_group1 == False):
            self.train_data = self.group2_list
        else:
            self.train_data = [
                i for i, j in zip(self.group1_list, self.group2_list) if i == j
            ]  # intersection of G_1 and G_2

        self.train_data = self.train_data + random.sample(
            self.group1_only, params.add_points_group1
        )  # add points from G_1 exclusively
        self.train_data = self.train_data + random.sample(
            self.group2_only, params.add_points_group2
        )  # add points from G_1 exclusively

        self.train_data_tensor= t.tensor(self.train_data).to(device)

    def __getitem__(self, idx):
        return [self.train_data_tensor[idx][0], self.train_data_tensor[idx][1]], self.train_data_tensor[idx][
            2
        ]

    def __len__(self):
        return len(self.train_data)


In [None]:
class MLP(t.nn.Module):
    def __init__(self, params):
        super().__init__()
        self.Embedding_left = t.nn.Embedding(params.N, params.embed_dim)
        self.Embedding_right = t.nn.Embedding(params.N, params.embed_dim)
        self.linear = t.nn.Linear(params.embed_dim * 2, params.hidden_size, bias=True)
        if params.activation == "gelu":
            self.activation = t.nn.GELU()
        if params.activation == "relu":
            self.activation = t.nn.ReLU()
        self.Umbedding = t.nn.Linear(params.hidden_size, params.N, bias=True)

    def forward(self, a):
        x1 = self.Embedding_left(a[0])
        x2 = self.Embedding_right(a[1])
        x12 = t.cat([x1, x2], -1)
        hidden = self.linear(x12)
        hidden = self.activation(hidden)
        out = self.Umbedding(hidden)
        return out

In [None]:
class MLP2(t.nn.Module):
    def __init__(self, params):
        super().__init__()
        self.Embedding_left = t.nn.Embedding(params.N, params.embed_dim)
        self.Embedding_right = t.nn.Embedding(params.N, params.embed_dim)
        self.linear_left = t.nn.Linear(params.embed_dim, params.hidden_size, bias=True)
        self.linear_right = t.nn.Linear(params.embed_dim, params.hidden_size, bias=True)
        if params.activation == "gelu":
            self.activation = t.nn.GELU()
        if params.activation == "relu":
            self.activation = t.nn.ReLU()
        self.Umbedding = t.nn.Linear(params.hidden_size, params.N, bias=True)

    def forward(self, a):
        x1 = self.Embedding_left(a[0])
        x2 = self.Embedding_right(a[1])
        hidden_x1 = self.linear_left(x1)
        hidden_x2 = self.linear_right(x2)
        hidden_sum = hidden_x1 + hidden_x2
        hidden = self.activation(hidden_sum)
        out = self.Umbedding(hidden)
        return out

In [None]:
@dataclass
class Parameters:
    N_1: int = 50
    N: int = N_1 * 2
    embed_dim: int = 32
    hidden_size: int = 64
    num_epoch: int = 2000
    batch_size: int = 512
    activation: str = "relu"  # gelu or relu
    checkpoint_every: int = 5
    max_steps_per_epoch: int = N * N // batch_size
    train_frac: float = 1
    weight_decay: float = 0.0002
    lr: float = 0.01
    beta_1: int = 0.9
    beta_2: int = 0.98
    warmup_steps = 0
    optimizer: str = "adam"  # adamw or adam or sgd
    data_group1: bool = True  # training data G_1
    data_group2: bool = True  # training data G_2
    add_points_group1: int = 0  # add points from G_1 only
    add_points_group2: int = 0  # add points from G_2 only


Accuracy and cross entropy loss functions.

In [None]:
def loss_fn(logits, labels):
    """
    Compute cross entropy loss.

    Args:
        logits (Tensor): (batch, group.order) tensor of logits
        labels (Tensor): (batch) tensor of labels

    Returns:
        float: cross entropy loss
    """
    log_probs = logits.log_softmax(dim=-1)
    correct_log_probs = log_probs.gather(dim=-1, index=labels[:, None])[:, 0]
    return -correct_log_probs.mean()


def get_accuracy(logits, labels):
    """
    Compute accuracy of model.

    Args:
        logits (torch.tensor): (batch, group.order) tensor of logits
        labels (torch.tensor): (batch) tensor of labels

    Returns:
        float: accuracy
    """
    return ((logits.argmax(-1) == labels).sum() / len(labels)).item()




def test_loss(model, params, Group_Dataset):
    """Create all possible pairs (x,y) and return loss and accuracy for G_1 and G_2"""
    test_labels_x = t.tensor([num for num in range(params.N) for _ in range(params.N)]).to(device)
    test_labels_y = t.tensor([num % params.N for num in range(params.N * params.N)]).to(device)

    logits = model([test_labels_x, test_labels_y])
    labels_group_1 = rearrange(Group_Dataset.group1, "a b-> (a b)").to(device)
    labels_group_2 = rearrange(Group_Dataset.group2, "a b-> (a b)").to(device)

    loss_group_1 = loss_fn(logits, labels_group_1)
    loss_group_2 = loss_fn(logits, labels_group_2)

    accuracy_group_1 = get_accuracy(logits, labels_group_1)
    accuracy_group_2 = get_accuracy(logits, labels_group_2)

    return (loss_group_1, loss_group_2), (accuracy_group_1, accuracy_group_2)

random.seed(42)

def random_indices(full_dataset, params):
    """Picks random subset of indices the data given"""
    num_indices = int(len(full_dataset) * params.train_frac)
    picked_indices = random.sample(list(range(len(full_dataset))), num_indices)
    return picked_indices


Training function

In [None]:
def train(model, params):
    current_time = datetime.today().strftime("%Y-%m-%d %H:%M:%S")
    wandb.init(

        project="Grokking ambiguous data",
        name=f"experiment_{current_time}",
        config={
            "Epochs": params.num_epoch,
            "Batch size": params.batch_size,
            "Cardinality": params.N,
            "Embedded dimension": params.embed_dim,
            "Hidden dimension": params.hidden_size,
            "Training": (params.data_group1, params.data_group2),
            "Added points": (params.add_points_group1, params.add_points_group2),
            "Train frac": params.train_frac,
            "Weight decay": params.weight_decay,
            "Learning rate": params.lr,
            "Warm up steps": params.warmup_steps,
        },
    )
    Group_Dataset = GroupData(params=params)

    train_data = t.utils.data.Subset(
        Group_Dataset, random_indices(Group_Dataset, ExperimentsParameters)
    )
    train_loader = DataLoader(
        dataset=train_data,
        batch_size=len(train_data),
        shuffle=True,
        drop_last=False
    )

    criterion = t.nn.CrossEntropyLoss()

    if params.optimizer == "sgd":
        optimizer = t.optim.SGD(model.parameters(), lr=params.lr)
    if params.optimizer == "adam":
        optimizer = t.optim.Adam(
            model.parameters(),
            weight_decay=params.weight_decay,
            lr=params.lr,
        )
    if params.optimizer == "adamw":
        optimizer = t.optim.AdamW(
            model.parameters(),
            weight_decay=params.weight_decay,
            lr=params.lr,
            betas=[params.beta_1, params.beta_2],
        )

    average_loss_training = 0
    step = 0
    for epoch in range(params.num_epoch):
        with t.no_grad():
            model.eval()

            average_loss_training = average_loss_training / (params.max_steps_per_epoch)

            losses_test, accuracies_test = test_loss(model, params, Group_Dataset)
            wandb.log({"Loss G_1": losses_test[0], "Loss G_2": losses_test[1]})
            wandb.log(
                {"Accuracy G_1": accuracies_test[0], "Accuracy G_2": accuracies_test[1]}
            )
            wandb.log({"Training loss": average_loss_training})
            average_loss_training = 0
        for x, z in train_loader:
            global_step = epoch * len(train_data) + step
            if global_step < params.warmup_steps:
                lr = global_step * params.lr / float(params.warmup_steps)
            else:
                lr = params.lr
            for g in optimizer.param_groups:
                g["lr"] = lr

            model.train()
            optimizer.zero_grad()
            output = model(x)
            loss = criterion(output, z)
            average_loss_training += loss.item()
            loss.backward()
            optimizer.step()
            step += 1

    wandb.finish()

## 4. Running experiments

We ran the following experiments:

1.   Train a network on a random subset of size $40\%$ of $D_1$ or $D_2$ respectively (i.e. this is the usual grokking set up)
2.   Train a network on $D_{12}$
3.   Train a network on $40\%$ of $D_{12}$
4.   Train a network on $D_{12}$ but add a small amount of points from $D_1$ or $D_2$ respectively



To run the experiments yourself, make sure that you ran all the cells in the previous "Code" section. You will also need a wandb API key. Alternatively, you can just skip to the next section for a summary.

### 4.1. Usual grokking
Train a network on a random subset of size $40\%$ of $D_1$ and $D_2$ respectively i.e. this is the usual grokking set up. This is mostly a sanity check, that the model with its parameters is indeed capable of grokking. Nothing unusual happening here.

In [None]:
ExperimentsParameters = Parameters(data_group1= True, data_group2=False, train_frac=0.4)

model = MLP2(ExperimentsParameters).to(device=device)

train(model=model, params=ExperimentsParameters)

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded\r'), FloatProgress(value=0.10963748894783377, max=1.…

0,1
Accuracy G_1,▁▂▃▄▄▄▄▄▄▄▄▄▅▇▇█████████████████████████
Accuracy G_2,▁▂▃▄▄▄▄▄▄▄▄▄▅▇▇█████████████████████████
Loss G_1,▄▅▇██▇▆▅▄▄▃▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss G_2,▂▃▅▇█▇▆▅▄▃▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Training loss,█▅▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
Accuracy G_1,0.996
Accuracy G_2,0.7484
Loss G_1,0.04031
Loss G_2,3.68317
Training loss,0.0013


Now let's grok $D_2$.

In [None]:
ExperimentsParameters = Parameters(data_group1= False, data_group2= True, train_frac=0.4)

model = MLP2(ExperimentsParameters).to(device=device)

train(model=model, params=ExperimentsParameters)

### 4.2. Train a network on ambiguous training data $D_{12}$

Now we train the model on the full ambiguous dataset $D_{12}$.

In [None]:
ExperimentsParameters = Parameters()

model = MLP2(ExperimentsParameters).to(device=device)

train(model=model, params=ExperimentsParameters)

### 4.3. Train a network on 40% of the ambiguous training data $D_{12}$

In [None]:
ExperimentsParameters = Parameters(train_frac=0.4)

model = MLP2(ExperimentsParameters).to(device=device)

train(model=model, params=ExperimentsParameters)

### 4.4. Train a network on ambiguous training data $D_{12}$ and add points form $D_1\setminus D_2$ or $D_2\setminus D_1$

In [None]:
ExperimentsParameters = Parameters(add_points_group1 = 1) # Alternatively add_points_group2

model = MLP2(ExperimentsParameters).to(device=device)

train(model=model, params=ExperimentsParameters)

## 5. Results


So far, most of our focus has been on experiment 4.2, i.e. we train the model on the full ambiguous data set $D_{12}$.
Let us first compare the accuracy for $f_1$ and $f_2$ (denoted by G_1 and G_2 in the graph below).




<img src="http://cogeometry.com/assets/accuracy1.png" width =500>
<img src="http://cogeometry.com/assets/accuracy2.png" width =500>

Interactive version: [Accuracy G_1](https://api.wandb.ai/links/louisyryj/18z3zs3n)
[Accuracy G_2](https://api.wandb.ai/links/louisyryj/l48upmwf)

We make the following observations:


1.   In three cases $f_1$ was fully grokked,
2.   On the other hand $f_2$ was never grokked,
3.   In most cases the accuracy of $f_1$ plateaus at some intermediate level,
4.   The accuracy of $f_2$ bumps up once. But notice that in this run the accuracy of $f_1$ and $f_2$ goes up simulatenously (around 400-800 steps).



Some consequences are:

1. It seems like $f_1$ is easier to learn than $f_2$. From a human intuition point of view this is reasonable. $f_1$ is a commutative group, as opposed to $f_2$ which is not commutative and by default seems easier to comprehend. Group operations are monsters of symmetry and the easiest form of a symmetry is $x+_1y=y+_1x$,
2. But this raises the question: What happens at most of the other runs? Can we give some comprehensible interpretation to the intermediate levels?
3. How can we understand the bump which occurs simulatenously in $f_1$ and $f_2$?



Let's look at the loss curves.

<img src=http://cogeometry.com/assets/loss1.png width =500>
<img src=http://cogeometry.com/assets/loss2.png width =500>

Interactive version: [Loss G_1](https://api.wandb.ai/links/louisyryj/xa62ez8o)
[Loss G_2](https://api.wandb.ai/links/louisyryj/dyz4kpy0)

We observe again that $f_2$ exhibits some sort of slingshot behaviour between 400-800 steps.

## 6. Evaluation of results


Let's focus on two runs (more to be added).

The first corresponds to 4.1 and the second to 4.3.




### 6.1 Usual grokking and local learning coefficient measure (as in 4.1)

Let's look at the losses:

<img src=http://cogeometry.com/assets/Grokking_1.png>

This is what we expected to see. The training loss (in yellow) goes down very quickly. The test loss of $G_1$ goes down to approx. $0$ eventually.

Here is a visualization of the training run (open gif in new tab to view the run again):

<img src=http://cogeometry.com/assets/Grokking_1.gif>

Each cell corresponds to one data point $x\in X$. A cell is:
- orange, if the model learned $G_1$,
- green, if the model learned $G_2$,
- blue, if the model learned the data point correctly and the output is the same for $G_1$ and $G_2$.


In this case we see that roughly $75$% is blue and since we grokked $G_1$, the rest is orange.


Now let's look at the local learning coefficient and the accuracy of $G_1$.

<img src=http://cogeometry.com/assets/llc_grokking_1.png>



Something that is interesting here: The LLC changes over time and indicates that the model is doing something!

Of course in this case, we knew what to test for (namely the accuracy of $G_1$). But imagine we were training the model without this knowledge. Then we would just see (almost) $0$ training loss, but not that something is happening in the background!

The measurements are still a bit all over the place. Better calibration (running more finegrained sweeps over LLC estimation hyperparameters) and more accurate measurements (sampling more chains from the local loss landscape) might give a better picture.

### 6.2 Grokking and local learning coefficient measure in the ambiguous case (as in 4.3)



Let's again look at the losses first:
<img src=http://cogeometry.com/assets/Grokking_12.png>

There is a weird slingshot behaviour. Although $G_2$ starts off better, in the end $G_1$ is grokked.
Let's look at the training run again (open gif in new tab
 to see the run again)

<img src=http://cogeometry.com/assets/Grokking_12.gif>

We can see that at initialization that $G_2$ is prefered: in the first frame at epoch $0$ there are more green points, But then eventually $G_1$ trumps it, as also the above losses indicate.

Now let's look at the local learning coefficient and the accuracy of $G_1$.

<img src=http://cogeometry.com/assets/llc_grokking_12.png>

This is slightly odd: It seems like we are grokking $G_1$ twice.


First, around step $5$ and then again around $20$ (corresponding to epoch $285$ and $1140$).

Also, while its seems like step $5$ is spotted by the LLC, the change in step $20$ goes unchanged.


## 7. Further steps

**Next steps which are already in progress:**

- Run more experiments to get more reliable data on outcomes
- Run experiments where there are more than $2$ group structures that can be learned

**Beyond that**:


- Try to solve the mystery of the intermediate grokking and ungrokking
- Start measuring the formation of circuit (size)