In [12]:
from torchvision.datasets import FashionMNIST
from torchvision import transforms
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
import tqdm

In [16]:
fashion_mnist_train = FashionMNIST(
    "data/FashionMNIST", 
    train = True,
    download = True,
    transform = transforms.ToTensor()
)

In [17]:
fashion_mnist_test = FashionMNIST(
    "data/FashionMNIST",
    train = False,
    download = True,
    transform = transforms.ToTensor()
)

In [18]:
batch_size = 128
train_loader = DataLoader(
    fashion_mnist_train,
    batch_size = batch_size,
    shuffle = True
)
test_loader = DataLoader(
    fashion_mnist_test,
    batch_size = batch_size,
    shuffle = False
)

In [19]:
class FlattenLayer(nn.Module):
    def forward(self, x):
        sizes = x.size()
        return x.view(sizes[0], -1)
    

In [23]:
conv_net = nn.Sequential(
    nn.Conv2d(1, 32, 5),
    nn.MaxPool2d(2),
    nn.ReLU(),
    nn.BatchNorm2d(32),
    nn.Dropout2d(0.25),
    nn.Conv2d(32, 64, 5),
    nn.MaxPool2d(2),
    nn.ReLU(),
    nn.BatchNorm2d(64),
    nn.Dropout2d(0.25),
    FlattenLayer()
)

In [30]:
test_input = torch.ones(1, 1, 28, 28)
conv_output_size = conv_net(test_input).size()[-1]
print(test_input.size())
print(conv_output_size)

torch.Size([1, 1, 28, 28])
1024


In [31]:
mlp = nn.Sequential(
    nn.Linear(conv_output_size, 200),
    nn.ReLU(),
    nn.BatchNorm1d(200),
    nn.Dropout(0.25),
    nn.Linear(200, 10)
)

net = nn.Sequential(
    conv_net,
    mlp
)

In [62]:
def eval_net(net, data_loader, device = "cpu"):
    net.eval()
    ys = torch.tensor([])
    y_preds = torch.tensor([])
    
    for x, y in data_loader:
        x = x.to(device)
        y = y.to(device)
        
        with torch.no_grad():
            _, y_pred = net(x).max(1)
            
        ys = torch.cat((ys, y))
        y_preds = torch.cat((y_preds, y_pred))
        
    acc = (ys == y_preds).sum().item() / len(ys)
    return acc

def train_net(net, 
              train_loader, 
              test_loader,
              optimizer_cls = optim.Adam,
              loss_fn = nn.CrossEntropyLoss(),
              n_iter = 10,
              device = "cpu"
             ):
    
    train_losses = []
    train_acc = []
    val_acc = []
    optimizer = optimizer_cls(net.parameters())
    
    for epoc in range(n_iter):
        running_loss = 0.0
        net.train()
        n = 0
        n_acc = 0
        
        for i, (xx, yy) in tqdm.tqdm(enumerate(train_loader), total = len(train_loader)):
            xx = xx.to(device)
            yy = yy.to(device)
            h = net(xx)
            loss = loss_fn(h, yy)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            n += len(xx)
            _, y_pred = h.max(1)
            n_acc += (yy == y_pred).sum().item()
            
        train_losses.append(running_loss / i)
        train_acc.append(n_acc / n)
        val_acc.append(eval_net(net, test_loader, device))
        if epoc % 10 != 0:
            continue
            
        print(epoc, train_losses[-1], train_acc[-1], val_acc[-1], flush = True)

In [63]:
train_net(net, train_loader, test_loader, n_iter = 20)

100%|██████████| 469/469 [00:53<00:00,  8.69it/s]


0 0.18512561079910678 0.9298333333333333 0.9106


100%|██████████| 469/469 [00:52<00:00,  8.94it/s]
100%|██████████| 469/469 [00:52<00:00,  8.86it/s]
100%|██████████| 469/469 [00:52<00:00,  8.89it/s]
100%|██████████| 469/469 [00:52<00:00,  8.97it/s]
100%|██████████| 469/469 [00:52<00:00,  8.85it/s]
100%|██████████| 469/469 [00:54<00:00,  8.53it/s]
100%|██████████| 469/469 [00:51<00:00,  9.12it/s]
100%|██████████| 469/469 [00:55<00:00,  8.47it/s]
100%|██████████| 469/469 [00:54<00:00,  8.59it/s]
100%|██████████| 469/469 [00:55<00:00,  8.44it/s]


10 0.14431103981999505 0.94455 0.9189


100%|██████████| 469/469 [00:53<00:00,  8.73it/s]
100%|██████████| 469/469 [00:54<00:00,  8.63it/s]
100%|██████████| 469/469 [00:54<00:00,  8.68it/s]
100%|██████████| 469/469 [00:54<00:00,  8.64it/s]
100%|██████████| 469/469 [00:56<00:00,  8.25it/s]
100%|██████████| 469/469 [00:54<00:00,  8.61it/s]
100%|██████████| 469/469 [00:55<00:00,  8.44it/s]
100%|██████████| 469/469 [00:59<00:00,  7.86it/s]
100%|██████████| 469/469 [00:53<00:00,  8.82it/s]


In [65]:
eval_net(net, test_loader)

0.9208