In [14]:
import matplotlib.pyplot as plt
import argparse
import os
import numpy as np
import math
import glob

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch
from google.colab import auth

#auth.authenticate_user()

#!curl https://sdk.cloud.google.com | bash
#!gcloud init

!gsutil cp -r gs://sat2plan-bucket/data-10k data




[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Copying gs://sat2plan-bucket/data-10k/005189_Chicago_41.90551_-88.34870.png...
Copying gs://sat2plan-bucket/data-10k/005190_Austin_30.25950_-97.66191.png...
Copying gs://sat2plan-bucket/data-10k/005191_San_Diego_32.77769_-117.25764.png...
Copying gs://sat2plan-bucket/data-10k/005192_Riga_56.85899_24.09087.png...
Copying gs://sat2plan-bucket/data-10k/005193_Jacksonville_30.43615_-81.44320.png...
Copying gs://sat2plan-bucket/data-10k/005194_San_Diego_32.74516_-117.03384.png...
Copying gs://sat2plan-bucket/data-10k/005195_Moscow_55.87324_37.54155.png...
Copying gs://sat2plan-bucket/data-10k/005196_Fort_Worth_32.71852_-97.48200.png...
Copying gs://sat2plan-bucket/data-10k/005197_Chicago_41.70420_-87.75552.png...
Copying gs://sat2plan-bucket/data-10k/005198_San_Antonio_29.40938_-98.40285.png...
Copying gs://sat2plan-bucket/data-10k/005199_Austin_30.37953_-97.62619.png...
Copying gs://sat2plan-bucket/data-10k/005200_London_51.5

In [59]:

n_epochs = 200
batch_size = 32
lr = 0.0002
b1 = 0.5
b2 = 0.999
n_cpu = 6
latent_dim = 100
img_size = 128
channels = 3
sample_interval = 10
from_scratch = True

cuda = True if torch.cuda.is_available() else False
from google.colab import drive
drive.mount('/content/drive')
os.makedirs("/content/drive/MyDrive/sat2plan/images", exist_ok=True)


def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02).cuda()
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02).cuda()
        torch.nn.init.constant_(m.bias.data, 0.0).cuda()


