In [2]:
"""
  PyTorch ecosystem
  suits better than TensorFlow, easier to debug and make custom stuff
"""
import torch
import torch.nn as nn ## foundation of neural networks ->
import torch.optim as optim # optimizers

"""
  TorchVision
"""
import torchvision.datasets as datasets # ready-to-go dataset
import torchvision.transforms as transforms # preprocessing, normalization

from torch.utils.data import DataLoader # standart pipeline training

import numpy as np
import matplotlib.pyplot as plt

In [5]:
"""
nn.Module
Base class for all neural network modules.

Your models should also subclass this class.
"""
class CNNEncoder(nn.Module):

  """
  __init__ function:
  input_channels= color channels, input data is 1 color (our MNIST dataset)
  feature_dim = how many numbers to use to describe the frame
  """

  def __init__(self, input_channels=1, feature_dim=512):
    super(CNNEncoder, self).__init__() # nn.Module init

    """
    Convolutional layers

      Amount of channels:
      1 channel -> initial image
      64 channels -> 64 different "detectors" (edges, corners, textures)
      128 channels -> 128 advanced shapes

      Kernel size:
      1x1 -> precise operations (color)
      3x3 -> local patterns (edges, small features)
      5x5 -> wider patterns, don't need them

      nn.Conv1d -> audio, text
      nn.Conv2d -> 2d images
      nn.Conv3d -> video, 3d models

      nn.BatchNorm -> our pit-stop master, keeps our model fit and "even" to finish a race

    """
    self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=3, padding=1)
    self.bn1 = nn.BatchNorm2d(64)

    self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
    self.bn2 = nn.BatchNorm2d(128)

  """
  forward() pass -> the heart of neural networks
  """
  def forward(self, x):
    """
    Init input:
      batch_size
      1 -> because black-white images
      28x28 -> our MNIST dataset
      (batch_size, 1, 28, 28)
    """

    # first block
    x = self.conv1(x) # (batch, 1, 28, 28) → (batch, 64, 28, 28)
    x = self.bn1(x)
    x = torch.relu(x)

     # second block
    x = self.conv2(x) # (batch, 64, 28, 28) → (batch, 128, 28, 28)
    x = self.bn2(x)
    x = torch.relu(x)

    return x