In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from tqdm import trange

# Define the model
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(28*28, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        return self.fc(x)

def main():
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Load the MNIST dataset
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    train_dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

    # Instantiate the model
    model = SimpleModel().to(device)

    # Use the standard PyTorch optimizer
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

    # Loss function
    criterion = nn.CrossEntropyLoss()

    # Training loop
    for epoch in trange(50):
        running_loss = 0.0
        for i, (inputs, labels) in enumerate(train_loader, 0):
            inputs, labels = inputs.to(device), labels.to(device)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}, Loss: {running_loss/(i+1)}")

    print("Finished Training!")
    
main()

  2%|▏         | 1/50 [00:09<08:05,  9.92s/it]

Epoch 1, Loss: 0.4268134767025709


  4%|▍         | 2/50 [00:19<07:50,  9.80s/it]

Epoch 2, Loss: 0.3632228336006403


  6%|▌         | 3/50 [00:28<07:36,  9.72s/it]

Epoch 3, Loss: 0.3490768820807338


  8%|▊         | 4/50 [00:38<07:26,  9.72s/it]

Epoch 4, Loss: 0.34025460026760895


 10%|█         | 5/50 [00:48<07:15,  9.68s/it]

Epoch 5, Loss: 0.339296058489879


 12%|█▏        | 6/50 [00:58<07:08,  9.73s/it]

Epoch 6, Loss: 0.3327415968090296


 14%|█▍        | 7/50 [01:07<06:58,  9.74s/it]

Epoch 7, Loss: 0.3299983456969261


 16%|█▌        | 8/50 [01:17<06:50,  9.78s/it]

Epoch 8, Loss: 0.32903885013659795


 18%|█▊        | 9/50 [01:27<06:40,  9.77s/it]

Epoch 9, Loss: 0.3305103853404522


 20%|██        | 10/50 [01:37<06:34,  9.86s/it]

Epoch 10, Loss: 0.3200183881888787


 22%|██▏       | 11/50 [01:47<06:25,  9.88s/it]

Epoch 11, Loss: 0.32868734936118127


 24%|██▍       | 12/50 [01:57<06:15,  9.87s/it]

Epoch 12, Loss: 0.32189359853466354


 26%|██▌       | 13/50 [02:07<06:03,  9.84s/it]

Epoch 13, Loss: 0.3220683484971523


 28%|██▊       | 14/50 [02:17<05:54,  9.85s/it]

Epoch 14, Loss: 0.3203840114702781


 30%|███       | 15/50 [02:26<05:41,  9.76s/it]

Epoch 15, Loss: 0.31750070289919774


 32%|███▏      | 16/50 [02:35<05:28,  9.65s/it]

Epoch 16, Loss: 0.31478713118682305


 34%|███▍      | 17/50 [02:45<05:18,  9.64s/it]

Epoch 17, Loss: 0.32167294321457546


 36%|███▌      | 18/50 [02:55<05:08,  9.64s/it]

Epoch 18, Loss: 0.31165535723070303


 38%|███▊      | 19/50 [03:04<04:58,  9.62s/it]

Epoch 19, Loss: 0.3196751326893767


 40%|████      | 20/50 [03:14<04:50,  9.69s/it]

Epoch 20, Loss: 0.31724788200954596


KeyboardInterrupt: 