In [None]:
import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import numpy as np

# from faces_model_test import FaceTest
from data import ResizeTransform, EmojiDataset, MSCeleb1MDataset

get_ipython().magic('matplotlib inline')
get_ipython().magic('reload_ext autoreload')
get_ipython().magic('autoreload 2')

In [None]:
train_set = EmojiDataset('./datasets/emoji_data/', 0, 1000, transform=ResizeTransform(96))
train_loader = torch.utils.data.DataLoader(train_set, batch_size=1, shuffle=True)

data_iter = iter(train_loader)

img_tens = data_iter.next()
# plt.imshow(og_img)
plt.imshow(np.transpose(img_tens.numpy()[0], (1, 2, 0)))
plt.show()

In [None]:
train_set = EmojiDataset('./datasets/emoji_data/', 0, 1000, transform=ResizeTransform(96))
train_loader = torch.utils.data.DataLoader(train_set, batch_size=100, shuffle=True)

data_iter = iter(train_loader)

sum = torch.zeros((3))
for i in range(0, 100):
    if i % 1000 == 0:
        print('data ' + str(i))
    img = data_iter.next()
    sum[0] += torch.sum(img[:,0,:,:])
    sum[1] += torch.sum(img[:,1,:,:])
    sum[2] += torch.sum(img[:,2,:,:])

sum /= 100*100*96*96
print(sum)

mean = sum
diff_sum = torch.zeros((3))
for i in range(0, 100):
    if i % 1000 == 0:
        print('data ' + str(i))
    img = data_iter.next()
    diff_sum[0] += torch.sum((img[:,0,:,:] - mean[0])*(img[:,0,:,:] - mean[0]))
    diff_sum[1] += torch.sum((img[:,1,:,:] - mean[1])*(img[:,1,:,:] - mean[1]))
    diff_sum[2] += torch.sum((img[:,2,:,:] - mean[2])*(img[:,2,:,:] - mean[2]))

diff_sum /= (100*100*96*96 - 1)
std = torch.sqrt(diff_sum)
print(std)
    
img_tens = data_iter.next()
# plt.imshow(og_img)
plt.imshow(np.transpose(img_tens.numpy()[0], (1, 2, 0)))
plt.show()

In [None]:
train_set = MSCeleb1MDataset('./datasets/ms-celeb-1m/data/', 'train', ResizeTransform(96))
train_loader = torch.utils.data.DataLoader(train_set, batch_size=1, shuffle=True)

In [None]:
data_iter = iter(train_loader)

img_tens = data_iter.next()
plt.imshow(np.transpose(img_tens.numpy()[0], (1, 2, 0)))
plt.show()

# Begin Model Testing

In [None]:
import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import matplotlib
import matplotlib.pyplot as plt
import numpy as np

# from faces_model_test_open import FaceTest
# from faces_model_test_sphere import FaceTestSphere

from digit_model_test import digits_model_test
from data import ResizeTransform, EmojiDataset, MSCeleb1MDataset

from IPython import get_ipython

get_ipython().magic('matplotlib inline')
get_ipython().magic('reload_ext autoreload')
get_ipython().magic('autoreload 2')

In [None]:
# test = FaceTest(use_gpu=True)
test = digits_model_test(use_gpu=True)

In [None]:
test.create_data_loaders()

In [None]:
test.visualize_single_batch()

In [None]:
def get_num_parameters(model):
    total_num=0
    if type(model) == type(dict()):
        for key in model:
            for p in model[key].parameters():
                total_num+=p.nelement()
    else:
        for p in model.parameters():
            total_num+=p.nelement()
    return total_num

test.create_model()
print(test.model)
print('num params: ' + str(get_num_parameters(test.model)))

In [None]:
test.create_loss_function()

In [None]:
test.create_optimizer()

In [None]:
kwargs = {}
kwargs["visualize_batches"] = 10
kwargs["save_batches"] = 50
test.train_model(num_epochs=3, **kwargs)

In [None]:
checkpoint_name = './log/1513065292_0_1099.tar' 

In [None]:
torch.save(test.log, checkpoint_name)

In [None]:
log=torch.load(checkpoint_name)
G_model=log['G_model']
D_model=log['D_model']
test.log = log
test.model['G'] = G_model
test.model['D'] = D_model

In [None]:
plt.figure()
matplotlib.rcParams['font.size'] = 14
for key in test.log:
    if 'loss' in key and test.log[key] is not None:
        y = test.log[key][2:]
        x = np.arange(0, 11450, 50)
        plt.plot(y,label=key)
        plt.legend()
#         plt.ylabel('Loss')
#         plt.xlabel('Batches (size 128)')
        plt.show()