In [16]:
# train a 3 layer MLP on mnist
import torch
import torch.nn as nn
import torchvision 
from torchvision import transforms
import os
import numpy as np 
import tqdm

model = nn.Sequential(
    nn.Linear(784, 150, bias=False),
    nn.ReLU(),
    nn.Linear(150, 200, bias=False),
    nn.ReLU(),
    nn.Linear(200, 10, bias=False),
)

train_set = torchvision.datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
        torch.flatten
    ])
)
test_set = torchvision.datasets.MNIST(
    root='./data',
    train=False,
    download=True,
    transform=transforms.Compose([
        transforms.
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
        torch.flatten
    ])
)

train_loader = torch.utils.data.DataLoader(
    dataset=train_set,
    batch_size=128,
    shuffle=True
)
test_loader = torch.utils.data.DataLoader(
    dataset=test_set,
    batch_size=128,
    shuffle=True
)

loss_fn = nn.CrossEntropyLoss()
optim = torch.optim.SGD(model.parameters(), lr=0.1)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optim,
    mode='max',
    factor=0.1,
    patience=2,
    min_lr = 1e-4
)


epochs = 12
eval_every = 100 # batches 

best_model = None

with torch.autocast("cuda"):
    for epoch in tqdm.tqdm(range(epochs)):
        
        train_loss = 0
        eval_loss = 0
        train_acc = 0
        eval_acc = 0
        best_acc = 0
        
        # train
        for i, (x, y) in enumerate(train_loader):
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            optim.zero_grad()
            loss.backward()
            optim.step()
            
            acc = (y_pred.argmax(dim=1) == y).float().mean()
            train_acc += acc 
            train_loss += loss.item()
            
            # eval
            if i % eval_every == 0:
                for x, y in test_loader:
                    model.eval()
                    y_pred = model(x)
                    loss = loss_fn(y_pred, y)
                    
                    eval_loss += loss.item()
                    acc = (y_pred.argmax(dim=1) == y).float().mean()
                    eval_acc += acc    
                    
                print(f"Epoch {epoch}: train acc: {train_acc / eval_every}, train loss: {train_loss / eval_every} \
                        , eval acc: {eval_acc / len(test_loader)}, eval loss: {eval_loss / len(test_loader)}\
                        lr: {optim.param_groups[0]['lr']}"
                    )
                
                if eval_acc / len(test_loader) > best_acc:
                    best_acc = eval_acc / len(test_loader)
                    best_model = model.state_dict()
                    
                train_loss = 0
                eval_loss = 0
                train_acc = 0
                eval_acc = 0
        
        scheduler.step(eval_acc / len(test_loader))

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch 0: train acc: 0.0010937500046566129, train loss: 0.023133807182312012, eval acc: 0.12480221688747406, eval loss: 2.2896044616457782
Epoch 0: train acc: 0.7679687738418579, train loss: 0.8164816942811012, eval acc: 0.8830102682113647, eval loss: 0.3744132918647573
Epoch 0: train acc: 0.8978906273841858, train loss: 0.33041078358888626, eval acc: 0.9067444801330566, eval loss: 0.2938633547930778
Epoch 0: train acc: 0.924609363079071, train loss: 0.2633488894253969, eval acc: 0.934335470199585, eval loss: 0.22176922697432433
Epoch 0: train acc: 0.9344531297683716, train loss: 0.2194523583352566, eval acc: 0.9376977682113647, eval loss: 0.19800268471995486


 10%|█         | 1/10 [00:17<02:36, 17.43s/it]

Epoch 1: train acc: 0.009140624664723873, train loss: 0.0026094257831573487, eval acc: 0.9476858973503113, eval loss: 0.17335808729823632
Epoch 1: train acc: 0.9498437643051147, train loss: 0.1726352559775114, eval acc: 0.9516416192054749, eval loss: 0.1560080115317921
Epoch 1: train acc: 0.9547656178474426, train loss: 0.15803983636200428, eval acc: 0.9575751423835754, eval loss: 0.13968276142885414
Epoch 1: train acc: 0.9541406035423279, train loss: 0.14909979697316886, eval acc: 0.9599485993385315, eval loss: 0.13456800778078126
Epoch 1: train acc: 0.9606249928474426, train loss: 0.1312691617757082, eval acc: 0.9629153609275818, eval loss: 0.12041272510644756


 20%|██        | 2/10 [00:34<02:16, 17.06s/it]

