In [1]:
import numpy as np

def empirical_risk(w, b, X, y):
    n = X.shape[0]
    predictions = X.dot(w) + b
    residuals = y - predictions
    mse = np.mean(residuals ** 2)
    return mse

def gradient_descent(X, y, C, learning_rate=0.01, num_iterations=1000):
    n, d = X.shape
    w = np.random.randn(d)
    b = 0.0

    for i in range(num_iterations):
        predictions = X.dot(w) + b
        residuals = y - predictions
        
        w_grad = -2/n * X.T.dot(residuals)
        b_grad = -2/n * np.sum(residuals)
        
        # Update weights with gradient descent
        w -= learning_rate * w_grad
        b -= learning_rate * b_grad

        # Apply the constraint on the weights
        if np.linalg.norm(w) > C:
            w = w / np.linalg.norm(w) * C

        if i % 100 == 0:
            print(f'Iteration {i}, Empirical Risk: {empirical_risk(w, b, X, y)}')

    return w, b

# Example usage
if __name__ == "__main__":
    # Generate synthetic data
    np.random.seed(0)
    X = np.random.randn(100, 3)
    true_w = np.array([1.5, -2.0, 1.0])
    true_b = 0.5
    y = X.dot(true_w) + true_b + 0.5 * np.random.randn(100)

    # Define the bound for the weight vector
    C = 5.0

    # Perform ERM with bounded linear function class
    w, b = gradient_descent(X, y, C)

    print(f'Learned weights: {w}')
    print(f'Learned bias: {b}')
    print(f'Final Empirical Risk: {empirical_risk(w, b, X, y)}')


Iteration 0, Empirical Risk: 5.997800192536225
Iteration 100, Empirical Risk: 0.27687754673180687
Iteration 200, Empirical Risk: 0.21536728196419652
Iteration 300, Empirical Risk: 0.2146004869876152
Iteration 400, Empirical Risk: 0.2145876318735173
Iteration 500, Empirical Risk: 0.2145873108459609
Iteration 600, Empirical Risk: 0.21458729983355457
Iteration 700, Empirical Risk: 0.21458729938647395
Iteration 800, Empirical Risk: 0.21458729936689303
Iteration 900, Empirical Risk: 0.21458729936600573
Learned weights: [ 1.47897143 -2.02529826  1.02072218]
Learned bias: 0.4089768375447561
Final Empirical Risk: 0.21458729936596502
