In [1]:
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "5"
import torch
import torch.nn as nn
import torchvision 
from torchvision import transforms
import numpy as np 
import tqdm

In [3]:
# train a 3 layer MLP on mnist
model = nn.Sequential(
    nn.Linear(784, 150, bias=False),
    nn.ReLU(),
    nn.Linear(150, 200, bias=False),
    nn.ReLU(),
    nn.Linear(200, 10, bias=False),
)

if torch.__version__.startswith("2"):
    print("Using torch 2.0 compile")
    torch.compile(model)
    
train_set = torchvision.datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
        torch.flatten
    ])
)
test_set = torchvision.datasets.MNIST(
    root='./data',
    train=False,
    download=True,
    transform=transforms.Compose([
        transforms.
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
        torch.flatten
    ])
)

train_loader = torch.utils.data.DataLoader(
    dataset=train_set,
    batch_size=128,
    shuffle=True
)
test_loader = torch.utils.data.DataLoader(
    dataset=test_set,
    batch_size=128,
    shuffle=True
)

loss_fn = nn.CrossEntropyLoss()
optim = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optim,
    mode='max',
    factor=0.1,
    patience=2,
    min_lr = 1e-4
)


epochs = 12
eval_every = 100 # batches 

best_model = None

with torch.autocast("cuda"):
    for epoch in tqdm.tqdm(range(epochs)):
        
        train_loss = 0
        eval_loss = 0
        train_acc = 0
        eval_acc = 0
        best_acc = 0
        
        # train
        for i, (x, y) in enumerate(train_loader):
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            optim.zero_grad()
            loss.backward()
            optim.step()
            
            acc = (y_pred.argmax(dim=1) == y).float().mean()
            train_acc += acc 
            train_loss += loss.item()
            
            # eval
            if i % eval_every == 0:
                for x, y in test_loader:
                    model.eval()
                    y_pred = model(x)
                    loss = loss_fn(y_pred, y)
                    
                    eval_loss += loss.item()
                    acc = (y_pred.argmax(dim=1) == y).float().mean()
                    eval_acc += acc    
                    
                print(f"Epoch {epoch}: train acc: {train_acc / eval_every}, train loss: {train_loss / eval_every},"
                    f" eval acc: {eval_acc / len(test_loader)}, eval loss: {eval_loss / len(test_loader)},"
                    f" lr: {optim.param_groups[0]['lr']}")

                
                if eval_acc / len(test_loader) > best_acc:
                    best_acc = eval_acc / len(test_loader)
                    best_model = model.state_dict()
                    
                train_loss = 0
                eval_loss = 0
                train_acc = 0
                eval_acc = 0
        
        scheduler.step(eval_acc / len(test_loader))

Using torch 2.0 compile


  0%|          | 0/12 [00:00<?, ?it/s]

Epoch 0: train acc: 0.0009374999790452421, train loss: 0.023020157814025877, eval acc: 0.13587816059589386, eval loss: 2.2883858409108995, lr: 0.1
Epoch 0: train acc: 0.8125, train loss: 0.610088667422533, eval acc: 0.9277096390724182, eval loss: 0.23557359234818928, lr: 0.1
Epoch 0: train acc: 0.9270312786102295, train loss: 0.24177031613886357, eval acc: 0.9427412748336792, eval loss: 0.18410793008117737, lr: 0.1
Epoch 0: train acc: 0.9390624761581421, train loss: 0.20490242075175047, eval acc: 0.9498615264892578, eval loss: 0.15951774462680274, lr: 0.1
Epoch 0: train acc: 0.9419531226158142, train loss: 0.1893035962432623, eval acc: 0.9521360993385315, eval loss: 0.157280340413504, lr: 0.1


  8%|▊         | 1/12 [00:21<03:56, 21.46s/it]

