In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image
import pylab
import matplotlib.pyplot as plt

import numpy as np
%matplotlib inline

import logging

logger = logging.getLogger("logger")    #logger名loggerを取得
logger.setLevel(logging.DEBUG)  #loggerとしてはDEBUGで

#handler1を作成
handler1 = logging.StreamHandler()
handler1.setFormatter(logging.Formatter("%(asctime)s %(levelname)8s %(message)s"))

from tensorboardX import SummaryWriter
import os



In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
import labnote as lb
if lb.utils.is_executed_on_ipython():
    %reload_ext autoreload
    %autoreload 2

note = lb.Note(arguments=['conf_train_args.yaml',{'config':'conf_AE.yaml'}])

writer = writer = SummaryWriter(note.params.output)

# Fix seed
if 'seed' in note.params.keys():
    np.random.seed(note.params.seed)
    _ = torch.manual_seed(note.params.seed)

In [None]:
# Mnist digits dataset
train_data = MNIST(
    root='./mnist/',
    train=True,                                     # this is training data
    transform=transforms.ToTensor(),    # Converts a PIL.Image or numpy.ndarray to
                                                    # torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]
    download=True,                        # download it if you don't have it
)
# plot one example
print(train_data.train_data.size())     # (60000, 28, 28)
print(train_data.train_labels.size())   # (60000)
plt.imshow(train_data.train_data[2].numpy(), cmap='gray')
plt.title('%i' % train_data.train_labels[2])
plt.show()

# Data Loader for easy mini-batch return in training, the image batch shape will be (50, 1, 28, 28)
train_loader = DataLoader(dataset=train_data, batch_size=note.params.batch_size, shuffle=True)


In [None]:
import src.network as net
AE = net.AutoEncoder(note.params,logger,input_dim=1, gpu=note.params.gpu,writer=writer)
optimizer = eval(note.params.optimizer.format('AE',note.params.learning_rate))

In [None]:
epoch = 0
global_step = 0
AE.train()

from sys import stdout

for epoch in range(0,note.params.num_epochs):
    for batch_idx, (x,y) in enumerate(train_loader):
        optimizer.zero_grad()
        z,x_pred = AE.forward(x)
        # z = z.detach() # to use z for further calculation.        
        
        # calc reconstruction loss
        l = AE.calc_loss('reconstruction',global_step,x_pred,x)
        l.backward()
        
        
        optimizer.step()
        if batch_idx % 100 == 0:
            writer.add_image('x_pred_epoch%3d'%epoch, x_pred[0], global_step)

        if batch_idx % 10 == 0:
            loss_data = l.detach().numpy()
            stdout.write(
                'Train Epoch: {}/{} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\r'.
                format(epoch, note.params.num_epochs, batch_idx * len(x), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss_data))        
        global_step += 1

In [None]:
# model save/load test.

model_path = os.path.join(note.params.output,'model')
AE.save(model_path)

In [None]:
AE2 = net.AutoEncoder.load(model_path,logger=logger,input_dim=1, gpu=note.params.gpu,writer=writer)
optimizer2 = eval(note.params.optimizer.format('AE2',note.params.learning_rate))

In [None]:
for src_img,src_label in train_loader:
    optimizer2.zero_grad()
    z,reconst = AE2.forward(src_img)
    # z = z.detach() # to use z for further calculation.        

    # calc reconstruction loss
    l = AE.calc_loss('reconstruction',global_step,reconst,src_img)
    l.backward()
    print(l)
    optimizer2.step()
