# Imports


In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt

# Load Data


In [None]:
# Definition of a transformation to normalize the data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

# Downloading MNIST data
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)

# Linear Model


In [None]:
# Model definition
class LinearClassifier(nn.Module):
    def __init__(self):
        super(LinearClassifier, self).__init__()
        self.fc1 = nn.Linear(28*28, 1)  # Linear layer 1

        # Initialization of weights to zero
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.zeros_(self.fc1.weight)
        if self.fc1.bias is not None:
            nn.init.zeros_(self.fc1.bias)

    def forward(self, x):
        x = x.view(-1, 28*28)   # Flatten the image into a vector
        x = self.fc1(x)
        return x

# Train & Test Loop


## Parameters


In [None]:
# Definition of DataLoader for training and test data
batch_size = 1
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

# Definition of the loss function
criterion = nn.BCEWithLogitsLoss()  # Binary Cross Entropy Loss

# Training the model
num_epochs = 15

# List of training set sizes
train_sizes = range(4000, len(train_dataset)+1, 14000)

## Loop


In [None]:
# Initialization of lists to store the results
train_losses = []
test_errors = []
for train_size in train_sizes:
    subset_train_dataset = torch.utils.data.Subset(train_dataset, indices=list(range(train_size)))
    subset_train_loader = DataLoader(dataset=subset_train_dataset, batch_size=batch_size, shuffle=True)

    # Model instantiation
    model = LinearClassifier()
    # Optimizer instantiation
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

    # Initialization of the loss for this training size
    train_loss_list = []

    for epoch in range(num_epochs):
        for i, (images, labels) in enumerate(subset_train_loader):
            # Forward pass
            outputs = model(images)
            labels = labels % 2  # Convert labels to 0 or 1 (even or odd)
            loss = criterion(outputs, labels.float().view(-1, 1))

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if (i+1) % 100 == 0:
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                      .format(epoch+1, num_epochs, i+1, len(subset_train_loader), loss.item()))
            if epoch+1 == num_epochs:
                train_loss_list.append(loss.item())

    # Model evaluation
    model.eval()
    with torch.no_grad():
        uncorrect = 0
        total = 0
        for images, labels in test_loader:
            outputs = model(images)
            predicted = torch.round(torch.sigmoid(outputs))  # Round predictions to 0 or 1
            total += labels.size(0)
            uncorrect += (predicted != (labels % 2).float().view(-1, 1)).sum().item()

    # Calculation of the error on the test set
    error = uncorrect / total

    # Save the average loss on the training set and the error on the test set
    train_losses.append(sum(train_loss_list) / len(train_loss_list))
    test_errors.append(error)

    print('Train Size: {}, Train Loss: {:.4f}, Test Error: {:.2f}%'.format(train_size, train_losses[-1], error*100))

# Plot


In [None]:
# Plotting the results on the same axis
plt.figure(figsize=(10, 5))

# Training loss curve
plt.plot(train_sizes, train_losses, marker='o', label='Train Loss')

# Test error curve
plt.plot(train_sizes, test_errors, marker='o', label='Test Error')

# Graph configuration
plt.title('Train Loss and Test Error vs Training Size')
plt.xlabel('Training Size')
plt.ylabel('Error')
plt.legend()
plt.grid(True)

# Displaying the graph
plt.show()