In [1]:
import os
import logging
import numpy as np

from scipy.spatial.transform import Rotation as R
from torchvision.transforms import ToTensor
from torchvision import transforms
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split

from PIL import Image
from pathlib import Path
from tqdm import tqdm
import json

import sys
import os

In [2]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

In [3]:
renders_path = Path('/mnt/ML/Datasets/shapenet renders/renders_pyrender')

In [4]:
with open(renders_path / 'metadatas.json', 'r') as f:
    asset_metadatas = json.load(f)

In [5]:
def refresh_cache():
    asset_metadatas = {}
    for asset_id, asset in asset_renders_list.items():
        with open(asset['metadata'], 'r') as f:
            metadatas = json.load(f)
        asset_metadatas[asset_id] = metadatas
    with open('asset_metadatas_cache.json', 'w') as f:
        json.dump(asset_metadatas, f)

In [6]:
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter('runs/autoencoder')

In [7]:
def load_asset(metadata, channel='rgba'):
    return Image.open(renders_path / metadata['asset_id'] / metadata[f'{channel}_filename'])

In [8]:
class NormalizeBackground(object):
    def __init__(self, size):
        self.size = size
        self.normalize = transforms.Normalize([0.5], [0.5])
    
    def __call__(self, image):
        rgb = image[:3]
        alpha = image[3]
        norm_rgb = self.normalize(rgb)
        return torch.mul(norm_rgb, alpha)

In [9]:
class RandomHue(object):
    def __init__(self):
        pass
    
    def adjust_hue(self, image, hue):
        rgb = image[:3]
        alpha = image[3:]
        rgb = transforms.functional.adjust_hue(rgb, hue)
        image = torch.cat([rgb, alpha], 0)
        return image
    
    def __call__(self, data, target):
        hue = np.random.rand() - 0.5
        data = self.adjust_hue(data, hue)
        target = self.adjust_hue(target, hue)
        return data, target

In [10]:
from torch.utils.data import Dataset
from torchvision.io import read_image

class AssetDataset(Dataset):
    def __init__(self, data, data_transforms, data_augmentations=None):
        self.data = data
        self.data_transforms = data_transforms
        self.data_augmentations = data_augmentations

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        input_image = load_asset(self.data[idx]['input'])
        target_image = load_asset(self.data[idx]['target'])
        input_image = transforms.ToTensor()(input_image)
        target_image = transforms.ToTensor()(target_image)
        if self.data_augmentations is not None:
            input_image, target_image = self.data_augmentations(input_image, target_image)
        input_image = self.data_transforms(input_image)
        target_image = self.data_transforms(target_image)
        return input_image, target_image

In [11]:
def get_image(x):
    invTrans = transforms.Compose([
        transforms.Normalize(mean = [ 0., 0., 0. ],
                             std = [ 1/0.5 ]),
        transforms.Normalize(mean = [ -0.5 ],
                             std = [ 1., 1., 1. ]),
        transforms.ToPILImage(),
    ])
    img = invTrans(x)
    return img

In [12]:
import matplotlib.pyplot as plt

def plot_predictions(data, output, target):
    fig = plt.figure(figsize=(2*len(data), 2*3))
    for idx in np.arange(len(target)):
        ax = fig.add_subplot(3, len(data), idx+1, xticks=[], yticks=[])
        plt.imshow(get_image(data[idx]))
        ax = fig.add_subplot(3, len(data), idx+1+len(data), xticks=[], yticks=[])
        plt.imshow(get_image(output[idx]))
        ax = fig.add_subplot(3, len(data), idx+1+len(data)*2, xticks=[], yticks=[])
        plt.imshow(get_image(target[idx]))
    return fig

In [13]:
def get_preprocessor(image_size):
    return transforms.Compose([
        transforms.Resize(image_size),
        NormalizeBackground(image_size),
    ])

In [14]:
from torch.utils.data import DataLoader

def makeDataLoader(data, data_transforms, data_augmentations, seed=0, test_size=0.1, batch_size=16):
    train_ids, valid_ids = train_test_split(data, test_size=test_size, random_state=seed)
    train_dataset = AssetDataset(data, data_transforms, data_augmentations)
    valid_dataset = AssetDataset(data, data_transforms)
    
    kwargs = {'num_workers': 4, 'pin_memory': True} if use_cuda else {}
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, **kwargs)
    valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, **kwargs)
    
    train_size = len(train_dataset)
    valid_size = len(valid_dataset)
    
    return train_loader, valid_loader, train_size, valid_size

In [15]:
nonlinearity = F.relu

In [16]:
class SelfAttention(nn.Module):
    def __init__(self, num_tokens, dim):
        super(SelfAttention, self).__init__()
        self.num_tokens = num_tokens
        self.size = dim
        self.multiheadattention = nn.MultiheadAttention(num_tokens * (dim + 1), num_tokens)
        
    def forward(self, xb):
        
        onehots = F.one_hot(torch.arange(0, self.num_tokens)).repeat(xb.shape[0], 1)
        tokens = xb.view(-1, num_tokens, dim)
        

