In [1]:
!pip install torchsummary

Collecting torchsummary
  Using cached torchsummary-1.5.1-py3-none-any.whl.metadata (296 bytes)
Using cached torchsummary-1.5.1-py3-none-any.whl (2.8 kB)
Installing collected packages: torchsummary
Successfully installed torchsummary-1.5.1


In [10]:
!pip install rf_calc



In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F  # Add this line

from tqdm import tqdm
from model import SimpleMNISTNet

# Global lists for tracking metrics
train_losses = []
test_losses = []
train_acc = []
test_acc = []

def train(model, device, train_loader, optimizer, epoch):
    model.train()
    pbar = tqdm(train_loader)
    correct = 0
    processed = 0
    running_loss = 0.0

    for batch_idx, (data, target) in enumerate(pbar):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        y_pred = model(data)

        loss = F.nll_loss(y_pred, target)
        loss.backward()
        optimizer.step()

        # Update metrics
        pred = y_pred.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
        processed += len(data)
        running_loss += loss.item()

        # Update progress bar
        pbar.set_description(
            desc=f'Loss={loss.item():.4f} Batch_id={batch_idx} Accuracy={100*correct/processed:.2f}'
        )

    # Calculate epoch metrics
    epoch_loss = running_loss / len(train_loader)
    epoch_accuracy = 100. * correct / processed

    train_losses.append(epoch_loss)
    train_acc.append(epoch_accuracy)

    return epoch_loss, epoch_accuracy

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    test_losses.append(test_loss)

    accuracy = 100. * correct / len(test_loader.dataset)  # Calculate accuracy
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        accuracy))  # Use the calculated accuracy

    test_acc.append(accuracy)  # Append accuracy to the list
    return accuracy  # Return the accuracy value

def validate_model(model, device, train_loader):
    # Check the number of parameters
    num_params = model.count_parameters()
    print(f"Number of parameters: {num_params}")

    # Validate only parameter count
    if num_params < 50000:
        print("Model validation successful: Parameters within limit")
        return True
    else:
        raise ValueError(
            f"Model validation failed: Parameters={num_params} (limit: 50000)"
        )

def train_and_test():
    # Device configuration
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Model initialization
    network = SimpleMNISTNet().to(device)
    network.print_model_summary()

    # Calculate dataset statistics
    initial_transform = transforms.Compose([transforms.ToTensor()])
    temp_dataset = torchvision.datasets.MNIST(
        root='./data', train=True, download=True, transform=initial_transform
    )
    mean, std = calculate_dataset_statistics(temp_dataset)
    print(f"Dataset statistics - Mean: {mean:.4f}, Std: {std:.4f}")

    # Transforms
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((mean,), (std,))
    ])

    transform_train = transforms.Compose([
        transforms.RandomRotation((-7, 7)),
        transforms.ToTensor(),
        transforms.Normalize((mean,), (std,)),
    ])

    # Datasets and Loaders
    train_dataset = torchvision.datasets.MNIST(
        root='./data', train=True, download=True, transform=transform_train
    )
    test_dataset = torchvision.datasets.MNIST(
        root='./data', train=False, download=True, transform=transform
    )

    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=128,
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )

    test_loader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=128,
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )

    # Validate the model
    validate_model(network, device, train_loader)

    # Optimizer and Scheduler
    optimizer = optim.SGD(network.parameters(), lr=0.01, momentum=0.9)
    scheduler = StepLR(optimizer, step_size=6, gamma=0.1)


    # Training Loop
    best_accuracy = 0

    for epoch in range(15):
        print(f"\nEpoch {epoch+1}/15")

        # Training phase
        train_loss, train_accuracy = train(network, device, train_loader, optimizer, epoch)

        # Testing phase
        test_accuracy = test(network, device, test_loader)

        print(f"\nEpoch {epoch+1}")
        print(f"Training - Loss: {train_loss:.4f}, Accuracy: {train_accuracy:.2f}%")
        print(f"Testing  - Accuracy: {test_accuracy:.2f}%")

        scheduler.step()

        # Save best model
        if test_accuracy > best_accuracy:
            best_accuracy = test_accuracy
            print(f"Best Test Accuracy: {best_accuracy:.2f}%")
            torch.save(network.state_dict(), 'best_model.pth')

    print("\nTraining completed!")
    print(f"Best Test Accuracy: {best_accuracy:.2f}%")

