<a href="https://colab.research.google.com/github/DXL64/MNIST-using-CNN/blob/main/cnn_mnist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torchvision.datasets as datasets
from tqdm import tqdm
import matplotlib.pyplot as plt

class ConvBlock(torch.nn.Module):
  def __init__ (self, in_channels, out_channels, kernel_size):
    super().__init__()
    layers = [
        torch.nn.Conv2d(in_channels, out_channels, kernel_size),
        torch.nn.BatchNorm2d(out_channels),
        torch.nn.ReLU(inplace = True)
    ]
    self.block = torch.nn.Sequential(*layers)

  def forward(self, x):
    return self.block(x)
def testBlock():
  block = ConvBlock(1,32,3)
  x = torch.zeros(10, 1, 32, 32)
  out = block(x)
  print(out.shape)
testBlock()


torch.Size([10, 32, 30, 30])


In [None]:
class CNN(torch.nn.Module):
  def __init__(self, channels, kernel_size, fc_in, num_class):
    super().__init__()
    layers = [
        ConvBlock(1 if i == 0 else channels[i-1],
              c,
              kernel_size) 
              for i, c in enumerate(channels)
    ]
    layers.append(torch.nn.Flatten())
    layers.append(torch.nn.Linear(fc_in, num_class))
    layers.append(torch.nn.BatchNorm1d(num_class))
    self.net = torch.nn.Sequential(*layers)

  def forward(self, x):
    out = self.net(x)
    return out
def testCNN():
  net = CNN([16*(i+1) for i in range(10)], 3, 10240, 10)
  x = torch.zeros(10, 1, 28, 28)
  out = net(x)
  print(out.shape)
testCNN()

torch.Size([10, 10])


In [None]:
class M3(CNN):
  def __init__(self):
    super().__init__([16*(i+2) for i in range(10)], 3, 11264, 10)
  def forward(self, x):
    out = self.net(x)
    return out

In [None]:
class M5(CNN):
  def __init__(self):
    super().__init__([32*(i+1) for i in range(5)], 5, 10240, 10)
  def forward(self, x):
    out = self.net(x)
    return out

In [None]:
class M7(CNN):
  def __init__(self):
    super().__init__([48*(i+1) for i in range(4)], 7, 3072, 10)
  def forward(self, x):
    out = self.net(x)
    return out

In [None]:
def get_datasets():
  from torchvision import transforms as T
  train_transform = T.Compose([
      T.ToTensor()
  ])
  test_transform = T.Compose([
      T.ToTensor()
  ])
  train_ds = datasets.MNIST('./data', train = True, download = True,transform = train_transform)
  test_ds = datasets.MNIST('./data', train = False, download = True, transform = test_transform)
  return dict(train = train_ds, test = test_ds)

ds = get_datasets()
ds.keys()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



dict_keys(['train', 'test'])

In [None]:
def get_dataloaders(datasets, config):
    train_loader = torch.utils.data.DataLoader(
      datasets['train'], batch_size=120, shuffle=True, num_workers=2, drop_last=True
    )
    test_loader = torch.utils.data.DataLoader(
      datasets['test'], batch_size=120, shuffle=False, num_workers=2, drop_last=False
    )
    return dict(train=train_loader, test=test_loader)

class Config:
  def __init__(self):
    self.batch_size = 120
    self.learning_rate = 1e-3
    self.device = "cuda"
    self.gamma = 0.98
    self.n_epoch = 50

config = Config()
loaders = get_dataloaders(ds, config)

for x,y in loaders['train']:
  print(x.shape, y.shape)
  break

torch.Size([120, 1, 28, 28]) torch.Size([120])


In [None]:
import matplotlib.pyplot as plt
net = M5()
net = net.to(config.device)
optimizer = torch.optim.Adam(net.parameters(), lr = config.learning_rate)
loss_function = torch.nn.CrossEntropyLoss()
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma = config.gamma)

total = config.n_epoch * len(loaders['train']) + len(loaders['test'])
plotTrainX = []
plotTrainY = []
plotTestX = []
plotTestY = []
plotTest = [{0, 0}]
with tqdm(total=total, position=0, leave=True) as pbar:
  for epoch in range(config.n_epoch):
    net.train()
    total = 0
    total_correct = 0
    for step, (images, labels) in enumerate(loaders['train']) :
      images, labels = images.to(config.device), labels.to(config.device)
      out = net(images)
      loss = loss_function(out, labels)

      ypred = torch.argmax(out, dim=1)
      batch_correct = torch.sum(ypred==labels)
      total += len(labels)
      total_correct += batch_correct.item()

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      pbar.update()
      pbar.set_description(f"{epoch} - {step} loss = {loss.item():.4f}")
    plotTrainX.append(epoch)
    plotTrainY.append(round(total_correct/total*100, 2))

    with torch.no_grad():
      total = 0
      total_correct = 0
      net.eval()
      for step, (images, labels) in enumerate(loaders['test']) :
        images, labels = images.to(config.device), labels.to(config.device)
        out = net(images)
        preds = torch.argmax(out, dim = 1)
        batch_correct = torch.sum(preds == labels).item()
        total += len(labels)
        total_correct += batch_correct
        pbar.update()
        pbar.set_description(f"{epoch} - {step} test_accuracy = {total_correct/total*100:.2f}")
      plotTestX.append(epoch)
      plotTestY.append(round(total_correct/total*100, 2))

    lr_scheduler.step()
  print(plotTrainX, plotTrainY)
  plt.plot(plotTrainX, plotTrainY)
  plt.plot(plotTestX, plotTestY)
  plt.show()

RuntimeError: ignored