In [1]:
# Code adapted from: https://github.com/aladdinpersson/Machine-Learning-Collection/blob/ac5dcd03a40a08a8af7e1a67ade37f28cf88db43/ML/Pytorch/GANs/2.%20DCGAN/train.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as tfms
from torch.utils.data import DataLoader

import os, math
import shutil
import random
import numpy as np
import skfmm

import wandb

## Initialize Weights and Biases

Configure the run

In [None]:
RECORD_METRICS = True

# Inputs
MAP_NAME = 'map_64x64'
DATASET = 'training'
BATCH_SIZE = 50


# Structure
NUM_LAYERS_CRIT = 5
KERNEL_CRIT = [4,4,4,4,4]
STRIDE_CRIT = [2,2,2,2,1]
PAD_CRIT = [1,1,1,1,0]
FEATURES_CRIT = [3,64,128,256,512]

NUM_LAYERS_GEN = 10
KERNEL_GEN = [4,4,4,4,4,4,4,4,4,4]
STRIDE_GEN = [2,2,2,2,1,1,2,2,2,2]
PAD_GEN = [1,1,1,1,0,0,1,1,1,1]
FEATURES_GEN = [3,64,128,256,512,1024,512,256,128,64]


# Hyperparameters
LR_CRIT = 1e-5
LR_GEN = 1e-5
CRIT_ITERATIONS = 5
LAMBDA = 10


# Internal Data
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MAP_SHAPE = (64,64)
NOISE_SHAPE = (BATCH_SIZE, 1, MAP_SHAPE[0], MAP_SHAPE[1])

NUM_EPOCHS = 50
START_EPOCH = 0

Initialize WandB

In [None]:
GROUP=''

CONFIG = dict(
    map_name = MAP_NAME,
    dataset = DATASET,

    layers_crit = NUM_LAYERS_CRIT,
    kernels_crit = KERNEL_CRIT,
    stride_crit = STRIDE_CRIT,
    padding_crit = PAD_CRIT,
    features_crit = FEATURES_CRIT,

    layers_gen = NUM_LAYERS_GEN,
    kernels_gen = KERNEL_GEN,
    stride_gen = STRIDE_GEN,
    padding_gen = PAD_GEN,
    features_gen = FEATURES_GEN,

    batch_size = BATCH_SIZE,
    learning_rate_crit = LR_CRIT,
    learning_rate_gen = LR_GEN,
    crit_iterations = CRIT_ITERATIONS,
    gp_coefficient = LAMBDA
)

if RECORD_METRICS:
    run = wandb.init(project='wgan-gp', entity='aicv-lab', config=CONFIG, group=GROUP)

## Define The GAN's Structure

In [None]:
class Critic(nn.Module):
    def __init__(self, f, k, s, p, device='cpu'):
        super(Critic, self).__init__()

        self.block1 = nn.Sequential(
            nn.Conv2d(f[0], f[1], k[0], s[0], p[0], device=device),
            nn.LeakyReLU(0.2)
        )

        self.block2 = nn.Sequential(
            nn.Conv2d(f[1], f[2], k[1], s[1], p[1], device=device),
            nn.InstanceNorm2d(f[2], affine=True, device=device),
            nn.LeakyReLU(0.2)
        )

        self.block3 = nn.Sequential(
            nn.Conv2d(f[2], f[3], k[2], s[2], p[2], device=device),
            nn.InstanceNorm2d(f[3], affine=True, device=device),
            nn.LeakyReLU(0.2)
        )

        self.block4 = nn.Sequential(
            nn.Conv2d(f[3], f[4], k[3], s[3], p[3], device=device),
            nn.InstanceNorm2d(f[4], affine=True, device=device),
            nn.LeakyReLU(0.2)
        )

        self.block5 = nn.Sequential(
            nn.Conv2d(f[4], 1, k[4], s[4], p[4], device=device)
        )

    def forward(self, x):
        y = self.block1(x)
        y = self.block2(y)
        y = self.block3(y)
        y = self.block4(y)
        y = self.block5(y)
        
        return y

