In [1]:
import torch
import torch.nn as nn
from torch.optim import Optimizer
from torch.autograd import Function
from typing import List, Optional, Tuple, Union
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.optim import Adam
import matplotlib.pyplot as plt
import numpy as np

In [2]:
class Lion(Optimizer):
  def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0.0):
    defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
    super().__init__(params, defaults)

  @torch.no_grad()
  def step(self, closure=None):
    loss = None
    if closure is not None:
      with torch.enable_grad():
        loss = closure()

    for group in self.param_groups:
      for p in group['params']:
        if p.grad is None:
          continue

        p.data.mul_(1 - group['lr'] * group['weight_decay'])

        grad = p.grad
        state = self.state[p]
        if len(state) == 0:
          state['exp_avg'] = torch.zeros_like(p)

        exp_avg = state['exp_avg']
        beta1, beta2 = group['betas']

        update = exp_avg * beta1 + grad * (1 - beta1)

        p.add_(update.sign_(), alpha=-group['lr'])

        exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)

    return loss

In [3]:
class SimpleRMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True))

        x_norm = x / (rms + self.eps)
        return x_norm * self.weight

def compare_rmsnorm():
    batch_size = 32
    seq_len = 128
    hidden_dim = 512
    x = torch.randn(batch_size, seq_len, hidden_dim)
    x = x.to(device)

    simple_norm = SimpleRMSNorm(hidden_dim)
    builtin_norm = nn.RMSNorm(hidden_dim)

    simple_norm = simple_norm.to(device)
    builtin_norm = builtin_norm.to(device)

    out_simple = simple_norm(x)
    out_builtin = builtin_norm(x)

    max_diff = torch.max(torch.abs(out_simple - out_builtin))
    mean_diff = torch.mean(torch.abs(out_simple - out_builtin))
    print(f"Maximum difference between outputs: {max_diff:.6f}")
    print(f"Mean difference between outputs: {mean_diff:.6f}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

compare_rmsnorm()

Using device: cuda
Maximum difference between outputs: 0.000005
Mean difference between outputs: 0.000001


In [4]:
class ExpCosFunction(Function):
    @staticmethod
    def forward(ctx, x, y):
        ctx.save_for_backward(x, y)

        exp_x = torch.exp(x)
        cos_y = torch.cos(y)
        return exp_x + cos_y

    @staticmethod
    def backward(ctx, grad_output):
        x, y = ctx.saved_tensors

        grad_x = grad_output * torch.exp(x)
        grad_y = grad_output * -torch.sin(y)
        return grad_x, grad_y

def compute_function_custom(x, y):

    x_tensor = torch.tensor(x, dtype=torch.float32, requires_grad=True)
    y_tensor = torch.tensor(y, dtype=torch.float32, requires_grad=True)


    result = ExpCosFunction.apply(x_tensor, y_tensor)

    result.backward()

    return {
        'value': result.item(),
        'grad_x': x_tensor.grad.item(),
        'grad_y': y_tensor.grad.item()
    }

def compute_function_torch(x, y):

    x_tensor = torch.tensor(x, dtype=torch.float32, requires_grad=True)
    y_tensor = torch.tensor(y, dtype=torch.float32, requires_grad=True)

    result = torch.exp(x_tensor) + torch.cos(y_tensor)

    result.backward()

    return {
        'value': result.item(),
        'grad_x': x_tensor.grad.item(),
        'grad_y': y_tensor.grad.item()
    }

def compare_implementations(x_val, y_val):
    custom_result = compute_function_custom(x_val, y_val)
    torch_result = compute_function_torch(x_val, y_val)

    print(f"\nComparing implementations for x={x_val}, y={y_val}:")
    print("\nFunction values:")
    print(f"Custom implementation: {custom_result['value']}")
    print(f"PyTorch implementation: {torch_result['value']}")
    print(f"Difference: {abs(custom_result['value'] - torch_result['value'])}")

    print("\ngradients with respect to x:")
    print(f"custom implementation: {custom_result['grad_x']}")
    print(f"pytorch implementation: {torch_result['grad_x']}")
    print(f"difference: {abs(custom_result['grad_x'] - torch_result['grad_x'])}")

    print("\ngradients with respect to y:")
    print(f"custom implementation: {custom_result['grad_y']}")
    print(f"pytorch implementation: {torch_result['grad_y']}")
    print(f"difference: {abs(custom_result['grad_y'] - torch_result['grad_y'])}")

test_cases = [
  (1.0, torch.pi/4),
]

for x_val, y_val in test_cases:
    compare_implementations(x_val, y_val)
    print("\n" + "="*50 + "\n")



Comparing implementations for x=1.0, y=0.7853981633974483:

Function values:
Custom implementation: 3.4253885746002197
PyTorch implementation: 3.4253885746002197
Difference: 0.0

gradients with respect to x:
custom implementation: 2.7182817459106445
pytorch implementation: 2.7182817459106445
difference: 0.0

gradients with respect to y:
custom implementation: -0.7071067690849304
pytorch implementation: -0.7071067690849304
difference: 0.0




In [5]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

def train(model, device, train_loader, optimizer, epoch):
    model.train()
    total_loss = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')
    return total_loss / len(train_loader)

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, '
          f'Accuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.2f}%)\n')
    return test_loss, accuracy

