In [None]:
from torch import nn
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch import nn,optim
from torch.utils.data import Dataset, DataLoader

In [None]:
# Set random seed for reproducibility
torch.manual_seed(1)

# Data class

In [None]:
class Data(Dataset):
    """A custom dataset class for generating data points with labels."""

    def __init__(self, train=True):
        """Initialize the dataset with data points and labels."""
        if train:
            self.x = torch.arange(-3, 3, 0.1).view(-1, 1)
            self.f = -3 * self.x + 1
            self.y = self.f + 0.1 * torch.randn(self.x.size())
            self.len = self.x.shape[0]
            if train:
                self.y[50:] = 20
        else:
            self.x = torch.arange(-3, 3, 0.1).view(-1, 1)
            self.y = -3 * self.x + 1
            self.len = self.x.shape[0]

    def __getitem__(self, index):
        """Get a single data point and its label."""
        return self.x[index], self.y[index]

    def __len__(self):
        """Get the length of the dataset."""
        return self.len

In [None]:
# Create train_data object and val_data object
train_data = Data()
val_data = Data(train=False)

In [None]:
plt.plot(train_data.x.numpy(), train_data.y.numpy(), 'xr', label='training data')
plt.plot(val_data.x.numpy(), val_data.y.numpy(), 'xy', label='validation data')
plt.plot(train_data.x.numpy(), train_data.f.numpy(), label='true function')
plt.xlabel('x')
plt.ylabel('y')
plt.legend(loc='upper right')
plt.show()

# Linear regression class

In [None]:
class linear_regression(nn.Module):
    """A custom linear regression model class."""

    def __init__(self, input_size, output_size):
        """Initialize the linear regression model."""
        super(linear_regression, self).__init__()
        self.linear = nn.Linear(input_size, output_size)

    def forward(self, x):
        """Forward pass through the model."""
        yhat = self.linear(x)
        return yhat

In [None]:
# Create the model object
model = linear_regression(1, 1)

In [None]:
# Create optimizer, cost function, and data loader object
optimizer = optim.SGD(model.parameters(), lr=0.1)
criterion = nn.MSELoss()
trainloader = DataLoader(dataset=train_data, batch_size=1)

# Early stopping

In [None]:
# Train the model with early stopping and save checkpoints
LOSS_TRAIN = []
LOSS_VAL = []
n = 1
min_loss = 1000

In [None]:
# Define the function for training the model with early stopping criterion
def train_model_early_stopping(epochs, min_loss):
    """
    Train the model using early stopping criterion.

    Args:
    epochs (int): Number of training epochs.
    min_loss (float): Initial minimum validation loss.
    """
    for epoch in range(epochs):
        for x, y in trainloader:
            yhat = model(x)
            loss = criterion(yhat, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Calculate training and validation losses
            loss_train = criterion(model(train_data.x), train_data.y).item()
            loss_val = criterion(model(val_data.x), val_data.y).item()

            # Store losses in respective lists
            LOSS_TRAIN.append(loss_train)
            LOSS_VAL.append(loss_val)

            # Update min_loss and save the model if validation loss improves
            if loss_val < min_loss:
                value = epoch
                min_loss = loss_val
                torch.save(model.state_dict(), 'best_model.pt')

In [None]:
# Train the model with early stopping criterion
train_model_early_stopping(20, min_loss)

In [None]:
# Plot the loss during training
plt.plot(LOSS_TRAIN, label='training cost')
plt.plot(LOSS_VAL, label='validation cost')
plt.xlabel("Iterations")
plt.ylabel("Cost")
plt.legend(loc='upper right')
plt.show()

# Model comparison

In [None]:
# Create a new linear regression model object
model_best = linear_regression(1,1)

# Assign the best model to model_best
model_best.load_state_dict(torch.load('best_model.pt'))

# Plot
plt.plot(model_best(val_data.x).data.numpy(), label = 'best model')
plt.plot(model(val_data.x).data.numpy(), label = 'maximum iterations')
plt.plot(val_data.y.numpy(), 'rx', label = 'true line')
plt.legend()
plt.show()