In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cpu


In [13]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=1000, shuffle=False)


In [15]:
class MNISTNet(nn.Module):
    def __init__(self):
        super(MNISTNet, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = MNISTNet().to(device)

In [16]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [17]:
epochs = 20
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        output = model(images)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss:.4f}")

Epoch [1/20], Loss: 254.9725
Epoch [2/20], Loss: 105.4269
Epoch [3/20], Loss: 75.4619
Epoch [4/20], Loss: 57.3910
Epoch [5/20], Loss: 45.7533
Epoch [6/20], Loss: 37.5149
Epoch [7/20], Loss: 31.5153
Epoch [8/20], Loss: 26.8191
Epoch [9/20], Loss: 26.6334
Epoch [10/20], Loss: 21.1173
Epoch [11/20], Loss: 19.2190
Epoch [12/20], Loss: 19.5657
Epoch [13/20], Loss: 16.1058
Epoch [14/20], Loss: 16.0196
Epoch [15/20], Loss: 12.4565
Epoch [16/20], Loss: 16.1686
Epoch [17/20], Loss: 14.7064
Epoch [18/20], Loss: 12.4440
Epoch [19/20], Loss: 10.8782
Epoch [20/20], Loss: 8.0326


In [18]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        output = model(images)
        _, predicted = torch.max(output.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"\nTest Accuracy: {100 * correct / total:.2f}%")


Test Accuracy: 97.93%


In [19]:
torch.save(model.state_dict(), "mnist_model.pth")
print("Model saved as mnist_model.pth")

Model saved as mnist_model.pth


In [25]:
# Test the updated model loading
import numpy as np
from PIL import Image, ImageOps
import importlib
import sys

# Import and reload Model_load to get latest changes
try:
    import Model_load
    importlib.reload(Model_load)
except:
    import Model_load

def test_model_with_sample():
    # Load a test image from the test dataset
    test_image, test_label = test_dataset[0]
    
    # Convert tensor back to PIL Image for testing
    test_img_array = test_image.squeeze().numpy()
    
    # Properly denormalize
    test_img_array = test_img_array * 0.3081 + 0.1307
    
    # Convert to 0-255 range and ensure proper data type
    test_img_array = np.clip(test_img_array * 255, 0, 255).astype(np.uint8)
    test_pil = Image.fromarray(test_img_array)
    
    # Save it temporarily
    test_pil.save("test_sample.png")
    
    print("=== Image Analysis ===")
    print(f"Image mean: {np.mean(test_img_array):.2f}")
    print(f"Mean > 127? {np.mean(test_img_array) > 127} (NEW logic: would trigger inversion)")
    print(f"Mean < 127? {np.mean(test_img_array) < 127} (OLD logic: would trigger inversion)")
    
    # Test direct model prediction
    print("\n=== Direct Model Test ===")
    model_cpu = model.cpu()
    with torch.no_grad():
        output = model_cpu(test_image.unsqueeze(0))
        predicted_direct = torch.argmax(output, dim=1).item()
    
    print(f"Actual label: {test_label}")
    print(f"Direct prediction: {predicted_direct}")
    
    # Test file-based prediction
    print("\n=== File-based Test (Updated) ===")
    result = Model_load.predict_mnist_probabilities("test_sample.png")
    
    print("File-based predictions:")
    print(result)
    
    # Clean up
    import os
    if os.path.exists("test_sample.png"):
        os.remove("test_sample.png")

# Run the test
test_model_with_sample()

Model loaded and ready for inference!
=== Image Analysis ===
Image mean: 23.51
Mean > 127? False (NEW logic: would trigger inversion)
Mean < 127? True (OLD logic: would trigger inversion)

=== Direct Model Test ===
Actual label: 7
Direct prediction: 7

=== File-based Test (Updated) ===
File-based predictions:
0: 0.0000
1: 0.0000
2: 0.0000
3: 0.0000
4: 0.0000
5: 0.0000
6: 0.0000
7: 1.0000
8: 0.0000
9: 0.0000