Epoch 2: train acc: 0.009687500074505806, train loss: 0.0010434401035308838, eval acc: 0.9624208807945251, eval loss: 0.11770394779270209
Epoch 2: train acc: 0.9690625071525574, train loss: 0.1095591752231121, eval acc: 0.965585470199585, eval loss: 0.11295794030722184
Epoch 2: train acc: 0.9684374928474426, train loss: 0.10447842994704842, eval acc: 0.968156635761261, eval loss: 0.10091136004539984
Epoch 2: train acc: 0.9696093797683716, train loss: 0.10024900089949369, eval acc: 0.9679588675498962, eval loss: 0.10323356535238555
Epoch 2: train acc: 0.9735937714576721, train loss: 0.08900752695277334, eval acc: 0.9711233973503113, eval loss: 0.09280011813521763


 30%|███       | 3/10 [00:50<01:57, 16.86s/it]

Epoch 3: train acc: 0.009843749925494194, train loss: 0.00033551327884197236, eval acc: 0.970530092716217, eval loss: 0.09129477308803721
Epoch 3: train acc: 0.9781249761581421, train loss: 0.07360215123742819, eval acc: 0.9714201092720032, eval loss: 0.0893281562772544
Epoch 3: train acc: 0.9783594012260437, train loss: 0.07638222517445684, eval acc: 0.9732001423835754, eval loss: 0.08386152166920373
Epoch 3: train acc: 0.9775000214576721, train loss: 0.07783071626909077, eval acc: 0.9734968543052673, eval loss: 0.0818927831994959
Epoch 3: train acc: 0.9778125286102295, train loss: 0.07336723348125815, eval acc: 0.9748813509941101, eval loss: 0.0800139919583556


 40%|████      | 4/10 [01:07<01:40, 16.72s/it]

Epoch 4: train acc: 0.009843749925494194, train loss: 0.000767785981297493, eval acc: 0.9751780033111572, eval loss: 0.07558449819872651
Epoch 4: train acc: 0.985546886920929, train loss: 0.05273984972387552, eval acc: 0.9757713675498962, eval loss: 0.0756629487806106
Epoch 4: train acc: 0.9872656464576721, train loss: 0.047638762630522254, eval acc: 0.9769580960273743, eval loss: 0.07259726532899975
Epoch 4: train acc: 0.9862499833106995, train loss: 0.04978282698430121, eval acc: 0.977155864238739, eval loss: 0.07351518735949751
Epoch 4: train acc: 0.9861719012260437, train loss: 0.05130794600583613, eval acc: 0.977155864238739, eval loss: 0.07163156507701814


 50%|█████     | 5/10 [01:24<01:23, 16.73s/it]

Epoch 5: train acc: 0.009999999776482582, train loss: 0.00029629159718751907, eval acc: 0.9780458807945251, eval loss: 0.07078435880285275
Epoch 5: train acc: 0.9881250262260437, train loss: 0.04700637098401785, eval acc: 0.9773536324501038, eval loss: 0.07166017134544216
Epoch 5: train acc: 0.9878125190734863, train loss: 0.04699056256562471, eval acc: 0.977749228477478, eval loss: 0.07034215670597704
Epoch 5: train acc: 0.9862499833106995, train loss: 0.05018775161355734, eval acc: 0.9775514006614685, eval loss: 0.0711555746649215
Epoch 5: train acc: 0.9877343773841858, train loss: 0.04695337207987905, eval acc: 0.9776503443717957, eval loss: 0.07049027455475511


 60%|██████    | 6/10 [01:41<01:07, 16.99s/it]

