In [None]:
!pip install gdown==4.4.0

In [1]:
import numpy as np
import os
import torch
import torch.nn as nn

In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models 

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

class CelebaEncoder(nn.Module):
    """ Celeba Encoder
        Args:
            init_num_filters (int): initial number of filters from encoder image channels
            lrelu_slope (float): positive number indicating LeakyReLU negative slope
            inter_fc_dim (int): intermediate fully connected dimensionality prior to embedding layer
            embedding_dim (int): embedding dimensionality
    """
    def __init__(self, init_num_filters=16, lrelu_slope=0.2, embedding_dim=128, nc=3, dropout=0.05):
        super(CelebaEncoder, self).__init__()

        self.init_num_filters_ = init_num_filters
        self.lrelu_slope_ = lrelu_slope
        self.embedding_dim_ = embedding_dim

        self.features = nn.Sequential(
            nn.Conv2d(nc,  self.init_num_filters_ * 1, 4, 2, 1, bias=False),
            nn.LeakyReLU(self.lrelu_slope_, inplace=True),
            nn.Dropout(dropout),
            
            # state size. (ndf) x 32 x 32
            nn.Conv2d(self.init_num_filters_, self.init_num_filters_ * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.init_num_filters_ * 2),
            nn.LeakyReLU(self.lrelu_slope_, inplace=True),
            nn.Dropout(dropout),
            
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(self.init_num_filters_  * 2, self.init_num_filters_ * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.init_num_filters_  * 4),
            nn.LeakyReLU(self.lrelu_slope_, inplace=True),
            nn.Dropout(dropout),
            
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(self.init_num_filters_  * 4, self.init_num_filters_ * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.init_num_filters_  * 8),
            nn.LeakyReLU(self.lrelu_slope_, inplace=True),
            nn.Dropout(dropout),
            
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(self.init_num_filters_ * 8, self.init_num_filters_ * 8, 4, 2, 0, bias=False),
        )
        
        self.fc_out = nn.Sequential(
            nn.Linear(self.init_num_filters_ * 8, self.embedding_dim_),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.features(x)
        x = x.flatten(start_dim=1)
        x = self.fc_out(x)
        return x

class CelebaDecoder(nn.Module):
    """ Celeba Decoder
        Args:
            init_num_filters (int): initial number of filters from encoder image channels
            lrelu_slope (float): positive number indicating LeakyReLU negative slope
            inter_fc_dim (int): intermediate fully connected dimensionality prior to embedding layer
            embedding_dim (int): embedding dimensionality
    """
    def __init__(self, init_num_filters=16, lrelu_slope=0.2, embedding_dim=128, nc=3, dropout=0.05):
        super(CelebaDecoder, self).__init__()

        self.init_num_filters_ = init_num_filters
        self.lrelu_slope_ = lrelu_slope
        self.embedding_dim_ = embedding_dim

        self.features = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(self.init_num_filters_ * 8, self.init_num_filters_ * 8, 3, 1, 1, bias=False),
            nn.BatchNorm2d(self.init_num_filters_ * 8),
            nn.LeakyReLU(lrelu_slope, inplace=True),
            nn.Dropout(dropout),
            
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(self.init_num_filters_ * 8, self.init_num_filters_ * 4, 3, 1, 1, bias=False),
            nn.BatchNorm2d(self.init_num_filters_ * 4),
            nn.LeakyReLU(lrelu_slope, inplace=True),
            nn.Dropout(dropout),

            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(self.init_num_filters_ * 4, self.init_num_filters_ * 2, 3, 1, 1, bias=False),
            nn.BatchNorm2d(self.init_num_filters_ * 2),
            nn.LeakyReLU(lrelu_slope, inplace=True),
            nn.Dropout(dropout),
            
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(self.init_num_filters_ * 2, self.init_num_filters_ * 1, 3, 1, 1, bias=False),
            nn.BatchNorm2d(self.init_num_filters_ * 1),
            nn.LeakyReLU(lrelu_slope, inplace=True),
            nn.Dropout(dropout),
            
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(self.init_num_filters_ * 1, nc, 3, 1, 1, bias=False),
            nn.Tanh()
        )
        
        self.fc_in = nn.Sequential(
            nn.Linear(self.embedding_dim_, self.init_num_filters_ * 8 * 4),
            nn.LeakyReLU(self.lrelu_slope_, inplace=True),
        )

    def forward(self, z):
        z = self.fc_in(z)
        z = z.view(-1, self.init_num_filters_ * 8, 2, 2)
        z = self.features(z)
        return z
    
class CelebaAutoencoder(nn.Module):
    """ Celeba Autoencoder
        Args:
            init_num_filters (int): initial number of filters from encoder image channels
            lrelu_slope (float): positive number indicating LeakyReLU negative slope
            inter_fc_dim (int): intermediate fully connected dimensionality prior to embedding layer
            embedding_dim (int): embedding dimensionality
    """
    def __init__(self, init_num_filters=16, lrelu_slope=0.2, inter_fc_dim=128, embedding_dim=2, conv_init='normal'):
        super(CelebaAutoencoder, self).__init__()

        self.init_num_filters_ = init_num_filters
        self.lrelu_slope_ = lrelu_slope
        self.inter_fc_dim_ = inter_fc_dim
        self.embedding_dim_ = embedding_dim

        self.encoder = CelebaEncoder(init_num_filters, lrelu_slope, embedding_dim)
        self.decoder = CelebaDecoder(init_num_filters, lrelu_slope, embedding_dim)

        if conv_init == 'normal':
            for m in self.modules():
                m.apply(weights_init)
        else:
            raise NotImplementedError()

    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z)

In [17]:
dummy = torch.randn((1, 3, 64, 64))

