In [None]:
%matplotlib inline
from data_loader import DataLoader
import numpy as np
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

batch_size = 3
dataloader = DataLoader(batch_size)

def show(img):
    print(img.shape)
    pilTrans = transforms.ToPILImage()
    pilImg = pilTrans(img)
    s = np.array(pilImg)
    plt.figure()
    plt.imshow(s)

for i in range(2):
    (b_img, b_map) = dataloader.get_batch()
    show(b_img[0])
    show(b_img[1])
    show(b_img[2])
    show(b_map[0])
    show(b_map[1])
    show(b_map[2])

In [None]:
from discriminator import Discriminator
from generator import Generator
import torch
import torch.nn as nn
from torch.autograd import Variable

batch_size = 10
lr = 0.0003

discriminator = Discriminator()
generator = Generator()
if torch.cuda.is_available():
    discriminator.cuda()
    generator.cuda()
loss_function = nn.BCELoss()

d_optim = torch.optim.Adam(discriminator.parameters(), lr=lr)
g_optim = torch.optim.Adam(generator.parameters(), lr=lr)

num_epoch = 1
dataloader = DataLoader(batch_size)
num_batch = dataloader.num_batches# length of data / batch_size
num_batch = 15

In [None]:
import time
from tqdm import tqdm

def to_variable(x,requires_grad=True):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x,requires_grad)

counter = 0
start_time = time.time()
for current_epoch in tqdm(range(1,num_epoch+1)):
    n_updates = 1

    d_cost_avg = 0
    g_cost_avg = 0
    for idx in range(num_batch):

        (batch_img, batch_map) = dataloader.get_batch()
        batch_img = to_variable(batch_img,requires_grad=False) # [-1,3,h,w]
        batch_map = to_variable(batch_map,requires_grad=False) # [-1,1,h,w]
        real_labels = to_variable(torch.FloatTensor(np.ones(batch_size, dtype = float)),requires_grad=False)
        fake_labels = to_variable(torch.FloatTensor(np.zeros(batch_size, dtype = float)),requires_grad=False)

        if n_updates % 2 == 1:
            #print('Training Discriminator...')
            discriminator.zero_grad()
            inp_d = torch.cat((batch_img,batch_map),1)
            #print(inp_d.size())
            outputs = discriminator(inp_d).squeeze()
            d_real_loss = loss_function(outputs,real_labels)
            #print('D_real_loss = ', d_real_loss.data[0])

            #print(outputs)
            real_score = outputs.data.mean()

#            fake_map = generator(batch_img)
#            inp_d = torch.cat((batch_img,fake_map),1)
#            outputs = discriminator(inp_d)
#            d_fake_loss = loss_function(outputs, fake_labels)
#            print('D_fake_loss = ', d_fake_loss.data[0])
            d_loss = torch.sum(torch.log(outputs))
            #print(d_loss)
            d_cost_avg += d_loss.data[0]

            d_loss.backward()
            d_optim.step()
        else:
            #print('Training Generator...')
            generator.zero_grad()
            fake_map = generator(batch_img)
            inp_d = torch.cat((batch_img,fake_map),1)
            outputs = discriminator(inp_d)
            fake_score = outputs.data.mean()

            g_gen_loss = loss_function(fake_map,batch_map)
            g_dis_loss = -torch.log(outputs)
            alpha = 0.05
            g_loss = torch.sum(g_dis_loss + alpha * g_gen_loss)

            g_cost_avg += g_loss.data[0]

            g_loss.backward()
            g_optim.step()

        n_updates += 1

        if (idx+1)%5 == 0:
            print("Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, g_loss: %.4f, D(x): %2.f, D(G(x)): %.2f, time: %4.4f"
                % (current_epoch, num_epoch, idx+1, num_batch, d_cost_avg, g_cost_avg,
                real_score, fake_score, time.time()-start_time))
        counter += 1

    # Save weights every 3 epoch
    if current_epoch % 3 == 0:
        #predict(model=model, image_stimuli=validation_sample, numEpoch=current_epoch, pathOutputMaps=DIR_TO_SAVE)
        print 'Epoch:', current_epoch, ' train_loss->', (d_cost_avg, g_cost_avg)


In [None]:
import cv2
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

def to_variable(x,requires_grad=True):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x,requires_grad)

def show(img):
    #print(img.shape)
    pilTrans = transforms.ToPILImage()
    pilImg = pilTrans(img)
    s = np.array(pilImg)
    plt.figure()
    plt.imshow(s)
    
def save_gray(img, path):
    pilTrans = transforms.ToPILImage()
    pilImg = pilTrans(img)
    print(path)
    pilImg.save(path)
    
def predict(model, img, epoch, path):
    to_tensor = transforms.ToTensor() # Transforms 0-255 numbers to 0 - 1.0.
    im = to_tensor(img)
    show(im)
    inp = to_variable(im.unsqueeze(0), False)
    #print(inp.size())
    out = model(inp)
    map_out = out.cpu().data.squeeze(0)
    #show_gray(map_out)
    
    new_path = path + str(epoch) + ".png"
    save_gray(map_out, new_path)
    
    #s = np.array(Image.open(new_path))
    #plt.figure()
    #plt.imshow(s)
    
DIR_TO_SAVE = "./generator_output/"
validation_sample = cv2.imread("COCO_val2014_000000143859.png")
predict(generator, validation_sample, current_epoch, DIR_TO_SAVE)

In [None]:
path = "./generator_output/6.png"
s = np.array(Image.open(new_path))
plt.figure()
plt.imshow(s)

In [None]:
import pickle
from generator import Generator
import torch.utils.model_zoo as model_zoo

with open('vgg16.pkl', 'rb') as f:
    data = pickle.load(f)
    print(len(data))
    G = Generator()
    #print(G)
    z = 'https://s3-us-west-2.amazonaws.com/jcjohns-models/vgg16-00b39a1b.pth'
    f= 'https://download.pytorch.org/models/vgg16-397923af.pth'
    G.load_state_dict(model_zoo.load_url(z))
    print(len(data))

In [1]:
from discriminator import Discriminator
from torch.autograd import Variable
import torch

#D = Discriminator()
x = Variable(torch.rand([17, 4, 192, 256]))
print(x.data.max())
print(x.data.min())
model = Discriminator()
print('Discriminator input', x.size()) #[-1, 4, 192, 256] because 4 comes from 3 color channel + salience layer.
out = model(x)
print('Discriminator out ', out) #[-1, 1]


0.999999642372
6.28642737865e-09
('Discriminator input', torch.Size([17, 4, 192, 256]))
('Discriminator out ', Variable containing:
 0.4530
 0.4530
 0.4531
 0.4530
 0.4530
 0.4531
 0.4530
 0.4530
 0.4530
 0.4531
 0.4529
 0.4530
 0.4530
 0.4530
 0.4530
 0.4529
 0.4530
[torch.FloatTensor of size 17x1]
)
