In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [2]:
batch_size = 256
num_workers = 4
lr = 0.0005
num_epochs = 400

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm

In [4]:
class Residual(nn.Module):
    def __init__(self, in_channels, out_channels, use_1x1conv=False, stride=1):
        super(Residual, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride, 1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1)
        if use_1x1conv:
            self.conv3 = nn.Conv2d(in_channels, out_channels, 1, stride)
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        out = self.bn2(self.conv2(F.relu(self.bn1(self.conv1(x)))))
        if self.conv3:
            x = self.conv3(x)
        return out + x

In [5]:
def residual_block(in_channels, out_channels, num_repeat, first_block=False):
    if first_block:
        assert in_channels == out_channels
    blk = []
    for i in range(num_repeat):
        if i == 0 and not first_block:
            blk.append(Residual(in_channels, out_channels, use_1x1conv=True, stride=2))
        else:
            blk.append(Residual(out_channels, out_channels))
    return nn.Sequential(*blk)

In [6]:
class ResNet(nn.Module):
    def __init__(self):
        super(ResNet, self).__init__()
        blk = []
        blk.append(nn.Sequential(
            nn.Conv2d(1, 64, 7, 2, 3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(3, 2)))
        blk.append(residual_block(64, 64, 2, True))
        blk.append(residual_block(64, 128, 2))
        blk.append(residual_block(128, 256, 2))
        blk.append(residual_block(256, 512, 2))
        blk.append(nn.AdaptiveAvgPool2d(1))
        self.blk = nn.Sequential(*blk)
        self.fc = nn.Linear(512, 10)
    
    def forward(self, x):
        b, c, h, w = x.shape
        return self.fc(self.blk(x).view(b, -1))

In [7]:
trans = []
trans.append(transforms.Resize((96, 96)))
trans.append(transforms.ToTensor())
transform = transforms.Compose(trans)
mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=True, transform=transform)
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=False, download=True, transform=transform)
print(len(mnist_train), len(mnist_test))
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)

60000 10000


In [8]:
net = ResNet().cuda()
optimizer = torch.optim.Adam(net.parameters(), lr)
loss = torch.nn.CrossEntropyLoss()

In [9]:
def train_FashionMNIST(net, train_iter, optimizer):
    train_loss = 0.0
    train_acc = 0.0
    train_num = 0
    
    for X, y in train_iter:
        X = X.cuda()
        y = y.cuda()
        y_hat = net(X)
        l = loss(y_hat, y)
        optimizer.zero_grad()
        # net.zero_grad()
        l.backward()
        optimizer.step()
        
        train_loss += l.item()
        train_acc += (y_hat.argmax(dim=1) == y).sum().item()
        train_num += y.shape[0]
    
    train_loss /= train_num
    train_acc /= train_num
    print('train loss: %.4f, train acc: %.3f' % (train_loss, train_acc))

In [10]:
def test_FashionMNIST(net, test_iter):
    test_acc = 0.0
    test_num = 0
    
    for X, y in test_iter:
        X = X.cuda()
        y = y.cuda()
        y_hat = net(X)
        test_acc += (y_hat.argmax(dim=1) == y).sum().item()
        test_num += y.shape[0]
    
    test_acc /= test_num
    print('test acc: %.3f' % (test_acc))

In [11]:
for i in range(num_epochs):
    print(f'epoch: {i}')
    train_FashionMNIST(net, train_iter, optimizer)
    test_FashionMNIST(net, test_iter)
    print('----------------')

epoch: 0
train loss: 0.0015, train acc: 0.859
test acc: 0.895
----------------
epoch: 1
train loss: 0.0010, train acc: 0.912
test acc: 0.903
----------------
epoch: 2
train loss: 0.0008, train acc: 0.922
test acc: 0.913
----------------
epoch: 3
train loss: 0.0007, train acc: 0.933
test acc: 0.922
----------------
epoch: 4
train loss: 0.0006, train acc: 0.941
test acc: 0.913
----------------
epoch: 5
train loss: 0.0006, train acc: 0.947
test acc: 0.922
----------------
epoch: 6
train loss: 0.0005, train acc: 0.952
test acc: 0.927
----------------
epoch: 7
train loss: 0.0004, train acc: 0.958
test acc: 0.926
----------------
epoch: 8
train loss: 0.0004, train acc: 0.962
test acc: 0.924
----------------
epoch: 9
train loss: 0.0003, train acc: 0.967
test acc: 0.925
----------------
epoch: 10
train loss: 0.0003, train acc: 0.972
test acc: 0.926
----------------
epoch: 11
train loss: 0.0003, train acc: 0.977
test acc: 0.923
----------------
epoch: 12
train loss: 0.0002, train acc: 0.978
tes

