In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device:{device}")

Using device:cuda


In [5]:
batch_size = 128
epochs = 50
learning_rate = 0.001
dropout_prob = 0.5
patience = 5

In [6]:
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,),(0.5,))])

In [7]:
train_dataset = datasets.MNIST(root='./data',train=True,download=True,transform=transform)
test_dataset = datasets.MNIST(root='./data',train=False,download=True,transform=transform)

100%|██████████| 9.91M/9.91M [00:00<00:00, 17.9MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 494kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.50MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 7.25MB/s]


In [8]:
train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True)
test_loader = DataLoader(test_dataset,batch_size=batch_size)

In [10]:
class MLP_BN(nn.Module):
  def __init__(self):
    super(MLP_BN,self).__init__()
    self.net = nn.Sequential(
        nn.Flatten(),

        nn.Linear(28*28,512),
        nn.BatchNorm1d(512),
        nn.ReLU(),
        nn.Dropout(dropout_prob),

        nn.Linear(512,256),
        nn.BatchNorm1d(256),
        nn.ReLU(),
        nn.Dropout(dropout_prob),

        nn.Linear(256,10)
    )

  def forward(self,x):
    return self.net(x)

model = MLP_BN().to(device)

In [11]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr = learning_rate)


In [15]:
train_losses =[]
val_losses = []

best_val_loss = float('inf')
epochs_on_imporve = 0

for epoch in range(epochs):
  model.train()
  running_loss = 0.0

  for inputs,targets in train_loader:
    inputs,targets = inputs.to(device), targets.to(device)

    outputs = model(inputs)
    loss = criterion(outputs,targets)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    running_loss += loss.item()

  model.eval()
  val_loss = 0.0
  with torch.no_grad():
    for inputs,targets in test_loader:
      inputs,targets = inputs.to(device),targets.to(device)
      outputs = model(inputs)
      loss = criterion(outputs,targets)
      val_loss += loss.item()


  train_losses.append(running_loss)
  val_losses.append(val_loss)

  print(f"Epoch {epoch+1}/{epochs}, Train Loss: {running_loss:.4f}, Val Loss: {val_loss:.4f}")

  if val_loss < best_val_loss:
    best_val_loss = val_loss
    epochs_on_imporve = 0
  else:
    epochs_on_imporve += 1

  if epochs_on_imporve >= patience:
    print(f"Early stopping at epoch {epoch+1}")
    break

Epoch 1/50, Train Loss: 79.4745, Val Loss: 7.7082
Epoch 2/50, Train Loss: 63.5299, Val Loss: 6.4178
Epoch 3/50, Train Loss: 54.9485, Val Loss: 5.8162
Epoch 4/50, Train Loss: 48.3393, Val Loss: 5.3017
Epoch 5/50, Train Loss: 44.9772, Val Loss: 4.9814
Epoch 6/50, Train Loss: 40.2129, Val Loss: 4.7403
Epoch 7/50, Train Loss: 38.5120, Val Loss: 4.7965
Epoch 8/50, Train Loss: 34.5021, Val Loss: 4.8539
Epoch 9/50, Train Loss: 33.6643, Val Loss: 4.2442
Epoch 10/50, Train Loss: 31.6829, Val Loss: 4.4216
Epoch 11/50, Train Loss: 29.1425, Val Loss: 4.4658
Epoch 12/50, Train Loss: 29.6268, Val Loss: 4.4637
Epoch 13/50, Train Loss: 27.0946, Val Loss: 4.4087
Epoch 14/50, Train Loss: 25.6009, Val Loss: 4.2008
Epoch 15/50, Train Loss: 24.6212, Val Loss: 4.1589
Epoch 16/50, Train Loss: 23.3636, Val Loss: 4.1793
Epoch 17/50, Train Loss: 22.8568, Val Loss: 4.0488
Epoch 18/50, Train Loss: 22.2477, Val Loss: 4.2166
Epoch 19/50, Train Loss: 22.3400, Val Loss: 4.4334
Epoch 20/50, Train Loss: 20.2726, Val Lo

In [16]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
  for inputs ,targets in test_loader:
    inputs, targets = inputs.to(device),targets.to(device)

    outputs = model(inputs)
    _, predicted = torch.max(outputs,1)
    total += targets.size(0)
    correct += (predicted == targets).sum().item()

accuracy = correct / total * 100
print(f"Final MNIST Test accuracy {accuracy}%")

Final MNIST Test accuracy 98.41%
