In [1]:
import os
import torch
#from utils import save_checkpoint, load_checkpoint, save_some_examples
import torch.nn as nn
import torch.optim as optim
#import config
#from model import Generator,Discriminator
from torch.utils.data import DataLoader
from torchvision import datasets
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.utils import save_image

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

%reload_ext autoreload
%autoreload 2

In [2]:
def save_some_examples(gen, val_loader, epoch, folder):
    x, y = next(iter(val_loader))
    x, y = x.to(DEVICE), y.to(DEVICE)
    gen.eval()
    with torch.no_grad():
        y_fake = gen(x)
        y_fake = y_fake * 0.5 + 0.5  # remove normalization#
        save_image(y_fake, folder + f"/y_gen_{epoch}.png")
        save_image(x * 0.5 + 0.5, folder + f"/input_{epoch}.png")
        save_image(y * 0.5 + 0.5, folder + f"/label_{epoch}.png")
    gen.train()


def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr


In [3]:

####################################################################################################################
################################################## DISCRIMINATOR ###################################################
####################################################################################################################
class CNN_Block(nn.Module):
    def __init__(self,in_channels,out_channels,stride=2):
        super().__init__()
        self.conv = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 4, stride, bias=False, padding_mode="reflect"),
                nn.BatchNorm2d(out_channels),
                nn.LeakyReLU(0.2)
            )
    def forward(self, x):
        return self.conv(x)

class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features = [64,128,256,512]):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels*2, features[0], kernel_size=4, stride=2, padding=1, padding_mode="reflect"),
            nn.LeakyReLU(0.2)
        )

        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(
                CNN_Block(in_channels, feature, stride=1 if feature==features[-1] else 2)

            )
            in_channels = feature
        layers.append(
            nn.Conv2d(
                in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"
            )
        )

        self.model = nn.Sequential(*layers)

    def forward(self,x,y):
        ### X = Correct Satellite Image
        ### Y = Correct/Fake Image

        x = torch.cat([x,y],dim=1)
        x = self.initial(x)
        return self.model(x)

####################################################################################################################
#################################################### GENERATOR #####################################################
####################################################################################################################

class Block(nn.Module):
    def __init__(self, in_channels, out_channels, down = True, act="relu", use_dropout=False):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False, padding_mode="reflect")
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU() if act=="relu" else nn.LeakyReLU(0.2),
        )
        self.use_dropout = use_dropout
        self.dropout = nn.Dropout(0.5)

    def forward(self,x):
        x = self.conv(x)
        return self.dropout(x) if self.use_dropout else x

