In [1]:
# train a 3 layer MLP on mnist
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "5"
import torch
import torch.nn as nn
import torchvision 
from torchvision import transforms
import numpy as np 
import tqdm

model = nn.Sequential(
    nn.Linear(784, 150, bias=False),
    nn.ReLU(),
    nn.Linear(150, 200, bias=False),
    nn.ReLU(),
    nn.Linear(200, 10, bias=False),
)

if torch.__version__.startswith("2"):
    print("Using torch 2.0 compile")
    torch.compile(model)
    
train_set = torchvision.datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
        torch.flatten
    ])
)
test_set = torchvision.datasets.MNIST(
    root='./data',
    train=False,
    download=True,
    transform=transforms.Compose([
        transforms.
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
        torch.flatten
    ])
)

train_loader = torch.utils.data.DataLoader(
    dataset=train_set,
    batch_size=128,
    shuffle=True
)
test_loader = torch.utils.data.DataLoader(
    dataset=test_set,
    batch_size=128,
    shuffle=True
)

loss_fn = nn.CrossEntropyLoss()
optim = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optim,
    mode='max',
    factor=0.1,
    patience=2,
    min_lr = 1e-4
)


epochs = 12
eval_every = 100 # batches 

best_model = None

with torch.autocast("cuda"):
    for epoch in tqdm.tqdm(range(epochs)):
        
        train_loss = 0
        eval_loss = 0
        train_acc = 0
        eval_acc = 0
        best_acc = 0
        
        # train
        for i, (x, y) in enumerate(train_loader):
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            optim.zero_grad()
            loss.backward()
            optim.step()
            
            acc = (y_pred.argmax(dim=1) == y).float().mean()
            train_acc += acc 
            train_loss += loss.item()
            
            # eval
            if i % eval_every == 0:
                for x, y in test_loader:
                    model.eval()
                    y_pred = model(x)
                    loss = loss_fn(y_pred, y)
                    
                    eval_loss += loss.item()
                    acc = (y_pred.argmax(dim=1) == y).float().mean()
                    eval_acc += acc    
                    
                print(f"Epoch {epoch}: train acc: {train_acc / eval_every}, train loss: {train_loss / eval_every},"
                    f" eval acc: {eval_acc / len(test_loader)}, eval loss: {eval_loss / len(test_loader)},"
                    f" lr: {optim.param_groups[0]['lr']}")

                
                if eval_acc / len(test_loader) > best_acc:
                    best_acc = eval_acc / len(test_loader)
                    best_model = model.state_dict()
                    
                train_loss = 0
                eval_loss = 0
                train_acc = 0
                eval_acc = 0
        
        scheduler.step(eval_acc / len(test_loader))

Using torch 2.0 compile


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch 0: train acc: 0.0006249999860301614, train loss: 0.023246397972106935, eval acc: 0.10146360844373703, eval loss: 2.301863613007944, lr: 0.1
Epoch 0: train acc: 0.8145312666893005, train loss: 0.6028286927938461, eval acc: 0.9208860993385315, eval loss: 0.2596525170757801, lr: 0.1
Epoch 0: train acc: 0.9255468845367432, train loss: 0.24776186376810075, eval acc: 0.9344343543052673, eval loss: 0.20644960184640523, lr: 0.1
Epoch 0: train acc: 0.9366406202316284, train loss: 0.20699586383998395, eval acc: 0.9401701092720032, eval loss: 0.19996813892186444, lr: 0.1
Epoch 0: train acc: 0.9435155987739563, train loss: 0.18499109778553247, eval acc: 0.9466969966888428, eval loss: 0.1583833569967294, lr: 0.1


  8%|▊         | 1/12 [00:20<03:50, 20.95s/it]

