<a href="https://colab.research.google.com/github/afirdousi/pytorch-basics/blob/main/006_neural_network_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
from torch import nn
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Using PyTorch version = { torch.__version__ }")
print(f"Using device = { device }")  # We will be doing device agnostic code in this tutorial

Using PyTorch version = 2.0.1+cu118
Using device = cpu


In [4]:
# Classification

# Types & Examples:
# Binary Class Classification: Spam vs Not Spam | Dog or Cat?
# Multiclass Classification: Sushi or Pizza or Biryani? | Dog or Cat or Tiger or Cow or whatever ...
# Multi Label Classification: One record has multiple labels like classifying a news into multiple categories or categorizing a wikipedia article to attach relevant categories to the page

In [6]:
# Classification Inputs and Outputs

# Example:

# Input (Picture) ---> ML Algo ---> Output (Label)

# Input
  # Genreally computer vision problems, convert pictures into 224x224
  # Each pictures has W = 224, H = 224 and C = 3 ... C here is color channels R,G,B
  # We create a numerical encoding of that like
  # [
  #     [0.31] [0.62], [0.44],
  #     [0.29] [0.95], [0.0.16],
  #     ...
  # ]

# ML Algo
  # Often exists for the type of problem you are solving
  # If not, you can create a new one

# Output
  # These are preidtion probabilities
  # For eaxmple, if we are predicting if the picture is sushi, burger or pizze,
  # For each input, we might return probablity of this image being either one

  # something like [0.97, 0.00, 0.03] i.e 97% chance its a sushi


In [7]:
# Input and Output shapes

# For example, image input is represented something like:

# [ batch_size, color_channels, width, height ] # the sequence of this can change, doesn't matter

# Example:
# Shape = [ None, 3, 224, 224 ]
# or
# Shape = [ 32, 3, 224, 224 ] # 32 is a common batch size
# Check https://twitter.com/ylecun/status/989610208497360896?s=20
# Batch size is saying train on 32 images at once


# Output for multiclass like sushi, burger or biryani
# Shape = [3]

In [22]:
# Architecture of a Classification Model

# Define sequential input layers: number of in_features and number of out_features

# for example:

# in_features define the number of neurons per layer
fun_model = nn.Sequential(
    nn.Linear(in_features = 3, out_features= 100), # input layer
    nn.Linear(in_features = 100, out_features= 100), # hidden layer
    nn.ReLU(), # hidden activation layer # More here: https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity
    nn.Linear(in_features = 100, out_features= 3) # output layer
)

# There is also output layer activation like Sigmoid (torch.sigmoid) for binary classification or Softmax (torch.softmax) for multiclass classification
# You will also define Loss function: Binary Crossentropy (torch.nn.BCELoss) for binary classification AND Cross Entropy (torch.nn.CrossEntropyLoss) for multiclass classification
# You will also define Optimizer function: SGD, Adam (see torch.optim) --> applies for both binary and multiclass classification

In [23]:
fun_model.state_dict()

OrderedDict([('0.weight',
              tensor([[-0.5668,  0.1125,  0.0617],
                      [-0.3779, -0.4895, -0.4872],
                      [ 0.1916,  0.0400,  0.3436],
                      [ 0.3215, -0.1303, -0.1735],
                      [-0.4442,  0.1094, -0.2282],
                      [-0.4222, -0.0020, -0.5535],
                      [ 0.1413, -0.0788,  0.2773],
                      [-0.0040,  0.0065, -0.3290],
                      [ 0.3223,  0.3576,  0.4861],
                      [ 0.3635,  0.5611,  0.5202],
                      [ 0.3831, -0.3791,  0.3612],
                      [ 0.0448, -0.2140,  0.1948],
                      [-0.4230,  0.4305, -0.2996],
                      [ 0.4768, -0.2743, -0.3628],
                      [ 0.5444, -0.3661, -0.3216],
                      [ 0.0628, -0.5092,  0.0734],
                      [ 0.4263, -0.1445, -0.3409],
                      [ 0.3257,  0.3870,  0.2601],
                      [-0.2123, -0.4786, -0.3752],
     