In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import numpy as np
import torch
import torchvision
import torch.nn as nn
import matplotlib.pyplot as plt
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import torch.nn.functional as F

np.random.seed(1)

In [None]:
NUMBER_OF_CLASSES = 10
EPOCHS = 40
BATCH_SIZE = 128
# VALIDATION_SPLIT = 0.2

In [None]:
# Load MNIST dataset
working_dir = '/content/drive/MyDrive/EECS 545 project/'
transform_list={
    torchvision.transforms.ToTensor(),
    # torchvision.transforms.Normalize(mean=(0), std=(1.0)),
}
transforms = torchvision.transforms.Compose(transform_list)
train_set = torchvision.datasets.FashionMNIST(train=True, root=working_dir + 'Fashion_MNIST', download = True, transform=transforms)
test_set = torchvision.datasets.FashionMNIST(train=False, root=working_dir + 'Fashion_MNIST', download = True, transform=transforms)

In [None]:
train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(4, 4), stride=(2, 2))
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(4, 4), stride=(2, 2))
        self.drop1 = nn.Dropout(p=0.25)
        self.lin1 = nn.Linear(1600, 128)
        self.drop2 = nn.Dropout(p=0.5)
        self.output = nn.Linear(128, 10)

    def forward(self, input_batch):
        x = F.relu(self.conv1(input_batch))
        x = F.relu(self.conv2(x))
        x = self.drop1(x)
        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.lin1(x))
        x = self.drop2(x)
        x = self.output(x)
        return x

In [None]:
device = torch.device("cuda")
cnn_mnist= SimpleCNN().to(device)
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cnn_mnist.parameters(), lr=0.001)
for ep in range(EPOCHS):
  for i, sample in enumerate(train_loader):
      features, label = sample
      # print(features.shape, label.shape)
      # print(type(features), type(label))
      # print(features.min())
      x = cnn_mnist(features.to(device))
      loss = loss_func(x, label.to(device).long())
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      if i % 100 == 0:
          print("Epoch: ", ep + 1, "Loss: ", loss.item())


In [None]:
torch.save(cnn_mnist.state_dict(), working_dir + "cnn_fashion_mnist_model.pth")

In [None]:
cnn_mnist.eval()
avg_loss = 0
avg_acc = 0
count = 0
for i, sample in enumerate(test_loader):
    features, label = sample
    x = cnn_mnist(features.to(device))
    loss = loss_func(x, label.to(device).long())
    acc = torch.sum(x.max(dim=1)[1] == label.to(device).long())
    avg_acc = avg_acc + acc
    avg_loss = avg_loss + loss
    count = count + x.shape[0]

avg_acc = avg_acc / count
avg_loss = avg_loss / count
print("Accuracy: ", avg_acc.item(), "Loss: ", avg_loss.item())

In [None]:
torch.save(cnn_mnist.to(torch.device("cpu")).state_dict(), working_dir + "cnn_fashion_mnist_model_cpu.pth")