class Generator(nn.Module):
    def __init__(self,in_channels=3,features=64):
        super().__init__()
        self.initial_down = nn.Sequential(
            nn.Conv2d(in_channels, features, 4, 2, 1, padding_mode="reflect"),
            nn.LeakyReLU(0.2)
        )  # 128 X 128

        ##############################################################################
        ################################## ENCODER ###################################
        ##############################################################################
        self.down1 = Block(features, features*2, down=True, act="leaky", use_dropout=False)    # 64 X 64
        self.down2 = Block(features*2, features*4, down=True, act="leaky", use_dropout=False)  # 32 X 32
        self.down3 = Block(features*4, features*8, down=True, act="leaky", use_dropout=False)  # 16 X 16
        self.down4 = Block(features*8, features*8, down=True, act="leaky", use_dropout=False)  # 8 X 8
        self.down5 = Block(features*8, features*8, down=True, act="leaky", use_dropout=False)  # 4 X 4
        self.down6 = Block(features*8, features*8, down=True, act="leaky", use_dropout=False)  # 2 X 2
        ##############################################################################
        ################################# BOTTLENECK #################################
        ##############################################################################
        self.bottleneck = nn.Sequential(
            nn.Conv2d(features*8,features*8,4,2,1,padding_mode="reflect"),                      # 1 X 1
            nn.ReLU()
        )
        ##############################################################################
        ################################## DECODER ###################################
        ##############################################################################
        self.up1 = Block(features*8, features*8, down=False, act="relu", use_dropout=True)
        self.up2 = Block(features*8*2, features*8, down=False, act="relu", use_dropout=True)
        self.up3 = Block(features*8*2, features*8, down=False, act="relu", use_dropout=True)
        self.up4 = Block(features*8*2, features*8, down=False, act="relu", use_dropout=False)
        self.up5 = Block(features*8*2, features*4, down=False, act="relu", use_dropout=False)
        self.up6 = Block(features*4*2, features*2, down=False, act="relu", use_dropout=False)
        self.up7 = Block(features*2*2, features, down=False, act="relu", use_dropout=False)
        self.final_up = nn.Sequential(
            nn.ConvTranspose2d(features*2, in_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )


    def forward(self,x):
        d1 = self.initial_down(x)
        d2 = self.down1(d1)
        d3 = self.down2(d2)
        d4 = self.down3(d3)
        d5 = self.down4(d4)
        d6 = self.down5(d5)
        d7 = self.down6(d6)

        bottleneck = self.bottleneck(d7)

        up1 = self.up1(bottleneck)
        up2 = self.up2(torch.cat([up1, d7], 1))
        up3 = self.up3(torch.cat([up2, d6], 1))
        up4 = self.up4(torch.cat([up3, d5], 1))
        up5 = self.up5(torch.cat([up4, d4], 1))
        up6 = self.up6(torch.cat([up5, d3], 1))
        up7 = self.up7(torch.cat([up6, d2], 1))

        return self.final_up(torch.cat([up7, d1],1))


In [4]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TRAIN_DIR = "./maps/train"
VAL_DIR = "./maps/val"
SAMPLE_INTERVAL = 10
LEARNING_RATE = 2e-4
BETA1 = 0.5
BATCH_SIZE = 16
NUM_WORKERS = 2
IMAGE_SIZE = 256
CHANNELS_IMG = 3
L1_LAMBDA = 100
LAMBDA_GP = 10
NUM_EPOCHS = 200
LOAD_MODEL = False
SAVE_MODEL = True
CHECKPOINT_DISC = "disc.pth.tar"
CHECKPOINT_GEN = "gen.pth.tar"
device_2 = torch.device(DEVICE)


In [5]:
import os

os.makedirs("images", exist_ok=True)
os.makedirs("data", exist_ok=True)

In [6]:
torch.backends.cudnn.benchmark = True
Gen_loss = []
Dis_loss = []

# !gsutil -m cp -r gs://sat2plan-bucket/data-1k data


In [7]:
#from google.colab import drive
#drive.mount('/content/drive')

In [8]:
dataloader = torch.utils.data.DataLoader(
    datasets.ImageFolder("data/", transform=transforms.Compose([
        # transforms.Resize(256),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])),
    batch_size=BATCH_SIZE,
    shuffle=True,
)

In [9]:
def train(netG, netD, train_dl, OptimizerG, OptimizerD, L1_Loss, BCE_Loss):
    loop = tqdm(train_dl)
    for idx, (x,y) in enumerate(loop):
        ############## Train Discriminator ##############
        #with torch.cuda.amp.autocast():
        y_fake = netG(x)
        D_real = netD(x,y)
        D_real_loss = BCE_Loss(D_real, torch.ones_like(D_real))
        D_fake = netD(x,y_fake.detach())
        D_fake_loss = BCE_Loss(D_fake, torch.zeros_like(D_fake))
        D_loss = (D_real_loss + D_fake_loss)/2

        netD.zero_grad()
        Dis_loss.append(D_loss.item())
        D_loss.backward()
        #D_Scaler.scale(D_loss).backward()
        OptimizerD.step()
        #D_Scaler.step(OptimizerD)
        #D_Scaler.update()

        ############## Train Generator ##############
        #with torch.cuda.amp.autocast():
        D_fake = netD(x, y_fake)
        G_fake_loss = BCE_Loss(D_fake, torch.ones_like(D_fake))
        L1 = L1_Loss(y_fake,y) * L1_LAMBDA
        G_loss = G_fake_loss + L1

        OptimizerG.zero_grad()
        Gen_loss.append(G_loss.item())
        G_loss.backward()
        #G_Scaler.scale(G_loss).backward()
        #G_Scaler.step(OptimizerG)
        OptimizerG.step()
        #G_Scaler.update()

        if idx % 10 == 0:
            loop.set_postfix(
                D_real=torch.sigmoid(D_real).mean().item(),
                D_fake=torch.sigmoid(D_fake).mean().item(),
            )

In [10]:
def main():
    netD = Discriminator(in_channels=3).to(device_2)#.cuda()
    netG = Generator(in_channels=3).to(device_2)#.cuda()
    OptimizerD = torch.optim.Adam(netD.parameters(),lr=LEARNING_RATE,betas=(BETA1,0.999))
    OptimizerG = torch.optim.Adam(netG.parameters(),lr=LEARNING_RATE,betas=(BETA1,0.999))
    BCE_Loss = nn.BCEWithLogitsLoss()
    L1_Loss = nn.L1Loss()

    cuda = True if torch.cuda.is_available() else False
    Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

    if LOAD_MODEL:
        load_checkpoint(
            CHECKPOINT_GEN,netG,OptimizerG,LEARNING_RATE
        )
        load_checkpoint(
            CHECKPOINT_DISC,netD,OptimizerD,LEARNING_RATE
        )

    """ train_dataset = Satellite2Map_Data(root=TRAIN_DIR)
    train_dl = DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=NUM_WORKERS,pin_memory=True)
    # G_Scaler = torch.cuda.amp.GradScaler()
    # D_Scaler = torch.cuda.amp.GradScaler()
    val_dataset = Satellite2Map_Data(root=VAL_DIR)
    val_dl = DataLoader(val_dataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=NUM_WORKERS,pin_memory=True) """


    """ for epoch in range(NUM_EPOCHS):
        train(
            netG, netD, train_dl, OptimizerG, OptimizerD, L1_Loss, BCE_Loss
        )
        #Generator_loss.append(g_loss.item())
        #Discriminator_loss.append(d_loss.item())
        if SAVE_MODEL and epoch%50==0:
            save_checkpoint(netG, OptimizerG, filename=CHECKPOINT_GEN)
            save_checkpoint(netD, OptimizerD, filename=CHECKPOINT_DISC)
        if epoch%2==0:
            save_some_examples(netG, val_dl, epoch, folder="evaluation") """

    for epoch in range(NUM_EPOCHS):
        for i, (imgs, _) in enumerate(dataloader):

            sat = F.interpolate(imgs[:, :, :, :512],
                                size=(IMAGE_SIZE, IMAGE_SIZE)).to(device_2)#.cuda()
            plan = F.interpolate(imgs[:, :, :, 512:],
                                 size=(IMAGE_SIZE, IMAGE_SIZE)).to(device_2)#.cuda()

            # Adversarial ground truths
            valid = Variable(Tensor(imgs.shape[0], 1).fill_(
                1.0), requires_grad=False).to(device_2)#.cuda()
            fake = Variable(Tensor(imgs.shape[0], 1).fill_(
                0.0), requires_grad=False).to(device_2)#.cuda()

            # Configure input
            real_imgs = plan

            ############## Train Discriminator ##############
            #with torch.cuda.amp.autocast():
            y_fake = netG(sat)
            D_real = netD(sat, real_imgs)
            D_real_loss = BCE_Loss(D_real, torch.ones_like(D_real))
            D_fake = netD(sat,y_fake.detach())
            D_fake_loss = BCE_Loss(D_fake, torch.zeros_like(D_fake))
            D_loss = (D_real_loss + D_fake_loss)/2

            netD.zero_grad()
            Dis_loss.append(D_loss.item())
            D_loss.backward()
            #D_Scaler.scale(D_loss).backward()
            OptimizerD.step()
            #D_Scaler.step(OptimizerD)
            #D_Scaler.update()

            ############## Train Generator ##############
            #with torch.cuda.amp.autocast():
            D_fake = netD(sat, y_fake)
            G_fake_loss = BCE_Loss(D_fake, torch.ones_like(D_fake))
            L1 = L1_Loss(y_fake,real_imgs) * L1_LAMBDA
            G_loss = G_fake_loss + L1

            OptimizerG.zero_grad()
            Gen_loss.append(G_loss.item())
            G_loss.backward()
            #G_Scaler.scale(G_loss).backward()
            #G_Scaler.step(OptimizerG)
            OptimizerG.step()
            #G_Scaler.update()


            print(
               "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                % (epoch+1, NUM_EPOCHS, i+1, len(dataloader), D_loss.item(), G_loss.item())
            )

            batches_done = epoch * len(dataloader) + i
            if batches_done % SAMPLE_INTERVAL == 0:
                concatenated_images = torch.cat(
                    (y_fake[:-5], sat[:-5], real_imgs[:-5]), dim=2)

                save_image(concatenated_images, "images/%d.png" %
                       batches_done, nrow=5, normalize=True)

In [11]:
main()

  valid = Variable(Tensor(imgs.shape[0], 1).fill_(


RuntimeError: CUDA error: the launch timed out and was terminated
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
