In [None]:
import torch
import torchvision
from torch import nn, optim
import torch.utils.data as Data
import numpy as np
import torchvision.transforms as transforms
import torch.nn.functional as F
import matplotlib.pyplot as plt
import time
%matplotlib inline
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
class GlobalAvgPool2d(nn.Module):
    def __init__(self):
        super(GlobalAvgPool2d, self).__init__()
    def forward(self, x):
        return F.avg_pool2d(x, kernel_size=x.size()[2:])

In [None]:
class FlattenLayer(nn.Module):
    def __init__(self):
        super(FlattenLayer, self).__init__()
    def forward(self, x):
        return x.view(x.shape[0], -1)

In [None]:
def conv_block(in_channels, out_channels):
    blk = nn.Sequential(nn.BatchNorm2d(in_channels), 
                        nn.ReLU(),
                        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
    return blk

In [None]:
class DenseBlock(nn.Module):
    def __init__(self, num_convs, in_channels, out_channels):
        super(DenseBlock, self).__init__()
        net = []
        for i in range(num_convs):
            in_c = in_channels + i * out_channels
            net.append(conv_block(in_c, out_channels))
        self.net = nn.ModuleList(net)
        self.out_channels = in_channels + num_convs * out_channels

    def forward(self, X):
        for blk in self.net:
            Y = blk(X)
            X = torch.cat((X, Y), dim=1)
        return X

In [None]:
blk = DenseBlock(2, 3, 10)
X = torch.rand(4, 3, 8, 8)
Y = blk(X)
Y.shape

In [None]:
def transition_block(in_channels, out_channels):
    blk = nn.Sequential(
            nn.BatchNorm2d(in_channels), 
            nn.ReLU(),
            nn.Conv2d(in_channels, out_channels, kernel_size=1),
            nn.AvgPool2d(kernel_size=2, stride=2))
    return blk

In [None]:
net = nn.Sequential(
    nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
    nn.BatchNorm2d(64),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)

In [None]:
num_channels, growth_rate = 64, 32
num_convs_in_dense_blocks = [4, 4, 4, 4]

for i, num_convs in enumerate(num_convs_in_dense_blocks):
    DB = DenseBlock(num_convs, num_channels, growth_rate)
    net.add_module("DenseBlocks_%d"%i, DB)
    num_channels = DB.out_channels
    if i != len(num_convs_in_dense_blocks)-1:
        net.add_module("transition_block_%d"%i, transition_block(num_channels, num_channels//2))
        num_channels = num_channels//2

In [None]:
net.add_module("BN", nn.BatchNorm2d(num_channels))
net.add_module("relu", nn.ReLU())
net.add_module("global_avg_pool", GlobalAvgPool2d())
net.add_module("fc", nn.Sequential(FlattenLayer(), nn.Linear(num_channels, 10)))

In [None]:
print(net)

In [None]:
X = torch.rand((1,1,96,96))
for name, layer in net.named_children():
    X = layer(X)
    print(name, 'output shape:\t', X.shape)

In [None]:
def evaluate_acc(net, data_iter, device):
    acc, n = 0.0, 0
    with torch.no_grad():
        for X, y in data_iter:
            net.eval()
            acc += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item()
            n += y.shape[0]
    return acc/n

In [None]:
def train_model(device, net, optimizer, train_iter, test_iter, n_epochs):
    net = net.to(device)
    print('train on: ', device)
    loss = nn.CrossEntropyLoss()
    for epoch in range(0, n_epochs):
        batch_ct, train_l_sm, n, train_acc_sm, start = 0, 0.0, 0, 0.0, time.time()
        for X, y in train_iter:
            X = X.to(device)
            y = y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            
            train_l_sm += l.cpu(); batch_ct += 1
            train_acc_sm += (y_hat.argmax(dim=1) == y).sum().cpu().item(); n += y.shape[0]
        test_acc = evaluate_acc(net, test_iter, device)
        print('epoch %d, loss %.3f, train acc %.3f, test acc %.3f, time %.3f sec'
             %(epoch+1, train_l_sm/batch_ct, train_acc_sm/n, test_acc, time.time()-start))

In [None]:
batch_size = 256
rt = r'D:\notebook_canticle\Datasets\fmnist/'
def load_fm(rt, batch_size, resize=None):
    trans = []
    if resize:
        trans.append(transforms.Resize(size=resize))
    trans.append(transforms.ToTensor())
    transform = transforms.Compose(trans)
    fm_train = torchvision.datasets.FashionMNIST(root=rt, train=True, transform=transform)
    fm_test  = torchvision.datasets.FashionMNIST(root=rt, train=False, transform=transform)
    train_iter = Data.DataLoader(fm_train, batch_size=batch_size, shuffle=True)
    test_iter  = Data.DataLoader(fm_test,  batch_size=batch_size, shuffle=False)
    return train_iter, test_iter
train_iter, test_iter = load_fm(rt, batch_size, 96)

In [None]:
optimizer = optim.Adam(net.parameters(), lr=lr)
train_model(device, net, optimizer, train_iter, test_iter, num_epochs)