def calculate_dataset_statistics(dataset):
    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=1000,
        shuffle=False,
        num_workers=4
    )

    mean = 0.
    std = 0.
    total_images = 0

    for images, _ in tqdm(loader, desc="Calculating dataset statistics"):
        batch_samples = images.size(0)
        images = images.view(batch_samples, images.size(1), -1)
        mean += images.mean(2).sum(0)
        std += images.std(2).sum(0)
        total_images += batch_samples

    mean /= total_images
    std /= total_images

    return mean.item(), std.item()

if __name__ == "__main__":
    train_and_test()


Model Summary:
Layer (type)               Output Shape         Param #
layer1                    torch.Size([1, 8, 26, 26]) 88
layer2                    torch.Size([1, 8, 24, 24]) 592
layer3                    torch.Size([1, 12, 22, 22]) 888
layer4                    torch.Size([1, 12, 20, 20]) 1320
pool                      torch.Size([1, 12, 10, 10]) 0
layer5                    torch.Size([1, 8, 8, 8]) 880
layer6                    torch.Size([1, 16, 6, 6]) 1184
layer7                    torch.Size([1, 12, 4, 4]) 1752
layer8                    torch.Size([1, 10, 2, 2]) 1100
avgpool                   torch.Size([1, 10, 1, 1]) 0
layer9                    torch.Size([1, 10, 1, 1]) 100
Total params: 7904


Calculating dataset statistics: 100%|██████████████████████████████████████████████████████████| 60/60 [00:11<00:00,  5.23it/s]


Dataset statistics - Mean: 0.1307, Std: 0.3015
Number of parameters: 7904
Model validation successful: Parameters within limit

Epoch 1/15


Loss=0.1463 Batch_id=468 Accuracy=89.43: 100%|███████████████████████████████████████████████| 469/469 [01:52<00:00,  4.17it/s]



Test set: Average loss: 0.0677, Accuracy: 9823/10000 (98.23%)


Epoch 1
Training - Loss: 0.4335, Accuracy: 89.43%
Testing  - Accuracy: 98.23%
Best Test Accuracy: 98.23%

Epoch 2/15


Loss=0.0774 Batch_id=468 Accuracy=98.01: 100%|███████████████████████████████████████████████| 469/469 [01:29<00:00,  5.27it/s]



Test set: Average loss: 0.0423, Accuracy: 9876/10000 (98.76%)


Epoch 2
Training - Loss: 0.0730, Accuracy: 98.01%
Testing  - Accuracy: 98.76%
Best Test Accuracy: 98.76%

Epoch 3/15


Loss=0.0707 Batch_id=468 Accuracy=98.34: 100%|███████████████████████████████████████████████| 469/469 [01:29<00:00,  5.23it/s]



Test set: Average loss: 0.0380, Accuracy: 9886/10000 (98.86%)


Epoch 3
Training - Loss: 0.0556, Accuracy: 98.34%
Testing  - Accuracy: 98.86%
Best Test Accuracy: 98.86%

Epoch 4/15


Loss=0.0513 Batch_id=468 Accuracy=98.60: 100%|███████████████████████████████████████████████| 469/469 [02:24<00:00,  3.24it/s]



Test set: Average loss: 0.0351, Accuracy: 9888/10000 (98.88%)


Epoch 4
Training - Loss: 0.0470, Accuracy: 98.60%
Testing  - Accuracy: 98.88%
Best Test Accuracy: 98.88%

Epoch 5/15


Loss=0.0785 Batch_id=468 Accuracy=98.73: 100%|███████████████████████████████████████████████| 469/469 [01:43<00:00,  4.52it/s]



Test set: Average loss: 0.0286, Accuracy: 9911/10000 (99.11%)


Epoch 5
Training - Loss: 0.0428, Accuracy: 98.73%
Testing  - Accuracy: 99.11%
Best Test Accuracy: 99.11%

Epoch 6/15