class Generator(nn.Module):
    def __init__(self, feature=64):
        super(Generator, self).__init__()

        self.init_size = img_size // 4  # Initial size before upsampling
        self.downscale = nn.Sequential(
            nn.Conv2d(channels, feature, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(feature),
            nn.ReLU(True),
            nn.Conv2d(feature, feature * 2, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(feature * 2),
            nn.ReLU(True),
            nn.Conv2d(feature * 2, feature * 4, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(feature * 4),
            nn.ReLU(True),
            nn.Conv2d(feature * 4, feature * 8, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(feature * 8),
            nn.ReLU(True),
            #nn.Flatten(feature * 8, latent_dim)
          )

        self.upscale = nn.Sequential(

            nn.ConvTranspose2d(feature * 8, feature * 4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(feature * 4),
            nn.ReLU(True),
            # state size. ``(ngf*4) x 8 x 8``
            nn.ConvTranspose2d( feature * 4, feature * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(feature * 2),
            nn.ReLU(True),
            # state size. ``(ngf*2) x 16 x 16``
            nn.ConvTranspose2d( feature * 2, feature, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(feature),
            nn.ReLU(True),
            # state size. ``(ngf) x 32 x 32``
            nn.ConvTranspose2d( feature, channels,kernel_size=4, stride=2, padding=1, bias=False),
            nn.Tanh()
        )

    def forward(self, img):
        #print(z.shape)
        out = self.downscale(img)
        #print(out.shape)
        #out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.upscale(out)
        # print("IMG", img.shape)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(
                0.2, inplace=True), nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.model = nn.Sequential(
            *discriminator_block(channels, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        # The height and width of downsampled image
        ds_size = img_size // 2 ** 4
        self.adv_layer = nn.Sequential(
            nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())

    def forward(self, img):
        #print(img.shape)
        out = self.model(img)
        #print(out.shape)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)

        return validity


# Loss function
adversarial_loss = torch.nn.BCELoss()

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()
if not from_scratch:

    generator_files = glob.glob('/content/drive/MyDrive/sat2plan/models_checkpoint/generator.pth')
    discriminator_files = glob.glob('/content/drive/MyDrive/sat2plan/models_checkpoint/discriminator.pth')

    generator_files.sort(key=os.path.getmtime, reverse=True)
    discriminator_files.sort(key=os.path.getmtime, reverse=True)

    if generator_files:
        print(generator_files[0])
        generator.load_state_dict(torch.load(generator_files[0]))
    if discriminator_files:
        discriminator.load_state_dict(torch.load(discriminator_files[0]))

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

# mount



# Configure data loader
dataloader = torch.utils.data.DataLoader(
    datasets.ImageFolder("data/", transform=transforms.Compose([
        # transforms.Resize(256),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])),
    batch_size=batch_size,
    shuffle=True,

)

# Optimizers
optimizer_G = torch.optim.Adam(
    generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(
    discriminator.parameters(), lr=lr, betas=(b1, b2))

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor



Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:

# ----------
#  Training
# ----------


for epoch in range(n_epochs):
    for i, (imgs, _) in enumerate(dataloader):

        sat = F.interpolate(imgs[:, :, :, :512],
                            size=(img_size, img_size)).cuda()
        plan = F.interpolate(imgs[:, :, :, 512:],
                             size=(img_size, img_size)).cuda()

        # plt.imshow(plan[0].permute(1, 2, 0))
        # plt.show()
        # plt.imshow(sat[0].permute(1, 2, 0))
        # plt.show()

        # Adversarial ground truths
        valid = Variable(Tensor(imgs.shape[0], 1).fill_(
            1.0), requires_grad=False).cuda()
        fake = Variable(Tensor(imgs.shape[0], 1).fill_(
            0.0), requires_grad=False).cuda()

        # Configure input
        real_imgs = plan

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Generate a batch of images
        gen_imgs = generator(sat)

        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
        )

        # Sauvegarder l'image
        # save_image(concatenated_images,
        #            f'gen_images/concatenated_image{epoch}-{i} .jpg')

        batches_done = epoch * len(dataloader) + i
        if batches_done % sample_interval == 0:
            concatenated_images = torch.cat(
                (gen_imgs[:-5], sat[:-5], real_imgs[:-5]), dim=2)

            save_image(concatenated_images, "/content/drive/MyDrive/sat2plan/images/%d.png" %
                       batches_done, nrow=5, normalize=True)
    torch.save(generator.state_dict(),
               f'/content/drive/MyDrive/sat2plan/models_checkpoint/generator.pth')
    torch.save(discriminator.state_dict(),
               f'/content/drive/MyDrive/sat2plan/models_checkpoint/discriminator.pth')

[Epoch 0/200] [Batch 0/157] [D loss: 0.693198] [G loss: 0.693498]
[Epoch 0/200] [Batch 1/157] [D loss: 0.693219] [G loss: 0.692759]
[Epoch 0/200] [Batch 2/157] [D loss: 0.693241] [G loss: 0.692893]
[Epoch 0/200] [Batch 3/157] [D loss: 0.693274] [G loss: 0.693209]
[Epoch 0/200] [Batch 4/157] [D loss: 0.693242] [G loss: 0.693120]
[Epoch 0/200] [Batch 5/157] [D loss: 0.693284] [G loss: 0.693086]
[Epoch 0/200] [Batch 6/157] [D loss: 0.693255] [G loss: 0.693027]
[Epoch 0/200] [Batch 7/157] [D loss: 0.693150] [G loss: 0.692988]
[Epoch 0/200] [Batch 8/157] [D loss: 0.693119] [G loss: 0.693168]
[Epoch 0/200] [Batch 9/157] [D loss: 0.693093] [G loss: 0.693120]
[Epoch 0/200] [Batch 10/157] [D loss: 0.693083] [G loss: 0.693221]
[Epoch 0/200] [Batch 11/157] [D loss: 0.692904] [G loss: 0.693230]
[Epoch 0/200] [Batch 12/157] [D loss: 0.693084] [G loss: 0.693133]
[Epoch 0/200] [Batch 13/157] [D loss: 0.692848] [G loss: 0.693209]
[Epoch 0/200] [Batch 14/157] [D loss: 0.692951] [G loss: 0.693196]
[Epoc