In [8]:
import numpy as np
import torch
import plotly.graph_objects as go

In [9]:
# 1. Generate the dataset
np.random.seed(5)
X = np.random.uniform(-2 * np.pi, 2 * np.pi, 100)
epsilon = np.random.normal(0, 0.1, 100)
y = np.sin(X) + epsilon

# Convert to PyTorch tensors
X_tensor = torch.tensor(X, dtype=torch.float32).view(-1, 1)
y_tensor = torch.tensor(y, dtype=torch.float32).view(-1, 1)

In [10]:
# 2. Initialize parameters for linear model
w = torch.randn(1, requires_grad=True, dtype=torch.float32)
b = torch.randn(1, requires_grad=True, dtype=torch.float32)

# 3. Define learning rate and number of epochs
learning_rate = 0.01
num_epochs = 1000
loss_fn = torch.nn.MSELoss()

# 4. Gradient Descent Loop
loss_values = []
for epoch in range(num_epochs):
    # Forward pass: compute predictions
    y_pred = w * X_tensor + b
    
    # Compute the loss
    loss = loss_fn(y_pred, y_tensor)
    loss_values.append(loss.item())
    
    # Backward pass: compute gradients
    loss.backward()
    
    # Update parameters (manually, as no optimizer is used)
    with torch.no_grad():
        w -= learning_rate * w.grad
        b -= learning_rate * b.grad
        
        # Zero gradients after updating
        w.grad.zero_()
        b.grad.zero_()

In [11]:
# 5. Plot the results using Plotly
# Scatter plot of the true data
data_scatter = go.Scatter(
    x=X, y=y, mode='markers', name='True Data', marker=dict(color='blue', size=8)
)

# Linear fit plot
data_line = go.Scatter(
    x=X, y=w.item() * X + b.item(), mode='lines', name='Linear Fit', line=dict(color='red', width=2)
)

# Combine the two plots
fig = go.Figure(data=[data_scatter, data_line])
fig.update_layout(
    title="Linear Fit to Non-Linear Data",
    xaxis_title="X",
    yaxis_title="y",
    template="ggplot2"
)
fig.show()

In [12]:
# Plot loss over epochs
fig_loss = go.Figure()
fig_loss.add_trace(go.Scatter(
    x=list(range(num_epochs)), y=loss_values, mode='lines', name='Training Loss'
))
fig_loss.update_layout(
    title="Training Loss Over Epochs",
    xaxis_title="Epochs",
    yaxis_title="Loss (MSE)",
    template="ggplot2"
)
fig_loss.show()
