In [1]:
import os
import sys

import torchvision
from torch.backends import cudnn
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils import data
from torchvision import transforms
from torchvision import datasets

import numpy as np
from matplotlib import pyplot as plt
import PIL.Image as Image

from pandas import read_fwf, DataFrame
from tqdm   import tqdm_notebook as tqdm

In [2]:
#Load our custom libs
from utils import *
from radioreader import *
from methods import *
from kittler import kittler_float

In [3]:
lrg = read_fwf('catalog/mrt-table3.txt', skiprows=41, header=None)
labeled = DataFrame({'Name':lrg[0], 'Label':lrg[7]})

#load the images
names = labeled['Name'].tolist()
labels = labeled['Label'].tolist()
images = []
directory = 'lrg'
ext = 'fits'

for i in tqdm(range(len(names))):
    f_name = '{0}/{1}.{2}'.format(directory, 
                                  names[i].replace('.','_'), 
                                  ext)
    im = readImg(f_name, normalize=True, sz=128)
#     k = kittler_float(im, copy=False)
    images.append(im.T)

images = np.array(images)
extended_sources = [ 0 if (l == '1' or l == '1F') else 1 for l in labels]
extended_sources = np.array(np.array(extended_sources))
print('# of Extended ', extended_sources.sum() , 'of', len(extended_sources))


HBox(children=(IntProgress(value=0, max=1442), HTML(value='')))


# of Extended  1037 of 1442


In [4]:
class LRG(data.Dataset):
    def __init__(self, images, target, transform=None):
        self.data = images
        self.labels = extended_sources
        self.data_len = len(self.data)
        if(transform == None):
            self.transform = transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip(),
                transforms.RandomRotation(180),
#                 transforms.RandomResizedCrop((128,128),(.5,1)),
                transforms.ToTensor()])
        else : self.transform = transform

    def __getitem__(self, index):
        index = index % self.data_len
        np_arr = self.data[index, :]
        y = self.labels[index]
        ## reshape np_arr to 28x28
        np_arr = np_arr.reshape(128, 128)

        ## convert to PIL-image
        img = Image.fromarray((np_arr*255).astype('uint8'))

        #apply the transformations and return tensors
        return self.transform(img), y
    def __len__(self):
        return self.data_len * 10

In [None]:
# test_noise = noise_sample(10,10,1,1,1,'cpu')
params = {}
params['num_z']      = 72
params['dis_c_dim']  =  2 # size of discrete latent code
params['num_con_c']  = 10 # size of continuous latent code

params['batch_size'] = 10
params['epochs'] = 200

params['image_size']  = 128
params['num_workers'] = 1

#optimization params
params['lrD']   = 0.0002  # Learning rate for Discriminator
params['lrG']   = 0.001   # Learning rate for Generator
params['beta1'] = 0.5   # Momentum 1 in Adam
params['beta2'] = 0.999 # Momentum 2 in Adam

######################
test_noise = noise_sample(params['dis_c_dim'], 
                          params['num_con_c'], params['num_z'],
                          params['batch_size'],'cpu')

In [None]:
dataset = LRG(images, extended_sources)
dataloader = data.DataLoader(dataset, batch_size=params['batch_size'])
dataset.__len__()

14420

In [None]:
class Generator(nn.Module):
    def __init__(self, z_dim=62, cc_dim=2, dc_dim=10):
        super(Generator, self).__init__()
        i_dim = z_dim + cc_dim + dc_dim
        self.main = nn.Sequential(
            nn.ConvTranspose2d(i_dim, 1024, 1, 1, bias=False),
            nn.BatchNorm2d(1024),
            nn.ReLU(True),
            nn.ConvTranspose2d(1024,  512, 8, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 1, 4, 2, 1, bias=False),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )

    def forward(self, z):
        out = self.main(z)
        return out

