## Training

In [None]:
model.train()
train_res_recon_error = []
train_res_internal_error = []
l2_lambda = 1e-3

for i in tqdm(xrange(num_training_updates)):
    (data,) = next(iter(training_loader))
    data = data.to(device)
    batch_size = data.size(0) 
    optimizer.zero_grad() 

    data_recon,internal_loss,test = model(data)
    i_loss = internal_loss
    data_test = data.view(batch_size * 20, 1, 64, 64)  
    
    recon_error = F.mse_loss(data_recon, data_test)/data_variance 
    l2_reg = sum(torch.sum(param ** 2) for param in model.parameters())
    loss = recon_error + i_loss  + l2_lambda*l2_reg
    loss.backward()
    
    optimizer.step() 
    
    train_res_recon_error.append(recon_error.item()) 
    train_res_internal_error.append(i_loss)
    if (i+1) % 50 == 0:
        print('%d iterations' % (i+1))
        print('recon_error: %.3f' % np.mean(train_res_recon_error[-100:]) )
        print('internal_loss: %.3f' % np.mean(train_res_internal_error[-100:]) ) 
        print()
        

In [None]:

train_res_recon_error_smooth = savgol_filter(train_res_recon_error, 201, 7)
f = plt.figure(figsize=(16,4))
ax = f.add_subplot(1,2,1)
ax.plot(train_res_recon_error_smooth)
ax.set_yscale('log')
ax.set_title('Smoothed NMSE.')
ax.set_xlabel('iteration')

## Test

In [None]:
model.eval()

train_res_recon_error = []

with torch.no_grad():
    overall_loss = 0
    
    for i in tqdm(xrange(10)):

        (data,) = next(iter(validation_loader))
        batch_size = data.size(0) 
        data = data.to(device)
        
        data_recon,internal_loss,test = model(data)
        
        data_test = data.view(batch_size * 20, 1, 64, 64)  
        
        recon_error = F.mse_loss(data_recon, data_test) / data_variance 
        l2_reg = sum(torch.sum(param ** 2) for param in model.parameters())
        loss = recon_error + internal_loss + l2_lambda * l2_reg
        overall_loss += recon_error.item()
        
        train_res_recon_error.append(recon_error.item()) 
        
        if (i+1) % 10 == 0:
            print('%d iterations' % (i+1))
            print('recon_error: %.3f' % np.mean(train_res_recon_error[-100:]))
            print()


### Visualization

In [None]:
def show(img):
    npimg = img.numpy()
    plt.figure(figsize = (10,5))
    fig = plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')
    fig.axes.get_xaxis().set_visible(False)
    fig.axes.get_yaxis().set_visible(False)
    plt.show()

In [None]:
grid_img = make_grid(data[5, :, :, :, :].cpu().data, nrow=10)
show(grid_img)

In [None]:
grid_img = make_grid(output[5, :, :, :, :].cpu().data, nrow=10)
show(grid_img)