Loss=0.0158 Batch_id=468 Accuracy=98.84: 100%|███████████████████████████████████████████████| 469/469 [01:24<00:00,  5.56it/s]



Test set: Average loss: 0.0340, Accuracy: 9899/10000 (98.99%)


Epoch 6
Training - Loss: 0.0389, Accuracy: 98.84%
Testing  - Accuracy: 98.99%

Epoch 7/15


Loss=0.0236 Batch_id=468 Accuracy=99.21: 100%|███████████████████████████████████████████████| 469/469 [01:02<00:00,  7.55it/s]



Test set: Average loss: 0.0204, Accuracy: 9941/10000 (99.41%)


Epoch 7
Training - Loss: 0.0275, Accuracy: 99.21%
Testing  - Accuracy: 99.41%
Best Test Accuracy: 99.41%

Epoch 8/15


Loss=0.0195 Batch_id=468 Accuracy=99.28: 100%|███████████████████████████████████████████████| 469/469 [00:57<00:00,  8.16it/s]



Test set: Average loss: 0.0199, Accuracy: 9936/10000 (99.36%)


Epoch 8
Training - Loss: 0.0255, Accuracy: 99.28%
Testing  - Accuracy: 99.36%

Epoch 9/15


Loss=0.0487 Batch_id=468 Accuracy=99.29: 100%|███████████████████████████████████████████████| 469/469 [00:48<00:00,  9.69it/s]



Test set: Average loss: 0.0192, Accuracy: 9939/10000 (99.39%)


Epoch 9
Training - Loss: 0.0248, Accuracy: 99.29%
Testing  - Accuracy: 99.39%

Epoch 10/15


Loss=0.0233 Batch_id=468 Accuracy=99.32: 100%|███████████████████████████████████████████████| 469/469 [00:42<00:00, 10.99it/s]



Test set: Average loss: 0.0189, Accuracy: 9940/10000 (99.40%)


Epoch 10
Training - Loss: 0.0239, Accuracy: 99.32%
Testing  - Accuracy: 99.40%

Epoch 11/15


Loss=0.0275 Batch_id=468 Accuracy=99.33: 100%|███████████████████████████████████████████████| 469/469 [00:41<00:00, 11.22it/s]



Test set: Average loss: 0.0192, Accuracy: 9942/10000 (99.42%)


Epoch 11
Training - Loss: 0.0231, Accuracy: 99.33%
Testing  - Accuracy: 99.42%
Best Test Accuracy: 99.42%

Epoch 12/15


Loss=0.0079 Batch_id=468 Accuracy=99.33: 100%|███████████████████████████████████████████████| 469/469 [00:49<00:00,  9.49it/s]



Test set: Average loss: 0.0195, Accuracy: 9936/10000 (99.36%)


Epoch 12
Training - Loss: 0.0230, Accuracy: 99.33%
Testing  - Accuracy: 99.36%

Epoch 13/15


Loss=0.0096 Batch_id=468 Accuracy=99.37: 100%|███████████████████████████████████████████████| 469/469 [00:43<00:00, 10.72it/s]



Test set: Average loss: 0.0192, Accuracy: 9938/10000 (99.38%)


Epoch 13
Training - Loss: 0.0225, Accuracy: 99.37%
Testing  - Accuracy: 99.38%

Epoch 14/15


Loss=0.0324 Batch_id=468 Accuracy=99.39: 100%|███████████████████████████████████████████████| 469/469 [01:06<00:00,  7.05it/s]



Test set: Average loss: 0.0188, Accuracy: 9939/10000 (99.39%)


Epoch 14
Training - Loss: 0.0220, Accuracy: 99.39%
Testing  - Accuracy: 99.39%

Epoch 15/15


Loss=0.0353 Batch_id=468 Accuracy=99.37: 100%|███████████████████████████████████████████████| 469/469 [00:57<00:00,  8.14it/s]



Test set: Average loss: 0.0191, Accuracy: 9937/10000 (99.37%)


Epoch 15
Training - Loss: 0.0219, Accuracy: 99.37%
Testing  - Accuracy: 99.37%

Training completed!
Best Test Accuracy: 99.42%
