## Progress Measures for Grokking on Real-world Tasks

This is the code for the main results of the ICML workshop paper.
Paper link: https://arxiv.org/abs/2405.12755

In [1]:
from collections import defaultdict
from itertools import islice
import random
import time
from pathlib import Path
import math
from IPython.display import clear_output
import numpy as np
from tqdm.auto import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.distributions import Categorical
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from plotly.subplots import make_subplots

  warn(


In [2]:
config = {
    'train_points': 1000,
    'test_points': 1000,
    'optimization_steps': 100000,
    'batch_size': 1000,
    'loss_function': nn.MSELoss,
    'optimizer': torch.optim.AdamW,
    # 'weight_decay': 0.01,
    'weight_decay': 0.01,
    'lr': 1e-3,
    # 'lr': 1e-2,
    # 'initialization_scale': 8.0,
    'initialization_scale': 8.0,
    'download_directory': ".",
    # 'depth': 3,
    'depth': 3,
    # 'width': 200,
    'width': 200,
    'activation': nn.ReLU,
    'log_freq': math.ceil(100000 / 150),
    'device': torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'),
    'dtype': torch.float64
}

In [3]:
def compute_accuracy(network, x, y):
    with torch.no_grad():
        correct, total = 0, 0
        logits = network(x)
        y_pred = torch.argmax(logits, dim=1)
        correct += torch.sum(y_pred == y)
        total += x.size(0)
        return (correct / total).item()

def compute_loss(network, x, y, config):
    with torch.no_grad():
        loss_fn = config['loss_function']
        one_hots = torch.eye(10, 10).to(config['device'])
        total = 0
        points = 0
        logits = network(x)
        total += loss_fn()(logits, one_hots[y]).item()
        points += x.size(0)
        return total / points

In [4]:
def data(config, seed=0):

    seed = seed
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    np.random.seed(seed)

    # load dataset
    train = torchvision.datasets.MNIST(root=config['download_directory'], train=True, transform=torchvision.transforms.ToTensor(), download=True)
    test = torchvision.datasets.MNIST(root=config['download_directory'], train=False, transform=torchvision.transforms.ToTensor(), download=True)

    train_loader = torch.utils.data.DataLoader(torch.utils.data.Subset(train, range(config['train_points'])), batch_size=config['batch_size'], shuffle=True)
    test_loader = torch.utils.data.DataLoader(torch.utils.data.Subset(test, range(config['test_points'])), batch_size=config['batch_size'], shuffle=True)

    train_x, train_y = next(iter(train_loader))
    test_x, test_y = next(iter(test_loader))

    train_x = train_x.to(config['device'])
    train_y = train_y.to(config['device'])
    test_x = test_x.to(config['device'])
    test_y = test_y.to(config['device'])

    return train_x, train_y, test_x, test_y

In [5]:
class Model(nn.Module):

    def __init__(self, config):
        super(Model, self).__init__()
        self.flatten = nn.Flatten()
        self.layers = nn.ModuleList()
        self.dropout = nn.Dropout(p=0.3)
        for i in range(config['depth']):
            if i == 0:
                self.layers.append(nn.Linear(784, config['width']))
                self.layers.append(config['activation']())
            elif i == config['depth'] - 1:
                self.layers.append(nn.Linear(config['width'], 10))
            else:
                self.layers.append(nn.Linear(config['width'], config['width']))
                self.layers.append(config['activation']())

    def forward(self, x):
        x = self.flatten(x)
        for layer in self.layers:
            # if isinstance(layer, nn.Linear):
            #     x = self.dropout(x)
            x = layer(x)
        return x

def create_model(config):
    m = Model(config).to(config['device'])
    with torch.no_grad():
        for p in m.parameters():
            p.data = config['initialization_scale'] * p.data
    return m

In [6]:
def dynamic_plot_v2(config, log):
    # Convert log data to numpy arrays
    log_steps = np.array(log['log_steps'])
    norms = np.array([n.cpu().detach() for n in log['norms']])
    sparsities = np.array([s.cpu().detach() for s in log['sparsities']])
    weights_entropies = np.array([w.cpu().detach() for w in log['weights_entropies']])
    greedy_circuit_complexities = np.array([c.cpu().detach() for c in log['greedy_circuit_complexities']])
    
    # Create a subplot with 5 columns and 1 row
    fig = make_subplots(rows=1, cols=5, 
                        # subplot_titles=(
        # 'Train/Test Accuracy', 
        # 'Weight Norm (L2)', 
        # 'Activation Sparsity', 
        # 'Weight Entropy', 
        # 'Local Circuit Complexity')
    )

    # Plot Train and Test Accuracy on the first subplot
    fig.add_trace(go.Scatter(opacity=1, line_smoothing=0.8, x=log_steps, y=log['train_accuracies'], 
                             mode='lines', name='Train Accuracy', line=dict(width=4, color='red')),
                  row=1, col=1)
    fig.add_trace(go.Scatter(opacity=1, line_smoothing=0.8, x=log_steps, y=log['test_accuracies'], 
                             mode='lines', name='Test Accuracy', line=dict(width=4, color='green')),
                  row=1, col=1)

    # Plot Weight Norm (L2) on the second subplot
    fig.add_trace(go.Scatter(opacity=1, line_smoothing=0.8, x=log_steps, y=norms, 
                             mode='lines', name='Weight Norm (L2)', line=dict(width=4, color='purple')),
                  row=1, col=2)

    # Plot Activation Sparsity on the third subplot
    fig.add_trace(go.Scatter(opacity=1, line_smoothing=0.8, x=log_steps, y=sparsities, 
                             mode='lines', name='Activation Sparsity', line=dict(width=4, color='blue')),
                  row=1, col=3)

    # Plot Weight Entropy on the fourth subplot
    fig.add_trace(go.Scatter(opacity=1, line_smoothing=0.8, x=log_steps, y=weights_entropies, 
                             mode='lines', name='Weight Entropy', line=dict(width=4, color='orange')),
                  row=1, col=4)

    # Plot Local Circuit Complexity on the fifth subplot
    fig.add_trace(go.Scatter(opacity=1, line_smoothing=0.8, x=log_steps, y=greedy_circuit_complexities, 
                             mode='lines', name='Local Circuit Complexity', line=dict(width=4, color='black')),
                  row=1, col=5)

    # Update layout
    fig.update_layout(autosize=False, width=1250, height=350,
                      plot_bgcolor='rgba(255, 255, 255, 1)',
                      legend=dict(font=dict(size=15, color="black")),
                      legend_title=dict(font=dict(size=20, color="blue")))

    # horizontal legend at the bottom
    fig.update_layout(legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1))

    # Update x-axis for all subplots
    for i in range(1, 6):
        fig.update_xaxes(title_text="Optimization Steps", type='log', row=1, col=i)
        fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='LightGray', row=1, col=i)

    # Update grid lines for y-axis
    for i in range(1, 6):
        fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='LightGray', row=1, col=i)

    # fig.add_vrect(
    #     x0=1300,
    #     x1=15000,
    #     fillcolor="blue",
    #     opacity=0.15,
    #     line_width=0,
    # )

    # fig.add_vrect(
    #     x0=40,
    #     x1=1300,
    #     fillcolor="red",
    #     opacity=0.1,
    #     line_width=0,
    # )

    fig.add_vrect(
        x0=15000,
        x1=100000,
        fillcolor="green",
        opacity=0.2,
        line_width=0,
    )

    # add another legend for the colored rectangles dict(size=8, symbol="diamond", line=dict(width=2, color="DarkSlateGrey")
    # fig.add_trace(go.Scatter(x=[None], y=[None], mode='markers', marker=dict(size=2, color='blue'), showlegend=True, name='Overfitting', legendgroup='Pre-Training'))
    # fig.add_trace(go.Scatter(x=[None], y=[None], mode='markers', marker=dict(size=20, symbol='square', opacity=0.1, color='red'), showlegend=True, name='Overfitting', legendgroup='Overfitting'))
    fig.add_trace(go.Scatter(x=[None], y=[None], mode='markers', marker=dict(size=20, symbol='square', opacity=0.2, color='green'), showlegend=True, name='Generalization', legendgroup='Generalization'))

    return fig