In [3]:
# 3 input channels (noise, map, initial path)
class Generator(nn.Module):
    def __init__(self, f, k, s, p, device='cpu'):
        super(Generator, self).__init__()

        self.block1 = nn.Sequential(
            nn.Conv2d(f[0], f[1], k[0], s[0], p[0], device=device),
            nn.ReLU()
        )

        self.block2 = nn.Sequential(
            nn.Conv2d(f[1], f[2], k[1], s[1], p[1], device=device),
            nn.InstanceNorm2d(f[2], affine=True, device=device),
            nn.ReLU()
        )

        self.block3 = nn.Sequential(
            nn. Conv2d(f[2], f[3], k[2], s[2], p[2], device=device),
            nn.InstanceNorm2d(f[3], affine=True, device=device),
            nn.ReLU()
        )

        self.block4 = nn.Sequential(
            nn. Conv2d(f[3], f[4], k[3], s[3], p[3], device=device),
            nn.InstanceNorm2d(f[4], affine=True, device=device),
            nn.ReLU()
        )

        self.block5 = nn.Sequential(
            nn. Conv2d(f[4], f[5], k[4], s[4], p[4], device=device),
            nn.ReLU()
        )

        self.block6 = nn.Sequential(
            nn. ConvTranspose2d(f[5], f[6], k[5], s[5], p[5], device=device),
            nn.InstanceNorm2d(f[6], affine=True, device=device),
            nn.ReLU()
        )

        self.block7 = nn.Sequential(
            nn. ConvTranspose2d(f[6], f[7], k[6], s[6], p[6], device=device),
            nn.InstanceNorm2d(f[7], affine=True, device=device),
            nn.ReLU()
        )

        self.block8 = nn.Sequential(
            nn. ConvTranspose2d(f[7], f[8], k[7], s[7], p[7], device=device),
            nn.InstanceNorm2d(f[8], affine=True, device=device),
            nn.ReLU()
        )

        self.block9 = nn.Sequential(
            nn. ConvTranspose2d(f[8], f[9], k[8], s[8], p[8], device=device),
            nn.InstanceNorm2d(f[9], affine=True, device=device),
            nn.ReLU()
        )
        
        self.block10 = nn.Sequential(
            nn. ConvTranspose2d(f[9], 1, k[9], s[9], p[9], device=device),
            nn.Sigmoid()
        )

    def forward(self, x):
        y = self.block1(x)
        y = self.block2(y)
        y = self.block3(y)
        y = self.block4(y)
        y = self.block5(y)
        y = self.block6(y)
        y = self.block7(y)
        y = self.block8(y)
        y = self.block9(y)
        y = self.block10(y)

        y = self._round(y)
        return y
    
    def _round(self, mat):
        # TODO: cite something? (this function is based off of Thor's code)
        mat_hard = torch.round(mat)
        mat = (mat_hard - mat.data) + mat

        return mat


In [None]:
# Save model definitions

# Save the generator definition
savepath = os.path.join(os.getcwd(), 'checkpoints', run.name)
if not os.path.isdir(savepath):
    os.makedirs(savepath)
shutil.copy(f'./GAN.py', os.path.join(savepath, 'GAN.py'))

## Define Essential Functions

In [None]:
def initialize_weights(model):
    # Initializes weights according to the DCGAN paper
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

In [5]:
def gradient_penalty(coeff, critic, real, fake, device="cpu"):
    # sample x_hat from P(x_hat)
    rand = torch.randn((real.shape[0], 1, 1, 1), device=device) # generate a random number from 0 to 1 for each matrix in the batch
    x_hat = rand*real + (1-rand)*fake

    critic_output = critic(x_hat)
    grad_ones = torch.ones_like(critic_output, device=device)

    gp = torch.autograd.grad(                                   # find magnitude of critic's resulting gradient
        inputs = x_hat,
        outputs = critic_output,
        grad_outputs = grad_ones,
        create_graph = True,
        retain_graph = True
    )[0]

    gp = torch.norm(gp, p=2, dim=(1,2,3))    # vector norm of each gradient
    gp = (gp - 1)**2
    gp = coeff * torch.mean(gp)

    return gp

