Skip to content

An image classification app using a convolutional neural network to classify handwritten digits from the MNIST dataset. Trains, tests and deploys the model onto a SageMaker endpoint with a custom inference script.

Notifications You must be signed in to change notification settings

Sid1279/MNIST-CNN

Repository files navigation

MNIST Image Classification and DCGAN Image Generation 🖼️

This repository showcases two major components:

  1. Image classification using a Convolutional Neural Network (CNN) on the MNIST dataset.
  2. Image generation using a Deep Convolutional Generative Adversarial Network (DCGAN).

Prerequisites 🌎

  • Python 3.6 or above
  • PyTorch
  • torchvision
  • scikit-learn
  • seaborn
  • matplotlib
  • boto3
  • sagemaker

Installation 🎩

  1. Clone the repository:
git clone https://github.com/Sid1279/MNIST-CNN.git
cd mnist-cnn
  1. Install the required Python packages:
pip install -r requirements.txt
  1. Rename your endpoint and bucket to your desired values in the deployment notebook (deploy.ipynb).

  2. Create a notebook instance in AWS SageMaker for deployment.

  3. Upload the MNIST PyTorch CNN.ipynb notebook to SageMaker.

  4. Run all cells in the notebook to deploy the trained model to Amazon SageMaker.

Image Classification: MNIST PyTorch CNN.ipynb

How does it work 🤳🏽

This notebook demonstrates the process of training a Convolutional Neural Network (CNN) on the MNIST dataset and deploying the trained model using Amazon SageMaker.

  1. Data Preparation: The script downloads the MNIST dataset and applies data transformations such as normalization and tensor conversion using torchvision.transforms. It creates custom Dataset objects and uses DataLoader to efficiently load and iterate over the training and test data.

  2. Model Architecture: The script defines a CNN model using the NeuralNetwork class, which inherits from torch.nn.Module. The model consists of convolutional layers, activation functions, pooling layers, and fully connected layers. The forward pass method implements the model's computation flow.

  3. Model Training: The script initializes the model, defines the loss function (cross-entropy loss), and sets up the optimizer (Adam optimizer). It then enters a loop over the specified number of epochs and performs the following steps for each epoch:

    • Iterates over the training data, computes the forward pass, calculates the loss, and performs backpropagation to update the model's parameters.
    • Computes training accuracy and loss and logs them to Tensorboard using SummaryWriter.
    • Evaluates the model on the test data, computes testing accuracy and loss, and logs them to Tensorboard.
  4. Evaluation: After training, the script generates a confusion matrix to visualize the model's performance on the training set using sklearn.metrics.confusion_matrix and seaborn.heatmap. The confusion matrix provides insights into the model's ability to correctly classify different digits.

  5. Model Deployment: The script saves the trained model's state dictionary using torch.save and creates a tar.gz archive containing the model file. It uploads the archive to an S3 bucket using the AWS SDK (boto3). It then uses the SageMaker Python SDK to create a PyTorchModel object, specifying the model data location, IAM role, and other necessary information. The model is deployed to an endpoint using deploy method, specifying the endpoint name and instance type.

  6. Interactive Drawing App: Created an interactive drawing application using PyQt5 that allows users to draw freehand on the canvas using their mouse. The drawn images are preprocessed and classified using a pre-trained CNN model, providing real-time digit classification and visualization of the model's predictions.

  7. Tensorflow Conversion: Check out the very similar model made using Tensorflow in MNIST Tensorflow CNN.ipynb! It has all aspects/key features outlined above, except converted from PyTorch into Tensorflow (with the exception of model deployment).

DCGAN Image Generation: MNIST DCGAN PyTorch.ipynb

Model Architecture

The DCGAN model consists of a generator and a discriminator, implemented using PyTorch. The generator creates high-quality images by learning from the training data distribution, while the discriminator tries to distinguish between real and fake images.

Training

The DCGAN model is trained using the MNIST dataset. The generator creates images from random noise, and the discriminator is trained to distinguish between real MNIST images and the images generated by the generator.

During training, generated images are saved in the images/ folder at various intervals, allowing you to visualize the progression of image quality. You can also generate several more images in generated_images to visualize how accurately the images resemble actual handwritten digits (0-9).

About

An image classification app using a convolutional neural network to classify handwritten digits from the MNIST dataset. Trains, tests and deploys the model onto a SageMaker endpoint with a custom inference script.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published