In [None]:
# Code adapted from: https://github.com/aladdinpersson/Machine-Learning-Collection/blob/ac5dcd03a40a08a8af7e1a67ade37f28cf88db43/ML/Pytorch/GANs/2.%20DCGAN/train.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as tfms
from torch.utils.data import DataLoader

import os, math
import random
import numpy as np

import wandb

## Initialize Weights and Biases

Configure the run

In [None]:
# Inputs
MAP_NAME = 'map_64x64'
DATASET = 'smooth_paths'
BATCH_SIZE = 50


# Structure
NUM_LAYERS_CRIT = 4
KERNEL_CRIT = [3,3,3,3]
PAD_CRIT = [1,1,1,1]
FEATURES_CRIT = [64,128,64]

NUM_LAYERS_GEN = 4
KERNEL_GEN = [3,3,3,3]
PAD_GEN = [1,1,1,1]
FEATURES_GEN = [64,128,64]


# Hyperparameters
LR_CRIT = 1e-5
LR_GEN = 1e-5
CRIT_ITERATIONS = 5
LAMBDA = 10


# Internal Data
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MAP_SHAPE = (64,64)
NOISE_SHAPE = (BATCH_SIZE, 1, MAP_SHAPE[0], MAP_SHAPE[1])

NUM_EPOCHS = 100
START_EPOCH = 0

Initialize WandB

In [None]:
GROUP=''

CONFIG = dict(
    map_name = MAP_NAME,
    dataset = DATASET,

    num_layers_critic = NUM_LAYERS_CRIT,
    kernel_sizes_critic = KERNEL_CRIT,
    padding_critic = PAD_CRIT,
    num_features_critic = FEATURES_CRIT,

    num_layers_gen = NUM_LAYERS_GEN,
    kernel_size_gen = KERNEL_GEN,
    padding_gen = PAD_GEN,
    num_features_gen = FEATURES_GEN,

    batch_size = BATCH_SIZE,
    learning_rate_critic = LR_CRIT,
    learning_rate_gen = LR_GEN,
    critic_iterations = CRIT_ITERATIONS,
    gp_coefficient = LAMBDA
)

wandb.init(project='wgan-gp', entity='aicv-lab', config=CONFIG, group=GROUP)

## Define The GAN's Structure

In [None]:
class Critic(nn.Module):
    def __init__(self, device='cpu'):
        super(Critic, self).__init__()

        self.block1 = nn.Sequential(
            nn.Conv2d(3, FEATURES_CRIT[0], KERNEL_CRIT[0], 1, PAD_CRIT[0], device=device),
            nn.LeakyReLU(0.2)
        )

        self.block2 = nn.Sequential(
            nn.Conv2d(FEATURES_CRIT[0], FEATURES_CRIT[1], KERNEL_CRIT[1], 1, PAD_CRIT[1], bias=False, device=device),
            nn.InstanceNorm2d(FEATURES_CRIT[1], affine=True, device=device),
            nn.LeakyReLU(0.2)
        )

        self.block3 = nn.Sequential(
            nn.Conv2d(FEATURES_CRIT[1], FEATURES_CRIT[2], KERNEL_CRIT[2], 1, PAD_CRIT[2], bias=False, device=device),
            nn.InstanceNorm2d(FEATURES_CRIT[2], affine=True, device=device),
            nn.LeakyReLU(0.2)
        )

        self.block4 = nn.Sequential(
            nn.Conv2d(FEATURES_CRIT[2], 1, KERNEL_CRIT[3], 1, PAD_CRIT[3], device=device), # convert to single channel
        )

    def forward(self, x):
        y = self.block1(x)
        y = self.block2(y)
        y = self.block3(y)
        y = self.block4(y)

        y = torch.mean(y)
        return y

In [None]:
# 3 input channels (noise, map, initial path)
class Generator(nn.Module):
    def __init__(self, device='cpu'):
        super(Generator, self).__init__()

        self.block1 = nn.Sequential(
            nn.Conv2d(3, FEATURES_GEN[0], KERNEL_GEN[0], 1, PAD_GEN[0], device=device),
            nn.ReLU()
        )

        self.block2 = nn.Sequential(
            nn.Conv2d(FEATURES_GEN[0], FEATURES_GEN[1], KERNEL_GEN[1], 1, PAD_GEN[1], device=device),
            nn.InstanceNorm2d(FEATURES_GEN[1], affine=True, device=device),
            nn.ReLU()
        )

        self.block3 = nn.Sequential(
            nn. Conv2d(FEATURES_GEN[1], FEATURES_GEN[2], KERNEL_GEN[2], 1, PAD_GEN[2], device=device),
            nn.InstanceNorm2d(FEATURES_GEN[2], affine=True, device=device),
            nn.ReLU()
        )

        self.block4 = nn.Sequential(
            nn. Conv2d(FEATURES_GEN[2], 1, KERNEL_GEN[3], 1, PAD_GEN[3], device=device),
            nn.Sigmoid()
        )

    def forward(self, x):
        y = self.block1(x)
        y = self.block2(y)
        y = self.block3(y)
        y = self.block4(y)

        # y = F.adaptive_max_pool2d(y, output_size=map_shape)

        y = self._round(y)
        return y
    
    def _round(self, mat):
        # TODO: cite something? (this function is based off of Thor's code)
        mat_hard = torch.round(mat)
        mat = (mat_hard - mat.data) + mat

        return mat