test acc: 0.935
----------------
epoch: 103
train loss: 0.0000, train acc: 1.000
test acc: 0.935
----------------
epoch: 104
train loss: 0.0000, train acc: 1.000
test acc: 0.935
----------------
epoch: 105
train loss: 0.0000, train acc: 1.000
test acc: 0.935
----------------
epoch: 106
train loss: 0.0000, train acc: 1.000
test acc: 0.935
----------------
epoch: 107
train loss: 0.0000, train acc: 1.000
test acc: 0.935
----------------
epoch: 108
train loss: 0.0000, train acc: 1.000
test acc: 0.935
----------------
epoch: 109
train loss: 0.0000, train acc: 1.000
test acc: 0.935
----------------
epoch: 110
train loss: 0.0000, train acc: 1.000
test acc: 0.936
----------------
epoch: 111
train loss: 0.0000, train acc: 1.000
test acc: 0.935
----------------
epoch: 112
train loss: 0.0000, train acc: 1.000
test acc: 0.936
----------------
epoch: 113
train loss: 0.0000, train acc: 1.000
test acc: 0.935
----------------
epoch: 114
train loss: 0.0000, train acc: 1.000
test acc: 0.935
------------

train loss: 0.0000, train acc: 1.000
test acc: 0.937
----------------
epoch: 205
train loss: 0.0000, train acc: 1.000
test acc: 0.937
----------------
epoch: 206
train loss: 0.0000, train acc: 1.000
test acc: 0.937
----------------
epoch: 207
train loss: 0.0000, train acc: 1.000
test acc: 0.937
----------------
epoch: 208
train loss: 0.0000, train acc: 1.000
test acc: 0.937
----------------
epoch: 209
train loss: 0.0000, train acc: 1.000
test acc: 0.937
----------------
epoch: 210
train loss: 0.0000, train acc: 1.000
test acc: 0.937
----------------
epoch: 211
train loss: 0.0000, train acc: 1.000
test acc: 0.937
----------------
epoch: 212
train loss: 0.0000, train acc: 1.000
test acc: 0.937
----------------
epoch: 213
train loss: 0.0000, train acc: 1.000
test acc: 0.937
----------------
epoch: 214
train loss: 0.0000, train acc: 1.000
test acc: 0.937
----------------
epoch: 215
train loss: 0.0000, train acc: 1.000
test acc: 0.938
----------------
epoch: 216
train loss: 0.0000, train ac

test acc: 0.938
----------------
epoch: 306
train loss: 0.0000, train acc: 1.000
test acc: 0.938
----------------
epoch: 307
train loss: 0.0000, train acc: 1.000
test acc: 0.938
----------------
epoch: 308
train loss: 0.0000, train acc: 1.000
test acc: 0.938
----------------
epoch: 309
train loss: 0.0000, train acc: 1.000
test acc: 0.938
----------------
epoch: 310
train loss: 0.0000, train acc: 1.000
test acc: 0.938
----------------
epoch: 311
train loss: 0.0000, train acc: 1.000
test acc: 0.938
----------------
epoch: 312
train loss: 0.0000, train acc: 1.000
test acc: 0.939
----------------
epoch: 313
train loss: 0.0000, train acc: 1.000
test acc: 0.939
----------------
epoch: 314
train loss: 0.0000, train acc: 1.000
test acc: 0.938
----------------
epoch: 315
train loss: 0.0000, train acc: 1.000
test acc: 0.938
----------------
epoch: 316
train loss: 0.0000, train acc: 1.000
test acc: 0.939
----------------
epoch: 317
train loss: 0.0000, train acc: 1.000
test acc: 0.939
------------

In [12]:
# Adam: lr=0.0005, epoch ???, test_acc=0.???, batch_size=256
# Adam: lr=0.001, epoch 181, test_acc=0.938, batch_size=256
# Adam: lr=0.01, epoch 20, test_acc=0.929, batch_size=256
# Adam: lr=0.1, epoch 12, test_acc=0.882, batch_size=256