In [7]:
def dynamic_plot(config, log):
    fig = make_subplots(specs=[[{"secondary_y": True}]])

    # first, scale everything from 0 to 1 after converting to numpy
    log_steps = np.array(log['log_steps'])
    norms = np.array([n.cpu().detach() for n in log['norms']])
    sparsities = np.array([s.cpu().detach() for s in log['sparsities']])
    weights_entropies = np.array([w.cpu().detach() for w in log['weights_entropies']])
    greedy_circuit_complexities = np.array([c.cpu().detach() for c in log['greedy_circuit_complexities']])
    norms = (norms - min(norms)) / (max(norms) - min(norms) + 1e-8)
    sparsities = (sparsities - min(sparsities)) / (max(sparsities) - min(sparsities) + 1e-8)
    weights_entropies = (weights_entropies - min(weights_entropies)) / (max(weights_entropies) - min(weights_entropies) + 1e-8)
    greedy_circuit_complexities = (greedy_circuit_complexities - min(greedy_circuit_complexities)) / (max(greedy_circuit_complexities) - min(greedy_circuit_complexities) + 1e-8)

    fig.add_trace(go.Scatter(opacity=1, line_smoothing=0.8, x=log_steps, y=log['train_accuracies'], mode='lines', name='train accuracy', line=dict(width=4, color='red')),secondary_y=False,)
    fig.add_trace(go.Scatter(opacity=1, line_smoothing=0.8, x=log_steps, y=log['test_accuracies'], mode='lines', name='test accuracy', line=dict(width=4, color='green')),secondary_y=False,)
    fig.add_trace(go.Scatter(opacity=0.5, line_smoothing=0.8, x=log_steps, y=norms, mode='lines', name='weight norm (L2)', line=dict(width=2, color='purple')),secondary_y=True,)
    fig.add_trace(go.Scatter(opacity=0.5, line_smoothing=0.8, x=log_steps, y=sparsities, mode='lines', name='activation sparsity', line=dict(width=2, color='blue')),secondary_y=True,)
    fig.add_trace(go.Scatter(opacity=0.5, line_smoothing=0.8, x=log_steps, y=weights_entropies, mode='lines', name='weight entropy', line=dict(width=2, color='orange')),secondary_y=True,)
    fig.add_trace(go.Scatter(opacity=0.5, line_smoothing=0.8, x=log_steps, y=greedy_circuit_complexities, mode='lines', name='local circuit complexity', line=dict(width=2, color='black')),secondary_y=True,)

    fig.update_layout(title_text=f"depth-3 width-200 ReLU MLP on MNIST\nUnconstrained Optimization α = {config['initialization_scale']}")

    fig.update_xaxes(title_text="Optimization Steps", type='log')
    fig.update_yaxes(title_text="<b>Accuracy</b>", secondary_y=False)

    fig.update_layout({'plot_bgcolor': 'rgba(255, 255, 255, 1)',})

    fig.update_layout(
        xaxis=dict(showgrid=True, gridwidth=1, gridcolor='LightGray'),
        yaxis=dict(showgrid=True, gridwidth=1, gridcolor='LightGray'),
    )

    fig.update_layout(autosize=False, width=1000, height=600,)

    fig.update_traces(line=dict(width=4))

    # font size
    fig.update_layout(legend = dict(font = dict(size = 20, color = "black")),
                  legend_title = dict(font = dict(size = 20, color = "blue")))

    return fig