Epoch 6: train acc: 0.009921874850988388, train loss: 0.0005395156145095825, eval acc: 0.9780458807945251, eval loss: 0.0710607511050339
Epoch 6: train acc: 0.987109363079071, train loss: 0.04560087102930993, eval acc: 0.9780458807945251, eval loss: 0.07058649712914153
Epoch 6: train acc: 0.9884374737739563, train loss: 0.04700321012176573, eval acc: 0.9780458807945251, eval loss: 0.06993568411592065
Epoch 6: train acc: 0.9864062666893005, train loss: 0.05009493986144662, eval acc: 0.977749228477478, eval loss: 0.06944180554651384
Epoch 6: train acc: 0.9889062643051147, train loss: 0.04195830235723406, eval acc: 0.9780458807945251, eval loss: 0.0696859419982456


 70%|███████   | 7/10 [01:58<00:50, 16.87s/it]

Epoch 7: train acc: 0.009999999776482582, train loss: 0.00023722447454929351, eval acc: 0.977749228477478, eval loss: 0.06988710895935192
Epoch 7: train acc: 0.9900000095367432, train loss: 0.03971543093211949, eval acc: 0.9781447649002075, eval loss: 0.07020035251692126
Epoch 7: train acc: 0.9883593916893005, train loss: 0.044605986243113876, eval acc: 0.9780458807945251, eval loss: 0.06956512378410826
Epoch 7: train acc: 0.9903905987739563, train loss: 0.04045677992515266, eval acc: 0.9773536324501038, eval loss: 0.07364676179529368
Epoch 7: train acc: 0.9868749976158142, train loss: 0.0492591520678252, eval acc: 0.9770569801330566, eval loss: 0.07213467821667466


 80%|████████  | 8/10 [02:15<00:33, 16.88s/it]

Epoch 8: train acc: 0.009921874850988388, train loss: 0.0004186992347240448, eval acc: 0.9784414768218994, eval loss: 0.06945385421107558
Epoch 8: train acc: 0.9878125190734863, train loss: 0.046494239810854196, eval acc: 0.978342592716217, eval loss: 0.0695319443092316
Epoch 8: train acc: 0.9884374737739563, train loss: 0.042786774151027204, eval acc: 0.9781447649002075, eval loss: 0.06943445870228394
Epoch 8: train acc: 0.9891406297683716, train loss: 0.04184523501433432, eval acc: 0.9782436490058899, eval loss: 0.06955708134212071
Epoch 8: train acc: 0.9884374737739563, train loss: 0.043266087006777525, eval acc: 0.9769580960273743, eval loss: 0.07128854925873913


 90%|█████████ | 9/10 [02:32<00:16, 16.90s/it]

Epoch 9: train acc: 0.009921874850988388, train loss: 0.00022818226367235183, eval acc: 0.9784414768218994, eval loss: 0.06999730705460415
Epoch 9: train acc: 0.9885937571525574, train loss: 0.044662206079810855, eval acc: 0.9782436490058899, eval loss: 0.06939372750547348
Epoch 9: train acc: 0.9897656440734863, train loss: 0.044999213563278316, eval acc: 0.9785403609275818, eval loss: 0.06980168135671676
Epoch 9: train acc: 0.9882031083106995, train loss: 0.04265894245356321, eval acc: 0.9784414768218994, eval loss: 0.06950056836999292
Epoch 9: train acc: 0.9880468845367432, train loss: 0.04303899850696325, eval acc: 0.977749228477478, eval loss: 0.07096755187486924


100%|██████████| 10/10 [03:55<00:00, 23.52s/it]


In [28]:
# save the first 5000 images in the test set as numpy array
test_images = []
test_labels = []
for i, (x, y) in enumerate(test_loader):
    test_images.append(x)
    test_labels.append(y)
    if i == 5000:
        break
test_images = torch.cat(test_images).numpy()
test_labels = torch.cat(test_labels).numpy()
np.save("./data/test_images.npy", test_images)
np.save("./data/test_labels.npy", test_labels)

# save model weights as numpy array
for name, param in best_model.items():
    np.save(f"./data/linear_{name}.npy", param.T.detach().cpu().numpy()) # W @ X -> X @ W.T
    
np.load("./data/linear_0.weight.npy").shape


(784, 150)