In [60]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import glob
import os
from PIL import Image
from torch.utils import data
from torchvision import transforms
import torch.nn.functional as F
import torchvision

In [None]:
# datasets: https://www.kaggle.com/datasets/soumikrakshit/anime-faces

In [61]:
imgs_path = glob.glob('../input/anime-faces/data/*png')

In [62]:
plt.figure(figsize=(12,8))
for i, img_path in enumerate(imgs_path[:6]):
    img = np.array(Image.open(img_path))
    plt.subplot(2,3,i+1)
    plt.imshow(img)
    print(img.shape)

In [63]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=0.5,std=0.5)  #图像规范化到[-1,1]
])

In [64]:
class Face_dataset(data.Dataset):
    def __init__(self, imgs_path):
        self.imgs_path = imgs_path

    def __getitem__(self, index):
        img_path = self.imgs_path[index]
        pil_img = Image.open(img_path)
        pil_img = transform(pil_img)
        return pil_img
    
    def __len__(self):
        return len(self.imgs_path)

In [65]:
dataset = Face_dataset(imgs_path)

In [66]:
BATCHSIZE = 32
dataloader = data.DataLoader(dataset,
                            batch_size=BATCHSIZE,
                            shuffle=True)

In [67]:
imgs_batch = next(iter(dataloader))    # 返回一个批次的训练数据
# 绘制批次中前6张图片  
plt.figure(figsize=(12, 8))
for i, img in enumerate(imgs_batch[:6]):
     # 设置channel最后，并还原到取值0-1之间
    img = (img.permute(1, 2, 0).numpy() + 1)/2  
    plt.subplot(2, 3, i+1)
    plt.imshow(img)

In [68]:
#定义生成器
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.linear1 = nn.Linear(100, 256*16*16)
        self.bn1 = nn.BatchNorm1d(256*16*16)
        self.deconv1 = nn.ConvTranspose2d(256, 128, 
                                          kernel_size=(3, 3), 
                                          padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.deconv2 = nn.ConvTranspose2d(128, 64, 
                                          kernel_size=(4, 4),
                                          stride=2, 
                                          padding=1)
        self.bn3 = nn.BatchNorm2d(64)
        self.deconv3 = nn.ConvTranspose2d(64, 3, 
                                          kernel_size=(4, 4), 
                                          stride=2, 
                                          padding=1)
            
    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = self.bn1(x)
        x = x.view(-1, 256, 16, 16)
        x = F.relu(self.deconv1(x))
        x = self.bn2(x)
        x = F.relu(self.deconv2(x))
        x = self.bn3(x)
        x =torch.tanh(self.deconv3(x))
        return x

In [69]:
# 定义判别器
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, 2)
        self.conv2 = nn.Conv2d(64, 128, 3, 2)
        self.bn = nn.BatchNorm2d(128)
        self.fc = nn.Linear(128*15*15, 1)

    def forward(self, x):
        x = F.dropout2d(F.leaky_relu(self.conv1(x)), p=0.5)
        x = F.dropout2d(F.leaky_relu(self.conv2(x)), p=0.5)
        x = self.bn(x)
        x = x.view(-1, 128*15*15)
        x = torch.sigmoid(self.fc(x))
        return x

In [70]:
device = "cuda" if torch.cuda.is_available() else "cpu"
gen = Generator().to(device)
dis = Discriminator().to(device)
loss_fn = torch.nn.BCELoss()                   # 定义损失函数
d_optimizer = torch.optim.Adam(dis.parameters(), lr=0.00001)
g_optimizer = torch.optim.Adam(gen.parameters(), lr=0.0001)
# 定义可视化函数
def generate_and_save_images(model, epoch, test_input):
    # np.squeeze去掉长度为1的维度
    predictions = model(test_input).permute(0, 2, 3, 1).cpu().numpy()
    fig = plt.figure(figsize=(20, 160))                             # 可视化16张图片
    for i in range(predictions.shape[0]):
        plt.subplot(1, 8, i+1)
        plt.imshow((predictions[i] + 1)/2)          # 注意取值范围的转换
        plt.axis('off')
    # plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
    plt.show()
# 设置生成绘图图片的随机张量，这里可视化16张图片
# 生成16个长度为100的随机正态分布张量
test_seed = torch.randn(8, 100, device=device)

In [75]:
D_loss = []                          # 记录训练过程中判别器loss变化
G_loss = []                          # 记录训练过程中生成器loss变化

#开始训练
for epoch in range(100):
    D_epoch_loss=0
    G_epoch_loss=0
    count = len(dataloader.dataset)
    for step, img in enumerate(dataloader):
        img = img.to(device)
        size=img.shape[0]
        random_seed = torch.randn(size, 100, device=device)   # 生成随机输入
        
        d_optimizer.zero_grad()
        real_output = dis(img)              # 判别器输入真实图片
        d_real_loss = loss_fn(real_output, 
                              torch.ones_like(real_output, device=device))
        d_real_loss.backward()
        # 生成器输入随机张量得到生成图片
        generated_img = gen(random_seed) 
        # 判别器输入生成图像，注意此处的detach方法
        fake_output = dis(generated_img.detach()) 
        d_fake_loss = loss_fn(fake_output, 
                              torch.zeros_like(fake_output, device=device))
        d_fake_loss.backward()
        
        disc_loss = d_real_loss + d_fake_loss      # 判别器的总损失
        d_optimizer.step()
        
        g_optimizer.zero_grad()
        fake_output = dis(generated_img)           # 判别器输入生成图像
        gen_loss = loss_fn(fake_output, 
                           torch.ones_like(fake_output, device=device))
        gen_loss.backward()
        g_optimizer.step()
        
        with torch.no_grad():
            D_epoch_loss += disc_loss.item()
            G_epoch_loss += gen_loss.item()
    with torch.no_grad():        
        D_epoch_loss /= count
        G_epoch_loss /= count
        D_loss.append(D_epoch_loss)
        G_loss.append(G_epoch_loss)
        # 训练完一个Epoch，打印提示并绘制生成的图片
        print("Epoch:", epoch)
        generate_and_save_images(gen, epoch, test_seed) 

In [74]:
plt.plot(range(1, len(D_loss)+1), D_loss, label='D_loss')
plt.plot(range(1, len(D_loss)+1), G_loss, label='G_loss')
plt.xlabel('epoch')
plt.legend()
plt.show()