In [8]:
def model_pruner(base_model, kappa, config):
    """Prunes the weights of `base_model` to create a new model of the same class.
    """
    edited_model = create_model(config)
    state_dict = base_model.state_dict().copy()
    for key in state_dict:
        if 'weight' in key:
            state_dict[key] = F.dropout(state_dict[key], p=kappa)
    edited_model.load_state_dict(state_dict)
    return edited_model

def logit_kl_div(model1, model2, x_train, config):
    # model1.eval()
    # model2.eval()
    logits1 = model1(x_train) 
    logits2 = model2(x_train) 
    p1 = torch.nn.functional.softmax(logits1, dim=1) + 1e-8
    p2 = torch.nn.functional.softmax(logits2, dim=1) + 1e-8
    kl_div = torch.sum(p1 * (torch.log(p1) - torch.log(p2)), dim=1)
    return torch.mean(kl_div)

In [9]:
def grokking_measure(log):
    steps = len(log['train_accuracies'])
    train_cutoff, test_cutoff = 0, 0
    while train_cutoff < steps and log['train_accuracies'][train_cutoff] < 0.9:
        train_cutoff += 1
    while test_cutoff < steps and log['test_accuracies'][test_cutoff] < 0.8:
        test_cutoff += 1
    print(train_cutoff, test_cutoff)
    if train_cutoff >= steps or test_cutoff >= steps:
        log['grokking'] = 0 # no grokking
    else:
        log['grokking'] = (test_cutoff - train_cutoff) / steps
    return log['grokking']

