In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from tqdm import tqdm as tqdm
from IPython.display import clear_output
import matplotlib.pyplot as plt
import time
import numpy as np
import os
import pickle
from IPython.core.debugger import set_trace
import sys
import PIL
from PIL import Image
from matplotlib import pyplot as plt


# load loss functions
sys.path.append('../loss')
from loss_provider import LossProvider

In [2]:
dataset_path = '../datasets/celebA/'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(0)
data_dim = (3,64,64)
data_size = np.prod(data_dim)
batch_size = 128

In [3]:
 
# key word args for loading data
kwargs = {'num_workers': 1, 'pin_memory': True} if device == "cuda" else {}

# transformers
transformers = transforms.Compose([
    transforms.ToTensor()                                # as tensors
])
transformers_la = transforms.Compose([
    transforms.Grayscale(),
    transforms.ToTensor()                                # as tensors
])

data_set = datasets.ImageFolder(dataset_path, transform=transformers)
data_set_la = datasets.ImageFolder(dataset_path, transform=transformers_la)


# load datasets and make them easily fetchable in DataLoaders
data_loader = torch.utils.data.DataLoader(
    data_set,
    batch_size=batch_size, shuffle=True, drop_last=True, **kwargs)
data_loader_la = torch.utils.data.DataLoader(
    data_set_la,
    batch_size=batch_size, shuffle=True, drop_last=True, **kwargs)

# Data loading

In [4]:
for data, lable in data_loader:
    x = data
    break
    
for data, lable in data_loader_la:
    x_la = data
    break
    
x = x.to(device)
x_la = x_la.to(device)

# We reconstruct an inut sample, for 500 iterations

In [0]:
class ReconSample(nn.Module):
    def __init__(self, ground_truth, loss_function):
        super().__init__()
        self.loss = loss_function
        self.recon = nn.Parameter(torch.randn(ground_truth.shape))
        self.sigmoid = nn.Sigmoid()
        
    def get_recon(self):
        return self.sigmoid(self.recon)
    
    def forward(self, ground_truth):
        return self.loss(self.get_recon(), ground_truth)

def runtime_test(x, loss_function, epochs=500):
    reconstructor = ReconSample(x, loss_function)
    reconstructor = reconstructor.to(device)
    optimizer = torch.optim.SGD(reconstructor.parameters(), lr=10**-4)
    
    # train
    torch.cuda.reset_max_memory_allocated()
    mem0 =  torch.cuda.max_memory_allocated() 
    reconstructor.loss = reconstructor.loss.to(device)
    t0 = time.time()
    for iter in tqdm(range(epochs), leave=True, position=0):
        optimizer.zero_grad()
        loss = reconstructor.forward(x)
        loss.backward()
        optimizer.step()
    t1 = time.time()
    mem1 = torch.cuda.max_memory_allocated()
        
    return {'runtime':t1 - t0, 'memory':(mem1 - mem0) / (1024**2)}
        

# Run test for each loss function.

In [9]:
loss_provider = LossProvider()
results = {}
for _ in range(5):
    for color_model in ['RGB', 'LA']:
        for loss_metric in loss_provider.loss_functions:
            if loss_metric == 'Watson-vgg':
                continue
            key = loss_metric + ' ' + color_model
            if key not in results:
                results[key] = {}
                results[key]['runtime'] = []
                results[key]['memory'] = []
            loss_function = loss_provider.get_loss_function(loss_metric, color_model)
            data = x if color_model == 'RGB' else x_la
            res =  runtime_test(data, loss_function, epochs=500)
            results[key]['runtime'].append(res['runtime'])
            results[key]['memory'].append(res['memory'])

pickle.dump(results, open(os.path.join(g_drive_path, 'runtime_results_repitition.pickle'), 'wb'))
    

100%|██████████| 500/500 [00:00<00:00, 1529.55it/s]
100%|██████████| 500/500 [00:00<00:00, 2431.09it/s]
100%|██████████| 500/500 [00:03<00:00, 125.67it/s]
100%|██████████| 500/500 [00:12<00:00, 41.09it/s]
100%|██████████| 500/500 [00:13<00:00, 37.82it/s]
Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:08<00:00, 68.4MB/s]
100%|██████████| 500/500 [01:08<00:00,  7.28it/s]
Downloading: "https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth" to /root/.cache/torch/checkpoints/squeezenet1_1-f364aa15.pth
100%|██████████| 4.74M/4.74M [00:00<00:00, 24.1MB/s]
100%|██████████| 500/500 [00:09<00:00, 51.94it/s]
100%|██████████| 500/500 [00:00<00:00, 2511.63it/s]
100%|██████████| 500/500 [00:00<00:00, 2648.92it/s]
100%|██████████| 500/500 [00:05<00:00, 97.34it/s]
100%|██████████| 500/500 [00:03<00:00, 143.77it/s]
100%|██████████| 500/500 [00:03<00:00, 136.24it/s]
100%|██████████| 500/50

# Print results.

In [12]:
for model in results:
    print('{}: runtime mean: {}s, max memory {}Mb'.format(model, np.mean(results[model]['runtime']), max(results[model]['memory'])))

L1 RGB: runtime mean: 0.2212538242340088s, max memory 24.0009765625Mb
L2 RGB: runtime mean: 0.18922247886657714s, max memory 24.0009765625Mb
SSIM RGB: runtime mean: 3.9674572944641113s, max memory 114.00390625Mb
Watson-dct RGB: runtime mean: 12.134006309509278s, max memory 96.00537109375Mb
Watson-fft RGB: runtime mean: 13.034832668304443s, max memory 111.00830078125Mb
Deeploss-vgg RGB: runtime mean: 68.31339735984803s, max memory 2213.69580078125Mb
Deeploss-squeeze RGB: runtime mean: 9.598410892486573s, max memory 544.99609375Mb
L1 LA: runtime mean: 0.177947473526001s, max memory 8.0009765625Mb
L2 LA: runtime mean: 0.17431650161743165s, max memory 8.0009765625Mb
SSIM LA: runtime mean: 5.137272596359253s, max memory 38.00244140625Mb
Watson-dct LA: runtime mean: 3.4779707908630373s, max memory 35.0634765625Mb
Watson-fft LA: runtime mean: 3.6676302909851075s, max memory 37.5029296875Mb
Deeploss-vgg LA: runtime mean: 68.26854333877563s, max memory 2205.69580078125Mb
Deeploss-squeeze LA: ru

In [13]:
torch.cuda.get_device_name(device=None)

'Tesla P100-PCIE-16GB'