class FrontEnd(nn.Module):
    def __init__(self):
        super(FrontEnd, self).__init__()
        
        self.main = nn.Sequential(
            nn.Conv2d(1, 64, 4, 2, 1),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Conv2d(512, 1024, 8, bias=False),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.1, inplace=True),
        )

    def forward(self, x):
        return self.main(x)
    
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.linear = nn.Sequential(
            nn.Conv2d(1024, 1, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        out = self.linear(x).view(-1, 1)
        return out #prob of being real

class Q(nn.Module):
    def __init__(self, cc_dim=2, dc_dim=10):
        super(Q, self).__init__()
        self.conv = nn.Conv2d(1024, 128, 1, bias=False)
        self.bn = nn.BatchNorm2d(128)
        self.lReLU = nn.LeakyReLU(0.1, inplace=True)
        
        self.conv_disc = nn.Conv2d(128, dc_dim, 1)
        self.conv_mu   = nn.Conv2d(128, cc_dim, 1)
        self.conv_var  = nn.Conv2d(128, cc_dim, 1)

    def forward(self, x):
        y = self.conv(x)
        
        disc_logits = self.conv_disc(y).squeeze()
        mu  = self.conv_mu(y).squeeze()
        var = self.conv_var(y).squeeze().exp()

        return disc_logits, mu, var

In [None]:
gen = Generator(params['num_z'], params['num_con_c'], params['dis_c_dim'])
dsc = Discriminator()
q   = Q(params['num_con_c'], params['dis_c_dim'])
fe  = FrontEnd()

# fe(gen(test_noise[0])).size()
d, _ = next(iter(dataloader))
# fe(d).size()
# q(fe(gen(test_noise[0])))[0].size()

In [None]:
device = 'cpu'
init_weights = True
if torch.cuda.is_available():
    print('using cuda')
    gen.cuda()
    dsc.cuda()
    q.cuda()
    fe.cuda()
    device = 'cuda'
if init_weights:
    gen.apply(weights_init)
    dsc.apply(weights_init)
    q.apply(weights_init)
    fe.apply(weights_init)

using cuda


In [None]:
optim_d = optim.Adam([{'params':fe.parameters()}, {'params': dsc.parameters()}], 
                 lr=params['lrD'], betas=[params['beta1'], params['beta2']])
optim_g = optim.Adam([{'params':gen.parameters()}, {'params': q.parameters()}], 
                 lr=params['lrG'], betas=[params['beta1'], params['beta2']])

In [None]:
G_losses = []
D_losses = []

BCE = nn.BCELoss().to(device) #Binary Cross Entropy loss
CE  = nn.CrossEntropyLoss().to(device)
CQ  = NormalNLLLoss()

label = torch.FloatTensor(params['batch_size'], 1).to(device)
label = Variable(label, requires_grad=False)
for epoch in range(params['epochs']):
    for n_i, batch_data in enumerate(dataloader):
        ##########################
        # Optimize Discriminator #
        ##########################
        optim_d.zero_grad()
        ####### Real Data  #######
        real_im   = batch_data[0].to(device) # batch data also contains label info
        fe_out = fe(real_im)
        real_prob = dsc(fe_out)  # propability of classifying as real
        label.data.fill_(1)
        loss_real = BCE(real_prob, label) #
        loss_real.backward()
        ####### Fake Data  #######
        noise, idx = noise_sample(params['dis_c_dim'], 
                          params['num_con_c'], params['num_z'],
                          params['batch_size'], device)
        fake_im = gen(noise)
        fe_out2 = fe(fake_im.detach())
        fake_prob = dsc(fe_out2)  # propability of classifying as real
        label.data.fill_(0)
        loss_fake = BCE(fake_prob, label)
        loss_fake.backward()
        
        discriminator_loss = loss_real + loss_fake
        optim_d.step()
        ##########################
        #   Optimize Generator   #
        ##########################
        optim_g.zero_grad()
        
        fe_out = fe(fake_im)
        fake_prob = dsc(fe_out)  # propability of classifying as real
        label.fill_(1)
        reconstruct_loss = BCE(fake_prob, label)
        
        q_logits, q_mu, q_var = q(fe_out)
        target = torch.LongTensor(idx).to(device)
        # Calculating loss for discrete latent code.
        dis_loss = CE(q_logits, target)

        # Calculating loss for continuous latent code.
        con_c = noise[:, params['num_z'] + params['dis_c_dim'] : ].view(-1, params['num_con_c'])
        con_loss = CQ(con_c, q_mu, q_var) * 0.1

        # Generator Loss (Reconstruct, Discreate and Latent code)
        generator_loss = reconstruct_loss + dis_loss + con_loss
        generator_loss.backward()
        optim_g.step()

        ##########################
        #      Logging Part      #
        ##########################
        if (n_i) % 100 == 0 and n_i > 0:
            sys.stdout.write('Epoch/Iter:{0}/{1}, D_Loss: {2}, G_Loss: {3}\r'.format(
                epoch + 1, n_i, discriminator_loss.item(), generator_loss.item())
            )
            sys.stdout.flush()
    # Save the losses for plotting.
    G_losses.append(generator_loss.item())
    D_losses.append(discriminator_loss.item())
    print('')

Epoch/Iter:1/1400, D_Loss: 0.169759601354599, G_Loss: 3.475357294082641675
Epoch/Iter:2/1400, D_Loss: 0.06105554848909378, G_Loss: 4.391475200653076
Epoch/Iter:3/1100, D_Loss: 0.05432390049099922, G_Loss: 4.252645492553711

In [None]:
idx = np.arange(10).repeat(10)
one_hot = np.zeros((100, 10))
one_hot[range(100), idx] = 1
    
fix_noise = torch.Tensor(100, 62).uniform_(-1, 1).cuda()
dis_c = torch.FloatTensor(100, 10).cuda()
dis_c.data.copy_(torch.Tensor(one_hot))

c = np.linspace(-1, 1, 10).reshape(1, -1)
c = np.repeat(c, 10, 0).reshape(-1, 1)
c1 = np.hstack([c, np.zeros_like(c)])
c2 = np.hstack([np.zeros_like(c), c])
    
con_c = torch.FloatTensor(100, 2).cuda()
con_c.data.copy_(torch.from_numpy(c1))
z = torch.cat([fix_noise, dis_c, con_c], 1).view(-1, 74, 1, 1)
x_save_c1 = gen(z)

con_c.data.copy_(torch.from_numpy(c2))
z = torch.cat([fix_noise, dis_c, con_c], 1).view(-1, 74, 1, 1)
x_save_c2 = gen(z)