In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.models as models
import tqdm
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import random
from torch.utils.data.dataloader import default_collate


In [2]:
class VGGNet(nn.Module):
    def __init__(self):
        super(VGGNet, self).__init__()

        # Take vgg13 untrained skeleton
        self.model = models.vgg13(weights=None)
        # Since original resnet18 has a 3-channel input, we have to change it to 1-channel for greyscale
        self.model.features[0] = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False)
        # Similarly, change final output nodes to 10, according to fashion-mnist's class
        self.model.classifier[6] = nn.Linear(4096, 10)

    def forward(self, x):
        return self.model(x)


In [3]:
model = VGGNet()

device = 'mps' if torch.backends.mps.is_available() else 'cpu'
# if torch.cuda.is_available():
#     model.cuda()

print(device)


mps


In [4]:

model.to(device)


VGGNet(
  (model): VGG(
    (features): Sequential(
      (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): ReLU(inplace=True)
      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): ReLU(inplace=True)
      (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (6): ReLU(inplace=True)
      (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (8): ReLU(inplace=True)
      (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): ReLU(inplace=True)
      (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (13): ReLU(inplace=True)
      (14): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (15): Conv2d(256, 5

In [5]:
# Get data loader Function referenced from hw4, loader.py
# Download Fashion MNIST dataset, apply transforms, fit data into dataloader
def get_data_loader(train_transformer, valid_transformer, batch_size):
    train_loader = DataLoader(torchvision.datasets.FashionMNIST(
        download=True, root=".", transform=train_transformer, train=True), batch_size=batch_size, shuffle=True, pin_memory=True)

    val_loader = DataLoader(torchvision.datasets.FashionMNIST(download=False, root=".",
                            transform=valid_transformer, train=False), batch_size=batch_size, shuffle=False, pin_memory=True)
    return train_loader, val_loader


In [6]:
epochs = 5
batch_size = 32

data_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(p=0.3),
    transforms.Normalize([0.5], [0.5])])
train_loader, val_loader = get_data_loader(data_transform, data_transform, batch_size)


In [7]:
loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
train_losses = []
valid_losses = []
train_accs = []
valid_accs = []


In [8]:
# Run function referenced from hw4, helper.py
# Perform forward propagation
# If mode is training, also perform backward propagation and optimize parameters
def run(mode, dataloader, model, optimizer=None, use_cuda=torch.cuda.is_available(), device=None):
    """
    mode: either "train" or "valid". If the mode is train, we will optimize the model
    """
    running_loss = []
    criterion = nn.CrossEntropyLoss()

    actual_labels = []
    predictions = []
    for inputs, labels in tqdm.tqdm(dataloader):
        if use_cuda:
            inputs, labels = inputs.cuda(), labels.cuda()

        if device == 'mps':
            inputs, labels = inputs.to(device), labels.to(device)

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        running_loss.append(loss.item())

        actual_labels += labels.view(-1).cpu().numpy().tolist()
        _, pred = torch.max(outputs, dim=1)

        predictions += pred.view(-1).cpu().numpy().tolist()

        if mode == "train":
            # zero the parameter gradients
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    acc = np.sum(np.array(actual_labels) == np.array(
        predictions)) / len(actual_labels)
    print(mode, "Accuracy:", acc)

    loss = np.mean(running_loss)

    return loss, acc


In [9]:

# Perform the actual training and validating process for each epoch
for epoch in range(epochs):
    loss, acc = run("train", train_loader, model, optimizer, device=device)
    train_losses.append(loss)
    train_accs.append(acc)
    with torch.no_grad():
        loss, acc = run("valid", val_loader, model, optimizer, device=device)
        valid_losses.append(loss)
        valid_accs.append(acc)

print("-"*60)
print("best validation accuracy is %.4f percent" % (np.max(valid_accs) * 100))

# save the model for future reference
torch.save(model, "%s.pt" % str(valid_accs[-1]))


100%|██████████| 1875/1875 [51:11<00:00,  1.64s/it]


train Accuracy: 0.8417166666666667


100%|██████████| 313/313 [01:45<00:00,  2.96it/s]


valid Accuracy: 0.8958


100%|██████████| 1875/1875 [50:59<00:00,  1.63s/it]


train Accuracy: 0.9095333333333333


100%|██████████| 313/313 [01:44<00:00,  3.00it/s]


valid Accuracy: 0.9143


100%|██████████| 1875/1875 [50:55<00:00,  1.63s/it]


train Accuracy: 0.92505


100%|██████████| 313/313 [01:44<00:00,  3.00it/s]


valid Accuracy: 0.9146


100%|██████████| 1875/1875 [50:49<00:00,  1.63s/it]


train Accuracy: 0.9347833333333333


100%|██████████| 313/313 [01:44<00:00,  3.00it/s]


valid Accuracy: 0.9211


100%|██████████| 1875/1875 [50:51<00:00,  1.63s/it]


train Accuracy: 0.94485


100%|██████████| 313/313 [01:46<00:00,  2.95it/s]


valid Accuracy: 0.9207
------------------------------------------------------------
best validation accuracy is 92.1100 percent