## Define Essential Functions

In [None]:
def initialize_weights(model):
    # Initializes weights according to the DCGAN paper
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

In [None]:
def gradient_penalty(coeff, critic, real, fake, device="cpu"):
    # sample x_hat from P(x_hat)
    rand = torch.randn((real.shape[0], 1, 1, 1), device=device) # generate a random number from 0 to 1 for each matrix in the batch
    x_hat = rand*real + (1-rand)*fake

    critic_output = critic(x_hat)
    grad_ones = torch.ones_like(critic_output, device=device)

    gp = torch.autograd.grad(                                   # find magnitude of critic's resulting gradient
        inputs = x_hat,
        outputs = critic_output,
        grad_outputs = grad_ones,
        create_graph = True,
        retain_graph = True
    )[0]

    gp = torch.norm(gp, p=2, dim=(1,2,3))    # vector norm of each gradient
    gp = (gp - 1)**2
    gp = coeff * torch.mean(gp)

    return gp

In [None]:
# Need to override __init__, __len__, __getitem__
# as per datasets requirement
class PathsDataset(torch.utils.data.Dataset):
    # init the dataset, shape = L x W
    def __init__(self, path_dir, map_file, transform=None, shape = (100,100), device='cpu'):
        self.device = device
        self.paths = [] # create a list to hold all paths read from file
        self.map = np.loadtxt(map_file, skiprows=2).reshape(shape)
        self.map = self.map[np.newaxis, :, :]
        for filename in os.listdir(path_dir):
            with open(os.path.join(path_dir, filename), 'r') as f: # open in readonly mode
                self.flat_path = np.loadtxt(f) # load in the flat path from file
                self.path = np.asarray(self.flat_path, dtype=int).reshape(len(self.flat_path)//2,2) #unflatten the path from the file

                self.path_matrix = self.convert_path(shape, self.path)
                
                self.paths.append(self.path_matrix) # add the path to paths list
        self.transform = transform
        print("Done!")

    def convert_path(self, map_dim, path):
        path_mat = np.zeros(map_dim, dtype=float)

        # Make the path continuous
        for i in range(path.shape[0] - 1):
            x = path[i,0]
            x1 = path[i,0]
            x2 = path[i+1,0]

            y = path[i,1]
            y1 = path[i,1]
            y2 = path[i+1,1]

            if (x1 < x2):
                x_dir = 1
            else:
                x_dir = -1

            if (y1 < y2):
                y_dir = 1
            else:
                y_dir = -1

            # Determine y from x
            if x2-x1 != 0:
                m = (y2-y1)/(x2-x1)
                while x != x2:
                    y = round(m*(x-x1) + y1)
                    path_mat[y,x] = 1
                    x += x_dir
            else:
                while x != x2:
                    path_mat[y1,x] = 1
                    x += x_dir


            x = path[i,0]
            x1 = path[i,0]
            x2 = path[i+1,0]

            y = path[i,1]
            y1 = path[i,1]
            y2 = path[i+1,1]

            # Determine x from y
            if y2-y1 != 0:
                m = (x2-x1)/(y2-y1)
                while y != y2:
                    x = round(m*(y-y1) + x1)
                    path_mat[y,x] = 1
                    y += y_dir
            else:
                while y != y2:
                    path_mat[y,x1] = 1
                    y += y_dir
            
        path_mat[path[path.shape[0]-1,1], path[path.shape[0]-1,0]] = 1     # Include the last point in the path

        


        # Add Initial Path onto the loaded path matrix
        initial_path = np.zeros_like(path_mat)

        # Create Straight line between start/end points
        x1 = path[0,0]
        y1 = path[0,1]
        x2 = path[path.shape[0]-1,0]
        y2 = path[path.shape[0]-1,1]

        x = x1
        y = y1

        if (x1 < x2):
            x_dir = 1
        else:
            x_dir = -1

        if (y1 < y2):
            y_dir = 1
        else:
            y_dir = -1

        # Determine y from x
        if x2-x1 != 0:
            m = (y2-y1)/(x2-x1)
            while x != x2:
                y = round(m*(x-x1) + y1)
                initial_path[y,x] = 1
                x += x_dir
        else:
            while x != x2:
                initial_path[y1,x] = 1
                x += x_dir

        x = x1
        y = y1
        # Determine x from y
        if y2-y1 != 0:
            m = (x2-x1)/(y2-y1)
            while y != y2:
                x = round(m*(y-y1) + x1)
                initial_path[y,x] = 1
                y += y_dir
        else:
            while y != y2:
                initial_path[y,x1] = 1
                y += y_dir

        initial_path[y1,x1] = 1     # Include the first point in the path
        initial_path[y2,x2] = 1     # Include the last point in the path

        # slope = -0.05

        # for x in range(0, len(initial_path)):
        #     for y in range(0, len(initial_path[x])):
        #         dis_start = math.sqrt((x-x1)**2 + (y-y1)**2)
        #         dis_goal = math.sqrt((x-x2)**2 + (y-y2)**2)
        #         dis = dis_start if dis_start < dis_goal else dis_goal

        #         height = slope*dis + 1

        #         if height < 0:
        #             initial_path[y][x] = 0
        #         else:
        #             initial_path[y][x] = height

        path_mat = np.stack((path_mat, initial_path))

        return path_mat

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        x = np.float32(self.paths[idx])
        x = torch.Tensor(x).to(self.device)

        if self.transform:
            x = self.transform(x)
        
        return x

        #return image, label

## Initialize Model & Data

In [None]:
map = np.loadtxt(f"./env/{MAP_NAME}/{MAP_NAME}.txt", skiprows=2).reshape(MAP_SHAPE)
map = map[np.newaxis,np.newaxis,:,:]
map = np.repeat(map, BATCH_SIZE, axis=0)
map = torch.Tensor(map).to(device)

train_dataset = PathsDataset(path_dir = f"./env/{MAP_NAME}/paths/{DATASET}/training", map_file = f"./env/{MAP_NAME}/{MAP_NAME}.txt", shape = MAP_SHAPE, transform=None, device=device)
dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)


