In [None]:
from model import *
from utils import *

BATCH_SIZE = 256
LR = 1e-3
MAX_EPOCHS = 50

In [None]:
multiSet = MultiSet('pokemon')
dataLoader = Utils.DataLoader(dataset=multiSet, shuffle=True, batch_size=1)

for i, data in enumerate(dataLoader):
    if i in range(5):
        plt.imshow(data[0,:,:,:].squeeze().numpy())
        plt.show()

In [None]:
try:
    net, epoch, losses, bces, kls, optimizer, scheduler = load_checkpoint("./checkpoints/", LR)
except:
    net = Net() # Initialize model
    if torch.cuda.is_available():
        if torch.cuda.device_count() > 1:
            net = torch.nn.DataParallel(net)
        net = net.cuda() 
    epoch = 0
    losses = []
    bces = []
    kls = []
    optimizer = optim.Adam(net.parameters(), lr=LR, amsgrad=True)
    scheduler = SGDRScheduler(optimizer, min_lr=1e-5, max_lr=LR, cycle_length=500, current_step=0)
    print("Starting new training")

gen_data_list()
multiSet = MultiSet('pokemon')
dataloader = Utils.DataLoader(dataset=multiSet, shuffle=True, batch_size=BATCH_SIZE)

train_losses, bces, kls = train(net, optimizer, scheduler,dataloader, epoch, "POKEVAE", losses, bces, kls, MAX_EPOCHS)
plt.figure(figsize=(10,5))
plt.plot(train_losses)
plt.show()

plt.figure(figsize=(10,5))
plt.plot(bces)
plt.show()

plt.figure(figsize=(10,5))
plt.plot(kls)
plt.show()

In [None]:
try:
    net, epoch, losses, bces, kls, optimizer, scheduler = load_checkpoint("./checkpoints/VAE_epoch_000490.pth", LR)
except:
    raise

gen_data_list()
multiSet = MultiSet('pokemon')
dataLoader = Utils.DataLoader(dataset=multiSet, shuffle=True, batch_size=BATCH_SIZE)

for images in dataLoader:
    clear_output(wait=True)
    multi_plot(images, net)
    break

In [None]:
net = Net() # Initialize model
if torch.cuda.is_available():
    if torch.cuda.device_count() > 1:
        net = torch.nn.DataParallel(net)
    net = net.cuda()

gen_data_list()
multiSet = MultiSet('pokemon')
dataLoader = Utils.DataLoader(dataset=multiSet, shuffle=True, batch_size=1)

for images in dataLoader:
    sweep(images, 99, -15, 15, 1)
    break

In [None]:
# Visualize moving average of losses
def visualize_losses_moving_average(losses,window=50,boundary='valid'):
    mav_losses = np.convolve(losses,np.ones(window)/window,boundary)
    corrected_mav_losses = np.append(np.full(window-1,np.nan),mav_losses)
    plt.figure(figsize=(10,5))
    plt.plot(losses)
    plt.plot(corrected_mav_losses)
    plt.ylim(ylim)
    plt.show()

visualize_losses_moving_average(train_losses)

In [None]:
#Test network accuracy

correct = 0
total = 0
for data in testloader:
    images, labels = data
    outputs, _, _ = net(Variable(images))
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum()

print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))

In [None]:
imageimage  ==  imageimage..permutepermute(0,2,3,1)
for i in range(256):
    os.system("mkdir ./sweep/"+str(i))
    sweep(image, i, -300, 300, 5)

In [None]:
def interpolate_linear(path1, path2, savepath, STEP):
    pokemon = mpimg.imread(path1)
    pokemon = cv2.resize(pokemon, (RESIZE,RESIZE))
    pokemon = (pokemon - np.min(pokemon))/np.max(pokemon - np.min(pokemon))
    image = torch.from_numpy(pokemon).unsqueeze(0)

    image = image[0,:,:,:].unsqueeze(0)
    image = image.permute(0,3,1,2)
    z_mean, z_logvar = net.encoder(Variable(image.float()).cuda())
    z1 = net.latent(z_mean, z_logvar)
    ss = net.decoder(z1)
    ss = ss.permute(0,2,3,1)
    plt.imshow(ss.data.cpu().squeeze().numpy())
    plt.show()

    pokemon = mpimg.imread(path2)
    pokemon = cv2.resize(pokemon, (RESIZE,RESIZE))
    pokemon = (pokemon - np.min(pokemon))/np.max(pokemon - np.min(pokemon))
    image = torch.from_numpy(pokemon).unsqueeze(0)

    image = image[0,:,:,:].unsqueeze(0)
    image = image.permute(0,3,1,2)
    z_mean, z_logvar = net.encoder(Variable(image.float()).cuda())
    z2 = net.latent(z_mean, z_logvar)
    ss = net.decoder(z2)
    ss = ss.permute(0,2,3,1)
    plt.imshow(ss.data.cpu().squeeze().numpy())
    plt.show()
    step = (z2-z1)/STEP

    for i in range(STEP):
        ss = net.decoder(z1 + i*step)
        ss = ss.permute(0,2,3,1)
        plt.imsave(savepath + "{0:03d}".format(i) + ".png", ss.data.cpu().squeeze().numpy())

In [None]:
STEP  ==  100
interpolate_linearinterpo ("./Pokemon/squirtle.jpg", "./Pokemon/charmander.jpg", "./interpolate/squirtle_charmander/", STEP)