Epoch 1: train acc: 0.009765625, train loss: 0.0006559369713068008, eval acc: 0.9614319801330566, eval loss: 0.12200899863073343, lr: 0.1
Epoch 1: train acc: 0.9572656154632568, train loss: 0.14233654592186212, eval acc: 0.9590585231781006, eval loss: 0.1340248633789111, lr: 0.1
Epoch 1: train acc: 0.9573437571525574, train loss: 0.14094047885388136, eval acc: 0.962124228477478, eval loss: 0.12388285706880726, lr: 0.1
Epoch 1: train acc: 0.9614843726158142, train loss: 0.12964561261236668, eval acc: 0.9670688509941101, eval loss: 0.10503417272356493, lr: 0.1
Epoch 1: train acc: 0.9625781178474426, train loss: 0.12707419376820325, eval acc: 0.962124228477478, eval loss: 0.1154981301743773, lr: 0.1


 17%|█▋        | 2/12 [00:41<03:28, 20.89s/it]

Epoch 2: train acc: 0.009687500074505806, train loss: 0.001050637736916542, eval acc: 0.9689477682113647, eval loss: 0.10096564399567727, lr: 0.1
Epoch 2: train acc: 0.9700000286102295, train loss: 0.10002584071829915, eval acc: 0.9695411324501038, eval loss: 0.09808941899881332, lr: 0.1
Epoch 2: train acc: 0.9685937762260437, train loss: 0.108954002186656, eval acc: 0.9650909900665283, eval loss: 0.11934234425897085, lr: 0.1
Epoch 2: train acc: 0.9642968773841858, train loss: 0.12090791739523411, eval acc: 0.9658821225166321, eval loss: 0.11060461432590515, lr: 0.1
Epoch 2: train acc: 0.9672656059265137, train loss: 0.11016163788735867, eval acc: 0.9698378443717957, eval loss: 0.09561656171432402, lr: 0.1


 25%|██▌       | 3/12 [01:02<03:07, 20.83s/it]

Epoch 3: train acc: 0.009843749925494194, train loss: 0.0005461175739765167, eval acc: 0.9726067781448364, eval loss: 0.09149287763652922, lr: 0.1
Epoch 3: train acc: 0.9736718535423279, train loss: 0.08497834936715662, eval acc: 0.972507894039154, eval loss: 0.08745302523256433, lr: 0.1
Epoch 3: train acc: 0.9702343940734863, train loss: 0.10156306691467762, eval acc: 0.9757713675498962, eval loss: 0.08615235306464042, lr: 0.1
Epoch 3: train acc: 0.9747655987739563, train loss: 0.08392089699395001, eval acc: 0.9692444801330566, eval loss: 0.11262341025668562, lr: 0.1
Epoch 3: train acc: 0.9722656011581421, train loss: 0.09245081733912229, eval acc: 0.9746835231781006, eval loss: 0.08653592999694468, lr: 0.1


 33%|███▎      | 4/12 [01:23<02:45, 20.74s/it]

Epoch 4: train acc: 0.009921874850988388, train loss: 0.00025272699072957037, eval acc: 0.9736946225166321, eval loss: 0.09162282367635524, lr: 0.010000000000000002
Epoch 4: train acc: 0.9809374809265137, train loss: 0.059943809346295895, eval acc: 0.9789358973503113, eval loss: 0.06937654392463685, lr: 0.010000000000000002
Epoch 4: train acc: 0.9848437309265137, train loss: 0.04672667453764007, eval acc: 0.9811115264892578, eval loss: 0.06493434265318188, lr: 0.010000000000000002
Epoch 4: train acc: 0.985156238079071, train loss: 0.0445464039593935, eval acc: 0.9798259735107422, eval loss: 0.06380936931347168, lr: 0.010000000000000002
Epoch 4: train acc: 0.987500011920929, train loss: 0.04000535752158612, eval acc: 0.9818037748336792, eval loss: 0.0602712632242378, lr: 0.010000000000000002


 42%|████▏     | 5/12 [01:43<02:24, 20.67s/it]

