In [1]:
import numpy as np
import matplotlib.pyplot as plt

In [2]:
import torch
from torch import nn, optim
from jcopdl.callback import Callback

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

## Dataset and Dataloader

In [3]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [4]:
bs = 64

transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.ToTensor()
])

train_set = datasets.ImageFolder("data/train/", transform=transform)
trainloader = DataLoader(train_set, batch_size=bs, shuffle=True, num_workers=4)

test_set = datasets.ImageFolder("data/test/", transform=transform)
testloader = DataLoader(test_set, batch_size=bs, shuffle=True)

## Training preparation

In [5]:
from autoencoder import Autoencoder

In [6]:
model = Autoencoder(z_size=32).to(device)
criterion = nn.BCELoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001)
callback = Callback(model, outdir="model")

## Training

In [7]:
from tqdm.auto import tqdm

In [8]:
def loop_fn(mode, dataset, dataloader, model, criterion, optimizer, device):
    if mode == "train":
        model.train()
    elif mode == "test":
        model.eval()
    cost = 0
    for feature, _ in tqdm(dataloader, desc=mode.title()):
        feature = feature.view(-1, 784).to(device)
        
        output = model(feature)
        loss = criterion(output, feature)
        
        if mode == "train":
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        
        cost += loss.item() * feature.shape[0]
    cost = cost / len(dataset)
    return cost

In [None]:
while True:
    train_cost = loop_fn("train", train_set, trainloader, model, criterion, optimizer, device)
    with torch.no_grad():
        test_cost = loop_fn("test", test_set, testloader, model, criterion, optimizer, device)
    
    # Logging
    callback.log(train_cost, test_cost)

    # Checkpoint
    
    callback.save_checkpoint()
        
    # Runtime Plotting
    callback.cost_runtime_plotting()
    
    # Early Stopping
    if callback.early_stopping(model, monitor="test_cost"):
        callback.plot_cost()
        break

Train:   0%|          | 0/563 [00:11<?, ?it/s]

Test:   0%|          | 0/313 [00:00<?, ?it/s]


Epoch     1
Train_cost  = 0.3051 | Test_cost  = 0.2901 | 


Train:   0%|          | 0/563 [00:11<?, ?it/s]

Test:   0%|          | 0/313 [00:00<?, ?it/s]


Epoch     2
Train_cost  = 0.2811 | Test_cost  = 0.2821 | 


Train:   0%|          | 0/563 [00:11<?, ?it/s]

Test:   0%|          | 0/313 [00:00<?, ?it/s]


Epoch     3
Train_cost  = 0.2757 | Test_cost  = 0.2774 | 


Train:   0%|          | 0/563 [00:13<?, ?it/s]

Test:   0%|          | 0/313 [00:00<?, ?it/s]


Epoch     4
Train_cost  = 0.2727 | Test_cost  = 0.2757 | 


Train:   0%|          | 0/563 [00:17<?, ?it/s]

Test:   0%|          | 0/313 [00:00<?, ?it/s]


Epoch     5
Train_cost  = 0.2707 | Test_cost  = 0.2751 | 


Train:   0%|          | 0/563 [00:13<?, ?it/s]

Test:   0%|          | 0/313 [00:00<?, ?it/s]


Epoch     6
Train_cost  = 0.2691 | Test_cost  = 0.2728 | 


Train:   0%|          | 0/563 [00:12<?, ?it/s]

Test:   0%|          | 0/313 [00:00<?, ?it/s]


Epoch     7
Train_cost  = 0.2682 | Test_cost  = 0.2718 | 


Train:   0%|          | 0/563 [00:12<?, ?it/s]

Test:   0%|          | 0/313 [00:00<?, ?it/s]


Epoch     8
Train_cost  = 0.2672 | Test_cost  = 0.2724 | 
[31m==> EarlyStop patience =  1 | Best test_cost: 0.2718[0m


Train:   0%|          | 0/563 [00:16<?, ?it/s]

Test:   0%|          | 0/313 [00:00<?, ?it/s]


Epoch     9
Train_cost  = 0.2664 | Test_cost  = 0.2706 | 


Train:   0%|          | 0/563 [00:13<?, ?it/s]

Test:   0%|          | 0/313 [00:00<?, ?it/s]


Epoch    10
Train_cost  = 0.2659 | Test_cost  = 0.2707 | 
[31m==> EarlyStop patience =  1 | Best test_cost: 0.2706[0m


Train:   0%|          | 0/563 [00:14<?, ?it/s]

Test:   0%|          | 0/313 [00:00<?, ?it/s]


Epoch    11
Train_cost  = 0.2654 | Test_cost  = 0.2701 | 


Train:   0%|          | 0/563 [00:13<?, ?it/s]

## Evaluation

In [None]:
feature, target = next(iter(testloader))
feature = feature.view(-1, 784).to(device)

In [None]:
with torch.no_grad():
    model.eval()
    enc = model.encode(features)
    dec = model.encode(enc)
    
fig, ax = plt.subplots(3, 8, figsize=(17,7))
for i in range(5):
    ax[0, i].imshow(features[i].view(28,28).cpu(), cmap='gray')
    ax[0, i].axis('off')
    
    ax[1, i].imshow(enc[i].view(28,28).cpu(), cmap='gray')
    ax[1, i].axis('off')
    
    ax[2, i].imshow(dec[i].view(28,28).cpu(), cmap='gray')
    ax[2, i].axis('off')