In [7]:
import torch
from torch import nn
from torchvision import datasets, transforms
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"

In [9]:
train_data=datasets.MNIST(root='data',train=True,download=True,
                     transform=transforms.ToTensor())

test_data=datasets.MNIST(root='data',train=False,download=True,transform=transforms.ToTensor())

In [11]:
train_loader=torch.utils.data.DataLoader(train_data,batch_size=32,shuffle=True)
test_loader=torch.utils.data.DataLoader(test_data,batch_size=32)

In [12]:
model=nn.Sequential(
    nn.Flatten(),
    nn.Linear(784,128),
    nn.ReLU(),
    nn.Linear(128,10)
)
device = torch.device("cpu")
model.to(device)

Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=784, out_features=128, bias=True)
  (2): ReLU()
  (3): Linear(in_features=128, out_features=10, bias=True)
)

In [15]:
loss_fn=nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model.parameters(),lr=0.001)

In [17]:
for epoch in range(3):
    for X,y in train_loader:
        X, y = X.to(device), y.to(device)
        pred=model(X)
        loss=loss_fn(pred,y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f'Epoch {epoch+1}, Loss: {loss.item():.4f}')

Epoch 1, Loss: 0.0899
Epoch 2, Loss: 0.0265
Epoch 3, Loss: 0.2903


In [18]:
correct=0
total=0

with torch.no_grad():
    for X,y in test_loader:
        X, y = X.to(device), y.to(device)
        pred=model(X)
        correct+=(pred.argmax(1)==y).sum().item()
        total+=y.size(0)
    print(f"Accuracy {100* correct/total:.2f}%")

Accuracy 97.14%


In [19]:
torch.cuda.empty_cache()

In [None]:
import matplotlib.pyplot as plt
import torch

model.cpu()  # make 100% sure model is on CPU
images, labels = next(iter(test_loader))
images, labels = images.cpu(), labels.cpu()

with torch.no_grad():
    preds = model(images).cpu().argmax(1)

plt.figure(figsize=(8, 4))
for i in range(6):
    plt.subplot(2, 3, i+1)
    plt.imshow(images[i].squeeze(), cmap='gray')
    plt.title(f"P:{preds[i].item()} | T:{labels[i].item()}")
    plt.axis('off')
plt.tight_layout()
plt.savefig("mnist_predictions.png")
print("✅ Saved predictions as mnist_predictions.png")