Epoch 1: train acc: 0.009687500074505806, train loss: 0.0009697346389293671, eval acc: 0.9570807218551636, eval loss: 0.13866330181024497, lr: 0.1
Epoch 1: train acc: 0.9591405987739563, train loss: 0.13834078039973974, eval acc: 0.9606408476829529, eval loss: 0.12890377475679676, lr: 0.1
Epoch 1: train acc: 0.9521874785423279, train loss: 0.1521232938952744, eval acc: 0.9642998576164246, eval loss: 0.12586349305472797, lr: 0.1
Epoch 1: train acc: 0.9555468559265137, train loss: 0.14240980783477425, eval acc: 0.9588607549667358, eval loss: 0.13827000522125465, lr: 0.1
Epoch 1: train acc: 0.9619531035423279, train loss: 0.12675292495638132, eval acc: 0.9672666192054749, eval loss: 0.11402801503250494, lr: 0.1


 17%|█▋        | 2/12 [00:43<03:36, 21.63s/it]

Epoch 2: train acc: 0.009062499739229679, train loss: 0.0018261298537254333, eval acc: 0.9676621556282043, eval loss: 0.11078253968418399, lr: 0.1
Epoch 2: train acc: 0.9681249856948853, train loss: 0.10654620604589582, eval acc: 0.9671677350997925, eval loss: 0.11460428218109699, lr: 0.1
Epoch 2: train acc: 0.9635937213897705, train loss: 0.11885077698156238, eval acc: 0.9690466523170471, eval loss: 0.11316778362268888, lr: 0.1
Epoch 2: train acc: 0.965624988079071, train loss: 0.11094953145831823, eval acc: 0.9691455960273743, eval loss: 0.10688534883570068, lr: 0.1
Epoch 2: train acc: 0.9684374928474426, train loss: 0.10919686485081911, eval acc: 0.9713212251663208, eval loss: 0.09527894179042004, lr: 0.1


 25%|██▌       | 3/12 [01:04<03:14, 21.58s/it]

Epoch 3: train acc: 0.009609375149011612, train loss: 0.001873568296432495, eval acc: 0.9698378443717957, eval loss: 0.10139738287352308, lr: 0.1
Epoch 3: train acc: 0.9689843654632568, train loss: 0.09640532900579274, eval acc: 0.9672666192054749, eval loss: 0.10881779979490026, lr: 0.1
Epoch 3: train acc: 0.9735937714576721, train loss: 0.09307361700572074, eval acc: 0.9708267450332642, eval loss: 0.0960476752134818, lr: 0.1
Epoch 3: train acc: 0.969921886920929, train loss: 0.09772212279029191, eval acc: 0.9711233973503113, eval loss: 0.09525389632186558, lr: 0.1
Epoch 3: train acc: 0.9697656035423279, train loss: 0.09760922577232123, eval acc: 0.9744857549667358, eval loss: 0.08542495347188102, lr: 0.1


 33%|███▎      | 4/12 [01:26<02:52, 21.51s/it]

Epoch 4: train acc: 0.009921874850988388, train loss: 0.00022985322400927543, eval acc: 0.9732001423835754, eval loss: 0.08694166765560078, lr: 0.010000000000000002
Epoch 4: train acc: 0.9803906083106995, train loss: 0.06124153076671064, eval acc: 0.9801226258277893, eval loss: 0.0633531309711405, lr: 0.010000000000000002
Epoch 4: train acc: 0.9840624928474426, train loss: 0.05311148347333074, eval acc: 0.980617105960846, eval loss: 0.061296733992197844, lr: 0.010000000000000002
Epoch 4: train acc: 0.9854687452316284, train loss: 0.04415240168105811, eval acc: 0.9829905033111572, eval loss: 0.05599070841422941, lr: 0.010000000000000002
Epoch 4: train acc: 0.9853906035423279, train loss: 0.04459624802693725, eval acc: 0.9828916192054749, eval loss: 0.055379907317648205, lr: 0.010000000000000002


 42%|████▏     | 5/12 [01:47<02:29, 21.38s/it]

