In [None]:
# Code adapted from: https://github.com/aladdinpersson/Machine-Learning-Collection/blob/ac5dcd03a40a08a8af7e1a67ade37f28cf88db43/ML/Pytorch/GANs/2.%20DCGAN/train.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as tfms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import os, math
import numpy as np

In [None]:
class Discriminator(nn.Module):
    def __init__(self, features, device='cpu'):
        super(Discriminator, self).__init__()

        self.block1 = nn.Sequential(
            nn.Conv2d(3, features, kernel_size=4, stride=2, padding=1, device=device),
            nn.LeakyReLU(0.2)
        )

        self.block2 = nn.Sequential(
            nn.Conv2d(features, features * 2, kernel_size=4, stride=2, padding=1, bias=False, device=device),
            nn.InstanceNorm2d(features * 2, affine=True, device=device),
            nn.LeakyReLU(0.2)
        )

        self.block3 = nn.Sequential(
            nn.Conv2d(features * 2, features * 4, kernel_size=4, stride=2, padding=1, bias=False, device=device),
            nn.InstanceNorm2d(features * 4, affine=True, device=device),
            nn.LeakyReLU(0.2)
        )

        self.block4 = nn.Sequential(
            nn.Conv2d(features * 4, features * 8, kernel_size=4, stride=2, padding=1, bias=False, device=device),
            nn.InstanceNorm2d(features * 8, affine=True, device=device),
            nn.LeakyReLU(0.2)
        )

        self.block5 = nn.Sequential(
            nn.Conv2d(features * 8, 1, kernel_size=4, stride=2, padding=0, device=device), # convert to single channel
            nn.AdaptiveAvgPool2d(1),    # pool the matrix into a single value for sigmoid
        )

    def forward(self, x):
        y = self.block1(x)
        y = self.block2(y)
        y = self.block3(y)
        y = self.block4(y)
        y = self.block5(y)

        return y

In [None]:
# 3 input channels (noise, map, and initial path)
class Generator(nn.Module):
    def __init__(self, features, device='cpu'):
        super(Generator, self).__init__()

        self.block1 = nn.Sequential(
            nn.Conv2d(3, features, 3, 1, 2, device=device),
            nn.ReLU()
        )

        self.block2 = nn.Sequential(
            nn.Conv2d(features, features*2, 3, 1, 2, device=device),
            nn.InstanceNorm2d(features*2, affine=True, device=device),
            nn.ReLU()
        )

        self.block3 = nn.Sequential(
            nn. ConvTranspose2d(features*2, features, 3, 1, 2, device=device),
            nn.BatchNorm2d(features, device=device),
            nn.ReLU()
        )

        self.block4 = nn.Sequential(
            nn. ConvTranspose2d(features, 1, 3, 1, 2, device=device),
            nn.Sigmoid()
        )

    def forward(self, x):
        y = self.block1(x)
        y = self.block2(y)
        y = self.block3(y)
        y = self.block4(y)

        # y = F.adaptive_max_pool2d(y, output_size=map_shape)

        y = self._round(y)
        return y
    
    def _round(self, mat):
        # TODO: cite something? (this function is based off of Thor's code)
        mat_hard = torch.round(mat)
        mat = (mat_hard - mat.data) + mat

        return mat


In [None]:
def initialize_weights(model):
    # Initializes weights according to the DCGAN paper
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

In [None]:
def gradient_penalty(critic, real, fake, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    epsilon = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    if real.shape != fake.shape:
        print("not same shape")
    interpolated_images = real * epsilon + fake * (1-epsilon) #interpolate epsilon % real image, (1-epsilon fake image)

    # Calculate critic scores
    mixed_scores = critic(interpolated_images)

    # compute grad of mixed scores w.r.t interpolated image
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]

    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty


