In [0]:
!pip install parse

In [0]:
import math
import random
import zipfile
from parse import *

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F

# Torchvision
import torchvision.utils
import torchvision.transforms as transforms
import torchvision.datasets as dset
import torchvision.utils as vutils
from torch.utils.data import Dataset, DataLoader

In [0]:
#### Configuraciones globales

# Numero de canales de entrada (3= RGB)
nc = 3
# Tamaño del latent vector (z)
latent_dims = 512
# Learning rate
lr = 1e-3
# Cantidad de épocas de entrenamiento
num_epochs = 10
# Tamaño del batch
batch_size = 128
# Escala de la imagen
image_size = 64
capacity = 32
# Cantidad de GPUs disponibles
ngpu = 1
# Directorio de descarga del dataset
dataroot = "datasets/shapenet/"

rs = np.random.RandomState(123)

In [0]:
#Usar esta opción si está en colab

#!wget https://www.dropbox.com/s/lc01fm5o8dbrp59/nvs_chair.zip

In [0]:
#with zipfile.ZipFile("nvs_chair.zip","r") as zip_ref:
#  zip_ref.extractall(dataroot)

In [0]:
### Manipulación y preparación de los datos.

class ImgDataset(Dataset):

  def __init__(self, ids):
    """
      Modela el dataset
    """
    self._ids = list(ids)
    self.dataset_name = "chair"
    self.img_path = 'datasets/shapenet'
    self.transform = transforms.Compose([
              transforms.Resize(image_size),
              transforms.ToTensor(),
              transforms.Normalize((0.5,), (0.5,), (0.5,))
            ])


  def __getitem__(self, idx):
    """
      Para una imágen en particular de un objeto se seleccion una imágen 
      del mismo objeto con un viewpoint distinto. Por lo tanto, este método
      devuelve la 3-upla (imagen, target_image, target_viewpoint)
    """
    img_id = self._ids[idx]
    only_id = img_id.split("_")[0]
    
    src_img = self.readImageToArray(img_id)

    elevation = random.choice([0, 10, 20])
    rotation = random.choice(range(0, 36, 2))

    rsin, rcos = math.sin(math.radians(rotation*10)), math.cos(math.radians(rotation*10))
    esin, ecos = math.sin(math.radians(elevation)), math.cos(math.radians(elevation))


    dst_img = self.readImageToArray("{}_{}_{}".format(only_id, rotation,elevation))

    src_img = self.transform(src_img)
    dst_img = self.transform(dst_img)

    return (src_img, dst_img, torch.tensor([rsin, rcos, esin, ecos]).float())


  def __len__(self):
    return len(self._ids)


  def readImageToArray(self, in_id):
    img = Image.open(self.img_path + '/' + self.dataset_name + '/' + in_id + '.png')
    return img


def all_ids(dataset_name='chair', shuffle_train=True, shuffle_test=True):
    """
      Tomado del repo de TB-Networks.
      Usa los archivos de train y validación para armar los arreglos donde 
      se emplean los códigos que identifican a cada objeto.
    """
    import os.path as osp

    with open(osp.join('.', 'id_' + dataset_name + '_train.txt'), 'r') as fp:
        ids_train = [s.strip() for s in fp.readlines() if s]
    if shuffle_train:
        rs.shuffle(ids_train)

    with open(osp.join('.', 'id_' + dataset_name + '_test.txt'), 'r') as fp:
        ids_test = [s.strip() for s in fp.readlines() if s]
    if shuffle_test:
        rs.shuffle(ids_test)

    return ids_train, ids_test


ids_train, ids_test = all_ids()

dataset_train = ImgDataset(ids_train)

train_loader = DataLoader(dataset_train,
                          batch_size=128,
                          num_workers=3, drop_last=True,
                          shuffle=True)

device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
print(device)

