In [None]:
import os
import numpy as np
import utils
import time
import torch
import torch.nn as nn
import torch.autograd as autograd

from IPython.display import Image as Display
from torchvision import datasets
import torchvision.transforms as transforms
from torchvision.utils import save_image
from PIL import Image
import matplotlib.pyplot as plt

In [None]:
epochs = 200
latent_dim = 100
n_class = 2
lr = 0.0002
save_term = 500 
start_time = time.time()

In [None]:
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.Grayscale(1),
    transforms.ToTensor(),
    transforms.Normalize((0.5), (0.5))])

In [None]:
data_dir = './Facemask/'
train_dataset = datasets.ImageFolder(os.path.join(data_dir, 'train'), transform=transform)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)

In [None]:
class Generator(nn.Module):
    def __init__(self, input_dim=100, output_dim=1, input_size=64, class_num=2):
        super(Generator, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.input_size = input_size
        self.class_num = class_num
        self.embed = nn.Embedding(self.class_num, self.class_num)

        self.fclayer = nn.Sequential(
            nn.Linear(self.input_dim + self.class_num, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024, 128 * (self.input_size // 4) * (self.input_size // 4)),
            nn.BatchNorm1d(128 * (self.input_size // 4) * (self.input_size // 4)),
            nn.ReLU(),)
        
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1),
            nn.Tanh(),)

    def forward(self, z, label):
        x = torch.cat((z, self.embed(label)), -1)
        x = self.fclayer(x)
        x = x.view(-1, 128, (self.input_size // 4), (self.input_size // 4))
        x = self.deconv(x)
        return x

In [None]:
class Discriminator(nn.Module):
    def __init__(self, input_dim=100, output_dim=1, input_size=64, class_num=2):
        super(Discriminator, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.input_size = input_size
        self.class_num = class_num
        self.embed = nn.Embedding(self.class_num, 1 * 64 * 64)

        def blocks(in_channels, out_channels, bn=True):
            block = [nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)]
            block.append(nn.LeakyReLU(0.2, inplace=True))
            block.append(nn.Dropout2d(0.25))
            if bn=True:
                block.append(nn.BatchNorm2d(out_channels, 0.8))
            return block

        self.conv_block = nn.Sequential(
            *blocks(2, 32, bn=False),
            *blocks(32, 64),
            *blocks(64, 128),
            *blocks(128, 256),
            *blocks(256, 512),)
        
        self.conv = nn.Sequential(
            nn.Conv2d(self.input_dim + self.class_num, 64, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),)
        
        self.fclayer = nn.Sequential(
            nn.Linear(512 * 2 * 2, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, self.output_dim),
            nn.Sigmoid(),)

    def forward(self, image, label):
        embed = self.embed(label).view((img.size(0), 1, 64, 64))
        x = torch.cat((image, embed), 1)
        x = self.conv_block(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

In [None]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

In [None]:
Generator = Generator()
Discriminator = Discriminator()
Generator.cuda()
Discriminator.cuda()
Generator.apply(weights_init_normal)
Discriminator.apply(weights_init_normal)
adversarial_loss = nn.MSELoss()
adversarial_loss.cuda()

optimizer_G = torch.optim.Adam(Generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(Discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

In [None]:
for epoch in range(epochs):
    for i, (images, label) in enumerate(train_dataloader):
        real = torch.cuda.FloatTensor(images.shape[0], 1).fill_(1.0)
        fake = torch.cuda.FloatTensor(images.shape[0], 1).fill_(0.0)
        real_imgs = images.cuda()
        labels = label.cuda()

        optimizer_G.zero_grad()
        z = torch.normal(mean=0, std=1, size=(images.shape[0], latent_dim)).cuda()
        generated_labels = torch.randint(0, n_class, (images.shape[0],)).cuda()
        generated_imgs = Generator(z, generated_labels)
        g_loss = adversarial_loss(Discriminator(generated_imgs, generated_labels), real)
        g_loss.backward()
        optimizer_G.step()

        optimizer_D.zero_grad()
        real_loss = adversarial_loss(Discriminator(real_imgs, labels), real)
        fake_loss = adversarial_loss(Discriminator(generated_imgs.detach(), generated_labels), fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        fin = epoch * len(train_dataloader) + i
        if fin % save_term == 0:
            z = torch.normal(mean=0, std=1, size=(8, latent_dim)).cuda()
            labels = torch.LongTensor([i for i in range(n_class) for _ in range(8)]).cuda()
            generated_imgs = Generator(z, labels)
            save_image(generated_imgs, f"./results/{fin}.png", nrow=8, normalize=True)

    print(f"[Epoch {epoch}/{epochs}] [D loss: {d_loss.item():.3f}] [G loss: {g_loss.item():.3f}] [Elapsed time: {time.time() - start_time:.2f}s]")

In [None]:
torch.save(Generator.state_dict(), "Gen_save.pt")
torch.save(Discriminator.state_dict(), "Disc_save.pt")

In [None]:
z = torch.normal(mean=0, std=1, size=(100, latent_dim)).cuda()
generated_labels_0 = torch.cuda.IntTensor(100).fill_(0)
gen_result_0 = Generator(z, generated_labels_0)

save_image(gen_result_0.data[:50], f'./results/with_mask.png', nrow=10, normalize=True)

In [None]:
z = torch.normal(mean=0, std=1, size=(100, latent_dim)).cuda()
generated_labels_1 = torch.cuda.IntTensor(100).fill_(1)
gen_result_1 = Generator(z, generated_labels_1)

save_image(gen_result_1.data[:50], f'./results/with_no_mask.png', nrow=10, normalize=True)

In [None]:
for i in range(10):
    z = torch.normal(mean=0, std=1, size=(100, latent_dim)).cuda()
    generated_labels_0 = torch.cuda.IntTensor(100).fill_(0)
    gen_result_0 = Generator(z, generated_labels_0)

    for j in range(50):
        save_image(gen_result_0.data[j], f'./results/with_mask/{i * 50 + j}.png', normalize=True)

for i in range(10):
    z = torch.normal(mean=0, std=1, size=(100, latent_dim)).cuda()
    generated_labels_1 = torch.cuda.IntTensor(100).fill_(1)
    gen_result_1 = Generator(z, generated_labels_1)

    for j in range(100):
        save_image(gen_result_1.data[j], f'./results/without_mask/{i * 50 + j}.png', normalize=True)