In [None]:
# Need to override __init__, __len__, __getitem__
# as per datasets requirement
class PathsDataset(torch.utils.data.Dataset):
    # init the dataset, shape = L x W
    def __init__(self, path_dir, map_file, transform=None, shape = (100,100), device='cpu'):
        self.device = device
        self.paths = [] # create a list to hold all paths read from file
        self.map = np.loadtxt(map_file, skiprows=2).reshape(shape)
        self.map = self.map[np.newaxis, :, :]
        for filename in os.listdir(path_dir):
            with open(os.path.join(path_dir, filename), 'r') as f: # open in readonly mode
                self.flat_path = np.loadtxt(f) # load in the flat path from file
                self.path = np.asarray(self.flat_path, dtype=int).reshape(len(self.flat_path)//2,2) #unflatten the path from the file

                self.path_matrix = self.convert_path(shape, self.path)
                
                self.paths.append(self.path_matrix) # add the path to paths list
        self.transform = transform
        print("Done!")

    def convert_path(self, map_dim, path):
        path_mat = np.zeros(map_dim, dtype=float)

        # Make the path continuous
        for i in range(path.shape[0] - 1):
            x = path[i,0]
            x1 = path[i,0]
            x2 = path[i+1,0]

            y = path[i,1]
            y1 = path[i,1]
            y2 = path[i+1,1]

            if (x1 < x2):
                x_dir = 1
            else:
                x_dir = -1

            if (y1 < y2):
                y_dir = 1
            else:
                y_dir = -1

            # Determine y from x
            if x2-x1 != 0:
                m = (y2-y1)/(x2-x1)
                while x != x2:
                    y = round(m*(x-x1) + y1)
                    path_mat[y,x] = 1
                    x += x_dir
            else:
                while x != x2:
                    path_mat[y1,x] = 1
                    x += x_dir


            x = path[i,0]
            x1 = path[i,0]
            x2 = path[i+1,0]

            y = path[i,1]
            y1 = path[i,1]
            y2 = path[i+1,1]

            # Determine x from y
            if y2-y1 != 0:
                m = (x2-x1)/(y2-y1)
                while y != y2:
                    x = round(m*(y-y1) + x1)
                    path_mat[y,x] = 1
                    y += y_dir
            else:
                while y != y2:
                    path_mat[y,x1] = 1
                    y += y_dir
            
        path_mat[path[path.shape[0]-1,1], path[path.shape[0]-1,0]] = 1     # Include the last point in the path

        


        # Add Initial Path onto the loaded path matrix
        initial_path = np.zeros_like(path_mat)

        # Create Straight line between start/end points
        x1 = path[0,0]
        y1 = path[0,1]
        x2 = path[path.shape[0]-1,0]
        y2 = path[path.shape[0]-1,1]

        # x = x1
        # y = y1

        # if (x1 < x2):
        #     x_dir = 1
        # else:
        #     x_dir = -1

        # if (y1 < y2):
        #     y_dir = 1
        # else:
        #     y_dir = -1

        # # Determine y from x
        # if x2-x1 != 0:
        #     m = (y2-y1)/(x2-x1)
        #     while x != x2:
        #         y = round(m*(x-x1) + y1)
        #         initial_path[y,x] = 1
        #         x += x_dir
        # else:
        #     while x != x2:
        #         initial_path[y1,x] = 1
        #         x += x_dir

        # x = x1
        # y = y1
        # # Determine x from y
        # if y2-y1 != 0:
        #     m = (x2-x1)/(y2-y1)
        #     while y != y2:
        #         x = round(m*(y-y1) + x1)
        #         initial_path[y,x] = 1
        #         y += y_dir
        # else:
        #     while y != y2:
        #         initial_path[y,x1] = 1
        #         y += y_dir

        initial_path[y1,x1] = 1     # Include the first point in the path
        initial_path[y2,x2] = 1     # Include the last point in the path

        slope = -0.05

        for x in range(0, len(initial_path)):
            for y in range(0, len(initial_path[x])):
                dis_start = math.sqrt((x-x1)**2 + (y-y1)**2)
                dis_goal = math.sqrt((x-x2)**2 + (y-y2)**2)
                dis = dis_start if dis_start < dis_goal else dis_goal

                height = slope*dis + 1

                if height < 0:
                    initial_path[y][x] = 0
                else:
                    initial_path[y][x] = height

        path_mat = np.stack((path_mat, initial_path))

        return path_mat

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        x = np.float32(self.paths[idx])
        x = torch.Tensor(x).to(self.device)

        if self.transform:
            x = self.transform(x)
        
        return x

        #return image, label

In [None]:
# Inputs
MAP_NAME = 'map_64x64'
MAP_SHAPE = (64,64)
# MAP_NAME = '8x12_map'
# MAP_SHAPE = (163,243)

LOAD = False
SAVE = True
GEN_PATH = './checkpoints/wgan_gp/generator/'
DISC_PATH = './checkpoints/wgan_gp/critic/'
LOAD_EPOCH = 0  # The epoch checkpoint to load

# # Hyperparameters etc.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 1e-4
BATCH_SIZE = 50
IMAGE_SIZE = 64
CHANNELS_IMG = 1
NUM_EPOCHS = 2
FEATURES_DISC = 64
FEATURES_GEN = 64

NOISE_SHAPE = (BATCH_SIZE, 1, MAP_SHAPE[0], MAP_SHAPE[1])

#Speicific to WGAN
CRITIC_ITERATIONS = 5 # how many times the critic loop runs for each generator loop
LAMBDA_GP = 10

In [None]:
transforms = tfms.Compose(
    [
        # tfms.ToTensor(),
        # tfms.Normalize(
        #     [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
        # ),
    ]
)


In [None]:
map = np.loadtxt(f"./data/{MAP_NAME}/{MAP_NAME}.txt", skiprows=2).reshape(MAP_SHAPE)
map = map[np.newaxis,np.newaxis,:,:]
map = np.repeat(map, BATCH_SIZE, axis=0)
map = torch.Tensor(map).to(device)

dataset = PathsDataset(path_dir = f"./data/{MAP_NAME}/paths_variety/", map_file = f"./data/{MAP_NAME}/{MAP_NAME}.txt", shape = MAP_SHAPE, transform=transforms, device=device)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

# train_dataset = PathsDataset(path_dir = f"./env/{MAP_NAME}/subsets/training", map_file = f"./env/{MAP_NAME}/{MAP_NAME}.txt", shape = MAP_SHAPE, transform=transforms, device=device)
# test_dataset = PathsDataset(path_dir = f"./env/{MAP_NAME}/subsets/testing", map_file = f"./env/{MAP_NAME}/{MAP_NAME}.txt", shape = MAP_SHAPE, transform=transforms, device=device)

# dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
# dataloader_test = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

In [None]:
curr_epoch = 0

gen = Generator(FEATURES_GEN, device=device)
critic = Discriminator(FEATURES_DISC, device=device)

opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas = (0.0, 0.9))
opt_critic = optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas = (0.0, 0.9))

