# Hierarchical equality

Reproducing the hierarchical equality experiment in [Finding Alignments Between Interpretable Causal Variables and Distributed Neural Representations](https://arxiv.org/pdf/2303.02536.pdf).

In [13]:
FORCE_CPU = False
SEED = 2384

HIDDEN_SIZE = 16
TASK_INPUT_SIZE = 1
TASK_TRAIN_SIZE = 2**22
TASK_TEST_SIZE = 2**12
TASK_BATCH_SIZE = 2048
TASK_LR = 1e-3
TASK_EPOCHS = 100

SAVE_MODEL = True
LOAD_MODEL = False
MODEL_PATH = "saved_models/hierarchical-equality.pt"
MODEL_TRAIN_DETAILS_PATH = "saved_models/hierarchical-equality-train-details.pickle"

## Setup

In [2]:
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/SamAdamDay/mechanistic-interpretability-projects.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

import plotly.io as pio
pio.renderers.default = "colab+vscode"

Running as a Jupyter notebook
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


In [14]:
from typing import Tuple
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

import numpy as np
from numpy.typing import NDArray

from tqdm import tqdm

import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio

In [4]:
np.random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x7f0f7cbac710>

In [5]:
if not FORCE_CPU and torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
print(device)

cuda


## Train a model

In [6]:
# create a three layer MLP in torch
class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.layer1 = nn.Linear(input_size, hidden_size, bias=True)
        self.layer2 = nn.Linear(hidden_size, hidden_size, bias=True)
        self.layer3 = nn.Linear(hidden_size, output_size, bias=True)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.relu(self.layer1(x))
        x = self.relu(self.layer2(x))
        x = self.softmax(self.layer3(x))
        return x

The dataset is divided equally between positive and negative classes. The positive instances are equally divided between:
- (A, A, B, B)
- (A, B, C, D)

The negative instances are equally divided between:
- (A, A, B, C)
- (A, B, C, C)

(There are more options but I chose to omit them)

In [7]:
def generate_data(size: int, input_size: int) -> Tuple[NDArray, NDArray]:

    X =  np.empty((size, input_size, 4))
    y = np.empty((size,))

    group_size = size // 4
    # (A, A, B, B)
    X[:group_size, :, 0] = X[:group_size, :, 1] = np.random.standard_normal((group_size, TASK_INPUT_SIZE))
    X[:group_size, :, 2] = X[:group_size, :, 3] = np.random.standard_normal((group_size, TASK_INPUT_SIZE))
    # (A, B, C, D)
    X[group_size:2*group_size, :, :] = np.random.standard_normal((group_size, TASK_INPUT_SIZE, 4))
    # (A, A, B, C)
    X[2*group_size:3*group_size, :, 0] = X[2*group_size:3*group_size, :, 1] = np.random.standard_normal((group_size, TASK_INPUT_SIZE))
    X[2*group_size:3*group_size, :, 2:4] = np.random.standard_normal((group_size, TASK_INPUT_SIZE, 2))
    # (A, B, C, C)
    X[3*group_size:, :, 0:2] = np.random.standard_normal((group_size, TASK_INPUT_SIZE, 2))
    X[3*group_size:, :, 2] = X[3*group_size:, :, 3] = np.random.standard_normal((group_size, TASK_INPUT_SIZE))

    y[:size // 2] = 1
    y[size // 2:] = 0

    return X, y

In [8]:
train_x, train_y = generate_data(TASK_TRAIN_SIZE, TASK_INPUT_SIZE)
test_x, test_y = generate_data(TASK_TEST_SIZE, TASK_INPUT_SIZE)

print(train_x.shape)
print(train_y.shape)
print(test_x.shape)
print(test_y.shape)

(4194304, 1, 4)
(4194304,)
(4096, 1, 4)
(4096,)


In [9]:
train_dataset = TensorDataset(torch.from_numpy(train_x).float(), torch.from_numpy(train_y).float())
train_dataloader = DataLoader(train_dataset, batch_size=TASK_BATCH_SIZE, shuffle=True)

test_dataset = TensorDataset(torch.from_numpy(test_x).float(), torch.from_numpy(test_y).float())
test_dataloader = DataLoader(test_dataset, batch_size=TASK_BATCH_SIZE, shuffle=True)

In [10]:
model = MLP(TASK_INPUT_SIZE * 4, HIDDEN_SIZE, 2)
model.to(device)

print(model)

MLP(
  (layer1): Linear(in_features=4, out_features=16, bias=True)
  (layer2): Linear(in_features=16, out_features=16, bias=True)
  (layer3): Linear(in_features=16, out_features=2, bias=True)
  (relu): ReLU()
  (softmax): Softmax(dim=1)
)


In [11]:
if LOAD_MODEL:

    model.load_state_dict(torch.load(MODEL_PATH))

    with open(MODEL_TRAIN_DETAILS_PATH, "rb") as f:
        dict = pickle.load(f)
        train_losses = dict["train_losses"]
        test_losses = dict["test_losses"]
        test_accuracies = dict["test_accuracies"]

else:

    optimizer = optim.SGD(model.parameters(), lr=TASK_LR)
    loss_fn = nn.CrossEntropyLoss()

    train_losses = np.empty((TASK_EPOCHS,))
    test_losses = np.empty((TASK_EPOCHS,))
    test_accuracies = np.empty((TASK_EPOCHS,))

    for epoch in range(TASK_EPOCHS):
        total_train_loss = 0.0
        for batch in tqdm(train_dataloader, desc=f"Training [{epoch+1}/{TASK_EPOCHS}]"):
            batch_x, batch_y = batch
            batch_x = batch_x.flatten(start_dim=1)
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)

            output = model(batch_x)
            loss = loss_fn(output, batch_y.long())
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            total_train_loss += loss.item()

        avg_train_loss = total_train_loss / len(train_dataloader)
        train_losses[epoch] = avg_train_loss

        with torch.no_grad():

            total_test_loss = 0.0
            total_test_accuracy = 0.0
            
            for batch in test_dataloader:
                batch_x, batch_y = batch
                batch_x = batch_x.flatten(start_dim=1)
                batch_x = batch_x.to(device)
                batch_y = batch_y.to(device)

                output = model(batch_x)
                loss = loss_fn(output, batch_y.long())

                pred = torch.argmax(output, dim=1)
                total_test_accuracy += torch.sum(pred == batch_y.long()).item() / len(pred)

                total_test_loss += loss.item()

            avg_test_loss = total_test_loss / len(test_dataloader)
            test_losses[epoch] = avg_test_loss

            avg_test_accuracy = total_test_accuracy / len(test_dataloader)
            test_accuracies[epoch] = avg_test_accuracy

        print(f"Epoch [{epoch+1}/{TASK_EPOCHS}] train loss: {avg_train_loss:0.5f}, test acc: {avg_test_accuracy:0.5%}")

Training [1/100]:   0%|          | 0/2048 [00:00<?, ?it/s]

Training [1/100]: 100%|██████████| 2048/2048 [00:25<00:00, 81.69it/s] 


Epoch [0/100] train loss: 0.69513, test acc: 49.92676%


Training [2/100]: 100%|██████████| 2048/2048 [00:24<00:00, 84.56it/s] 


Epoch [1/100] train loss: 0.69414, test acc: 49.92676%


Training [3/100]: 100%|██████████| 2048/2048 [00:23<00:00, 87.10it/s] 


Epoch [2/100] train loss: 0.69362, test acc: 49.90234%


Training [4/100]: 100%|██████████| 2048/2048 [00:23<00:00, 85.88it/s] 


Epoch [3/100] train loss: 0.69334, test acc: 50.17090%


Training [5/100]: 100%|██████████| 2048/2048 [00:23<00:00, 86.59it/s] 


Epoch [4/100] train loss: 0.69318, test acc: 50.56152%


Training [6/100]: 100%|██████████| 2048/2048 [00:24<00:00, 84.19it/s] 


Epoch [5/100] train loss: 0.69308, test acc: 50.21973%


Training [7/100]: 100%|██████████| 2048/2048 [00:24<00:00, 84.79it/s] 


Epoch [6/100] train loss: 0.69300, test acc: 51.85547%


Training [8/100]: 100%|██████████| 2048/2048 [00:25<00:00, 80.20it/s]


Epoch [7/100] train loss: 0.69294, test acc: 51.39160%


Training [9/100]: 100%|██████████| 2048/2048 [00:25<00:00, 79.92it/s]


Epoch [8/100] train loss: 0.69288, test acc: 51.78223%


Training [10/100]: 100%|██████████| 2048/2048 [00:25<00:00, 79.25it/s]


Epoch [9/100] train loss: 0.69283, test acc: 51.73340%


Training [11/100]: 100%|██████████| 2048/2048 [00:25<00:00, 79.78it/s]


Epoch [10/100] train loss: 0.69277, test acc: 51.68457%


Training [12/100]: 100%|██████████| 2048/2048 [00:24<00:00, 82.15it/s]


Epoch [11/100] train loss: 0.69272, test acc: 51.73340%


Training [13/100]: 100%|██████████| 2048/2048 [00:24<00:00, 81.93it/s]


Epoch [12/100] train loss: 0.69266, test acc: 52.22168%


Training [14/100]: 100%|██████████| 2048/2048 [00:24<00:00, 82.04it/s]


Epoch [13/100] train loss: 0.69261, test acc: 52.34375%


Training [15/100]: 100%|██████████| 2048/2048 [00:24<00:00, 82.92it/s]


Epoch [14/100] train loss: 0.69255, test acc: 52.56348%


Training [16/100]: 100%|██████████| 2048/2048 [00:24<00:00, 82.11it/s]


Epoch [15/100] train loss: 0.69249, test acc: 52.85645%


Training [17/100]: 100%|██████████| 2048/2048 [00:24<00:00, 82.86it/s]


Epoch [16/100] train loss: 0.69242, test acc: 53.22266%


Training [18/100]: 100%|██████████| 2048/2048 [00:24<00:00, 83.12it/s] 


Epoch [17/100] train loss: 0.69236, test acc: 53.36914%


Training [19/100]: 100%|██████████| 2048/2048 [00:24<00:00, 84.08it/s]


Epoch [18/100] train loss: 0.69230, test acc: 53.61328%


Training [20/100]: 100%|██████████| 2048/2048 [00:24<00:00, 83.47it/s]


Epoch [19/100] train loss: 0.69223, test acc: 54.15039%


Training [21/100]: 100%|██████████| 2048/2048 [00:24<00:00, 83.13it/s]


Epoch [20/100] train loss: 0.69217, test acc: 54.24805%


Training [22/100]: 100%|██████████| 2048/2048 [00:24<00:00, 82.63it/s]


Epoch [21/100] train loss: 0.69211, test acc: 54.54102%


Training [23/100]: 100%|██████████| 2048/2048 [00:24<00:00, 82.99it/s]


Epoch [22/100] train loss: 0.69204, test acc: 54.85840%


Training [24/100]: 100%|██████████| 2048/2048 [00:24<00:00, 81.98it/s]


Epoch [23/100] train loss: 0.69198, test acc: 55.34668%


Training [25/100]: 100%|██████████| 2048/2048 [00:24<00:00, 84.27it/s] 


Epoch [24/100] train loss: 0.69192, test acc: 55.51758%


Training [26/100]: 100%|██████████| 2048/2048 [00:24<00:00, 83.73it/s] 


Epoch [25/100] train loss: 0.69186, test acc: 55.83496%


Training [27/100]: 100%|██████████| 2048/2048 [00:24<00:00, 84.54it/s] 


Epoch [26/100] train loss: 0.69179, test acc: 56.51855%


Training [28/100]: 100%|██████████| 2048/2048 [00:23<00:00, 86.42it/s] 


Epoch [27/100] train loss: 0.69173, test acc: 56.56738%


Training [29/100]: 100%|██████████| 2048/2048 [00:23<00:00, 87.13it/s] 


Epoch [28/100] train loss: 0.69166, test acc: 56.78711%


Training [30/100]: 100%|██████████| 2048/2048 [00:24<00:00, 84.30it/s] 


Epoch [29/100] train loss: 0.69159, test acc: 57.15332%


Training [31/100]: 100%|██████████| 2048/2048 [00:24<00:00, 85.15it/s] 


Epoch [30/100] train loss: 0.69152, test acc: 57.34863%


Training [32/100]: 100%|██████████| 2048/2048 [00:23<00:00, 86.02it/s] 


Epoch [31/100] train loss: 0.69145, test acc: 57.69043%


Training [33/100]: 100%|██████████| 2048/2048 [00:23<00:00, 86.40it/s] 


Epoch [32/100] train loss: 0.69138, test acc: 57.95898%


Training [34/100]: 100%|██████████| 2048/2048 [00:24<00:00, 84.52it/s] 


Epoch [33/100] train loss: 0.69130, test acc: 58.30078%


Training [35/100]: 100%|██████████| 2048/2048 [00:23<00:00, 88.21it/s] 


Epoch [34/100] train loss: 0.69122, test acc: 58.54492%


Training [36/100]: 100%|██████████| 2048/2048 [00:22<00:00, 89.05it/s] 


Epoch [35/100] train loss: 0.69114, test acc: 58.74023%


Training [37/100]: 100%|██████████| 2048/2048 [00:23<00:00, 86.99it/s] 


Epoch [36/100] train loss: 0.69105, test acc: 59.05762%


Training [38/100]: 100%|██████████| 2048/2048 [00:23<00:00, 85.92it/s] 


Epoch [37/100] train loss: 0.69096, test acc: 59.25293%


Training [39/100]: 100%|██████████| 2048/2048 [00:24<00:00, 85.20it/s]


Epoch [38/100] train loss: 0.69087, test acc: 59.35059%


Training [40/100]: 100%|██████████| 2048/2048 [00:23<00:00, 86.93it/s] 


Epoch [39/100] train loss: 0.69077, test acc: 59.42383%


Training [41/100]: 100%|██████████| 2048/2048 [00:23<00:00, 85.56it/s] 


Epoch [40/100] train loss: 0.69067, test acc: 59.54590%


Training [42/100]: 100%|██████████| 2048/2048 [00:23<00:00, 86.48it/s] 


Epoch [41/100] train loss: 0.69056, test acc: 59.57031%


Training [43/100]: 100%|██████████| 2048/2048 [00:22<00:00, 89.10it/s] 


Epoch [42/100] train loss: 0.69045, test acc: 59.57031%


Training [44/100]: 100%|██████████| 2048/2048 [00:22<00:00, 89.60it/s] 


Epoch [43/100] train loss: 0.69033, test acc: 59.54590%


Training [45/100]: 100%|██████████| 2048/2048 [00:23<00:00, 86.40it/s] 


Epoch [44/100] train loss: 0.69021, test acc: 59.69238%


Training [46/100]: 100%|██████████| 2048/2048 [00:23<00:00, 86.05it/s] 


Epoch [45/100] train loss: 0.69009, test acc: 59.64355%


Training [47/100]: 100%|██████████| 2048/2048 [00:24<00:00, 84.61it/s] 


Epoch [46/100] train loss: 0.68996, test acc: 59.49707%


Training [48/100]: 100%|██████████| 2048/2048 [00:23<00:00, 87.63it/s] 


Epoch [47/100] train loss: 0.68982, test acc: 59.91211%


Training [49/100]: 100%|██████████| 2048/2048 [00:22<00:00, 89.19it/s] 


Epoch [48/100] train loss: 0.68968, test acc: 59.93652%


Training [50/100]: 100%|██████████| 2048/2048 [00:23<00:00, 86.85it/s] 


Epoch [49/100] train loss: 0.68953, test acc: 59.74121%


Training [51/100]: 100%|██████████| 2048/2048 [00:23<00:00, 87.22it/s] 


Epoch [50/100] train loss: 0.68938, test acc: 59.64355%


Training [52/100]: 100%|██████████| 2048/2048 [00:23<00:00, 87.40it/s] 


Epoch [51/100] train loss: 0.68922, test acc: 59.91211%


Training [53/100]: 100%|██████████| 2048/2048 [00:24<00:00, 85.10it/s] 


Epoch [52/100] train loss: 0.68905, test acc: 60.10742%


Training [54/100]: 100%|██████████| 2048/2048 [00:24<00:00, 84.64it/s] 


Epoch [53/100] train loss: 0.68887, test acc: 60.13184%


Training [55/100]: 100%|██████████| 2048/2048 [00:24<00:00, 82.86it/s]


Epoch [54/100] train loss: 0.68868, test acc: 60.22949%


Training [56/100]: 100%|██████████| 2048/2048 [00:23<00:00, 86.01it/s] 


Epoch [55/100] train loss: 0.68849, test acc: 60.22949%


Training [57/100]: 100%|██████████| 2048/2048 [00:23<00:00, 88.61it/s] 


Epoch [56/100] train loss: 0.68828, test acc: 60.32715%


Training [58/100]: 100%|██████████| 2048/2048 [00:23<00:00, 86.32it/s] 


Epoch [57/100] train loss: 0.68807, test acc: 60.20508%


Training [59/100]: 100%|██████████| 2048/2048 [00:23<00:00, 87.20it/s] 


Epoch [58/100] train loss: 0.68784, test acc: 60.32715%


Training [60/100]: 100%|██████████| 2048/2048 [00:24<00:00, 83.49it/s] 


Epoch [59/100] train loss: 0.68761, test acc: 60.30273%


Training [61/100]: 100%|██████████| 2048/2048 [00:22<00:00, 89.34it/s] 


Epoch [60/100] train loss: 0.68736, test acc: 60.42480%


Training [62/100]: 100%|██████████| 2048/2048 [00:22<00:00, 89.76it/s] 


Epoch [61/100] train loss: 0.68711, test acc: 60.27832%


Training [63/100]: 100%|██████████| 2048/2048 [00:22<00:00, 90.36it/s] 


Epoch [62/100] train loss: 0.68684, test acc: 60.22949%


Training [64/100]: 100%|██████████| 2048/2048 [00:23<00:00, 87.31it/s] 


Epoch [63/100] train loss: 0.68656, test acc: 60.30273%


Training [65/100]: 100%|██████████| 2048/2048 [00:24<00:00, 84.92it/s] 


Epoch [64/100] train loss: 0.68627, test acc: 60.30273%


Training [66/100]: 100%|██████████| 2048/2048 [00:23<00:00, 87.30it/s] 


Epoch [65/100] train loss: 0.68596, test acc: 60.47363%


Training [67/100]: 100%|██████████| 2048/2048 [00:23<00:00, 86.16it/s] 


Epoch [66/100] train loss: 0.68564, test acc: 60.47363%


Training [68/100]: 100%|██████████| 2048/2048 [00:23<00:00, 87.62it/s] 


Epoch [67/100] train loss: 0.68529, test acc: 60.40039%


Training [69/100]: 100%|██████████| 2048/2048 [00:24<00:00, 84.94it/s] 


Epoch [68/100] train loss: 0.68492, test acc: 60.40039%


Training [70/100]: 100%|██████████| 2048/2048 [00:24<00:00, 83.47it/s] 


Epoch [69/100] train loss: 0.68452, test acc: 60.49805%


Training [71/100]: 100%|██████████| 2048/2048 [00:23<00:00, 87.25it/s] 


Epoch [70/100] train loss: 0.68410, test acc: 60.44922%


Training [72/100]: 100%|██████████| 2048/2048 [00:23<00:00, 86.46it/s] 


Epoch [71/100] train loss: 0.68366, test acc: 60.47363%


Training [73/100]: 100%|██████████| 2048/2048 [00:24<00:00, 84.62it/s] 


Epoch [72/100] train loss: 0.68318, test acc: 60.64453%


Training [74/100]: 100%|██████████| 2048/2048 [00:24<00:00, 82.27it/s] 


Epoch [73/100] train loss: 0.68268, test acc: 60.69336%


Training [75/100]: 100%|██████████| 2048/2048 [00:24<00:00, 84.85it/s] 


Epoch [74/100] train loss: 0.68215, test acc: 60.83984%


Training [76/100]: 100%|██████████| 2048/2048 [00:23<00:00, 86.63it/s] 


Epoch [75/100] train loss: 0.68160, test acc: 61.13281%


Training [77/100]: 100%|██████████| 2048/2048 [00:23<00:00, 86.45it/s] 


Epoch [76/100] train loss: 0.68102, test acc: 61.49902%


Training [78/100]: 100%|██████████| 2048/2048 [00:23<00:00, 85.63it/s] 


Epoch [77/100] train loss: 0.68040, test acc: 62.01172%


Training [79/100]: 100%|██████████| 2048/2048 [00:24<00:00, 84.63it/s] 


Epoch [78/100] train loss: 0.67975, test acc: 62.50000%


Training [80/100]: 100%|██████████| 2048/2048 [00:23<00:00, 87.79it/s] 


Epoch [79/100] train loss: 0.67907, test acc: 62.74414%


Training [81/100]: 100%|██████████| 2048/2048 [00:23<00:00, 88.31it/s] 


Epoch [80/100] train loss: 0.67836, test acc: 63.08594%


Training [82/100]: 100%|██████████| 2048/2048 [00:23<00:00, 87.89it/s] 


Epoch [81/100] train loss: 0.67761, test acc: 63.52539%


Training [83/100]: 100%|██████████| 2048/2048 [00:23<00:00, 87.48it/s] 


Epoch [82/100] train loss: 0.67682, test acc: 63.91602%


Training [84/100]: 100%|██████████| 2048/2048 [00:23<00:00, 87.19it/s] 


Epoch [83/100] train loss: 0.67599, test acc: 64.50195%


Training [85/100]: 100%|██████████| 2048/2048 [00:23<00:00, 87.56it/s] 


Epoch [84/100] train loss: 0.67512, test acc: 65.01465%


Training [86/100]: 100%|██████████| 2048/2048 [00:23<00:00, 87.89it/s] 


Epoch [85/100] train loss: 0.67420, test acc: 65.52734%


Training [87/100]: 100%|██████████| 2048/2048 [00:23<00:00, 88.40it/s] 


Epoch [86/100] train loss: 0.67322, test acc: 65.82031%


Training [88/100]: 100%|██████████| 2048/2048 [00:23<00:00, 87.72it/s] 


Epoch [87/100] train loss: 0.67220, test acc: 66.16211%


Training [89/100]: 100%|██████████| 2048/2048 [00:23<00:00, 88.38it/s] 


Epoch [88/100] train loss: 0.67112, test acc: 66.43066%


Training [90/100]: 100%|██████████| 2048/2048 [00:22<00:00, 89.13it/s] 


Epoch [89/100] train loss: 0.67000, test acc: 66.45508%


Training [91/100]: 100%|██████████| 2048/2048 [00:24<00:00, 84.57it/s] 


Epoch [90/100] train loss: 0.66882, test acc: 66.89453%


Training [92/100]: 100%|██████████| 2048/2048 [00:24<00:00, 83.99it/s] 


Epoch [91/100] train loss: 0.66759, test acc: 66.87012%


Training [93/100]: 100%|██████████| 2048/2048 [00:24<00:00, 84.78it/s] 


Epoch [92/100] train loss: 0.66632, test acc: 66.99219%


Training [94/100]: 100%|██████████| 2048/2048 [00:23<00:00, 86.32it/s] 


Epoch [93/100] train loss: 0.66502, test acc: 67.11426%


Training [95/100]: 100%|██████████| 2048/2048 [00:23<00:00, 88.16it/s] 


Epoch [94/100] train loss: 0.66369, test acc: 67.28516%


Training [96/100]: 100%|██████████| 2048/2048 [00:23<00:00, 86.34it/s] 


Epoch [95/100] train loss: 0.66234, test acc: 67.50488%


Training [97/100]: 100%|██████████| 2048/2048 [00:23<00:00, 87.79it/s] 


Epoch [96/100] train loss: 0.66095, test acc: 67.62695%


Training [98/100]: 100%|██████████| 2048/2048 [00:23<00:00, 87.37it/s] 


Epoch [97/100] train loss: 0.65954, test acc: 67.57812%


Training [99/100]: 100%|██████████| 2048/2048 [00:23<00:00, 87.45it/s] 


Epoch [98/100] train loss: 0.65809, test acc: 67.91992%


Training [100/100]: 100%|██████████| 2048/2048 [00:23<00:00, 88.18it/s] 

Epoch [99/100] train loss: 0.65662, test acc: 67.96875%





In [15]:
if SAVE_MODEL:
    torch.save(model.state_dict(), MODEL_PATH)
    with open(MODEL_TRAIN_DETAILS_PATH, "wb") as f:
        pickle.dump({"train_losses": train_losses, "test_losses": test_losses, "test_accuracies": test_accuracies}, f)

In [24]:
def plot_loss(train_losses, test_losses):
    epochs = list(range(1, len(train_losses) + 1))

    # Create a line plot for train loss
    train_trace = go.Scatter(
        x=epochs,
        y=train_losses,
        mode='lines',
        name='Train Loss'
    )

    # Create a line plot for test loss
    test_trace = go.Scatter(
        x=epochs,
        y=test_losses,
        mode='lines',
        name='Test Loss'
    )

    # Create the layout for the plot
    layout = go.Layout(
        title='Train and Test Loss over Epochs',
        xaxis=dict(title='Epoch'),
        yaxis=dict(title='Loss'),
        showlegend=True
    )

    # Combine the train and test traces into a data list
    data = [train_trace, test_trace]

    # Create the figure and display the plot
    fig = go.Figure(data=data, layout=layout)
    fig.show()

plot_loss(train_losses, test_losses)

In [25]:
fig = px.line(x=range(1, len(test_accuracies) + 1), y=test_accuracies, title="Test Accuracy over Epochs")
fig.update_layout(xaxis_title="Epoch", yaxis_title="Accuracy", yaxis_tickformat="0.0%")

In [26]:
with torch.no_grad():

    total_train_loss = 0.0
    total_train_accuracy = 0.0
    
    for batch in tqdm(train_dataloader):
        batch_x, batch_y = batch
        batch_x = batch_x.flatten(start_dim=1)
        batch_x = batch_x.to(device)
        batch_y = batch_y.to(device)

        output = model(batch_x)

        pred = torch.argmax(output, dim=1)
        total_train_accuracy += torch.sum(pred == batch_y.long()).item() / len(pred)

    avg_train_accuracy = total_train_accuracy / len(train_dataloader)

print(f"Train accuracy: {avg_train_accuracy:0.5%}")

100%|██████████| 2048/2048 [00:20<00:00, 99.07it/s] 

Train accuracy: 68.41300%



