In [8]:
import torch
import torch.nn.functional as F

class SAGA:
    def __init__(self, X, y, loss_fn, grad_loss_fn, lr=0.1, epochs=100):
        self.X = X
        self.y = y
        self.loss_fn = loss_fn
        self.grad_loss_fn = grad_loss_fn
        self.lr = lr
        self.epochs = epochs
        self.n_samples, self.n_features = X.shape
        self.w = torch.zeros(self.n_features, dtype=torch.float32, requires_grad=True)
        self.grad_store = torch.zeros((self.n_samples, self.n_features), dtype=torch.float32)

    def fit(self):
        for epoch in range(self.epochs):
            indices = torch.randperm(self.n_samples)
            avg_grad_w = torch.mean(self.grad_store, dim=0)
            
            for i in indices:
                xi = self.X[i]
                yi = self.y[i]
                
                pred = torch.dot(xi, self.w) 
                
                grad_w = self.grad_loss_fn(pred, yi, xi)
                
                new_w = self.w - self.lr * (grad_w - self.grad_store[i, :] + avg_grad_w)
                
                self.w = new_w
                
                self.grad_store[i, :] = grad_w.detach()

            # Compute and display the loss for the current epoch
            current_loss = self.compute_loss()
            print(f"Epoch {epoch + 1}/{self.epochs}, Loss: {current_loss:.4f}")

    def compute_loss(self):
        preds = self.X @ self.w 
        return self.loss_fn(preds, self.y).item()

    def predict(self, X):
        return X @ self.w 

In [10]:
# Generate some synthetic data
torch.manual_seed(42)
X = torch.randn(100, 2)
y = (torch.sigmoid(X[:, 0] * 2 - X[:, 1] * 3) > 0.5).float()

# Define the log-sum-exp loss function for squared residuals
def log_sum_exp_loss(preds, targets):
    squared_residuals = (preds - targets) ** 2
    return torch.logsumexp(squared_residuals, dim=0)

# Define the gradient of the log-sum-exp loss function for squared residuals
def log_sum_exp_grad(pred, y, x):
    squared_residual = (pred - y) ** 2
    exp_term = torch.exp(squared_residual)
    grad_common_term = (2 * (pred - y) * exp_term) / torch.sum(exp_term)
    grad_w = grad_common_term * x
    return grad_w



# Initialize and train SAGA
saga = SAGA(X, y, log_sum_exp_loss, log_sum_exp_grad, lr=0.001, epochs=100)
saga.fit()

# Predict on new data
X_test = torch.randn(10, 2)
predictions = saga.predict(X_test)
print("Predictions:", predictions)

Epoch 1/100, Loss: 5.1099
Epoch 2/100, Loss: 5.0893
Epoch 3/100, Loss: 5.0737
Epoch 4/100, Loss: 5.0624
Epoch 5/100, Loss: 5.0538
Epoch 6/100, Loss: 5.0472
Epoch 7/100, Loss: 5.0421
Epoch 8/100, Loss: 5.0383
Epoch 9/100, Loss: 5.0352
Epoch 10/100, Loss: 5.0329
Epoch 11/100, Loss: 5.0310
Epoch 12/100, Loss: 5.0295
Epoch 13/100, Loss: 5.0283
Epoch 14/100, Loss: 5.0274
Epoch 15/100, Loss: 5.0266
Epoch 16/100, Loss: 5.0259
Epoch 17/100, Loss: 5.0254
Epoch 18/100, Loss: 5.0250
Epoch 19/100, Loss: 5.0247
Epoch 20/100, Loss: 5.0244
Epoch 21/100, Loss: 5.0241
Epoch 22/100, Loss: 5.0239
Epoch 23/100, Loss: 5.0238
Epoch 24/100, Loss: 5.0236
Epoch 25/100, Loss: 5.0235
Epoch 26/100, Loss: 5.0234
Epoch 27/100, Loss: 5.0233
Epoch 28/100, Loss: 5.0232
Epoch 29/100, Loss: 5.0232
Epoch 30/100, Loss: 5.0231
Epoch 31/100, Loss: 5.0231
Epoch 32/100, Loss: 5.0231
Epoch 33/100, Loss: 5.0230
Epoch 34/100, Loss: 5.0230
Epoch 35/100, Loss: 5.0230
Epoch 36/100, Loss: 5.0230
Epoch 37/100, Loss: 5.0230
Epoch 38/1