In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.transforms.transforms import Resize
from tqdm import tqdm
import matplotlib.pyplot as plt
import torchvision
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
else:
    DEVICE = torch.device("cpu")

In [None]:
BATCH = 32
PATH = './model/'

In [None]:
DATASET = torchvision.datasets.ImageFolder("data", transforms.Compose([
    transforms.ToTensor(), transforms.Resize(
        128), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
]))

DATALOADER = DataLoader(dataset=DATASET, batch_size=BATCH,
                        shuffle=True, pin_memory=True, num_workers=6, drop_last=True)

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.seq = nn.Sequential(
            nn.ConvTranspose2d(100, 1024, 4, bias=False),  # 4x4
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(1024, 512, 4, 2, 1, bias=False),  # 8x8
            nn.BatchNorm2d(512),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),  # 16x16
            nn.BatchNorm2d(256),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),  # 32x32
            nn.BatchNorm2d(128),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),  # 64x64
            nn.BatchNorm2d(64),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(64, 3, 4, 2, 1),  # 128x128
            nn.Tanh(),
        )

    def forward(self, value):
        return self.seq(value)


class Critic(nn.Module):
    def __init__(self):
        super(Critic, self).__init__()
        self.seq = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),  # 64x64
            nn.BatchNorm2d(64),
            nn.LeakyReLU(),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),  # 32x32
            nn.BatchNorm2d(128),
            nn.LeakyReLU(),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),  # 16x16
            nn.BatchNorm2d(256),
            nn.LeakyReLU(),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),  # 8x8
            nn.BatchNorm2d(512),
            nn.LeakyReLU(),
            nn.Conv2d(512, 1024, 4, 2, 1, bias=False),  # 4x4
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(),
            nn.Conv2d(1024, 1, 4),
        )

    def forward(self, value):
        return self.seq(value)


In [None]:
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        if isinstance(m, (nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)

def train(dataloader, epoch):

    critic = Critic()
    critic = torch.load(PATH + "Cnet.pth")
    critic.train()
    critic.apply(initialize_weights)
    critic.to(device=DEVICE)
    print(critic)

    generator = Generator()
    generator = torch.load(PATH + "Gnet.pth")
    generator.train()
    generator.apply(initialize_weights)
    generator.to(device=DEVICE)
    print(generator)

    optim_critic = optim.RMSprop(critic.parameters(), lr=0.00005)
    optim_generator = optim.RMSprop(generator.parameters(), lr=0.00005)

    static_noise = torch.randn((BATCH, 100, 1, 1), device=DEVICE)
    tensorboard_step = 0
    writer = SummaryWriter("runs/GAN/test")

    ld = []
    lg = []
    for i in range(epoch):
        print(f"epoch: {i+1}")
        iter = 0
        for image, _ in tqdm(dataloader):
            image = image.to(device=DEVICE)

            # train critic
            for _ in range(5):
                critic_noise = torch.randn((BATCH, 100, 1, 1), device=DEVICE)
                critic_fake = generator(critic_noise)
                critic_real = critic(image).reshape(-1)
                critic_output = critic(critic_fake).reshape(-1)

                # wasserstein loss
                loss_critic = -(torch.mean(critic_real) -
                                torch.mean(critic_output))

                ld.append(loss_critic.item())
                for param in critic.parameters():
                    param.grad = None
                loss_critic.backward(retain_graph=True)
                optim_critic.step()

                # weight clipping
                for p in critic.parameters():
                    p.data.clamp_(-0.01, 0.01)

            # train generator
            generator_noise = torch.randn((BATCH, 100, 1, 1), device=DEVICE)
            generator_fake = generator(generator_noise)
            generator_output = critic(generator_fake).reshape(-1)
            loss_generator = -torch.mean(generator_output)
            lg.append(loss_generator.item())
            for param in generator.parameters():
                param.grad = None
            loss_generator.backward()
            optim_generator.step()
            iter += 1

            if iter % 50 == 0 or iter == 1:
                # unnecessary duplicated code
                logToTensorboard(ld, lg, generator_noise, writer,
                                 tensorboard_step, generator, image)
                tensorboard_step += 1
                ld = []
                lg = []

        print("----------------\n\rsaving...")
        torch.save(generator, PATH + "Gnet.pth")
        torch.save(generator, PATH + "Cnet.pth")
        writer.flush()
        print("saved\n\r----------------")
    writer.flush()
    writer.close()

def logToTensorboard(ld, lg, static_noise, writer, step, generator, real):
    with torch.no_grad():
        writer.add_scalars('current_run', {'loss_Critic': torch.mean(torch.tensor(ld)).item(
        ), 'loss_Generator': torch.mean(torch.tensor(lg)).item()}, global_step=step)
        test = generator(static_noise).cpu()
        test_grid = torchvision.utils.make_grid(
            test[:BATCH], normalize=True
        )
        writer.add_image("TestImage", test_grid, global_step=step)
        real_grid = torchvision.utils.make_grid(
            real[:BATCH], normalize=True
        )
        writer.add_image("RealImage", real_grid, global_step=step)


In [None]:
def start():
    train(DATALOADER, 50)


if __name__ == "__main__":
    start()