In [17]:
import numpy as np
import pandas as pd
import time
import re
import os
import random
import torch
import matplotlib.pyplot as plt
import seaborn as sns 

import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd

from sklearn.datasets import load_iris

import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import FashionMNIST

In [2]:
normalize = transforms.Normalize(mean=[x / 255 for x in [127.5, 127.5, 127.5]],
                                         std=[x / 255 for x in [127.5, 127.5, 127.5]])

transform_train = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Resize((32,32)),                   
                    transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
                    normalize,
                    ])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((32,32)),

    transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
    normalize
    ])

In [None]:
train = FashionMNIST('/content/images/', download=True, train=True, transform = transform_train)
test = FashionMNIST('/content/images/', download=True, train=False, transform = transform_test)

In [4]:
model = torchvision.models.resnet18(pretrained=False)

In [7]:
train_ids = random.sample(range(0, 20000), 4000)
trainloader = torch.utils.data.DataLoader(train, sampler = train_ids, batch_size=4)
valloader = torch.utils.data.DataLoader(train, sampler = range(50000, 60000), batch_size=1024)
testloader = torch.utils.data.DataLoader(test, batch_size=1024)

In [9]:
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

In [18]:
checkpoint_path = '/content/checkpoint/'
if not os.path.exists(checkpoint_path):
  print('Making directory...')
  os.makedirs(checkpoint_path)

Making directory...


In [None]:
best_acc = 0
for epoch in range(1, 181):
  losses = []
  print('-------------Epoch', epoch, '-------------------')
  for batch in trainloader:
    img = batch[0]
    label = batch[1]
    output = model(img)
    loss = loss_fn(output, label)
    losses.append(loss)
    loss.backward()
    opt.step()
    opt.zero_grad()
  acc = 0
  for batch in valloader:
    img = batch[0]
    label = batch[1]
    output = model(img)
    preds = torch.argmax(output, axis=1)
    corr = (preds == label).sum()
    acc += corr
  acc = acc/10000
  if acc > best_acc:
    best_acc = acc
    best_epoch = epoch
    torch.save({'model_state_dict':model.state_dict(),
                'opt_state_dict':opt.state_dict(),
                'epoch':epoch,
                }, '/content/checkpoint/best.pt')
  print(f'Batch Loss = {loss.mean().item()}, Acc = {acc}, Best Acc {best_acc} at {best_epoch}')