In [10]:
def progress_measures(
        model, 
        x_train, 
        layer_idx_sparsity=2,
        layer_idx_entropy=2, 
        tau_sparsity=1,
        config=None,
    ):

    activations = []

    x = x_train.clone().reshape(-1, 784)
    for i, layer in enumerate(model.layers):
        x = layer(x)
        if i == layer_idx_sparsity:
            activations.append(x.cpu())
            # sparsity is the fraction of activations that are less than tau
            sparsity = torch.mean((x < tau_sparsity).float()).item()
            diff = x - tau_sparsity
            beta = 0.3
            smooth_comparison = torch.sigmoid(-diff / beta)  # beta is a hyperparameter to control the sharpness
            diff_sparsity = torch.mean(smooth_comparison)

    # weight sparsity
    model_layer_weights = model.layers[layer_idx_sparsity].weight
    weights_sparsity = torch.mean((model_layer_weights < tau_sparsity).float())

    cloned_model = create_model(config)
    cloned_model.load_state_dict(model.state_dict().copy())
    model_layer_weights = cloned_model.layers[layer_idx_entropy].weight
    weights_entropy = -torch.sum(torch.abs(model_layer_weights) * torch.log(torch.abs(model_layer_weights)))
    greedy_circuit_complexity = logit_kl_div(model, model_pruner(model, 0.5, config), x_train, config)
    activations = torch.cat(activations, dim=1)

    return activations, sparsity, diff_sparsity, weights_entropy, greedy_circuit_complexity

In [11]:
def smooth_sparsity(model, x_train, layer_idx, tau):
    x = x_train.clone().reshape(-1, 784)
    diff_sparsity = 0
    for i, layer in enumerate(model.layers):
        x = layer(x)
        if isinstance(layer, nn.Linear):
            diff = x - tau
            beta = 0.1
            smooth_comparison = torch.sigmoid(-diff / beta)  # beta is a hyperparameter to control the sharpness
            diff_sparsity += torch.mean(smooth_comparison)
    return diff_sparsity

def smooth_weight_sparsity(model, layer_idx, tau):
    model_layer_weights = model.layers[layer_idx].weight
    diff = model_layer_weights - tau
    beta = 0.1
    smooth_comparison = torch.sigmoid(-diff / beta)  # beta is a hyperparameter to control the sharpness
    diff_sparsity = torch.mean(smooth_comparison)
    return diff_sparsity

In [12]:
def smooth_entropy(model, x_train, layer_idx):
    x = x_train.clone().reshape(-1, 784)
    weights_entropy = 0
    for i, layer in enumerate(model.layers):
        x = layer(x)
        if isinstance(layer, nn.Linear):
            weights_entropy += -torch.sum(torch.abs(layer.weight) * torch.log(torch.abs(layer.weight)))
    return weights_entropy

In [13]:
def classwise_accuracy(model, x, y):
    y_classes = [i for i in range(10)]
    classwise_accuracies = []
    for c in y_classes:
        idx = (y == c)
        x_class = x[idx]
        y_class = y[idx]
        logits = model(x_class)
        y_pred = torch.argmax(logits, dim=1)
        correct = torch.sum(y_pred == y_class)
        total = x_class.size(0)
        classwise_accuracies.append((correct / total).item())
    return classwise_accuracies

In [14]:
def logger(log, steps, model, train_x, train_y, test_x, test_y, config):
    log['train_losses'].append(compute_loss(model, train_x, train_y, config))
    log['train_accuracies'].append(compute_accuracy(model, train_x, train_y))
    log['test_losses'].append(compute_loss(model, test_x, test_y, config))
    log['test_accuracies'].append(compute_accuracy(model, test_x, test_y))
    log['log_steps'].append(steps)
    temp_model = create_model(config)
    temp_model.load_state_dict(model.state_dict().copy())
    total = sum(torch.pow(p, 2).sum() for p in temp_model.parameters())
    log['norms'].append(torch.sqrt(total))
    last_layer = sum(torch.pow(p, 2).sum() for p in model.layers[-1].parameters())
    log['last_layer_norms'].append(float(np.sqrt(last_layer.item())))
    activations, sparsity, diff_sparsity, weights_entropy, greedy_circuit_complexity = progress_measures(model, train_x, config=config)
    log['sparsities'].append(diff_sparsity)
    log['weights_entropies'].append(weights_entropy)
    log['greedy_circuit_complexities'].append(greedy_circuit_complexity)
    return log