Epoch 5: train acc: 0.009921874850988388, train loss: 0.00020769592374563216, eval acc: 0.9819027185440063, eval loss: 0.05902310645891518, lr: 0.010000000000000002
Epoch 5: train acc: 0.9874218702316284, train loss: 0.040622996229212734, eval acc: 0.9826938509941101, eval loss: 0.05862420238554478, lr: 0.010000000000000002
Epoch 5: train acc: 0.9888281226158142, train loss: 0.03430558605352416, eval acc: 0.9822982549667358, eval loss: 0.05822027453421792, lr: 0.010000000000000002
Epoch 5: train acc: 0.9885156154632568, train loss: 0.0359953520889394, eval acc: 0.9817048907279968, eval loss: 0.05872260904694094, lr: 0.010000000000000002
Epoch 5: train acc: 0.9888281226158142, train loss: 0.033574492561165244, eval acc: 0.9832871556282043, eval loss: 0.05630905228325083, lr: 0.010000000000000002


 50%|█████     | 6/12 [02:04<02:03, 20.62s/it]

Epoch 6: train acc: 0.009765625, train loss: 0.0004800581932067871, eval acc: 0.9824960231781006, eval loss: 0.057481262016499154, lr: 0.010000000000000002
Epoch 6: train acc: 0.9892968535423279, train loss: 0.03184480210999027, eval acc: 0.9835838675498962, eval loss: 0.05543783839638758, lr: 0.010000000000000002
Epoch 6: train acc: 0.9899218678474426, train loss: 0.02917545793345198, eval acc: 0.9834849834442139, eval loss: 0.05662036053040596, lr: 0.010000000000000002
Epoch 6: train acc: 0.9886718988418579, train loss: 0.03481555435690097, eval acc: 0.9841772317886353, eval loss: 0.05548799330592627, lr: 0.010000000000000002
Epoch 6: train acc: 0.989062488079071, train loss: 0.034338021259754894, eval acc: 0.9838805198669434, eval loss: 0.05489967645829708, lr: 0.010000000000000002


 58%|█████▊    | 7/12 [02:24<01:43, 20.65s/it]

Epoch 7: train acc: 0.009765625, train loss: 0.000679905116558075, eval acc: 0.9833860993385315, eval loss: 0.05469738915960505, lr: 0.0010000000000000002
Epoch 7: train acc: 0.9896093606948853, train loss: 0.0319267761381343, eval acc: 0.9833860993385315, eval loss: 0.054860241581461855, lr: 0.0010000000000000002
Epoch 7: train acc: 0.9915624856948853, train loss: 0.02667244266718626, eval acc: 0.9835838675498962, eval loss: 0.05434909136302746, lr: 0.0010000000000000002
Epoch 7: train acc: 0.9917187690734863, train loss: 0.027622374065686017, eval acc: 0.9827927350997925, eval loss: 0.05622238102899511, lr: 0.0010000000000000002
Epoch 7: train acc: 0.9916406273841858, train loss: 0.0283700835716445, eval acc: 0.9835838675498962, eval loss: 0.054027486160762986, lr: 0.0010000000000000002


 67%|██████▋   | 8/12 [02:45<01:22, 20.63s/it]

Epoch 8: train acc: 0.009921874850988388, train loss: 0.00015269946306943892, eval acc: 0.983781635761261, eval loss: 0.05420899158343673, lr: 0.0010000000000000002
Epoch 8: train acc: 0.9915624856948853, train loss: 0.0279787433182355, eval acc: 0.9836827516555786, eval loss: 0.05399251416549558, lr: 0.0010000000000000002
Epoch 8: train acc: 0.9908593893051147, train loss: 0.02955838860711083, eval acc: 0.983781635761261, eval loss: 0.053859674214470916, lr: 0.0010000000000000002
Epoch 8: train acc: 0.9900781512260437, train loss: 0.02974007256794721, eval acc: 0.9832871556282043, eval loss: 0.05593558600985834, lr: 0.0010000000000000002
Epoch 8: train acc: 0.9906250238418579, train loss: 0.02831370714586228, eval acc: 0.983188271522522, eval loss: 0.06042297940135379, lr: 0.0010000000000000002


 75%|███████▌  | 9/12 [03:06<01:01, 20.63s/it]

