In [None]:
# Code adapted from: https://github.com/aladdinpersson/Machine-Learning-Collection/blob/ac5dcd03a40a08a8af7e1a67ade37f28cf88db43/ML/Pytorch/GANs/2.%20DCGAN/train.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as tfms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import os
import numpy as np
import math
import matplotlib.pyplot as plt

In [None]:
class Discriminator(nn.Module):
    def __init__(self, features, device='cpu'):
        super(Discriminator, self).__init__()

        self.block1 = nn.Sequential(
            nn.Conv2d(1, features, kernel_size=4, stride=2, padding=1, device=device),
            nn.LeakyReLU(0.2)
        )

        self.block2 = nn.Sequential(
            nn.Conv2d(features, features * 2, kernel_size=4, stride=2, padding=1, bias=False, device=device),
            nn.BatchNorm2d(features * 2, device=device),
            nn.LeakyReLU(0.2)
        )

        self.block3 = nn.Sequential(
            nn.Conv2d(features * 2, features * 4, kernel_size=4, stride=2, padding=1, bias=False, device=device),
            nn.BatchNorm2d(features * 4, device=device),
            nn.LeakyReLU(0.2)
        )

        self.block4 = nn.Sequential(
            nn.Conv2d(features * 4, features * 8, kernel_size=4, stride=2, padding=1, bias=False, device=device),
            nn.BatchNorm2d(features * 8, device=device),
            nn.LeakyReLU(0.2)
        )

        self.block5 = nn.Sequential(
            nn.Conv2d(features * 8, 1, kernel_size=4, stride=2, padding=0, device=device), # convert to single channel
            nn.AdaptiveAvgPool2d(1),    # pool the matrix into a single value for sigmoid
        )

    def forward(self, x):
        y = self.block1(x)
        y = self.block2(y)
        y = self.block3(y)
        y = self.block4(y)
        y = self.block5(y)

        return y

In [None]:
class Generator(nn.Module):
    def __init__(self, noise_channels, features, device='cpu'):
        super(Generator, self).__init__()

        self.block1 = nn.Sequential(
            nn.ConvTranspose2d(noise_channels, features * 16, kernel_size=4, stride=1, padding=0, device=device),
            nn.LeakyReLU(0.2)
        )

        self.block2 = nn.Sequential(
            nn.ConvTranspose2d(features * 16, features * 8, kernel_size=4, stride=2, padding=1, bias=False, device=device),
            nn.BatchNorm2d(features * 8, device=device),
            nn.LeakyReLU(0.2)
        )

        self.block3 = nn.Sequential(
            nn.ConvTranspose2d(features * 8, features * 4, kernel_size=4, stride=2, padding=1, bias=False, device=device),
            nn.BatchNorm2d(features * 4, device=device),
            nn.LeakyReLU(0.2)
        )

        self.block4 = nn.Sequential(
            nn.ConvTranspose2d(features * 4, features * 2, kernel_size=4, stride=2, padding=1, bias=False, device=device),
            nn.BatchNorm2d(features * 2, device=device),
            nn.LeakyReLU(0.2)
        )

        self.block5 = nn.Sequential(
            nn.ConvTranspose2d(features * 2, 1, kernel_size=4, stride=2, padding=1, device=device),
            nn.Tanh()
        )

    def forward(self, x, map_shape):
        y = self.block1(x)
        y = self.block2(y)
        y = self.block3(y)
        y = self.block4(y)
        y = self.block5(y)

        y = F.adaptive_avg_pool2d(y, map_shape)
        return y

In [None]:
def initialize_weights(model):
    # Initializes weights according to the DCGAN paper
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