Epoch 5: train acc: 0.009843749925494194, train loss: 0.00032492294907569884, eval acc: 0.9829905033111572, eval loss: 0.05499761088388748, lr: 0.010000000000000002
Epoch 5: train acc: 0.9888281226158142, train loss: 0.03701133350841701, eval acc: 0.982594907283783, eval loss: 0.05569509451595854, lr: 0.010000000000000002
Epoch 5: train acc: 0.9872656464576721, train loss: 0.04014173424337059, eval acc: 0.9832871556282043, eval loss: 0.05308365862510061, lr: 0.010000000000000002
Epoch 5: train acc: 0.9887499809265137, train loss: 0.0371159563260153, eval acc: 0.9840783476829529, eval loss: 0.0526392324518153, lr: 0.010000000000000002
Epoch 5: train acc: 0.9871875047683716, train loss: 0.04034820873523131, eval acc: 0.983188271522522, eval loss: 0.05281196342106272, lr: 0.010000000000000002


 50%|█████     | 6/12 [02:09<02:09, 21.50s/it]

Epoch 6: train acc: 0.009999999776482582, train loss: 8.807899430394173e-05, eval acc: 0.9834849834442139, eval loss: 0.05319766499119799, lr: 0.010000000000000002
Epoch 6: train acc: 0.9898437261581421, train loss: 0.034437640514224765, eval acc: 0.9827927350997925, eval loss: 0.05439295138739332, lr: 0.010000000000000002
Epoch 6: train acc: 0.9901562333106995, train loss: 0.03161859963787719, eval acc: 0.9833860993385315, eval loss: 0.05349271798423857, lr: 0.010000000000000002
Epoch 6: train acc: 0.9889843463897705, train loss: 0.03407466796925292, eval acc: 0.9826938509941101, eval loss: 0.053520268660600154, lr: 0.010000000000000002
Epoch 6: train acc: 0.9886718988418579, train loss: 0.03505327807040885, eval acc: 0.9821004867553711, eval loss: 0.056134261974852676, lr: 0.010000000000000002


 58%|█████▊    | 7/12 [02:30<01:47, 21.52s/it]

Epoch 7: train acc: 0.009843749925494194, train loss: 0.00033349640667438505, eval acc: 0.9835838675498962, eval loss: 0.0542416813683142, lr: 0.0010000000000000002
Epoch 7: train acc: 0.9918749928474426, train loss: 0.028499848588835447, eval acc: 0.9836827516555786, eval loss: 0.05434502239185798, lr: 0.0010000000000000002
Epoch 7: train acc: 0.9897656440734863, train loss: 0.03486225467408076, eval acc: 0.9836827516555786, eval loss: 0.05361362839341635, lr: 0.0010000000000000002
Epoch 7: train acc: 0.9903905987739563, train loss: 0.030746293964330106, eval acc: 0.9835838675498962, eval loss: 0.053360715956438945, lr: 0.0010000000000000002
Epoch 7: train acc: 0.9903125166893005, train loss: 0.030708893092814833, eval acc: 0.983188271522522, eval loss: 0.053406424661273065, lr: 0.0010000000000000002


 67%|██████▋   | 8/12 [02:52<01:26, 21.60s/it]

Epoch 8: train acc: 0.009921874850988388, train loss: 0.00010471918620169162, eval acc: 0.982594907283783, eval loss: 0.05380655919743961, lr: 0.0010000000000000002
Epoch 8: train acc: 0.991015613079071, train loss: 0.028535872872453183, eval acc: 0.9836827516555786, eval loss: 0.05291455404617812, lr: 0.0010000000000000002
Epoch 8: train acc: 0.9908593893051147, train loss: 0.02714365129126236, eval acc: 0.9838805198669434, eval loss: 0.05299098169596135, lr: 0.0010000000000000002
Epoch 8: train acc: 0.9903905987739563, train loss: 0.030029519682284445, eval acc: 0.9829905033111572, eval loss: 0.057351874570727726, lr: 0.0010000000000000002
Epoch 8: train acc: 0.9909374713897705, train loss: 0.029188879309222104, eval acc: 0.9834849834442139, eval loss: 0.05279186218932766, lr: 0.0010000000000000002


 75%|███████▌  | 9/12 [03:14<01:04, 21.66s/it]

