In [1]:
#FashionMNIST using LARS (Layer wise adaptive weight scaling)

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

In [3]:
class SimpleMLP(nn.Module):
    def __init__(self):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.fc2 = nn.Linear(128,10)

        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x
        

In [4]:
class LARSOptimizer(optim.Optimizer):
    def __init__(self, params, lr, weight_decay=0, trust_coef=0.001):
        defaults = {'lr': lr, 'weight_decay': weight_decay, 'trust_coef': trust_coef}
        super(LARSOptimizer, self).__init__(params, defaults)

    def step(self):
        for group in self.param_groups:
            weight_decay = group['weight_decay']
            trust_coef = group['trust_coef']

            for p in group['params']:
                if p.grad is None:
                    continue

                global_lr = group['lr']
                weight_norm = torch.norm(p.data, p=2)
                grad_norm = torch.norm(p.grad.data, p=2)
                local_lr = trust_coef * weight_norm / (grad_norm + weight_decay * weight_norm)
                adjusted_lr = global_lr * local_lr

                p.grad.data += weight_decay * p.data
                p.data -= adjusted_lr * p.grad.data

In [5]:
transform = transforms.ToTensor()

In [6]:
trainset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=32, shuffle=True)



Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|███████████████████████████| 26421880/26421880 [01:52<00:00, 233956.15it/s]


Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|█████████████████████████████████| 29515/29515 [00:00<00:00, 149264.48it/s]


Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|█████████████████████████████| 4422102/4422102 [00:13<00:00, 322153.17it/s]


Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|█████████████████████████████████| 5148/5148 [00:00<00:00, 28523483.48it/s]

Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw






In [7]:
trainset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=32, shuffle=True)

testset = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)
testloader = DataLoader(testset, batch_size=32, shuffle=False)

# Initialize model and optimizer
model = SimpleMLP()
optimizer = LARSOptimizer(model.parameters(), lr=0.01, weight_decay=1e-4)



In [10]:
criterion = nn.CrossEntropyLoss()
num_epochs = 8
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(trainloader):
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

Epoch [1/8], Loss: 1.0580
Epoch [2/8], Loss: 1.1794
Epoch [3/8], Loss: 1.0322
Epoch [4/8], Loss: 0.9201
Epoch [5/8], Loss: 1.2918
Epoch [6/8], Loss: 0.9122
Epoch [7/8], Loss: 0.8917
Epoch [8/8], Loss: 1.0934


In [None]:
num_epochs = 8
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(trainloader):
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)