In [None]:
import time
import torch
from torch import nn
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torchvision.datasets as dset
from IPython.display import display, clear_output
from torch.autograd import Variable

In [None]:
def status(batch_size, ep, epoch, i, loss, data_loader):
    # status
    clear_output(wait=True)
    print(str(ep) + '/' + str(epoch))
    print('batch: ' + str(i+1) + '/' + str(len(data_loader)) + 
             ' [' + '='*int((i+1)/(len(data_loader)/20)) +
              '>' + ' '*(20 - int((i+1)/(len(data_loader)/20))) +
              ']')
    print('Loss: %.4g '% ((loss / ((i+1)*batch_size))))
    
#-------------------------------------------------------------------
def showAllImages(x,y,z):
    x = x[1,:,:,:].detach()
    y = y[1,:,:,:].detach()
    z = z[1,:,:,:].detach()
    
    x = x.cpu()
    y = y.cpu()
    z = z.cpu()
    
    plt.figure(figsize=(12,8))
    plt.subplot(131)
    plt.imshow(np.transpose(x, (1,2,0)))
    plt.subplot(132)
    plt.imshow(np.transpose(y, (1,2,0)))
    plt.subplot(133)
    plt.imshow(np.transpose(z, (1,2,0)))
    plt.show()

#-------------------------------------------------------------------
def conv(dimIn, dimOut):
    model = nn.Sequential(
        nn.Conv2d(dimIn, dimOut, kernel_size=3, stride=1,
                  padding=1),
        nn.BatchNorm2d(dimOut),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Conv2d(dimOut, dimOut, kernel_size=3, stride=1,
                 padding=1),
        nn.BatchNorm2d(dimOut)
    )
    return model

#-------------------------------------------------------------------
def pool():
    p = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
    return p

#-------------------------------------------------------------------
def invConv(dimIn, dimOut):
    model = nn.Sequential(
        nn.ConvTranspose2d(dimIn, dimOut, kernel_size=3, stride=2,
                           padding=1,output_padding=1),
        nn.BatchNorm2d(dimOut),
        nn.LeakyReLU(0.2, inplace=True)
    )
    return model
    
#-------------------------------------------------------------------
def last(dimIn, dimOut):
    model = nn.Sequential(
        nn.Conv2d(dimIn, dimOut, kernel_size=3, stride=1,
                  padding=1),
        nn.Tanh()
    )
    return model

In [None]:
class PezzGen(nn.Module):
    def __init__(self):
        super().__init__()
        #self.dimIn = dimIn
        #self.dimOut = dimOut
        #self.filterNum = filterNum
        #actFun = nn.LeakyReLU(0.2, inplace=True)
        
        print("\n------Initializing PezzGen------\n")
        
        self.conv1 = conv(3, 6)
        self.pool1 = pool()
        self.conv2 = conv(6, 12)
        self.pool2 = pool()
        self.conv3 = conv(12, 24)
        self.pool3 = pool()
        self.conv4 = conv(24, 48)
        self.pool4 = pool()
        
        self.bridge = conv(48, 96)
        
        self.inv1 = invConv(96, 48)
        self.up1 = conv(96, 48)
        self.inv2 = invConv(48, 24)
        self.up2 = conv(48, 24)
        self.inv3 = invConv(24, 12)
        self.up3 = conv(24, 12)
        self.inv4 = invConv(12, 6)
        self.up4 = conv(12, 6)
        
        self.last = last(6, 3)
        
    def forward(self, img):
        conv1 = self.conv1(img)
        pool1 = self.pool1(conv1)
        conv2 = self.conv2(pool1)
        pool2 = self.pool2(conv2)
        conv3 = self.conv3(pool2)
        pool3 = self.pool3(conv3)
        conv4 = self.conv4(pool3)
        pool4 = self.pool4(conv4)
        
        bridge = self.bridge(pool4)
        
        inv1 = self.inv1(bridge)
        join1 = torch.cat([inv1, conv4],dim=1)
        up1 = self.up1(join1)
        inv2 = self.inv2(up1)
        join2 = torch.cat([inv2, conv3],dim=1)
        up2 = self.up2(join2)
        inv3 = self.inv3(up2)
        join3 = torch.cat([inv3, conv2],dim=1)
        up3 = self.up3(join3)
        inv4 = self.inv4(up3)
        join4 = torch.cat([inv4, conv1],dim=1)
        up4 = self.up4(join4)
        
        res = self.last(up4)
        return res

net = PezzGen()
              

In [None]:
# parameters
batch_size = 50
img_size = 256
lr = 0.0005
epoch = 10

# Generator
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
generator = nn.DataParallel(net)

img_dir = "./maps/"
trainset = dset.ImageFolder(root=img_dir,
                            transform = transforms.Compose([
                            transforms.Scale(size=img_size),
                            transforms.CenterCrop(size=(img_size,
                            img_size*2)),
                            transforms.ToTensor(),
                            ]))
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, 
                                          num_workers=2)

recon_loss_func = nn.MSELoss()
gen_optimizer = torch.optim.Adam(net.parameters(),lr=lr)




In [None]:
for ep in range(epoch):
    for i, (image, label) in enumerate(trainloader):
        
        satel_image, map_image = torch.chunk(image, chunks=2, dim=3)
        
        gen_optimizer.zero_grad()
        
        x = Variable(satel_image).cuda(0) # add .cuda(0) for parallel
        y_ = Variable(map_image).cuda(0)
        #x = Variable(satel_image)
        #y_ = Variable(map_image)
        y = generator.forward(x)
        
        current_loss = recon_loss_func(y,y_)
        current_loss.backward()
        gen_optimizer.step()
        
        # status
        status(batch_size, ep+1 , epoch, i, current_loss,
               trainloader)   
        
        # images display
        if i%100 == 0:
            showAllImages(x,y,y_)
            time.sleep(2)

In [None]:
torch.save(generator,'./myCNN.pkl')

In [None]:
with torch.no_grad():
    for (data, label) in trainloader:
        img, ideal = torch.chunk(data, chunks=2, dim=3)
        
        print(img.shape)
        
        image = img[0,:,:,:].detach()
        #print(image.shape)
        plt.imshow(np.transpose(image, (1,2,0)))
        plt.show()
        
        x = Variable(img).cuda(0)
        #x = Variable(img)
        
        res = generator.forward(x)
        res = res.cpu()
        
        image = res[0,:,:,:].detach()
        #print(image.shape)
        plt.imshow(np.transpose(image, (1,2,0)))
        plt.show()
        
        break