In [None]:
# Need to override __init__, __len__, __getitem__
# as per datasets requirement
class PathsDataset(torch.utils.data.Dataset):
    # init the dataset, shape = L x W
    def __init__(self, path, transform=None, shape = (100,100)):
        self.paths = [] # create a list to hold all paths read from file
        for filename in os.listdir(path):
            with open(os.path.join(path, filename), 'r') as f: # open in readonly mode
                self.flat_path = np.loadtxt(f) # load in the flat path from file
                self.path = np.asarray(self.flat_path, dtype=int).reshape(len(self.flat_path)//2,2) #unflatten the path from the file

                self.path_matrix = self.convert_path(shape, self.path)
                
                self.paths.append(self.path_matrix) # add the path to paths list
        self.transform = transform
        print("Done!")

    def convert_path(self, map_dim, path):
        path_mat = np.zeros(map_dim, dtype=float)

        RATIO = 400/128

        # Make the path continuous
        for i in range(path.shape[0] - 1):
            x = math.floor(path[i,0] / RATIO)
            x1 = math.floor(path[i,0] / RATIO)
            x2 = math.floor(path[i+1,0] / RATIO)

            y = math.floor(path[i,1] / RATIO)
            y1 = math.floor(path[i,1] / RATIO)
            y2 = math.floor(path[i+1,1] / RATIO)

            if (x1 < x2):
                x_dir = 1
            else:
                x_dir = -1

            if (y1 < y2):
                y_dir = 1
            else:
                y_dir = -1

            # Determine y from x
            if x2-x1 != 0:
                m = (y2-y1)/(x2-x1)
                while x != x2:
                    y = round(m*(x-x1) + y1)
                    path_mat[y,x] = 1
                    x += x_dir
            else:
                while x != x2:
                    path_mat[y1,x] = 1
                    x += x_dir


            x = math.floor(path[i,0] // RATIO)
            x1 = math.floor(path[i,0] // RATIO)
            x2 = math.floor(path[i+1,0] // RATIO)

            y = math.floor(path[i,1] // RATIO)
            y1 = math.floor(path[i,1] // RATIO)
            y2 = math.floor(path[i+1,1] // RATIO)

            # Determine x from y
            if y2-y1 != 0:
                m = (x2-x1)/(y2-y1)
                while y != y2:
                    x = round(m*(y-y1) + x1)
                    path_mat[y,x] = 1
                    y += y_dir
            else:
                while y != y2:
                    path_mat[y,x1] = 1
                    y += y_dir
            
        path_mat[path[path.shape[0]-1,1], path[path.shape[0]-1,0]] = 1     # Include the last point in the path

        return path_mat

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        x = np.float32(self.paths[idx])

        if self.transform:
            x = self.transform(x).cuda()
            

        return x

        #return image, label

In [None]:
def noise_dims(map_dims):
    # For each convtranspose layer:
    # dim = round(math.ceil((dim + (2*padding) - kernel_size) / stride)) + 1
    # The output will need to be trimmed down to the desired size (This may hurt generalizability)
    noise_shape = []
    for i in range(len(map_dims)):
        shape = map_dims[i]

        shape = ((shape + (2*1) - 4) // 2) + 1  # Layer 5
        shape = ((shape + (2*1) - 4) // 2) + 1  # Layer 4
        shape = ((shape + (2*1) - 4) // 2) + 1  # Layer 3
        shape = ((shape + (2*1) - 4) // 2) + 1  # Layer 2
        shape = ((shape + (2*0) - 4) // 2) + 1  # Layer 1

        # shape += 1  # Trim the generated paths down later
        noise_shape.append(shape)

    noise_shape = tuple(noise_shape)
    return noise_shape

In [None]:
# # Hyperparameters etc.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 5e-4
BATCH_SIZE = 64
CHANNELS_IMG = 1
NOISE_DIM = 100
NUM_EPOCHS = 20
FEATURES_DISC = 64
FEATURES_GEN = 64

MAP_SHAPE = (128,128)
noise_shape = noise_dims(MAP_SHAPE)

#Speicific to WGAN
WEIGHT_CLIP = 0.01 # C param from WGAN paper
CRITIC_ITERATIONS = 5 # how many times the critic loop runs for each generator loop

In [None]:
transforms = tfms.Compose(
    [
        tfms.ToTensor(),
        tfms.Normalize(
            [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
        ),
    ]
)


In [None]:
# comment mnist above and uncomment below if train on CelebA
# dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)

dataset = PathsDataset(path = "./env/map_20x20/paths/", shape = (400,400), transform=transforms)

dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)



gen = Generator(NOISE_DIM, FEATURES_GEN, device=device)
critic = Discriminator(FEATURES_DISC, device=device)
initialize_weights(gen)
initialize_weights(critic)


In [None]:
opt_gen = optim.RMSprop(gen.parameters(), lr=LEARNING_RATE)
opt_critic = optim.RMSprop(critic.parameters(), lr=LEARNING_RATE)

In [None]:
fixed_noise = torch.randn((BATCH_SIZE, NOISE_DIM, noise_shape[0], noise_shape[1]), device=device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0

In [None]:
gen.train()
critic.train()

In [None]:
for epoch in range(NUM_EPOCHS):
    # Target labels not needed! <3 unsupervised
    # for batch_idx, (real, _) in enumerate(dataloader):
    for batch_idx, real in enumerate(dataloader):
        noise = torch.randn((BATCH_SIZE, NOISE_DIM, noise_shape[0], noise_shape[1]), device=device)
        fake = gen(noise, MAP_SHAPE)

        for _ in range(CRITIC_ITERATIONS):
            noise = torch.randn((BATCH_SIZE, NOISE_DIM, noise_shape[0], noise_shape[1]), device=device)
            fake = gen(noise, MAP_SHAPE)
            critic_real = critic(real)
            critic_fake = critic(fake)
            loss_critic = -(torch.mean(critic_real) - torch.mean(critic_fake)) #   want to maximize (according to paper) but 
                                                                            #   optim algorithms are for minimizing so take - 
            critic.zero_grad()
            loss_critic.backward(retain_graph=True) # want to re use the computations for fake for generator
            opt_critic.step()

            for p in critic.parameters():
                p.data.clamp_(-WEIGHT_CLIP, WEIGHT_CLIP)

        ### Training generator: min E(critic(gen_fake))
        output = critic(fake)
        loss_gen = -torch.mean(output)
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()        

        # Print losses occasionally and print to tensorboard
        if batch_idx % 100 == 0:
            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} \
                  Loss D: {loss_critic:.4f}, loss G: {loss_gen:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise, MAP_SHAPE)
                # take out (up to) 32 examples
                img_grid_real = torchvision.utils.make_grid(
                    real[:BATCH_SIZE], normalize=True
                )
                img_grid_fake = torchvision.utils.make_grid(
                    fake[:BATCH_SIZE], normalize=True
                )

                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)

            step += 1

In [None]:
noise = torch.randn((BATCH_SIZE, NOISE_DIM, noise_shape[0], noise_shape[1]), device=device)
fake = gen(noise)

In [None]:
plt.imshow(fake.cpu().detach().numpy()[7][0])