In [None]:
from juliacall import Main as jl
import numpy as np

jl.include("batch_sgd_ns.jl")

## Newton-Schulz

In [None]:
G = np.random.rand(5, 5)
O = np.array(jl.NewtonSchulz5(G))

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
sns.heatmap(np.abs(G.T @ G), annot=True, fmt=".2f", cmap="viridis", vmin=0, vmax=1)
plt.title("abs(G.T @ G) (Input Matrix)")

plt.subplot(1, 2, 2)
sns.heatmap(np.abs(O.T @ O), annot=True, fmt=".2f", cmap="viridis", vmin=0, vmax=1)
plt.title("abs(O.T @ O) (Output Matrix)")
plt.show()

## Toy model

In [None]:
output_dim = 2
input_dim = 10
m = 1000

# Imagine we have some true linear model we want to learn,
# and we generate random data for it.
# In reality, you would have multiple layers and non-linearities,
# but you can upgrade this example by implementing backpropagation.
true_W = np.random.normal(0, 1, (input_dim, output_dim))
X = np.random.rand(m, input_dim)
Y = jl.LinearPredict(true_W, X)

N_try = 50
N_epochs = 1000
batch_size = 100
learning_rate = 0.005

In [None]:
with_newton_schulz_losses = np.zeros(N_epochs)
without_newton_schulz_losses = np.zeros(N_epochs)
adaptive_losses = np.zeros(N_epochs)

with_newton_schulz_times = np.zeros(N_epochs)
without_newton_schulz_times = np.zeros(N_epochs)
adaptive_times = np.zeros(N_epochs)

for _ in range(N_try):
    W = np.random.normal(0, 1, (input_dim, output_dim))

    _, losses, times = jl.Train(W, X, Y, N_epochs, batch_size, True, learning_rate)
    with_newton_schulz_losses += np.array(losses)
    with_newton_schulz_times += np.array(times)

    _, losses, times = jl.Train(W, X, Y, N_epochs, batch_size, False, learning_rate)
    without_newton_schulz_losses += np.array(losses)
    without_newton_schulz_times += np.array(times)
    
    _, losses, times = jl.TrainNSAdapt(W, X, Y, N_epochs, batch_size, 1e-4, learning_rate)
    adaptive_losses += np.array(losses)
    adaptive_times += np.array(times)
    
with_newton_schulz_losses /= N_try
without_newton_schulz_losses /= N_try
adaptive_losses /= N_try

with_newton_schulz_times /= N_try
without_newton_schulz_times /= N_try
adaptive_times /= N_try

In [None]:
import matplotlib.pyplot as plt

plt.plot(with_newton_schulz_losses, label="With Newton-Schulz", color="#6196FF")
plt.plot(without_newton_schulz_losses, label="Without Newton-Schulz", color="#FF6196")
plt.plot(adaptive_losses, label="Adaptive Newton-Schulz", color="#68FD5D")
plt.yscale("log")
plt.xlabel("Epoch")
plt.ylabel("L2 Loss")
plt.title("Training Loss over Epochs")
plt.grid(alpha=0.3)
plt.legend()
plt.tight_layout()
plt.show()

# plt.savefig("training_loss_comparison.pdf")

In [None]:
import matplotlib.pyplot as plt

# smooth out time plots
def smooth(data, window_size=20):
    return np.convolve(data, np.ones(window_size) / window_size, mode='valid')

plt.plot(smooth(with_newton_schulz_times)/1e6, label="With Newton-Schulz", color="#6196FF")
plt.plot(smooth(without_newton_schulz_times)/1e6, label="Without Newton-Schulz", color="#FF6196")
plt.plot(smooth(adaptive_times)/1e6, label="Adaptive Newton-Schulz", color="#68FD5D")
plt.xlabel("Epoch")
plt.ylabel("Time (ms)")
plt.title("Performance Comparison - Time per Epoch")
plt.grid(alpha=0.3)
plt.legend()
plt.tight_layout()
plt.show()

# plt.savefig("training_time_comparison.pdf")

### Sample prediction

In [None]:
W = np.random.normal(0, 1, (input_dim, output_dim))
W, *_ = jl.TrainNSAdapt(W, X, Y, N_epochs, batch_size)
W = np.array(W)

In [None]:
print(f"Prediction: {jl.LinearPredict(W, X)}")
print(f"Target: {Y}")