In [6]:
# Need to override __init__, __len__, __getitem__
# as per datasets requirement
class PathsDataset(torch.utils.data.Dataset):
    # init the dataset, shape = L x W
    def __init__(self, path_dir, map_file, transform=None, shape = (100,100), device='cpu'):
        self.device = device
        self.paths = [] # create a list to hold all paths read from file
        self.map = np.loadtxt(map_file, skiprows=2).reshape(shape)
        self.map = self.map[np.newaxis, :, :]
        for filename in os.listdir(path_dir):
            with open(os.path.join(path_dir, filename), 'r') as f: # open in readonly mode
                self.flat_path = np.loadtxt(f) # load in the flat path from file
                self.path = np.asarray(self.flat_path, dtype=int).reshape(len(self.flat_path)//2,2) #unflatten the path from the file

                self.path_matrix = self.convert_path(shape, self.path)
                
                self.paths.append(self.path_matrix) # add the path to paths list
        self.transform = transform
        print("Done!")

    def convert_path(self, map_dim, path):
        path_mat = np.zeros(map_dim, dtype=float)

        # Make the path continuous
        for i in range(path.shape[0] - 1):
            x = path[i,0]
            x1 = path[i,0]
            x2 = path[i+1,0]

            y = path[i,1]
            y1 = path[i,1]
            y2 = path[i+1,1]

            if (x1 < x2):
                x_dir = 1
            else:
                x_dir = -1

            if (y1 < y2):
                y_dir = 1
            else:
                y_dir = -1

            # Determine y from x
            if x2-x1 != 0:
                m = (y2-y1)/(x2-x1)
                while x != x2:
                    y = round(m*(x-x1) + y1)
                    path_mat[y,x] = 1
                    x += x_dir
            else:
                while x != x2:
                    path_mat[y1,x] = 1
                    x += x_dir


            x = path[i,0]
            x1 = path[i,0]
            x2 = path[i+1,0]

            y = path[i,1]
            y1 = path[i,1]
            y2 = path[i+1,1]

            # Determine x from y
            if y2-y1 != 0:
                m = (x2-x1)/(y2-y1)
                while y != y2:
                    x = round(m*(y-y1) + x1)
                    path_mat[y,x] = 1
                    y += y_dir
            else:
                while y != y2:
                    path_mat[y,x1] = 1
                    y += y_dir
            
        path_mat[path[path.shape[0]-1,1], path[path.shape[0]-1,0]] = 1     # Include the last point in the path

        


        # Add Initial Path onto the loaded path matrix
        initial_path = np.zeros_like(path_mat)

        # Create Straight line between start/end points
        x1 = path[0,0]
        y1 = path[0,1]
        x2 = path[path.shape[0]-1,0]
        y2 = path[path.shape[0]-1,1]

        x = x1
        y = y1

        if (x1 < x2):
            x_dir = 1
        else:
            x_dir = -1

        if (y1 < y2):
            y_dir = 1
        else:
            y_dir = -1

        # Determine y from x
        if x2-x1 != 0:
            m = (y2-y1)/(x2-x1)
            while x != x2:
                y = round(m*(x-x1) + y1)
                initial_path[y,x] = 1
                x += x_dir
        else:
            while x != x2:
                initial_path[y1,x] = 1
                x += x_dir

        x = x1
        y = y1
        # Determine x from y
        if y2-y1 != 0:
            m = (x2-x1)/(y2-y1)
            while y != y2:
                x = round(m*(y-y1) + x1)
                initial_path[y,x] = 1
                y += y_dir
        else:
            while y != y2:
                initial_path[y,x1] = 1
                y += y_dir

        initial_path[y1,x1] = 1     # Include the first point in the path
        initial_path[y2,x2] = 1     # Include the last point in the path

        path_mat = np.stack((path_mat, initial_path))

        return path_mat

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        x = np.float32(self.paths[idx])
        x = torch.Tensor(x).to(self.device)

        if self.transform:
            x = self.transform(x)
        
        return x

        #return image, label

## Initialize Model & Data

In [8]:
def grad_from_map(map, device=device, max=1, min=-1, slope=0.1):
    map[map == 0] = min # remap the map to be between -1 and 1
    map[map == 1] = max
    sd = skfmm.distance(map, dx = slope) # compute signed distance
    sd = torch.Tensor(-sd).to(device)# turn sd into a tensor
    return sd

In [9]:
map = np.loadtxt(f"./env/{MAP_NAME}/{MAP_NAME}.txt", skiprows=2).reshape(MAP_SHAPE)
map = map[np.newaxis,np.newaxis,:,:]
map = np.repeat(map, BATCH_SIZE, axis=0)
map = grad_from_map(map)

train_dataset = PathsDataset(path_dir = f"./env/{MAP_NAME}/paths/{DATASET}", map_file = f"./env/{MAP_NAME}/{MAP_NAME}.txt", shape = MAP_SHAPE, transform=None, device=device)
dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)


Done!


In [10]:
curr_epoch = START_EPOCH

gen = Generator(FEATURES_GEN, KERNEL_GEN, STRIDE_GEN, PAD_GEN, device=device)
critic = Critic(FEATURES_CRIT, KERNEL_CRIT, STRIDE_CRIT, PAD_CRIT,device=device)

opt_gen = optim.Adam(gen.parameters(), lr=LR_GEN, betas = (0.0, 0.9))
opt_critic = optim.Adam(critic.parameters(), lr=LR_CRIT, betas = (0.0, 0.9))

initialize_weights(gen)
initialize_weights(critic)