Epoch 9: train acc: 0.009921874850988388, train loss: 0.00017538832500576973, eval acc: 0.9834849834442139, eval loss: 0.052853408558412064, lr: 0.0010000000000000002
Epoch 9: train acc: 0.9888281226158142, train loss: 0.03286971835885197, eval acc: 0.9835838675498962, eval loss: 0.05260746173053697, lr: 0.0010000000000000002
Epoch 9: train acc: 0.9917187690734863, train loss: 0.02558575275936164, eval acc: 0.9836827516555786, eval loss: 0.05254153088827491, lr: 0.0010000000000000002
Epoch 9: train acc: 0.9921875, train loss: 0.024617901316378265, eval acc: 0.9836827516555786, eval loss: 0.05257258074626303, lr: 0.0010000000000000002
Epoch 9: train acc: 0.991406261920929, train loss: 0.02776756959967315, eval acc: 0.983781635761261, eval loss: 0.05265917575953505, lr: 0.0010000000000000002


 83%|████████▎ | 10/12 [03:35<00:43, 21.70s/it]

Epoch 10: train acc: 0.009921874850988388, train loss: 0.00013857657089829444, eval acc: 0.9838805198669434, eval loss: 0.05268703555523217, lr: 0.00010000000000000003
Epoch 10: train acc: 0.9900000095367432, train loss: 0.030712137279333546, eval acc: 0.9838805198669434, eval loss: 0.05258933916323983, lr: 0.00010000000000000003
Epoch 10: train acc: 0.9921875, train loss: 0.024480454471195116, eval acc: 0.9824960231781006, eval loss: 0.05475099367101358, lr: 0.00010000000000000003
Epoch 10: train acc: 0.9904687404632568, train loss: 0.030018166974186897, eval acc: 0.9838805198669434, eval loss: 0.053807159809136314, lr: 0.00010000000000000003
Epoch 10: train acc: 0.9918749928474426, train loss: 0.026334327224176377, eval acc: 0.983188271522522, eval loss: 0.05370580112632317, lr: 0.00010000000000000003


 92%|█████████▏| 11/12 [03:57<00:21, 21.75s/it]

Epoch 11: train acc: 0.009921874850988388, train loss: 0.0003418070077896118, eval acc: 0.9838805198669434, eval loss: 0.05285081607466446, lr: 0.00010000000000000003
Epoch 11: train acc: 0.9900781512260437, train loss: 0.028908225242048502, eval acc: 0.9838805198669434, eval loss: 0.05257432444958323, lr: 0.00010000000000000003
Epoch 11: train acc: 0.990234375, train loss: 0.029238635851070286, eval acc: 0.9839794039726257, eval loss: 0.05254003660505802, lr: 0.00010000000000000003
Epoch 11: train acc: 0.9910937547683716, train loss: 0.029836320751346648, eval acc: 0.9832871556282043, eval loss: 0.0546753067417141, lr: 0.00010000000000000003
Epoch 11: train acc: 0.9914844036102295, train loss: 0.026752150703687222, eval acc: 0.9819027185440063, eval loss: 0.055864433333560636, lr: 0.00010000000000000003


100%|██████████| 12/12 [04:19<00:00, 21.60s/it]


In [8]:
# save the first 5000 images in the test set as numpy array
test_images = []
test_labels = []
for i, (x, y) in enumerate(test_loader):
    test_images.append(x)
    test_labels.append(y)

test_images = torch.cat(test_images).numpy()
test_labels = torch.cat(test_labels).numpy()
np.save("./data/test_images.npy", test_images)
np.save("./data/test_labels.npy", test_labels)

# save model weights as numpy array
for idx, (name, param) in enumerate(best_model.items()):
    name = name.split(".")
    name[0] = str(idx)
    name = ".".join(name)
    np.save(f"./data/linear_{name}.npy", param.detach().cpu().T.numpy()) # W @ X -> X @ W.T, cnpy reads in column-major order
    


In [10]:
param = np.load("./data/linear_0.weight.npy")
print("shape:", param.shape, ",", param.dtype, ", first 3 values (cupy): ", param[:3, 0])

shape: (784, 150) , float32 , first 3 values (cupy):  [-0.01497371 -0.0039914   0.01799681]