if LOAD:
    # Load gen
    checkpoint = torch.load(f'{GEN_PATH}epoch-{LOAD_EPOCH}.tar')
    gen.load_state_dict(checkpoint['model_state_dict'])
    opt_gen.load_state_dict(checkpoint['optimizer_state_dict'])
    curr_epoch = checkpoint['epoch']
    loss_gen = checkpoint['loss']

    # Load critic
    checkpoint = torch.load(f'{DISC_PATH}epoch-{LOAD_EPOCH}.tar')
    critic.load_state_dict(checkpoint['model_state_dict'])
    opt_critic.load_state_dict(checkpoint['optimizer_state_dict'])
    curr_epoch = checkpoint['epoch']
    loss_critic = checkpoint['loss']
else:
    initialize_weights(gen)
    initialize_weights(critic)

In [None]:
fixed_noise = torch.randn(NOISE_SHAPE, device=device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
writer_overlay = SummaryWriter(f"logs/fake_overlay")
step = 0

In [None]:
gen.train()
critic.train()

In [None]:
for epoch in range(NUM_EPOCHS):
    curr_epoch += 1
    for batch_idx, real in enumerate(dataloader):
        real = real.to(device)
        real = torch.concat((real, map), axis=1)

        initial_path = real[:,1:2,:,:]
        noise = torch.randn(NOISE_SHAPE, device=device)
        noise = torch.concat((noise, initial_path, map), axis=1)
        fake = gen(noise)

        fixed_input = torch.concat((fixed_noise, initial_path, map), axis=1)

        for _ in range(CRITIC_ITERATIONS):
            noise = torch.randn(NOISE_SHAPE, device=device)
            noise = torch.concat((noise, initial_path, map), axis=1)
            fake = gen(noise)
            fake = torch.concat((fake, initial_path, map), axis=1)
            critic_real = critic(real)
            critic_fake = critic(fake)
            gp = gradient_penalty(critic, real, fake, device=device) # compute the gradient penalty
            loss_critic = (
                -(torch.mean(critic_real) - torch.mean(critic_fake)) + LAMBDA_GP*gp
            ) #   want to maximize (according to paper) but 
                                                                            #   optim algorithms are for minimizing so take - 
            critic.zero_grad()
            loss_critic.backward(retain_graph=True) # want to re use the computations for fake for generator
            opt_critic.step()

        ### Training generator: min E(critic(gen_fake))
        output = critic(fake)
        loss_gen = -torch.mean(output)
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()        

        # Print losses occasionally and print to tensorboard
        if batch_idx % 100 == 0:
            print(
                f"Epoch [{epoch + 1}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} \
                  Loss D: {loss_critic:.4f}, loss G: {loss_gen:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_input)
                # take out (up to) 32 examples
                img_grid_real = torchvision.utils.make_grid(
                    real[:BATCH_SIZE], normalize=True
                )
                img_grid_fake = torchvision.utils.make_grid(
                    fake[:BATCH_SIZE], normalize=True
                )

                fake = torch.concat((fake, initial_path, map), axis=1)
                img_grid_fake_overlay = torchvision.utils.make_grid(
                    fake[:BATCH_SIZE], normalize=True
                )

                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)
                writer_overlay.add_image("Fake", img_grid_fake_overlay, global_step=step)

            step += 1

    # save generator checkpoint
    if SAVE:
        torch.save({
                    'epoch': epoch,
                    'model_state_dict': gen.state_dict(),
                    'optimizer_state_dict': opt_gen.state_dict(),
                    'loss': loss_gen,
        }, f"{GEN_PATH}epoch-{epoch}.tar")

        # save critic checkpoint
        torch.save({
                    'epoch': epoch,
                    'model_state_dict': critic.state_dict(),
                    'optimizer_state_dict': opt_critic.state_dict(),
                    'loss': loss_critic,
        }, f"{DISC_PATH}epoch-{epoch}.tar")