In [None]:
# Code adapted from: https://github.com/aladdinpersson/Machine-Learning-Collection/blob/ac5dcd03a40a08a8af7e1a67ade37f28cf88db43/ML/Pytorch/GANs/2.%20DCGAN/train.py
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as tfms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import os
import numpy as np
from itertools import product
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

In [None]:
class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            # input: N x channels_img x 64 x 64
            nn.Conv2d(
                channels_img, features_d, kernel_size=4, stride=2, padding=1
            ),
            nn.LeakyReLU(0.2),
            # _block(in_channels, out_channels, kernel_size, stride, padding)
            self._block(features_d, features_d * 2, 4, 2, 1),
            self._block(features_d * 2, features_d * 4, 4, 2, 1),
            self._block(features_d * 4, features_d * 8, 4, 2, 1),
            # After all _block img output is 4x4 (Conv2d below makes into 1x1)
            nn.Conv2d(features_d * 8, 1, kernel_size=4, stride=2, padding=0),
            nn.Sigmoid(),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.disc(x)

In [None]:
class Generator(nn.Module):
    def __init__(self, channels_noise, channels_img, features_g):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            # Input: N x channels_noise x 1 x 1
            self._block(channels_noise, features_g * 16, 4, 1, 0),  # img: 4x4
            self._block(features_g * 16, features_g * 8, 4, 2, 1),  # img: 8x8
            self._block(features_g * 8, features_g * 4, 4, 2, 1),  # img: 16x16
            self._block(features_g * 4, features_g * 2, 4, 2, 1),  # img: 32x32
            nn.ConvTranspose2d(
                features_g * 2, channels_img, kernel_size=4, stride=2, padding=1
            ),
            # Output: N x channels_img x 64 x 64
            nn.Tanh(),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.net(x)

In [None]:
def initialize_weights(model):
    # Initializes weights according to the DCGAN paper
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

In [None]:
def test():
    N, in_channels, H, W = 8, 3, 64, 64
    noise_dim = 100
    x = torch.randn((N, in_channels, H, W))
    disc = Discriminator(in_channels, 8)
    assert disc(x).shape == (N, 1, 1, 1), "Discriminator test failed"
    gen = Generator(noise_dim, in_channels, 8)
    z = torch.randn((N, noise_dim, 1, 1))
    assert gen(z).shape == (N, in_channels, H, W), "Generator test failed"

In [None]:
test()