In [None]:
curr_epoch = START_EPOCH

gen = Generator(device=device)
critic = Critic(device=device)

opt_gen = optim.Adam(gen.parameters(), lr=LR_GEN, betas = (0.0, 0.9))
opt_critic = optim.Adam(critic.parameters(), lr=LR_CRIT, betas = (0.0, 0.9))

initialize_weights(gen)
initialize_weights(critic)

In [None]:
fixed_noise = torch.randn(NOISE_SHAPE, device=device).abs()
# writer_real = SummaryWriter(f"logs/real")
# writer_fake = SummaryWriter(f"logs/fake")
# writer_overlay = SummaryWriter(f"logs/fake_overlay")

In [None]:
gen.train()
critic.train()

## Train the Model

In [None]:
for epoch in range(NUM_EPOCHS):
    curr_epoch += 1
    for batch_idx, real in enumerate(dataloader):
        real = real.to(device)
        real = torch.concat((real, map), axis=1)

        initial_path = real[:,1:2,:,:]
        fixed_input = torch.concat((fixed_noise, initial_path, map), axis=1)

        for _ in range(CRIT_ITERATIONS):
            noise = torch.randn(NOISE_SHAPE, device=device).abs()
            noise = torch.concat((noise, initial_path, map), axis=1)
            fake = gen(noise)
            fake = torch.concat((fake, initial_path, map), axis=1)
            critic_real = critic(real)
            critic_fake = critic(fake)
            gp = gradient_penalty(LAMBDA, critic, real, fake, device=device) # compute the gradient penalty
            loss_critic = (
                torch.mean(critic_fake) - torch.mean(critic_real) + gp
            )

            critic.zero_grad()
            loss_critic.backward(retain_graph=True)
            opt_critic.step()

        ### Training generator: min E(critic(gen_fake))
        output = critic(fake)
        loss_gen = -torch.mean(output)
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # Print losses occasionally and print to tensorboard
        if batch_idx % 100 == 0:
            print(
                f"Epoch [{curr_epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} " +     # TODO: print correct ending epoch based on initial (loaded) epoch num
                  f"Loss D: {loss_critic:.4f}, Lambda GP: {gp:.4f}, loss G: {loss_gen:.4f}"
            )

            # TODO: Generate example outputs
            if BATCH_SIZE > 8:
                outputs = gen(fixed_input[:8,:,:,:])
                inputs = real[:8,:,:,:]
                outputs = torch.concat((outputs, fixed_input[:8,1:2,:,:], map[:8,:,:,:]), axis=1)
            else:
                outputs = gen(fixed_input)
                inputs = real
                outputs = torch.concat((outputs, fixed_input[:,1:2,:,:], map), axis=1)

            wandb.log({
                'epoch': curr_epoch,
                'generator loss': loss_gen,
                'critic loss': loss_critic,
                'gradient penalty': gp,
                'fake': wandb.Image(outputs),
                'real': wandb.Image(inputs)
            })