In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader

In [None]:
# Create Data class
class Data(Dataset):
    """A custom dataset class for generating synthetic data points with labels."""

    # Constructor
    def __init__(self, train=True):
        """Initialize the dataset with data points and labels.

        Args:
            train (bool, optional): Determines whether to create the training dataset with outliers. Default is True.
        """
        self.x = torch.arange(-3, 3, 0.1).view(-1, 1)
        self.f = -3 * self.x + 1
        self.y = self.f + 0.1 * torch.randn(self.x.size())
        self.len = self.x.shape[0]

        # Introduce outliers for the training dataset
        if train:
            self.y[0] = 0
            self.y[50:55] = 20
        else:
            pass

    # Getter
    def __getitem__(self, index):
        """Get a single data point and its label.

        Args:
            index (int): Index of the data point.

        Returns:
            tuple: A tuple containing the data point and its label.
        """
        return self.x[index], self.y[index]

    # Get Length
    def __len__(self):
        """Get the length of the dataset."""
        return self.len

In [None]:
# Create training dataset and validation dataset
train_data = Data()
val_data = Data(train=False)

# Plot training points
plt.plot(train_data.x.numpy(), train_data.y.numpy(), 'xr', label="training data")
plt.plot(train_data.x.numpy(), train_data.f.numpy(), label="true function")
plt.xlabel('x')
plt.ylabel('y')
plt.legend()
plt.show()

In [None]:
# Create Linear Regression Class
class linear_regression(nn.Module):
    """A custom linear regression model class."""

    # Constructor
    def __init__(self, input_size, output_size):
        """Initialize the linear regression model."""
        super(linear_regression, self).__init__()
        self.linear = nn.Linear(input_size, output_size)

    # Prediction function
    def forward(self, x):
        """Forward pass through the model."""
        yhat = self.linear(x)
        return yhat

In [None]:
# Create MSELoss function and DataLoader
criterion = nn.MSELoss()
trainloader = DataLoader(dataset=train_data, batch_size=1)

In [None]:
# Define the train model function and train the model
def train_model_with_lr(iter, lr_list):
    """
    Train the model using different learning rates and store results.

    Args:
    iter (int): Number of iterations for training.
    lr_list (list): List of learning rates to try.
    """
    global MODELS
    MODELS = []

    for i, lr in enumerate(lr_list):
        model = linear_regression(1, 1)
        optimizer = optim.SGD(model.parameters(), lr=lr)

        for epoch in range(iter):
            for x, y in trainloader:
                yhat = model(x)
                loss = criterion(yhat, y)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        # Calculate training loss
        Yhat_train = model(train_data.x)
        train_loss = criterion(Yhat_train, train_data.y)
        train_error[i] = train_loss.item()

        # Calculate validation loss
        Yhat_val = model(val_data.x)
        val_loss = criterion(Yhat_val, val_data.y)
        validation_error[i] = val_loss.item()
        MODELS.append(model)

# Define learning rates
learning_rates = [0.0001, 0.001, 0.01, 0.1]
train_error = torch.zeros(len(learning_rates))
validation_error = torch.zeros(len(learning_rates))

# Train models with different learning rates
train_model_with_lr(10, learning_rates)

In [None]:
# Plot the training loss and validation loss
plt.semilogx(np.array(learning_rates), train_error.numpy(), label='training loss/total Loss')
plt.semilogx(np.array(learning_rates), validation_error.numpy(), label='validation cost/total Loss')
plt.ylabel('Cost/Total Loss')
plt.xlabel('Learning Rate')
plt.legend()
plt.show()

In [None]:
# Plot the predictions
for model, learning_rate in zip(MODELS, learning_rates):
    yhat = model(val_data.x)
    plt.plot(val_data.x.numpy(), yhat.detach().numpy(), label='lr:' + str(learning_rate))

plt.plot(val_data.x.numpy(), val_data.f.numpy(), 'or', label='validation data')
plt.xlabel('x')
plt.ylabel('y')
plt.legend()
plt.show()