In [None]:
# Need to override __init__, __len__, __getitem__
# as per datasets requirement
class PathsDataset(torch.utils.data.Dataset):
    # init the dataset, shape = L x W
    def __init__(self, path, transform=None, shape = (100,100)):
        print("Loading paths dataset...")
        # Read in path files
        # Convert to x by y np arrays
        # add the np arrays to a list
        # set self.transform and self.data
        self.paths = [] # create a list to hold all paths read from file
        for filename in os.listdir(path):
            with open(os.path.join(path, filename), 'r') as f: # open in readonly mode
                self.flat_path = np.loadtxt(f) # load in the flat path from file
                self.path = np.asarray(self.flat_path).reshape(len(self.flat_path)//2,2) #unflatten the path from the file
                
                # xvales which to interpolate on
                # want to interpolate on xvalues from the min xval in the path to the largest xval in the path
                self.xvals = np.linspace(int(min(self.path[:,0])), int(max(self.path[:,0])), int(max(self.path[:,0])-min(self.path[:,0])))
                self.xvals = self.xvals.astype(int)

                # interpolate for all xvals using the paths from file's x and y values
                self.interp_path = np.interp(self.xvals, self.path[:,0], self.path[:,1])
                self.interp_path = np.array(self.interp_path).astype(int)

                # create a LxW matrix where all the values where path is equal to 1
                self.path_matrix = np.zeros(shape)
                self.path_matrix[self.interp_path, self.xvals] = 1
                

                self.paths.append(self.path_matrix) # add the path to paths list
        self.transform = transform
        print("Done!")

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        # print("getitem")
        # if torch.is_tensor(idx):
        #     idx = idx.tolist()
        # imagePath = self.paths_file + "/" + self.data['Image_path'][idx]
        # image = sk.imread(imagePath)
        # label = self.data['Condition'][idx]
        # image = Image.fromarray(image)

        # if self.sourceTransform:
        #     image = self.sourceTransform(image)
        x = np.float32(self.paths[idx])

        if self.transform:
            x = self.transform(x)
            

        return x

        #return image, label

In [None]:
# # Hyperparameters etc.
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# LEARNING_RATE_DISC = 1e-7  # could also use two lrs, one for gen and one for disc
# LEARNING_RATE_GEN = 1e-4  # could also use two lrs, one for gen and one for disc
# BATCH_SIZE = 10
# IMAGE_SIZE = 256
# CHANNELS_IMG = 1
# NOISE_DIM = 128
# NUM_EPOCHS = 50
# FEATURES_DISC = 256
# FEATURES_GEN = 256
# Hyperparameters etc.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# LEARNING_RATE_DISC = 2e-6  # could also use two lrs, one for gen and one for disc
# LEARNING_RATE_GEN = 2e-4  # could also use two lrs, one for gen and one for disc
# BATCH_SIZE = 5
# NOISE_DIM = 100
# NUM_EPOCHS = 25
# FEATURES_DISC = 64
# FEATURES_GEN = 64

IMAGE_SIZE = 64
CHANNELS_IMG = 1 # MNIST or maps
# CHANNELS_IMG = 3 # CelebA Dataset

MAX_DATA_POINTS = 100
MAX_IMG_DATA = 10

parameters = dict(
    lr_disc = [2e-6],
    lr_gen = [2e-4],
    batch_size = [10],
    num_epochs = [25],
    noise_dim = [100],
    features_disc = [64],
    features_gen = [64]
)
param_values = [v for v in parameters.values()]
total_param_runs = np.prod([len(v) for v in parameters.values()])


In [None]:
transforms = tfms.Compose(
    [
        # tfms.ToPILImage(),
        # tfms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        tfms.ToTensor(),
        tfms.Normalize(
            [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
        ),
    ]
)

In [None]:
# If you train on MNIST, remember to set channels_img to 1
# dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms,
#                        download=True)

# comment mnist above and uncomment below if train on CelebA
# dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
dataset = PathsDataset(path = "./data/map_64x64/", shape = (64,64), transform=transforms)

# dataset[0].shape

In [None]:
criterion = nn.BCELoss()

In [None]:
run_number = 0

for (lr_disc, lr_gen, batch_size, num_epochs, noise_dim, features_disc, features_gen) in product(*param_values):

    writer = SummaryWriter(f"logs/run{run_number}")

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    gen = Generator(noise_dim, CHANNELS_IMG, features_gen).to(device)
    disc = Discriminator(CHANNELS_IMG, features_disc).to(device)
    initialize_weights(gen)
    initialize_weights(disc)
    opt_gen = optim.Adam(gen.parameters(), lr=lr_gen, betas=(0.5, 0.999))
    opt_disc = optim.Adam(disc.parameters(), lr=lr_disc, betas=(0.5, 0.999))

    fixed_noise = torch.randn(32, noise_dim, 1, 1).to(device)

    gen.train()
    disc.train()

    loss_step_rate = round((len(dataloader)*num_epochs)/MAX_DATA_POINTS)
    img_step_rate = round((len(dataloader)*num_epochs)/MAX_IMG_DATA)

    step = 0
    loss_step = 0
    img_step = 0

    for epoch in range(num_epochs):
        # Target labels not needed! <3 unsupervised
        # for batch_idx, (real, _) in enumerate(dataloader):
        for batch_idx, real in enumerate(dataloader):
            real = real.to(device)
            noise = torch.randn(batch_size, noise_dim, 1, 1).to(device)
            fake = gen(noise)

            ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
            disc_real = disc(real.float()).reshape(-1)
            loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
            disc_fake = disc(fake.detach()).reshape(-1)
            loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
            loss_disc = (loss_disc_real + loss_disc_fake) / 2
            disc.zero_grad()
            loss_disc.backward()
            opt_disc.step()

            ### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
            output = disc(fake).reshape(-1)
            loss_gen = criterion(output, torch.ones_like(output))
            gen.zero_grad()
            loss_gen.backward()
            opt_gen.step()

            # printing loss to tensorboard
            if step % loss_step_rate == 0:
                writer.add_scalar(f"Discriminator Loss", loss_disc, loss_step)
                writer.add_scalar(f"Generator Loss", loss_gen, loss_step)
                loss_step += 1

            # printing image data to tensorboard
            if step % img_step_rate == 0:
                print(
                    f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(dataloader)} \
                    Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}"
                )

                with torch.no_grad():
                    fake = gen(fixed_noise)
                    # take out (up to) 32 examples
                    img_grid_real = torchvision.utils.make_grid(
                        real[:batch_size], normalize=True
                    )
                    img_grid_fake = torchvision.utils.make_grid(
                        fake[:batch_size], normalize=True
                    )

                    writer.add_image("Real", img_grid_real, global_step=img_step)
                    writer.add_image("Fake", img_grid_fake, global_step=img_step)

                img_step += 1

            step += 1

    #saving hyperparams:
    writer.add_hparams({"features_gen": features_gen, "features_disc": features_disc, "noise_dim": noise_dim, "lr_gen": lr_gen, "lr_disc": lr_disc, "batch_size": batch_size, "epochs": num_epochs}, {"gen loss": loss_gen}, run_name=f"run{run_number}")
    writer.close()
    run_number += 1

In [None]:
# noise = torch.randn(BATCH_SIZE, NOISE_DIM, 1, 1).to(device)
# fake = gen(noise)

In [None]:
# plt.imshow(fake.cpu().detach().numpy()[7][0])