In [1]:
import torch
from torchvision import datasets, transforms
import torch.nn as nn
import torch.optim as optim

# Define the linear model
class LinearModel(nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        self.fc = nn.Linear(784, 10)  # 784 input features, 10 output classes

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten the input tensor
        x = self.fc(x)
        return x

# Load the MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

# Create the linear model
model = LinearModel()

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Train the model
num_epochs = 10
for epoch in range(num_epochs):
    for batch_idx, (data, targets) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = model(data)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        if (batch_idx+1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx+1}/{len(train_loader)}], Loss: {loss.item():.4f}')

# Test the model
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for data, targets in test_loader:
        outputs = model(data)
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += (predicted == targets).sum().item()

    accuracy = 100 * correct / total
    print(f'Test Accuracy: {accuracy:.2f}%')


  warn(f"Failed to load image Python extension: {e}")


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


9913344it [00:01, 6754200.08it/s]                              


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


29696it [00:00, 894084.74it/s]           


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


1649664it [00:00, 1922968.96it/s]                             


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


5120it [00:00, 9510556.46it/s]          


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Epoch [1/10], Step [100/938], Loss: 0.6864
Epoch [1/10], Step [200/938], Loss: 0.4771
Epoch [1/10], Step [300/938], Loss: 0.4810
Epoch [1/10], Step [400/938], Loss: 0.7025
Epoch [1/10], Step [500/938], Loss: 0.2717
Epoch [1/10], Step [600/938], Loss: 0.4637
Epoch [1/10], Step [700/938], Loss: 0.5501
Epoch [1/10], Step [800/938], Loss: 0.4003
Epoch [1/10], Step [900/938], Loss: 0.4138
Epoch [2/10], Step [100/938], Loss: 0.3411
Epoch [2/10], Step [200/938], Loss: 0.3436
Epoch [2/10], Step [300/938], Loss: 0.5511
Epoch [2/10], Step [400/938], Loss: 0.4869
Epoch [2/10], Step [500/938], Loss: 0.4117
Epoch [2/10], Step [600/938], Loss: 0.2690
Epoch [2/10], Step [700/938], Loss: 0.2928
Epoch [2/10], Step [800/938], Loss: 0.3721
Epoch [2/10], Step [900/938], Loss: 0.3113
Epoch [3/10], Step [100/938], Loss: 0.3754
Epoch [3/10], Step [200/938], Loss: 0.2287
Epoch [3/10], Step [300/938], Loss: 0.3345
Epoch [3/10], Step [40