In [5]:
from src import *
import numpy as np
import torch
import time


%matplotlib inline


In [6]:
torch.cuda.is_available()


True

# CUDA

# Parameters

In [7]:
from configs import *

# Main

### Dataset

In [8]:
transforms = transforms.Compose([
    transforms.Resize(size=(resize_h, resize_w)),
    transforms.ToTensor(),    
])
dataset = CustomDataset(dataset_name=dataset_name, transforms=transforms)
train_loader = dataset.get_dataloader(is_train=True,
                                      batch_size=batch_size,
                                      shuffle=True,
                                      num_workers=num_workers,
                                      prefetch_factor=prefetch_factor)
test_loader = dataset.get_dataloader(is_train=False,
                                     batch_size=batch_size,
                                     shuffle=False,
                                     num_workers=num_workers,
                                     prefetch_factor=prefetch_factor)

In [9]:
torch.manual_seed(1)
torch.cuda.manual_seed(1)
cuda = True if torch.cuda.is_available() else False
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('PyTorch is using', device)

PyTorch is using cuda


### Model

In [10]:
input_dim = resize_h, resize_w, input_ch
model = AAE(input_dim, channels, num_z).to(device)
print(model)
discriminator = Discriminator(input_dim, channels, num_z).to(device)
print(discriminator)
optimizer_G = optim.Adam(model.parameters(), lr=init_lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr = init_lr)

# scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer,
#                                         lr_lambda=lambda epoch: lr_decay ** epoch)

AAE(
  (encoder): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): GELU()
    (2): Dropout(p=0.1, inplace=False)
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (4): GELU()
    (5): Dropout(p=0.1, inplace=False)
    (6): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (7): GELU()
    (8): Dropout(p=0.1, inplace=False)
    (9): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (10): GELU()
    (11): Dropout(p=0.1, inplace=False)
    (12): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (13): GELU()
    (14): Dropout(p=0.1, inplace=False)
    (15): Conv2d(512, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (16): GELU()
    (17): Dropout(p=0.1, inplace=False)
    (18): Flatten(start_dim=1, end_dim=-1)
  )
  (z_mu): Linear(in_features=20480, out_features=1024, bias=True)
  (z_logvar): Linear(in_features=20480, out_features=1024, bia

In [None]:

from torch.autograd import Variable

adversarial_loss = torch.nn.BCELoss()
pixelwise_loss = torch.nn.L1Loss()

# Learn
total_time = 0
# train_losses = {"Generator loss": []}
# test_losses = {"Generator loss": [], "Discriminator loss":[]}

for epoch in range(1, epochs + 1):
    for i, (imgs, _) in enumerate(train_loader):
        
        # Adversarial ground truths
        valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(Tensor))
        

        '''Train Generator'''
        start_time = time.time()
        optimizer_G.zero_grad()
        
        mu,logvar = model.encode(real_imgs)
        encoded_imgs = model.reparameterize(mu, logvar)
        decoded_imgs = model.decode(encoded_imgs)

        # Loss measures generator's ability to fool the discriminator
        g_loss = 0.001 * adversarial_loss(discriminator(encoded_imgs), valid) + 0.999 * pixelwise_loss(
            decoded_imgs, real_imgs
        )
        
        g_loss.backward()
        optimizer_G.step()

        '''Train Discriminator'''
        optimizer_D.zero_grad()
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], num_z))))

        real_loss = adversarial_loss(discriminator(z), valid)
        fake_loss = adversarial_loss(discriminator(encoded_imgs.detach()), fake)
        d_loss = 0.5 * (real_loss + fake_loss)
        
        d_loss.backward()
        optimizer_D.step()
        
        end_time = time.time()
        dt = end_time - start_time
        total_time += dt
    print(f'Epoch {epoch} / {epochs} in {dt:.2f} secs')
    print(f'Generator loss {g_loss.item():.4f}, Discriminator loss {d_loss.item():.4f}')
    # generate and visualize
    samples, recons = reconstruct(model, test_loader, device)
    visualize_imgs(samples, recons)
# print('Train loss[ELBO]:', train_losses["ELBO"])
# print('Test loss[Generator loss]:', g_loss.item())
# print('Test loss[MSE]:', test_losses["MSE"])
print(f'Average {total_time / epochs:.2f} secs per epoch consumed')
print(f'Total {total_time:.2f} secs consumed')