def plot_metrics(lion_metrics, adam_metrics, save_path='comparison.png'):
    epochs = range(1, len(lion_metrics['train_loss']) + 1)

    plt.figure(figsize=(12, 5))

    # Plot training loss
    plt.subplot(1, 2, 1)
    plt.plot(epochs, lion_metrics['train_loss'], 'b-', label='Lion')
    plt.plot(epochs, adam_metrics['train_loss'], 'r-', label='Adam')
    plt.title('Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    # Plot test accuracy
    plt.subplot(1, 2, 2)
    plt.plot(epochs, lion_metrics['test_acc'], 'b-', label='Lion')
    plt.plot(epochs, adam_metrics['test_acc'], 'r-', label='Adam')
    plt.title('Test Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

def train_with_optimizer(optimizer_name, model, device, train_loader, test_loader, epochs):
    if optimizer_name == 'Lion':
        optimizer = Lion(model.parameters(), lr=1e-4)
    else:  # Adam
        optimizer = Adam(model.parameters(), lr=1e-4)

    train_losses = []
    test_losses = []
    test_accuracies = []

    for epoch in range(1, epochs + 1):
        train_loss = train(model, device, train_loader, optimizer, epoch)
        test_loss, test_acc = test(model, device, test_loader)

        train_losses.append(train_loss)
        test_losses.append(test_loss)
        test_accuracies.append(test_acc)

    return {
        'train_loss': train_losses,
        'test_loss': test_losses,
        'test_acc': test_accuracies
    }

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('data', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000)

epochs = 10


print("\nTraining with Lion optimizer...")
lion_model = SimpleCNN().to(device)
lion_metrics = train_with_optimizer('Lion', lion_model, device, train_loader, test_loader, epochs)


print("\nTraining with Adam optimizer...")
adam_model = SimpleCNN().to(device)
adam_metrics = train_with_optimizer('Adam', adam_model, device, train_loader, test_loader, epochs)


plot_metrics(lion_metrics, adam_metrics)
print("\nComparison plot saved as 'comparison.png'")


print("\nFinal Results:")
print(f"Lion - Best accuracy: {max(lion_metrics['test_acc']):.2f}%")
print(f"Adam - Best accuracy: {max(adam_metrics['test_acc']):.2f}%")

Using device: cuda


100%|██████████| 9.91M/9.91M [00:00<00:00, 12.7MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 340kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 3.18MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 6.79MB/s]



Training with Lion optimizer...

Test set: Average loss: 0.0482, Accuracy: 9850/10000 (98.50%)


Test set: Average loss: 0.0371, Accuracy: 9873/10000 (98.73%)


Test set: Average loss: 0.0308, Accuracy: 9895/10000 (98.95%)


Test set: Average loss: 0.0272, Accuracy: 9909/10000 (99.09%)


Test set: Average loss: 0.0336, Accuracy: 9898/10000 (98.98%)


Test set: Average loss: 0.0302, Accuracy: 9914/10000 (99.14%)


Test set: Average loss: 0.0316, Accuracy: 9906/10000 (99.06%)


Test set: Average loss: 0.0311, Accuracy: 9924/10000 (99.24%)


Test set: Average loss: 0.0363, Accuracy: 9909/10000 (99.09%)


Test set: Average loss: 0.0327, Accuracy: 9914/10000 (99.14%)


Training with Adam optimizer...

Test set: Average loss: 0.1067, Accuracy: 9685/10000 (96.85%)


Test set: Average loss: 0.0637, Accuracy: 9798/10000 (97.98%)


Test set: Average loss: 0.0527, Accuracy: 9825/10000 (98.25%)


Test set: Average loss: 0.0422, Accuracy: 9854/10000 (98.54%)


Test set: Average loss: 0.0372, Accur