# Import
---


In [0]:
import time
import pandas as pd

import torch
import torch.nn            as nn
import torch.nn.functional as F
import torch.optim         as optim

import torchvision
import torchvision.transforms as transforms
import torchvision.datasets   as datasets

# Model
---
AlexNetを基にした小さい画像用CNNモデル

In [0]:
class CNN(nn.Module):
    
    def __init__(self, num_classes):
        super(CNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(  1,  32, kernel_size=5, padding=2), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d( 32,  64, kernel_size=5, padding=2), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d( 64, 128, kernel_size=3, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        self.classifier = nn.Sequential(
            nn.Dropout(), nn.Linear(256*6*6, 4096), nn.ReLU(inplace=True),
            nn.Dropout(), nn.Linear(   4096, 4096), nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = x.view(-1, 256*6*6) # flatten
        x = self.classifier(x)
        return x

In [0]:
if torch.cuda.is_available(): device = 'cuda'
else                        : device = 'cpu' 

net = CNN(num_classes=10).to(device)
net.load_state_dict(torch.load('models/FMNIST.pt')) # load pretrained model by MNIST
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters())

# Load Dataset
---

In [4]:
toTensor  = transforms.ToTensor()
normalize = transforms.Normalize((0.5, ), (0.5, ))
transform = transforms.Compose([toTensor, normalize])

train_set = datasets.MNIST(root='../data', download=True, train=True , transform=transform)
valid_set = datasets.MNIST(root='../data', download=True, train=False, transform=transform)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=1000, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=1000, shuffle=True)

print('size  :', train_set[0][0].size())
print('train :', len(train_set))
print('valid :', len(valid_set))

size  : torch.Size([1, 28, 28])
train : 60000
valid : 10000


# Training & Validation
---

In [0]:
def train():
    net.train()

    # mini-batch learning
    epoch_loss = 0
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        # forward
        t = net(x)
        # cal loss
        loss = loss_func(t, y)
        epoch_loss += loss.item()
        # backward
        loss.backward()
        # update model
        optimizer.step()
        optimizer.zero_grad()

    return epoch_loss

In [0]:
def valid():
    net.eval()

    with torch.no_grad():
        # mini-batch validation (for save memory)
        epoch_accu = 0
        for x, y in valid_loader:
            x, y = x.to(device), y.to(device)
            # forward
            t = net(x)
            # cal accuracy
            _, p = torch.max(t.data, dim=1)
            accu = (p == y).sum()
            epoch_accu += accu.item()

    return epoch_accu / len(valid_set)

In [7]:
max_epoch = 30

results = []
for epoch in range(max_epoch):
    st = time.time()
    loss = train()
    md = time.time()
    accu = valid()
    ed = time.time()

    results.append([epoch+1, md-st, loss, ed-md, accu])
    print(f'{epoch+1:>2} | {md - st:5.2f}s | {loss:>6.3f} | {ed - md:5.2f}s | {accu:.2%}')

 1 | 11.64s | 43.162 |  1.58s | 97.64%
 2 | 11.73s | .4.651 |  1.61s | 98.69%
 3 | 11.41s | .2.945 |  1.55s | 98.90%
 4 | 11.28s | .2.071 |  1.65s | 98.98%
 5 | 11.05s | .1.486 |  1.57s | 99.11%
 6 | 11.01s | .1.327 |  1.54s | 99.03%
 7 | 11.21s | .1.088 |  1.55s | 99.17%
 8 | 11.33s | .0.893 |  1.57s | 99.11%
 9 | 11.36s | .0.895 |  1.55s | 99.15%
10 | 11.37s | .0.680 |  1.56s | 98.98%
11 | 11.30s | .0.614 |  1.55s | 99.07%
12 | 11.20s | .0.492 |  1.61s | 99.11%
13 | 11.18s | .0.654 |  1.55s | 99.23%
14 | 11.29s | .0.513 |  1.55s | 99.14%
15 | 11.31s | .0.333 |  1.61s | 99.25%
16 | 11.31s | .0.350 |  1.55s | 99.25%
17 | 11.34s | .0.478 |  1.57s | 99.13%
18 | 11.31s | .0.274 |  1.55s | 99.23%
19 | 11.32s | .0.534 |  1.57s | 99.10%
20 | 11.31s | .0.457 |  1.55s | 99.11%
21 | 11.34s | .0.334 |  1.56s | 99.12%
22 | 11.31s | .0.375 |  1.56s | 99.22%
23 | 11.30s | .0.328 |  1.56s | 99.25%
24 | 11.29s | .0.304 |  1.54s | 99.24%
25 | 11.28s | .0.189 |  1.57s | 99.21%
26 | 11.25s | .0.360 |  1

# Save Model & Result
---

In [0]:
torch.save(net.state_dict(), 'models/TL_FMNISTtoMNIST.pt')

In [0]:
columns = ['epoch', 'train time(s)', 'loss', 'valid time(s)', 'accu']
df = pd.DataFrame(data=results, columns=columns).set_index('epoch')
df.to_csv('./results/TL_FMNISTtoMNIST.csv')