In [17]:
model = create_model(config)
train_x, train_y, test_x, test_y = data(config)
optimizer = config['optimizer'](model.parameters(), lr=config['lr'], weight_decay=config['weight_decay'])
steps = 0
one_hots = torch.eye(10, 10).to(config['device'])
checkpoints_path = '/home/b-sgolechha/grokking/model_checkpoints/',
saving_steps = [20, 2000, 20000, 90000],
save_models = False

log = {
    'train_losses': [],
    'test_losses': [],
    'train_accuracies': [],
    'test_accuracies': [],
    'norms': [],
    'last_layer_norms': [],
    'log_steps': [],
    'sparsities': [],
    'weights_entropies': [],
    'classwise_accuracies': [],
    'greedy_circuit_complexities': [],
    'grokking': float,
}

In [18]:
for i in range(config['optimization_steps']):
    if save_models and steps in saving_steps:
        torch.save(model.state_dict(), checkpoints_path + f"mnist_model_{steps}.pt")
    if (steps < 30) or (steps < 150 and steps % 10 == 0) or steps % config['log_freq'] == 0:
        log = logger(log, steps, model, train_x, train_y, test_x, test_y, config)
        clear_output(wait=True)
        dynamic_plot_v2(config, log).show()
        log['classwise_accuracies'].append(classwise_accuracy(model, train_x, train_y))

    optimizer.zero_grad()
    logits = model(train_x)
    loss = config['loss_function']()(logits, one_hots[train_y])
    norm = sum([torch.sum(torch.pow(model.layers[i].weight, 2)) for i in [0, 2, 4]])
    coeff = 0.00000000018
    # coeff = 0
    loss -= coeff * norm
    loss.backward()
    optimizer.step()
    steps += 1

In [17]:
print(grokking_measure(log))

43 191
0


In [43]:
fig = dynamic_plot_v2(config, log)
# fig.show()
path = './plots/dropout_mnist.pdf'
fig.write_image(path)

In [None]:
print(log['sparsities'][:5], log['sparsities'][-5:])

[tensor(0.5792, device='cuda:0', grad_fn=<MeanBackward0>), tensor(0.5863, device='cuda:0', grad_fn=<MeanBackward0>), tensor(0.5940, device='cuda:0', grad_fn=<MeanBackward0>), tensor(0.6017, device='cuda:0', grad_fn=<MeanBackward0>), tensor(0.6092, device='cuda:0', grad_fn=<MeanBackward0>)] [tensor(0.9409, device='cuda:0', grad_fn=<MeanBackward0>), tensor(0.9417, device='cuda:0', grad_fn=<MeanBackward0>), tensor(0.9426, device='cuda:0', grad_fn=<MeanBackward0>), tensor(0.9423, device='cuda:0', grad_fn=<MeanBackward0>), tensor(0.9432, device='cuda:0', grad_fn=<MeanBackward0>)]


In [None]:
## RQ: Can we induce grokking by using sparsity losses without large inits and weight decay?

In [None]:
# total number of parameters in the model
total_params = sum(p.numel() for p in model.parameters())
total_params
# average absolute value of the weights in the model
total_weights = sum(torch.sum(torch.abs(p)) for p in model.parameters())
total_weights / total_params

tensor(4.7822, device='cuda:0', grad_fn=<DivBackward0>)

In [None]:
new_model = create_model(config)
total_weights = sum(torch.sum(torch.abs(p)) for p in new_model.parameters())
total_weights / total_params

tensor(0.1724, device='cuda:0', grad_fn=<DivBackward0>)

In [None]:
def plot_classwise_accuracies(classwise_accuracies):
    # plot each class with different color and plot after each epoch
    fig = go.Figure()
    for i in range(10):
        fig.add_trace(go.Scatter
                        (x=[j for j in range(len(classwise_accuracies))],
                         y=[c[i] for c in classwise_accuracies],
                         mode='lines',
                         name=f'{i}'))
    fig.update_layout(title_text='Classwise Accuracies (Test Set)')
    fig.update_xaxes(title_text='Epochs')
    fig.update_yaxes(title_text='Accuracy')
    # white background
    fig.update_layout({'plot_bgcolor': 'rgba(255, 255, 255, 1)',})
    # fine grids
    fig.update_layout(
        xaxis=dict(showgrid=True, gridwidth=1, gridcolor='LightGray'),
        yaxis=dict(showgrid=True, gridwidth=1, gridcolor='LightGray'),
    )
    # square
    fig.update_layout(autosize=False, width=800, height=600,)
    fig.show()

