Skip to content

F-Yousefi/Getting-Started-with-Pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

16 Commits
 
 
 
 
 
 
 
 

Repository files navigation

banner.PNG

Getting-Started-with-Pytorch

In this notebook I will showcase a convoluted neural network model that achieves 99.6% accuracy on the MNIST Handwritten Digit problem. This model is built using Pytorch Lightning. This package is great for beginners and experts alike as it offers simple yet powerful APIs.

header.png

MNIST Image from Wikipedia



Abstract

This project is based on Mnist dataset, and the architecture used in this project can be find through the link. Although the article was implemented mentioned architecture on Keras, I re-wrote it on Pytorch lightning. Pytorch is considered as an academic tool in AI so it is so much better to learn Pytorch from skratch.



Table of Contents


Requirements

This project does not need any specific requirements. The dataset, Mnist, also is available in torchvision.datasets.

Dataset and pre-processing

Data augmentation is extremely important. For image data, it means we can artificially increase the number of images our model sees.This is achieved by Rotating the Image, Flipping the Image, Zooming the Image, Changing light conditions, Cropping it etc.

Architecture

In order to build a strong Deep neural network, we should go through the following steps:

  1. Add Convolutional Layers — Building blocks of ConvNets and what do the heavy computation
  2. Add Pooling Layers — Steps along image — reduces params and decreases likelihood of overfitting
  3. Add Batch Normalization Layer — Scales down outliers, and forces NN to not relying too much on a Particular Weight
  4. Add Dropout Layer — Regularization Technique that randomly drops a percentage of neurons to avoid overfitting (usually 20% — 50%)
  5. Add Flatten Layer — Flattens the input as a 1D vector
  6. Add Output Layer — Units equals number of classes. Sigmoid for Binary Classification, Softmax in case of Multi-Class Classification.
  7. Add Dense Layer — Fully connected layer which performs a linear operation on the layer’s input
model.add(Conv2D(filters=32, kernel_size=(3, 3), activation='relu', strides=1, padding='same', data_format='channels_last',
                 input_shape=(28,28,1)))
model.add(BatchNormalization())
model.add(Conv2D(filters=32, kernel_size=(3, 3), activation='relu', strides=1, padding='same', data_format='channels_last'))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2, 2), strides=2, padding='valid' ))
model.add(Dropout(0.25))

model.add(Conv2D(filters=64, kernel_size=(3, 3), activation='relu', strides=1, padding='same', data_format='channels_last'))
model.add(BatchNormalization())
model.add(Conv2D(filters=64, kernel_size=(3, 3), strides=1, padding='same', activation='relu', data_format='channels_last'))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2, 2), padding='valid', strides=2))
model.add(Dropout(0.25))

model.add(Flatten())
model.add(Dense(512, activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.25))
model.add(Dense(1024, activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.5))
model.add(Dense(10, activation='softmax'))




Colaboratory Notebook

The Second way to train and test RFC is to use the .ipynb file in the main directory. It is very informative and builds up your intuition of the process of pre-processing and makes you more knowledgeable about the dataset. In addition, you don't even need to clone the repository, because it can be executed by Google Colaboratory Online.

References:

The following list contains several links to every resource that helped us implement this project.

  1. The article published by Brendan Artley MNIST: Keras Simple CNN (99.6%)
  2. Pytorch website QUICKSTART
  3. Maths-ML developed by Farzad Yousefi
  4. House Price Prediction developed by Farzad Yousefi
  5. Machine Learning course published by Coursera

About

In this repository, I'm going to learn pytorch through Mnist problem.

Topics

Resources

License

Stars

Watchers

Forks