In [49]:
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0), (1))])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(dataset = train_dataset, batch_size = 64, shuffle = True)
test_loader = DataLoader(dataset = test_dataset, batch_size = 64, shuffle = False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:01<00:00, 7824421.16it/s]


Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 9371475.62it/s]


Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 6604035.63it/s]


Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 4557542.77it/s]

Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw






In [82]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Iterable
from torch.nn import Parameter

class BasicNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 128, bias=False),
            nn.GELU(),
            # nn.BatchNorm1d(128),
            nn.Linear(128, 64, bias=False),
            nn.GELU(),
            # nn.BatchNorm1d(64),
            nn.Linear(64, 10, bias=False)
        )
        
    def forward(self, x: torch.Tensor):
        return self.model(x)
    
# Claude's function
def matrix_power_neg_quarter(A, epsilon=1e-6):
    """
    Compute A^(-1/4) for a symmetric positive definite matrix A.
    """
    # Compute eigendecomposition
    eigenvalues, eigenvectors = torch.linalg.eigh(A)
    
    # Compute powered eigenvalues
    powered_eigenvalues = 1.0 / torch.pow(torch.clamp(eigenvalues, min=epsilon), 0.25)
    
    # Reconstruct the matrix
    return eigenvectors @ torch.diag(powered_eigenvalues) @ eigenvectors.t()

class ShampooOptimizer(torch.optim.Optimizer):
    def __init__(self, parameters: Iterable[Parameter], lr=0.001, betas=(0.9, 0.999), epsilon=1e-8):
        self.params = list(parameters)
        self.lr = lr
        defaults = dict(lr=lr, betas=betas, epsilon=epsilon)
        super(ShampooOptimizer, self).__init__(self.params, defaults)
        
        # assumes 2D everything. Don't want to even think about 3D

        # self.G_hat = [torch.zeros(p.shape) for p in self.params] # d1 x d2
        self.L = [epsilon * torch.eye(int(p.shape[0])) for p in self.params] # d1 x d1
        # self.L_tilda = [torch.zeros((p.shape[0], p.shape[0])) for p in self.params] # d1 x d1
        
        self.R = [epsilon * torch.eye(int(p.shape[1])) for p in self.params] # d2 x d2
        # self.R_tilda = [torch.zeros((p.shape[1], p.shape[1])) for p in self.params] # d2 x d2
        
        # self.M = [epsilon * torch.zeros(p.shape) for p in self.params] # d1 x d2
        
    def step(self, closure = None):
        loss = None
        if closure is not None:
            loss = closure()
            
    
        # I'm treating weights and bias seperately 
        L = self.L
        R = self.R
        for i, param in enumerate(self.params):
            if param.grad is None:
                continue
            L[i].add_(torch.matmul(param.grad, param.grad.T))
            R[i].add_(torch.matmul(param.grad.T, param.grad))
            
            D = torch.matmul(torch.matmul(matrix_power_neg_quarter(L[i]), param.grad), matrix_power_neg_quarter(R[i]))
            
            param.data.sub_(D, alpha=self.lr)
            
        return loss

In [87]:
import time
model = BasicNN()
loss_fn = F.cross_entropy
optimizer = ShampooOptimizer(model.parameters(), lr=0.1)

total_loss = 0
last_time = time.time()
for idx, (x, label) in enumerate(train_loader):
    batch_size = x.shape[0]
    logits = model(x.view(batch_size, -1))
    logits = F.softmax(logits, dim=-1)

    loss = loss_fn(logits, label)
    total_loss += loss.item()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if idx % 100 == 0 and idx != 0:
        print(f"average loss at index {idx}: {total_loss/100} | took {time.time()-last_time} seconds")
        last_time = time.time()
        total_loss = 0

average loss at index 100: 1.7933749449253082 | took 12.877652406692505 seconds
average loss at index 200: 1.5970934927463531 | took 11.907608985900879 seconds
average loss at index 300: 1.5768053781986238 | took 11.973227739334106 seconds
average loss at index 400: 1.5609676575660705 | took 12.131686687469482 seconds
average loss at index 500: 1.5553697979450225 | took 12.106920957565308 seconds
average loss at index 600: 1.552356003522873 | took 12.12163758277893 seconds
average loss at index 700: 1.5471753704547881 | took 12.13194990158081 seconds
average loss at index 800: 1.5423400354385377 | took 12.200146436691284 seconds
average loss at index 900: 1.5334815895557403 | took 12.112584829330444 seconds


In [91]:
correct = 0
total = 0

model.eval()
with torch.no_grad():
    for idx, (img, label) in enumerate(test_loader):
        batch_size = img.shape[0]
        logits = model(img.view(batch_size, -1))
        logits = F.softmax(logits, dim=-1)
        
        correct += (logits.argmax(dim=-1) == label).sum()
        total += label.numel()



print(f"total accuracy: {correct/total}")
# It does poorly because of improper backward pass implementation

total accuracy: 0.9404000043869019


In [25]:
torch.optim.Adam([torch.randn(3,3)])

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.001
    maximize: False
    weight_decay: 0
)