In [7]:
import time 
import numpy as np 
import torch 
import torch.nn as nn 
import torch.nn.functional as F 
import torchvision 
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


In [45]:
batch_size = 32 
TRAIN_SIZE = 10000
epochs = 3 


In [2]:
transform = transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize((0.1307,), (0.3081,)),])

In [5]:
train_dataset = datasets.MNIST(root="./data", 
                                train = True, 
                                transform = transform, 
                                download = True)

test_dataset = datasets.MNIST(root="./data", 
                                train = False, 
                                transform = transform, 
                                download = True)

type(train_dataset),type(test_dataset)

(torchvision.datasets.mnist.MNIST, torchvision.datasets.mnist.MNIST)

In [50]:
train_dataset.data.shape

torch.Size([60000, 28, 28])

In [10]:
train_loader = DataLoader(dataset=train_dataset, 
                          batch_size=batch_size, 
                          shuffle=True)

test_loader = DataLoader(dataset=test_dataset, 
                          batch_size=batch_size, 
                          shuffle=False)

type(train_loader), type(test_loader)

(torch.utils.data.dataloader.DataLoader,
 torch.utils.data.dataloader.DataLoader)

In [12]:
# len(train_dataset), len(test_dataset) # 60000, 10000

# Pre-allocate tensors of the appropriate size 
train_data = torch.zeros(60000, 1, 28, 28) 
train_labels = torch.zeros(60000, dtype=torch.long) 
test_data = torch.zeros(10000, 1, 28, 28)
test_labels = torch.zeros(10000, dtype=torch.long)

In [24]:
# Load all training data into RAM 

for idx, (data, label) in enumerate(train_loader): 
    start_idx = idx * batch_size 
    end_idx = start_idx + data.size(0) 
    train_data[start_idx:end_idx] = data
    train_labels[start_idx:end_idx] = label

print(train_data.shape, train_data.dtype)

# Load all test data into RAM 

for idx, (data, label) in enumerate(test_loader): 
    start_idx = idx * batch_size 
    end_idx = start_idx + data.size(0) 
    test_data[start_idx:end_idx] = data
    test_labels[start_idx:end_idx] = label      

print(test_data.shape, test_data.dtype)

torch.Size([60000, 1, 28, 28]) torch.float32
torch.Size([10000, 1, 28, 28]) torch.float32


In [47]:
iters_per_epoch = TRAIN_SIZE // batch_size
print("Iters per epoch:", iters_per_epoch)

Iters per epoch: 312


In [39]:
class MLP(nn.Module):
    def __init__(self, in_features, hidden_features, num_classes): 
        super(MLP, self).__init__() 
        self.fc1 = nn.Linear(in_features, hidden_features) 
        self.relu = nn.ReLU() 
        self.fc2 = nn.Linear(hidden_features, num_classes) 

    def forward(self, x): 
        x = x.reshape(batch_size, 28*28) 
        x = self.fc1(x) 
        x = self.relu(x) 
        x = self.fc2(x) 
        return x

In [40]:
model = MLP(in_features=28*28, hidden_features=256, num_classes=10)

criterion = nn.CrossEntropyLoss() 
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [41]:
# Training the model 

def train(model, criterion, optimizer, epoch):
    model.train() 
    running_loss = 0.0 

    for i in range(iters_per_epoch):

        optimizer.zero_grad() 
        data = train_data[i * batch_size : (i + 1) * batch_size] 
        targets = train_labels[i * batch_size : (i + 1) * batch_size] 

        start = time.time() 
        outputs = model(data)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad() 
        end = time.time()
        running_loss += loss.item() 

        if i % 100 == 99 or i == 0: 
            print(f"Epoch: {epoch+1}, Iter: {i+1}, Loss: {loss}")
            print(f"Time per batch: {(end-start) * 1e3:.4f} sec")
            running_loss = 0.0 



In [42]:
def evaluate(model, test_data, test_labels): 
    device = torch.device("cpu") 
    model.to(device) 
    model.eval() 

    total_batch_accuracy = torch.tensor(0.0, device=device) 
    num_batches = 0 
    with torch.no_grad(): 
        for i in range(len(test_data) // batch_size): 
            data = test_data[i * batch_size : (i + 1) * batch_size] 
            target = test_labels[i * batch_size : (i + 1) * batch_size] 
            outputs = model(data) 

            _, predicted = torch.max(outputs, 1) 
            correct_batch = (predicted == target).sum().item() 
            total_batch = target.size(0) 
            if total_batch != 0: 
                batch_accuracy = correct_batch / total_batch 
                total_batch_accuracy += batch_accuracy 
                num_batches += 1 

    avg_batch_accuracy = total_batch_accuracy / num_batches 
    print(f"Average Batch Accuracy: {avg_batch_accuracy * 100: .2f}%") 

In [46]:
if __name__ == "__main__": 
    for epoch in range(epochs): 
        train(model, criterion, optimizer, epoch) 
        evaluate(model, test_data, test_labels) 

    print("Finished Training")  

Epoch: 1, Iter: 1, Loss: 2.3500967025756836
Time per batch: 143.3296 sec
Epoch: 1, Iter: 100, Loss: 0.32422706484794617
Time per batch: 3.2609 sec
Epoch: 1, Iter: 200, Loss: 0.3650629222393036
Time per batch: 7.0124 sec
Epoch: 1, Iter: 300, Loss: 0.19618771970272064
Time per batch: 0.0000 sec
Average Batch Accuracy:  92.39%
Epoch: 2, Iter: 1, Loss: 0.42112624645233154
Time per batch: 8.1015 sec
Epoch: 2, Iter: 100, Loss: 0.2059713751077652
Time per batch: 0.0000 sec
Epoch: 2, Iter: 200, Loss: 0.1865643709897995
Time per batch: 5.5268 sec
Epoch: 2, Iter: 300, Loss: 0.09503141045570374
Time per batch: 2.0816 sec
Average Batch Accuracy:  94.14%
Epoch: 3, Iter: 1, Loss: 0.19763953983783722
Time per batch: 0.0000 sec
Epoch: 3, Iter: 100, Loss: 0.1086488589644432
Time per batch: 4.5102 sec
Epoch: 3, Iter: 200, Loss: 0.10490388423204422
Time per batch: 1.0335 sec
Epoch: 3, Iter: 300, Loss: 0.05188275873661041
Time per batch: 0.0000 sec
Average Batch Accuracy:  94.49%
Finished Training