In [11]:
fixed_noise = torch.randn(NOISE_SHAPE, device=device).abs()
# writer_real = SummaryWriter(f"logs/real")
# writer_fake = SummaryWriter(f"logs/fake")
# writer_overlay = SummaryWriter(f"logs/fake_overlay")

In [12]:
gen.train()
critic.train()

Discriminator(
  (block1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
  )
  (block2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (2): LeakyReLU(negative_slope=0.2)
  )
  (block3): Sequential(
    (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (2): LeakyReLU(negative_slope=0.2)
  )
  (block4): Sequential(
    (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (2): LeakyReLU(negative_slope=0.2)
  )
  (block5): Sequential(
    (0): Conv2d(512, 1, kernel_size=(4, 4), stride=(2, 2))
    (1): Ada

## Train the Model

In [14]:
for epoch in range(NUM_EPOCHS):
    curr_epoch += 1
    for batch_idx, real in enumerate(dataloader):
        real = real.to(device)
        real = torch.concat((real, map), axis=1)

        initial_path = real[:,1:2,:,:]
        fixed_input = torch.concat((fixed_noise, initial_path, map), axis=1)


        for _ in range(CRIT_ITERATIONS):
            noise = torch.randn(NOISE_SHAPE, device=device).abs()
            noise = torch.concat((noise, initial_path, map), axis=1)

            fake = gen(noise)
            fake = torch.concat((fake, initial_path, map), axis=1)
            critic_real = critic(real)
            critic_fake = critic(fake)
            gp = gradient_penalty(LAMBDA, critic, real, fake, device=device) # compute the gradient penalty
            # gp = 0
            loss_critic = (
                torch.mean(critic_fake) - torch.mean(critic_real) + gp
            )

            critic.zero_grad()
            loss_critic.backward(retain_graph=True)
            opt_critic.step()

        ### Training generator: min E(critic(gen_fake))
        output = critic(fake)
        loss_gen = -torch.mean(output)
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # Print losses occasionally
        if batch_idx % 100 == 0:
            print(
                f"Epoch [{curr_epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} " +
                  f"Loss D: {loss_critic:.4f}, Lambda GP: {gp:.4f}, loss G: {loss_gen:.4f}"
            )

            if RECORD_METRICS:
                savepath = os.path.join(os.getcwd(), 'checkpoints', run.name, 'gen')
                if not os.path.isdir(savepath):
                    os.makedirs(savepath)
                torch.save({
                            'dataset': DATASET,
                            'config': CONFIG,
                            'state': gen.state_dict()
                            },
                            os.path.join(savepath, f'step_{run.step}.tar'))

                # save critic checkpoint
                savepath = os.path.join(os.getcwd(), 'checkpoints', run.name, 'crit')
                if not os.path.isdir(savepath):
                    os.makedirs(savepath)
                torch.save({
                            'dataset': DATASET,
                            'config': CONFIG,
                            'state': critic.state_dict()
                            },
                            os.path.join(savepath, f'step_{run.step}.tar'))

            if BATCH_SIZE > 8:
        torch.save(gen, f"{GEN_PATH}epoch-{epoch}.pt")
                outputs = gen(fixed_input[:8,:,:,:])
                inputs = real[:8,:,:,:]
                outputs = torch.concat((outputs, fixed_input[:8,1:2,:,:], map[:8,:,:,:]), axis=1)
            else:
                outputs = gen(fixed_input)
                inputs = real
                outputs = torch.concat((outputs, fixed_input[:,1:2,:,:], map), axis=1)

            if RECORD_METRICS:
                wandb.log({
                    'epoch': curr_epoch,
                    'generator loss': loss_gen,
                    'critic loss': loss_critic,
                    'gradient penalty': gp,
                    'fake': wandb.Image(outputs),
                    'real': wandb.Image(inputs)
                })

Epoch [1/500] Batch 0/468 Loss D: 32.0964, Lambda GP: 337.1546, loss G: -1.5879
Epoch [1/500] Batch 100/468 Loss D: -12.1294, Lambda GP: 34.8987, loss G: 27.8032
Epoch [1/500] Batch 200/468 Loss D: -7.7926, Lambda GP: 21.1115, loss G: 24.2308
Epoch [1/500] Batch 300/468 Loss D: -8.2358, Lambda GP: 18.5229, loss G: 26.4175
Epoch [1/500] Batch 400/468 Loss D: -8.5411, Lambda GP: 13.8237, loss G: 27.1672
Epoch [2/500] Batch 0/468 Loss D: -9.3875, Lambda GP: 21.2846, loss G: 27.4417


KeyboardInterrupt: 

In [None]:
if RECORD_METRICS:
    wandb.finish()