In [17]:
class AutoEncoder(nn.Module):
    def __init__(self, size):
        super(AutoEncoder, self).__init__()
        self.size = size
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(128, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
        )
        self.middle = nn.Sequential(
            nn.Conv2d(256, 32, 1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(16*16*32, 1000),
            nn.ReLU(),
            nn.Linear(1000, 16*16*32),
            nn.ReLU(),
            nn.Unflatten(1, (32, 16, 16)),
            nn.Conv2d(32, 256, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 256, 2, stride=2),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, 2, stride=2),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 2, stride=2),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 2, stride=2),
            nn.Conv2d(32, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 3, 1),
            nn.Tanh(),
        )

    def forward(self, xb):
        output = xb
        output = self.encoder(output)
        output = self.middle(output)
        output = self.decoder(output)
        return output

In [18]:
def train(model, train_loader, loss_func, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        if use_cuda:
            data, target = data.cuda(), target.cuda()
        optimizer.zero_grad()
        output = model(data)
        loss = loss_func(0.5*output+0.5, 0.5*target+0.5)
        loss.backward()
        optimizer.step()

        writer.add_scalar('training loss',
                          loss.data.item() / len(data),
                          epoch * len(train_loader) + batch_idx)

        if batch_idx % 16 == 0:
            writer.add_figure('predictions',
                              plot_predictions(data[:4], output[:4], target[:4]),
                              global_step=epoch * len(train_loader) + batch_idx)

In [19]:
def validation(model, valid_loader, loss_func, epoch, metric=None):
    model.eval()
    validation_loss = 0
    with torch.no_grad():
        for data, target in valid_loader:
            if use_cuda:
                data, target = data.cuda(), target.cuda()
            output = model(data)
            validation_loss += loss_func(0.5*output+0.5, 0.5*target+0.5).data.item()

        validation_loss /= len(valid_loader.dataset)

        writer.add_scalar('validation loss',
                          validation_loss,
                          epoch + 1)
        
        print(f'epoch {epoch}: Validation set: Average loss: {validation_loss:.4f}')

In [20]:
def train_model(model, train_loader, valid_loader, loss_func, lr=0.001, seed=0, epochs=30):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    for epoch in range(epochs):
        train(model, train_loader, loss_func, optimizer, epoch)
        validation(model, valid_loader, loss_func, epoch)

In [21]:
image_size = 256
preprocess = get_preprocessor((image_size, image_size))
data_transforms = preprocess
data_augmentations = None#RandomHue()
data = []
for metadata in asset_metadatas:
    input_data = next(filter(lambda m: m['name'] == 'angle_30', metadata))
    target_data = next(filter(lambda m: m['name'] == 'angle_60', metadata))
    data.append({'input': input_data, 'target': target_data})

In [22]:
lr = 0.01
batch_size = 64
train_loader, valid_loader, train_size, val_size = makeDataLoader(data, data_transforms, data_augmentations, seed=0, batch_size=batch_size)
model = AutoEncoder(image_size).to(device)

In [23]:
model

AutoEncoder(
  (encoder): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU()
    (11): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (12): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True

In [24]:
from pytorch_msssim import MS_SSIM, ms_ssim, SSIM, ssim

class MS_SSIM_Loss(MS_SSIM):
    def forward(self, img1, img2):
        return  100*(1 - super(MS_SSIM_Loss, self).forward(img1, img2))
    
ms_ssim_loss = MS_SSIM_Loss(data_range=1.0, size_average=True, channel=3)

In [25]:
train_model(model, train_loader, valid_loader, ms_ssim_loss, seed=0, epochs=200)

epoch 0: Validation set: Average loss: 0.9337
epoch 1: Validation set: Average loss: 0.7012
epoch 2: Validation set: Average loss: 0.5679
epoch 3: Validation set: Average loss: 0.4674
epoch 4: Validation set: Average loss: 0.4387
epoch 5: Validation set: Average loss: 0.4152
epoch 6: Validation set: Average loss: 0.3981
epoch 7: Validation set: Average loss: 0.5088
epoch 8: Validation set: Average loss: 0.3613
epoch 9: Validation set: Average loss: 0.5880
epoch 10: Validation set: Average loss: 0.3502
epoch 11: Validation set: Average loss: 0.3385
epoch 12: Validation set: Average loss: 0.4410
epoch 13: Validation set: Average loss: 0.3496
epoch 14: Validation set: Average loss: 0.3372
epoch 15: Validation set: Average loss: 0.3337
epoch 16: Validation set: Average loss: 0.3478
epoch 17: Validation set: Average loss: 0.3228
epoch 18: Validation set: Average loss: 0.3219
epoch 19: Validation set: Average loss: 0.3371
epoch 20: Validation set: Average loss: 0.3167
epoch 21: Validation se

KeyboardInterrupt: 