Epoch 9: train acc: 0.009999999776482582, train loss: 4.71358560025692e-05, eval acc: 0.9823971390724182, eval loss: 0.05668435206848987, lr: 0.0010000000000000002
Epoch 9: train acc: 0.9901562333106995, train loss: 0.030069835144095124, eval acc: 0.9834849834442139, eval loss: 0.05355700517321577, lr: 0.0010000000000000002
Epoch 9: train acc: 0.9919531345367432, train loss: 0.026737412011716515, eval acc: 0.9830893874168396, eval loss: 0.05812597695938489, lr: 0.0010000000000000002
Epoch 9: train acc: 0.9907812476158142, train loss: 0.02682540925685316, eval acc: 0.9840783476829529, eval loss: 0.053480812605662724, lr: 0.0010000000000000002
Epoch 9: train acc: 0.9917187690734863, train loss: 0.02680053055402823, eval acc: 0.9840783476829529, eval loss: 0.05341307197909661, lr: 0.0010000000000000002


 83%|████████▎ | 10/12 [03:26<00:41, 20.69s/it]

Epoch 10: train acc: 0.009921874850988388, train loss: 0.00045234546065330505, eval acc: 0.9832871556282043, eval loss: 0.0593593227264436, lr: 0.00010000000000000003
Epoch 10: train acc: 0.9917187690734863, train loss: 0.027876479031983762, eval acc: 0.9839794039726257, eval loss: 0.053309094955440305, lr: 0.00010000000000000003
Epoch 10: train acc: 0.9915624856948853, train loss: 0.02765995785361156, eval acc: 0.983188271522522, eval loss: 0.059383448561111205, lr: 0.00010000000000000003
Epoch 10: train acc: 0.9924218654632568, train loss: 0.025830444670282303, eval acc: 0.9840783476829529, eval loss: 0.053318472414077084, lr: 0.00010000000000000003
Epoch 10: train acc: 0.9906250238418579, train loss: 0.029232904515229165, eval acc: 0.9832871556282043, eval loss: 0.054444220718703695, lr: 0.00010000000000000003


 92%|█████████▏| 11/12 [03:47<00:20, 20.69s/it]

Epoch 11: train acc: 0.009921874850988388, train loss: 0.00028110474348068236, eval acc: 0.9839794039726257, eval loss: 0.05329462136451414, lr: 0.00010000000000000003
Epoch 11: train acc: 0.9911718964576721, train loss: 0.029427686266135424, eval acc: 0.9839794039726257, eval loss: 0.053310558991211975, lr: 0.00010000000000000003
Epoch 11: train acc: 0.9925000071525574, train loss: 0.026688256522174925, eval acc: 0.9839794039726257, eval loss: 0.053309991943231445, lr: 0.00010000000000000003
Epoch 11: train acc: 0.9909374713897705, train loss: 0.029677310059778392, eval acc: 0.9832871556282043, eval loss: 0.05461381115549822, lr: 0.00010000000000000003
Epoch 11: train acc: 0.991406261920929, train loss: 0.026521230812650173, eval acc: 0.9832871556282043, eval loss: 0.05499958628759118, lr: 0.00010000000000000003


100%|██████████| 12/12 [04:08<00:00, 20.70s/it]


In [None]:
# save the first 5000 images in the test set as numpy array
test_images = []
test_labels = []
for i, (x, y) in enumerate(test_loader):
    test_images.append(x)
    test_labels.append(y)

test_images = torch.cat(test_images).numpy()
test_labels = torch.cat(test_labels).numpy()
np.save("./data/test_images.npy", test_images)
np.save("./data/test_labels.npy", test_labels)

# save model weights as numpy array
for idx, (name, param) in enumerate(best_model.items()):
    name = name.split(".")
    name[0] = str(idx)
    name = ".".join(name)
    np.save(f"./data/linear_{name}.npy", param.detach().cpu().numpy()) # W @ X -> X @ W.T, cnpy reads column-major
    


((784, 150), dtype('float32'))

In [None]:
param = np.load("./data/linear_0.weight.npy")
print("shape:", param.shape, ",", param.dtype, ", first 3 values (cupy): ", param[0, :3])

shape: (784, 150) , float32 , first 3 values:  [-0.02242141 -0.01722893  0.01018926]
