## Experimenting with momentum
In this exercise, your goal is to find the optimal momentum such that the optimizer can find the minimum of the following non-convex function x**4 + x**3 - 5*x**2 in 20 steps. You will experiment with two different momentum values. For this problem, the learning rate is fixed at 0.01.

You are provided with the optimize_and_plot() function that takes the learning rate for the first argument. This function will run 20 steps of the SGD optimizer and display the results.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

def optimize_and_plot(lr, momentum, num_epochs=100, input_size=1, output_size=1, x_train=None, y_train=None):
    # Initialize model, criterion, and optimizer
    model = nn.Linear(input_size, output_size)
    criterion = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

    # Convert data to PyTorch tensors if provided
    if x_train is not None and y_train is not None:
        x_train = torch.Tensor(x_train)
        y_train = torch.Tensor(y_train)

    losses = []

    for epoch in range(num_epochs):
        # Forward pass
        outputs = model(x_train)
        loss = criterion(outputs, y_train)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Save the loss for plotting
        losses.append(loss.item())

        # Print the loss every 10 epochs and add a red marker
        if (epoch + 1) % 10 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
            plt.scatter(epoch + 1, loss.item(), color='red', marker='o')

    # Plot the loss over epochs
    plt.plot(range(1, num_epochs+1), losses, label='Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss Over Epochs')
    plt.legend()
    plt.show()

# Example usage:
# lr = 0.01
# x_train, y_train, input_size, output_size are provided based on your specific use case
# optimize_and_plot(lr, num_epochs=100, input_size=input_size, output_size=output_size, x_train=x_train, y_train=y_train)


In [2]:
'''Just do the same as 5'''

'Just do the same as 5'