In [15]:
import matplotlib.pyplot as plt
import torch
from torchvision import datasets, transforms

In [None]:
transform = transforms.Compose([transforms.ToTensor()])
mnist_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

batch_size = 32
#split train, val, test using train_test_split
from sklearn.model_selection import train_test_split
train_data, val_data = train_test_split(mnist_data, test_size= 0.3, random_state=42)
val_data, test_data = train_test_split(val_data,test_size=0.5, random_state=42)

train_loader_pytorch = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader_pytorch = torch.utils.data.DataLoader(val_data,batch_size=batch_size, shuffle=True)
test_loader_pytorch = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True)



torch.Size([32, 1, 28, 28])
torch.Size([32])


In [30]:
# Baseline model with Convolutional model
class ConvNet(torch.nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = torch.nn.Conv2d(1,32,3,1)
        self.pool1 = torch.nn.MaxPool2d(2,2)
        self.conv2 = torch.nn.Conv2d(32,64,3,1)
        self.pool2 = torch.nn.MaxPool2d(2,2)
        self.fc1 = torch.nn.Linear(1600,128)
        self.fc2 = torch.nn.Linear(128,10)
        self.relu = torch.nn.ReLU()
    def forward(self,X):
        X = self.conv1(X)
        X = self.relu(X)
        X = self.pool1(X)
        X = self.conv2(X)
        X = self.relu(X)
        X = self.pool2(X)
        X = torch.flatten(X,1)
        X = self.fc1(X)
        X = self.relu(X)
        X = self.fc2(X)
        return X


In [None]:
from torch.optim.optimizer import Optimizer

class StochasticArmijoSGD(Optimizer):
    def __init__(self, params, lr=1.0, c=1e-4, tau=0.5, max_backtracks=10):
        """
        Stochastic Armijo line search optimizer.

        Args:
            params: model parameters
            lr: initial step size (alpha_0)
            c: Armijo condition parameter (small, e.g. 1e-4)
            tau: step size reduction factor (0 < tau < 1, e.g. 0.5)
            max_backtracks: max number of backtracking steps
        """
        defaults = dict(lr=lr, c=c, tau=tau, max_backtracks=max_backtracks)
        super(StochasticArmijoSGD, self).__init__(params, defaults)

    def step(self, closure):
        if closure is None:
            raise ValueError("StochasticArmijoSGD requires a closure that returns loss and computes grads.")

        # compute initial loss + grads
        loss = closure(backward=True)
        loss_value = loss.item()

        grads = []
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is not None:
                    grads.append(p.grad.detach().clone().view(-1))

        if len(grads) == 0:
            raise RuntimeError("No gradients found — check that your closure calls loss.backward() and model parameters require grad.")
        g = torch.cat(grads)
        g_norm_sq = g.dot(g).item()

        # store original params
        orig_params = []
        for group in self.param_groups:
            orig_params.append([p.detach().clone() for p in group['params']])

        # line search
        for group, orig_group_params in zip(self.param_groups, orig_params):
            lr = group['lr']
            c = group['c']
            tau = group['tau']
            max_backtracks = group['max_backtracks']

            step_size = lr
            success = False

            for _ in range(max_backtracks):
                # trial update (safe to do without grad tracking)
                with torch.no_grad():
                    for p, orig_p in zip(group['params'], orig_group_params):
                        if p.grad is None:
                            continue
                        p.data.copy_(orig_p - step_size * p.grad)

                # re-evaluate loss (no backward)
                new_loss = closure(backward=False)

                if new_loss.item() <= loss_value - c * step_size * g_norm_sq:
                    success = True
                    break

                step_size *= tau

            if not success:
                with torch.no_grad():
                    for p, orig_p in zip(group['params'], orig_group_params):
                        p.data.copy_(orig_p)

        return loss

In [45]:
# Training loop
import tqdm
criterion = torch.nn.CrossEntropyLoss()

num_epochs = 20  
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for epoch in range(num_epochs):
    for optimizer_name in ["Armijo", "Adam"]:
        #reset model weights
        model = ConvNet()

        if optimizer_name == "Adam":
            optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        else:
            optimizer = StochasticArmijoSGD(model.parameters(), lr=1.0, c=1e-4, tau=0.5, max_backtracks=10)

        for images, labels in tqdm.tqdm(train_loader_pytorch):
            images, labels = images.to(device), labels.to(device)
            if optimizer_name == "Adam":
                optimizer.zero_grad()
                outputs = model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
            else:  # armijo
                def closure(backward=True):
                    optimizer.zero_grad()
                    outputs = model(images)
                    loss = criterion(outputs, labels)
                    if backward:
                        loss.backward()
                    return loss
                optimizer.step(closure)
        # validation
        model.eval()
        val_loss = 0
        correct =0 
        total= 0
        with torch.no_grad():
            for images, labels in val_loader_pytorch:
                outputs = model(images)
                loss = criterion(outputs,labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs.data,1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
            val_loss = val_loss / len(val_loader_pytorch)
            val_accuracy = correct / total
        print(f'Epoch [{epoch+1}/{num_epochs}], Optimizer: {optimizer_name}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}')



100%|██████████| 1313/1313 [01:17<00:00, 16.88it/s]


Epoch [1/20], Optimizer: Armijo, Val Loss: 2.3095, Val Accuracy: 0.0951


100%|██████████| 1313/1313 [00:22<00:00, 58.10it/s]


Epoch [1/20], Optimizer: Adam, Val Loss: 0.0660, Val Accuracy: 0.9800


100%|██████████| 1313/1313 [01:41<00:00, 12.95it/s]


Epoch [2/20], Optimizer: Armijo, Val Loss: 2.1211, Val Accuracy: 0.3279


100%|██████████| 1313/1313 [00:21<00:00, 61.00it/s]


Epoch [2/20], Optimizer: Adam, Val Loss: 0.0663, Val Accuracy: 0.9801


100%|██████████| 1313/1313 [00:28<00:00, 46.06it/s]


Epoch [3/20], Optimizer: Armijo, Val Loss: 2.3241, Val Accuracy: 0.0951


100%|██████████| 1313/1313 [00:21<00:00, 62.00it/s]


Epoch [3/20], Optimizer: Adam, Val Loss: 0.0622, Val Accuracy: 0.9811


100%|██████████| 1313/1313 [00:35<00:00, 37.39it/s]


Epoch [4/20], Optimizer: Armijo, Val Loss: 2.3079, Val Accuracy: 0.1058


100%|██████████| 1313/1313 [00:21<00:00, 60.02it/s]


Epoch [4/20], Optimizer: Adam, Val Loss: 0.0706, Val Accuracy: 0.9764


100%|██████████| 1313/1313 [00:29<00:00, 45.17it/s]


Epoch [5/20], Optimizer: Armijo, Val Loss: 2.3096, Val Accuracy: 0.1058


100%|██████████| 1313/1313 [00:23<00:00, 56.20it/s]


Epoch [5/20], Optimizer: Adam, Val Loss: 0.0723, Val Accuracy: 0.9779


100%|██████████| 1313/1313 [01:54<00:00, 11.47it/s]


Epoch [6/20], Optimizer: Armijo, Val Loss: 1.8612, Val Accuracy: 0.3147


100%|██████████| 1313/1313 [00:19<00:00, 65.97it/s]


Epoch [6/20], Optimizer: Adam, Val Loss: 0.0849, Val Accuracy: 0.9737


100%|██████████| 1313/1313 [01:16<00:00, 17.05it/s]


Epoch [7/20], Optimizer: Armijo, Val Loss: 2.3109, Val Accuracy: 0.1131


100%|██████████| 1313/1313 [00:20<00:00, 65.25it/s]


Epoch [7/20], Optimizer: Adam, Val Loss: 0.0753, Val Accuracy: 0.9766


100%|██████████| 1313/1313 [00:33<00:00, 39.61it/s]


Epoch [8/20], Optimizer: Armijo, Val Loss: 2.3086, Val Accuracy: 0.1007


100%|██████████| 1313/1313 [00:25<00:00, 52.22it/s]


Epoch [8/20], Optimizer: Adam, Val Loss: 0.0731, Val Accuracy: 0.9782


100%|██████████| 1313/1313 [00:49<00:00, 26.33it/s]


Epoch [9/20], Optimizer: Armijo, Val Loss: 0.7832, Val Accuracy: 0.7789


100%|██████████| 1313/1313 [00:23<00:00, 56.38it/s]


Epoch [9/20], Optimizer: Adam, Val Loss: 0.0628, Val Accuracy: 0.9813


100%|██████████| 1313/1313 [02:11<00:00,  9.99it/s]


Epoch [10/20], Optimizer: Armijo, Val Loss: 2.2201, Val Accuracy: 0.2009


100%|██████████| 1313/1313 [00:30<00:00, 42.83it/s]


Epoch [10/20], Optimizer: Adam, Val Loss: 0.0680, Val Accuracy: 0.9793


100%|██████████| 1313/1313 [00:46<00:00, 28.32it/s]


Epoch [11/20], Optimizer: Armijo, Val Loss: 2.3134, Val Accuracy: 0.1010


100%|██████████| 1313/1313 [00:32<00:00, 40.06it/s]


Epoch [11/20], Optimizer: Adam, Val Loss: 0.0790, Val Accuracy: 0.9757


100%|██████████| 1313/1313 [01:16<00:00, 17.24it/s]


Epoch [12/20], Optimizer: Armijo, Val Loss: 2.3086, Val Accuracy: 0.0951


100%|██████████| 1313/1313 [00:29<00:00, 44.34it/s]


Epoch [12/20], Optimizer: Adam, Val Loss: 0.0747, Val Accuracy: 0.9778


100%|██████████| 1313/1313 [00:43<00:00, 30.03it/s]


Epoch [13/20], Optimizer: Armijo, Val Loss: 2.3040, Val Accuracy: 0.1058


100%|██████████| 1313/1313 [00:31<00:00, 41.23it/s]


Epoch [13/20], Optimizer: Adam, Val Loss: 0.0624, Val Accuracy: 0.9818


100%|██████████| 1313/1313 [00:54<00:00, 24.05it/s]


Epoch [14/20], Optimizer: Armijo, Val Loss: 2.3086, Val Accuracy: 0.1058


100%|██████████| 1313/1313 [00:31<00:00, 41.23it/s]


Epoch [14/20], Optimizer: Adam, Val Loss: 0.0701, Val Accuracy: 0.9787


100%|██████████| 1313/1313 [01:18<00:00, 16.67it/s]


Epoch [15/20], Optimizer: Armijo, Val Loss: 2.3053, Val Accuracy: 0.1058


100%|██████████| 1313/1313 [00:32<00:00, 40.79it/s]


Epoch [15/20], Optimizer: Adam, Val Loss: 0.0700, Val Accuracy: 0.9793


100%|██████████| 1313/1313 [01:05<00:00, 19.92it/s]


Epoch [16/20], Optimizer: Armijo, Val Loss: 2.3100, Val Accuracy: 0.1006


100%|██████████| 1313/1313 [00:31<00:00, 41.39it/s]


Epoch [16/20], Optimizer: Adam, Val Loss: 0.0690, Val Accuracy: 0.9789


100%|██████████| 1313/1313 [01:04<00:00, 20.35it/s]


Epoch [17/20], Optimizer: Armijo, Val Loss: 2.3108, Val Accuracy: 0.1010


100%|██████████| 1313/1313 [00:30<00:00, 43.02it/s]


Epoch [17/20], Optimizer: Adam, Val Loss: 0.0672, Val Accuracy: 0.9803


100%|██████████| 1313/1313 [01:15<00:00, 17.37it/s]


Epoch [18/20], Optimizer: Armijo, Val Loss: 0.8134, Val Accuracy: 0.7444


100%|██████████| 1313/1313 [00:32<00:00, 40.14it/s]


Epoch [18/20], Optimizer: Adam, Val Loss: 0.0613, Val Accuracy: 0.9813


100%|██████████| 1313/1313 [01:54<00:00, 11.46it/s]


Epoch [19/20], Optimizer: Armijo, Val Loss: 2.3097, Val Accuracy: 0.1131


100%|██████████| 1313/1313 [00:31<00:00, 41.76it/s]


Epoch [19/20], Optimizer: Adam, Val Loss: 0.0568, Val Accuracy: 0.9831


100%|██████████| 1313/1313 [02:31<00:00,  8.64it/s]


Epoch [20/20], Optimizer: Armijo, Val Loss: 2.1144, Val Accuracy: 0.2328


100%|██████████| 1313/1313 [00:30<00:00, 42.83it/s]


Epoch [20/20], Optimizer: Adam, Val Loss: 0.0729, Val Accuracy: 0.9780