In [0]:
src_imgs, dst_imgs, viewpoint = next(iter(train_loader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(src_imgs.to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))

In [0]:
### Arquitetura de Red 8192

In [0]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        c = capacity
        self.conv1 = nn.Conv2d(in_channels=nc, out_channels=c, kernel_size=4, stride=2, padding=1) # out: c x 14 x 14
        self.conv2 = nn.Conv2d(in_channels=c, out_channels=c*2, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.Conv2d(in_channels=c*2, out_channels=c*2, kernel_size=4, stride=2, padding=1) # out: c x 7 x 7
        
        self.fc = nn.Linear(in_features=c*2*8*8, out_features=latent_dims)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(x.size(0), -1)     
        x = self.fc(x)
        return x

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        c = capacity
        
        self.fc1 = nn.Linear(in_features=latent_dims+64, out_features=latent_dims)
        self.fc2 = nn.Linear(in_features=latent_dims, out_features=latent_dims)
        self.fc3 = nn.Linear(in_features=latent_dims, out_features=c*2*8*8)

        self.conv1 = nn.ConvTranspose2d(in_channels=c*2, out_channels=c*2, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.ConvTranspose2d(in_channels=c*2, out_channels=c, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.ConvTranspose2d(in_channels=c, out_channels=nc, kernel_size=4, stride=2, padding=1)
            
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = x.view(x.size(0), capacity*2, 8, 8)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = torch.tanh(self.conv3(x)) 
        return x

class Viewpoint(nn.Module):
    """
      Parte de la red que modela el viewpoint de una imagen.
      El viewpoint de la imágen está formado por el ángulo de rotación
      y el ángulo de elevación.
    """

    def __init__(self):
      super(Viewpoint, self).__init__()
      self.fc1 = nn.Linear(in_features=4, out_features=64)
      self.fc2 = nn.Linear(in_features=64, out_features=64)
      self.fc3 = nn.Linear(in_features=64, out_features=64)

    def forward(self, x):
      x = F.relu(self.fc1(x))
      x = F.relu(self.fc2(x))
      x = F.relu(self.fc3(x))
      return x
    
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.viewpoint = Viewpoint()

    def forward(self, x, theta):
        """
          Para evitar confusión, theta no es una ángulo
          particular sino que es es el vector de viewpoint.
        """
        latent = self.encoder(x)
        theta = self.viewpoint(theta)
        result = torch.cat((latent, theta), dim=1)
        x_recon = self.decoder(result)
        return x_recon
    
autoencoder = Autoencoder()

autoencoder = autoencoder.to(device)

In [0]:
num_params = sum(p.numel() for p in autoencoder.parameters() if p.requires_grad)
print('Number of parameters: %d' % num_params)
print(autoencoder)

In [0]:
## Tomdas del repo de ejemplo
# Funciones de graficación de las imagenes
def to_img(x):
    x = 0.5 * (x + 1)
    x = x.clamp(0, 1)
    return x

def show_image(img):
    img = img.cpu()
    img = to_img(img)
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

def visualise_output(images, viewpoint, model):

    with torch.no_grad():

        images = images.to(device)
        viewpoint = viewpoint.to(device)
        images = model(images, viewpoint)
        images = images.cpu()
        images = to_img(images)
        np_imagegrid = torchvision.utils.make_grid(images[:50], 10, 5).numpy()
        plt.imshow(np.transpose(np_imagegrid, (1, 2, 0)))
        plt.show()


In [0]:
optimizer = torch.optim.Adam(params=autoencoder.parameters(), lr=lr, weight_decay=1e-5)

# red en modo entrenamiento
autoencoder.train()

train_loss_avg = []

print('Training ...')
for epoch in range(num_epochs):
    train_loss_avg.append(0)
    num_batches = 0

    image_batch_recon = None
    viewpoint = None
    
    #for image_batch,_,rotation,elevation in dataloader:
    for src_imgs, dst_imgs, viewpoint in train_loader:
        
        #image_batch = image_batch.to(device)
        src_imgs = src_imgs.to(device)
        dst_imgs = dst_imgs.to(device)
        viewpoint = viewpoint.to(device)
       
        # reconstrucción del autoencoder
        image_batch_recon = autoencoder(src_imgs, viewpoint)
        
        # error de reconstrucción
        loss = F.mse_loss(image_batch_recon, dst_imgs)
        
        # backpropagation
        optimizer.zero_grad()
        loss.backward()
        
        # optimizamos los pesos usando el gradiente propagado por backprop
        optimizer.step()
        
        train_loss_avg[-1] += loss.item()
        num_batches += 1

    print("Input")
    show_image(torchvision.utils.make_grid(src_imgs[:50],10,5))
    plt.show()
    
    print("Target")
    show_image(torchvision.utils.make_grid(dst_imgs[:50],10,5))
    plt.show()

    print("Reconstruction")
    visualise_output(image_batch_recon, viewpoint, autoencoder)
        
    train_loss_avg[-1] /= num_batches
    print('Epoch [%d / %d] average reconstruction error: %f' % (epoch+1, num_epochs, train_loss_avg[-1]))
    print("-"*25)

In [0]:
train_loss_avg

In [0]:
fig = plt.figure()
plt.plot(train_loss_avg)
plt.xlabel('Epochs')
plt.ylabel('Reconstruction error')
plt.show()

In [0]:
## Model en mode evaluación.
autoencoder.eval()

#iterator = iter(dataloader)
iterator = iter(train_loader)
iterator.next()
iterator.next()

images, target, viewpoint = iterator.next()

In [0]:
### Generamos una UI minima para interectuar con la red y generar 
### las nuevas vistas.

In [0]:
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

In [0]:
autoencoder.eval()

def f(rotation, elevation):
  with torch.no_grad():

    rsin, rcos = math.sin(math.radians(rotation)), math.cos(math.radians(rotation))
    esin, ecos = math.sin(math.radians(elevation)), math.cos(math.radians(elevation))

    viewpoint = [rsin, rcos, esin, ecos]

    viewpoint = torch.tensor([viewpoint])
    viewpoint = viewpoint.repeat(batch_size, 1).float()
    viewpoint = viewpoint.to(device)

    batch = images.to(device)

    resultado = autoencoder(batch, viewpoint)

    resultado = resultado.cpu()
    resultado = to_img(resultado)
    
    fig, ax = plt.subplots(1, 1)
    np_imagegrid = torchvision.utils.make_grid(resultado[:50], 10, 5).numpy()
    plt.imshow(np.transpose(np_imagegrid, (1, 2, 0)))
    plt.show()

interactive_plot = interactive(f, rotation=widgets.IntSlider(min=0, max=360, step=1, value=10),\
                               elevation=widgets.IntSlider(min=0, max=20, step=1, value=10))
output = interactive_plot.children[-1]
output.layout.height = '350px'
interactive_plot