In [None]:
plot_classwise_accuracies(log['classwise_accuracies'])

In [None]:
# can we train a model to perform on the test set even at high norms?

In [None]:
def eval_hacky_model(model, test_x, test_y):
    l2 = torch.sqrt(sum(torch.pow(p, 2).sum() for p in model.parameters())).item()
    acc = compute_accuracy(model, test_x, test_y)
    return f"Norm: {round(l2, 2)}, Accuracy: {round(acc, 3)}"

In [None]:
eval_hacky_model(model, test_x, test_y)

'Norm: 1025.71, Accuracy: 0.834'

In [None]:
def hacky_model(model):
    new_model = create_model(config)
    new_model.load_state_dict(model.state_dict().copy())
    return new_model

In [None]:
# train a new model to have same logits as the old model
new_model = create_model(config)
old_model_copy = create_model(config)
old_model_copy.load_state_dict(model.state_dict().copy())
new_optim = config['optimizer'](new_model.parameters(), lr=config['lr']*10)
for i in range(20000):
    logits = model(train_x)
    new_logits = new_model(train_x)
    new_loss = torch.nn.functional.mse_loss(logits, new_logits)
    nwd = 0.0001 / (1 + 0.001 * i)
    new_loss -= nwd * new_model.layers[0].weight.norm(2)
    new_loss -= nwd * new_model.layers[2].weight.norm(2)
    new_loss -= nwd * new_model.layers[4].weight.norm(2)
    new_model.zero_grad()
    new_loss.backward()
    new_optim.step()
    if i % 1000 == 0:
        print(f"Step {i}: {new_loss.item()},{eval_hacky_model(new_model, train_x, train_y)},{eval_hacky_model(new_model, test_x, test_y)}")

Step 0: 306.3731689453125,Norm: 94.32, Accuracy: 0.064,Norm: 94.32, Accuracy: 0.053


Step 1000: -0.05345746502280235,Norm: 1703.95, Accuracy: 0.863,Norm: 1703.95, Accuracy: 0.433
Step 2000: -0.08384177833795547,Norm: 2875.98, Accuracy: 0.955,Norm: 2875.98, Accuracy: 0.541
Step 3000: -0.09305138885974884,Norm: 3863.72, Accuracy: 0.983,Norm: 3863.72, Accuracy: 0.588
Step 4000: -0.09603177756071091,Norm: 4761.36, Accuracy: 0.994,Norm: 4761.36, Accuracy: 0.629
Step 5000: -0.09665238112211227,Norm: 5606.58, Accuracy: 0.998,Norm: 5606.58, Accuracy: 0.64
Step 6000: -0.09623942524194717,Norm: 6409.67, Accuracy: 0.999,Norm: 6409.67, Accuracy: 0.647
Step 7000: -0.09498421102762222,Norm: 7167.63, Accuracy: 1.0,Norm: 7167.63, Accuracy: 0.657
Step 8000: -0.0931207537651062,Norm: 7873.8, Accuracy: 1.0,Norm: 7873.8, Accuracy: 0.644
Step 9000: -0.09128343313932419,Norm: 8525.03, Accuracy: 1.0,Norm: 8525.03, Accuracy: 0.642
Step 10000: -0.0891101211309433,Norm: 9118.96, Accuracy: 1.0,Norm: 9118.96, Accuracy: 0.649
Step 11000: -0.08679385483264923,Norm: 9655.28, Accuracy: 1.0,Norm: 9655

In [None]:
# prediction of this new model on a sample train point
logits = new_model(train_x[0].reshape(1, -1))
print(logits)
print(f'Correct: {train_y[0]}')
# train loss on this point
l

tensor([[0.2179, 0.1112, 0.0868, 0.0335, 0.1071, 0.2396, 0.1821, 0.1154, 0.1765,
         0.1897]], device='cuda:0', grad_fn=<AddmmBackward0>)
Correct: 9


This should be good enough for now.

Starting something new at `information_theory/si_score_cnn_sae.ipynb`.