enc = CelebaEncoder()
dec = CelebaDecoder()

emb = enc(dummy)
print(emb.size())
rec = dec(emb)
print(rec.size())

torch.Size([1, 128])
torch.Size([1, 3, 64, 64])


In [18]:
from torchvision.utils import make_grid, save_image
import matplotlib.pyplot as plt


def plot_img(img, name):
    img = img.detach().cpu().numpy()
    plt.imshow(np.transpose(img, (1, 2, 0)))
    plt.savefig(name)


def train(train_loader, model, optimizer, criterion,
          device, lr_schedule=None):
    loss_sum = 0.0
    num_iters = len(train_loader)
    model.train()
    rand_batch = np.random.randint(0, num_iters)
    for idx, inp in enumerate(train_loader):
        if lr_schedule is not None:
            lr = lr_schedule(idx / num_iters)
            adjust_learning_rate(optimizer, lr)

        inp = inp.to(device)
        optimizer.zero_grad()
        
        output = model(inp)
        loss = criterion(inp, output)
        
        loss.backward()
        optimizer.step()

        with torch.no_grad():
            if idx == rand_batch:
                print_images = torch.cat([inp, output], dim=3)
                grid = make_grid(print_images, nrow=8, normalize=False)
                plot_img(grid.detach().cpu(), 'train_{}.jpg'.format(idx))
                
        loss_sum += loss.item()

    return {
        'loss': loss_sum / num_iters
    }


def test(test_loader, model, criterion,
         device):
    loss_sum = 0.0
    model.eval()
    num_iters = len(test_loader)
    rand_batch = np.random.randint(0, num_iters)
    for idx, inp in enumerate(test_loader):
        inp = inp.to(device)

        with torch.no_grad():
            output = model(inp)
            loss = criterion(inp, output)

        with torch.no_grad():
            if idx == rand_batch:
                print_images = torch.cat([inp, output], dim=3)
                grid = make_grid(print_images, nrow=8, normalize=False)
                save_image(grid.detach().cpu(), fp='imgs/test_{}.jpg'.format(idx))

        loss_sum += loss.item()

    return {
        'loss': loss_sum / num_iters
    }

In [28]:
import torch.optim as optim
import dataset

b_size = 128
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
base_lr = 1e-4

loaders = dataset.build_loader(
    dataset.CelebADataset,
    'data',
    b_size,
    2
)

In [29]:
model = CelebaAutoencoder(init_num_filters=32)
model = model.to(device)

opt = optim.Adam(model.parameters(), lr=base_lr, weight_decay=1e-8)

In [25]:
print(np.sum([np.prod(list(p.shape)) for p in model.parameters()]), 'params in AutoEncoder')

2721442 params in AutoEncoder


In [26]:
def learning_rate_schedule(base_lr, epoch, total_epochs):
    alpha = epoch / total_epochs
    if alpha <= 0.5:
        factor = 1.0
    elif alpha <= 0.9:
        factor = 1.0 - (alpha - 0.5) / 0.4 * 0.99
    else:
        factor = 0.01
    return factor * base_lr

def adjust_learning_rate(optimizer, lr):
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr

In [None]:
import time
import tqdm.auto as tqdm
start_epoch = 0
max_epochs = 40

criterion = nn.L1Loss()

test_res = {'loss': None}
for epoch in range(start_epoch, max_epochs+1):
    time_ep = time.perf_counter()

    lr = learning_rate_schedule(base_lr, epoch, max_epochs)
    adjust_learning_rate(opt, lr)

    train_res = train(loaders['train'], model, opt, criterion, device)
    #test_res = test(loaders['test'], model, criterion, device)
    time_ep = time.perf_counter() - time_ep
    print(f"[Epoch]: {epoch}, Train loss: {train_res['loss']}, Test loss: {test_res['loss']}, time elaplsed: {(time_ep / 60):.3f}")

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


[Epoch]: 0, Train loss: 0.16963877804853297, Test loss: None, time elaplsed: 3.069


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


[Epoch]: 1, Train loss: 0.1558675419300895, Test loss: None, time elaplsed: 3.034


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


[Epoch]: 2, Train loss: 0.15382722840465682, Test loss: None, time elaplsed: 3.023


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


[Epoch]: 3, Train loss: 0.15249412753917152, Test loss: None, time elaplsed: 3.001


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


[Epoch]: 4, Train loss: 0.15166372418253107, Test loss: None, time elaplsed: 2.980


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


[Epoch]: 5, Train loss: 0.15102593057635255, Test loss: None, time elaplsed: 2.984


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


[Epoch]: 6, Train loss: 0.15070591927176774, Test loss: None, time elaplsed: 2.974


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


[Epoch]: 7, Train loss: 0.15022706827431015, Test loss: None, time elaplsed: 2.964


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


[Epoch]: 8, Train loss: 0.14988604559240115, Test loss: None, time elaplsed: 2.970


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


[Epoch]: 9, Train loss: 0.14968018036849973, Test loss: None, time elaplsed: 2.976


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


[Epoch]: 10, Train loss: 0.14937034886493425, Test loss: None, time elaplsed: 3.020


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


[Epoch]: 11, Train loss: 0.14919319732572495, Test loss: None, time elaplsed: 2.963


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


[Epoch]: 12, Train loss: 0.1489880167399392, Test loss: None, time elaplsed: 2.981


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


[Epoch]: 13, Train loss: 0.14882057767223428, Test loss: None, time elaplsed: 2.978


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


[Epoch]: 14, Train loss: 0.14877844024888595, Test loss: None, time elaplsed: 3.012


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


[Epoch]: 15, Train loss: 0.1485369990670721, Test loss: None, time elaplsed: 2.985


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
