In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

import os
import cv2
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

%matplotlib inline

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [32]:
dir_path = "/root/autodl-tmp/hw2/"

class PairedData(Dataset):
    def __init__(self, phase):
        super(PairedData, self).__init__() 
        self.img_path_list = self.load_img_data(phase)     # 获取数据列表
        self.label_path_list = self.load_label_data(phase)  # 获取label列表
        self.num_samples = len(self.img_path_list)                 # 数据量

    def __getitem__(self, idx):
        face = cv2.imread(self.img_path_list[idx])     # 读取数据
        face = cv2.resize(face, (256,256))
        face = face.astype('float32') / 127.5 - 1.     # 归一化
        face = face.transpose(2, 0, 1)                 # HWC -> CHW

        cartoon = cv2.imread(self.label_path_list[idx])      # 读取数据
        cartoon = cartoon.astype('float32') / 127.5 - 1.     # 归一化
        cartoon = cartoon.transpose(2, 0, 1)                 # HWC -> CHW
        
        return torch.from_numpy(face), torch.from_numpy(cartoon)

    def __len__(self):
        return self.num_samples

    @staticmethod
    def load_img_data(phase):
        assert phase in ['train', 'test'], "phase1 should be set within ['train', 'test']"
        data_path = dir_path + "datasets/" + phase + "A/"
        all_list =  [os.path.join(data_path, x) for x in os.listdir(data_path)]
        print(all_list[:5])

        if phase == 'train':
            return all_list[:1500]
        else:
            return all_list[:150]
        
    @staticmethod
    def load_label_data(phase):
        assert phase in ['train', 'test'], "phase1 should be set within ['train', 'test']"
        data_path = dir_path + "datasets/" + phase + "B/"
        all_list =  [os.path.join(data_path, x) for x in os.listdir(data_path)]
        print(all_list[:5])

        if phase == 'train':
            return all_list[:1500]
        else:
            return all_list[:150]

In [33]:
paired_dataset_train = PairedData('train')
paired_dataset_test = PairedData('test')

['/root/autodl-tmp/hw2/datasets/trainA/0.png', '/root/autodl-tmp/hw2/datasets/trainA/10.png', '/root/autodl-tmp/hw2/datasets/trainA/100.png', '/root/autodl-tmp/hw2/datasets/trainA/1000.png', '/root/autodl-tmp/hw2/datasets/trainA/1001.png']
['/root/autodl-tmp/hw2/datasets/trainB/10.png', '/root/autodl-tmp/hw2/datasets/trainB/0.png', '/root/autodl-tmp/hw2/datasets/trainB/100.png', '/root/autodl-tmp/hw2/datasets/trainB/1000.png', '/root/autodl-tmp/hw2/datasets/trainB/1001.png']
['/root/autodl-tmp/hw2/datasets/testA/5002.png', '/root/autodl-tmp/hw2/datasets/testA/5001.png', '/root/autodl-tmp/hw2/datasets/testA/5003.png', '/root/autodl-tmp/hw2/datasets/testA/5004.png', '/root/autodl-tmp/hw2/datasets/testA/5005.png']
['/root/autodl-tmp/hw2/datasets/testB/5001.png', '/root/autodl-tmp/hw2/datasets/testB/5002.png', '/root/autodl-tmp/hw2/datasets/testB/5003.png', '/root/autodl-tmp/hw2/datasets/testB/5004.png', '/root/autodl-tmp/hw2/datasets/testB/5005.png']


In [34]:
class Downsample(nn.Module):
    # LeakyReLU => conv => batch norm
    def __init__(self, in_dim, out_dim, kernel_size=4, stride=2, padding=1):
        super(Downsample, self).__init__()

        self.layers = nn.Sequential(
            nn.LeakyReLU(0.2),
            nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_dim)
        )

    def forward(self, x):
        x = self.layers(x)
        return x


class Upsample(nn.Module):
    # ReLU => deconv => batch norm => dropout
    def __init__(self, in_dim, out_dim, kernel_size=4, stride=2, padding=1, use_dropout=False):
        super(Upsample, self).__init__()

        sequence = [
            nn.ReLU(),
            nn.ConvTranspose2d(in_dim, out_dim, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_dim)
        ]

        if use_dropout:
            sequence.append(nn.Dropout(p=0.5))

        self.layers = nn.Sequential(*sequence)

    def forward(self, x, skip):
        x = self.layers(x)
        x = torch.cat([x, skip], dim=1)
        return x

In [35]:
class UnetGenerator(nn.Module):
    def __init__(self, input_nc=3, output_nc=3, ngf=64):
        super(UnetGenerator, self).__init__()

        self.down1 = nn.Conv2d(input_nc, ngf, kernel_size=4, stride=2, padding=1)
        self.down2 = Downsample(ngf, ngf*2)
        self.down3 = Downsample(ngf*2, ngf*4)
        self.down4 = Downsample(ngf*4, ngf*8)
        self.down5 = Downsample(ngf*8, ngf*8)
        self.down6 = Downsample(ngf*8, ngf*8)
        self.down7 = Downsample(ngf*8, ngf*8)

        self.center = Downsample(ngf*8, ngf*8)

        self.up7 = Upsample(ngf*8, ngf*8, use_dropout=True)
        self.up6 = Upsample(ngf*8*2, ngf*8, use_dropout=True)
        self.up5 = Upsample(ngf*8*2, ngf*8, use_dropout=True)
        self.up4 = Upsample(ngf*8*2, ngf*8)
        self.up3 = Upsample(ngf*8*2, ngf*4)
        self.up2 = Upsample(ngf*4*2, ngf*2)
        self.up1 = Upsample(ngf*2*2, ngf)

        self.output_block = nn.Sequential(
            nn.ReLU(),
            nn.ConvTranspose2d(ngf*2, output_nc, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        
        c = self.center(d7)
        
        x = self.up7(c, d7)
        x = self.up6(x, d6)
        x = self.up5(x, d5)
        x = self.up4(x, d4)
        x = self.up3(x, d3)
        x = self.up2(x, d2)
        x = self.up1(x, d1)

        x = self.output_block(x)
        return x

In [36]:
class NLayerDiscriminator(nn.Module):
    def __init__(self, input_nc=6, ndf=64):
        super(NLayerDiscriminator, self).__init__()

        self.layers = nn.Sequential(
            nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=1), 
            nn.LeakyReLU(0.2),
            
            ConvBlock(ndf, ndf*2),
            ConvBlock(ndf*2, ndf*4),
            ConvBlock(ndf*4, ndf*8, stride=1),

            nn.Conv2d(ndf*8, 1, kernel_size=4, stride=1, padding=1),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.layers(input)


class ConvBlock(nn.Module):
    # conv => batch norm => LeakyReLU
    def __init__(self, in_dim, out_dim, kernel_size=4, stride=2, padding=1):
        super(ConvBlock, self).__init__()

        self.layers = nn.Sequential(
            nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_dim),
            nn.LeakyReLU(0.2)
        )

    def forward(self, x):
        x = self.layers(x)
        return x

In [37]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

generator = UnetGenerator().to(device)
discriminator = NLayerDiscriminator().to(device)

In [38]:
from torch import nn, optim

# 超参数
LR = 1e-4
BATCH_SIZE = 8
EPOCHS = 100

# 优化器
optimizerG = optim.Adam(
    generator.parameters(),
    lr=LR,
    betas=(0.5, 0.999)
)

optimizerD = optim.Adam(
    discriminator.parameters(),
    lr=LR,
    betas=(0.5, 0.999)
)

# 损失函数
bce_loss = nn.BCELoss()
l1_loss = nn.L1Loss()

# dataloader
data_loader_train = DataLoader(
    paired_dataset_train,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True
)

data_loader_test = DataLoader(
    paired_dataset_test,
    batch_size=BATCH_SIZE
)

In [39]:
out = generator(torch.ones([8, 3, 256, 256]).to(device))
print('生成器输出尺寸：', out.shape)

out = discriminator(torch.ones([8, 6, 256, 256]).to(device))
print('鉴别器输出尺寸：', out.shape)

生成器输出尺寸： torch.Size([8, 3, 256, 256])
鉴别器输出尺寸： torch.Size([8, 1, 30, 30])


In [40]:
results_save_path = './results/'
if os.path.exists(results_save_path):
    os.makedirs(results_save_path, exist_ok=True)  # 保存每个epoch的测试结果

weights_save_path = './checkpoints/'
if os.path.exists(weights_save_path):
    os.makedirs(weights_save_path, exist_ok=True)  # 保存模型

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [41]:
EPOCHS = 50
from torch.utils.tensorboard import SummaryWriter
# 实例化SummaryWriter对象, 用tensorboard记录实验结果
writer = SummaryWriter(log_dir='/root/tf-logs/')

In [42]:
for epoch in range(EPOCHS):
    for real_A, real_B in tqdm(data_loader_train):
        real_A = real_A.to(device)
        real_B = real_B.to(device)

        optimizerD.zero_grad()

        # D(real)
        real_AB = torch.cat([real_A, real_B], dim=1)
        d_real_predict = discriminator(real_AB)
        d_real_loss = bce_loss(d_real_predict, torch.ones_like(d_real_predict))

        # D(fake)
        fake_B = generator(real_A).detach()
        fake_AB = torch.cat([real_A, fake_B], dim=1)
        d_fake_predict = discriminator(fake_AB)
        d_fake_loss = bce_loss(d_fake_predict, torch.zeros_like(d_fake_predict))

        # train D
        d_loss = (d_real_loss + d_fake_loss) / 2.
        d_loss.backward()
        optimizerD.step()

        optimizerG.zero_grad()

        # D(fake)
        fake_B = generator(real_A)
        fake_AB = torch.cat([real_A, fake_B], dim=1)
        g_fake_predict = discriminator(fake_AB)
        g_bce_loss = bce_loss(g_fake_predict, torch.ones_like(g_fake_predict))
        g_l1_loss = l1_loss(fake_B, real_B) * 100.
        g_loss = g_bce_loss + g_l1_loss

        # train G
        g_loss.backward()
        optimizerG.step()

        # 记录训练和验证损失
        writer.add_scalar('Loss/D', d_loss.item(), global_step = epoch)
        writer.add_scalar('Loss/G', g_loss.item(), global_step = epoch)

    print(f'Epoch [{epoch+1}/{EPOCHS}] Loss D: {d_loss.item()}, Loss G: {g_loss.item()}')

    if (epoch+1) % 10 == 0:
        torch.save(generator.state_dict(), os.path.join(weights_save_path, 'epoch'+str(epoch+1).zfill(3)+'.pt'))

        # test
        generator.eval()
        with torch.no_grad():
            for real_A, real_B in data_loader_test:
                real_A = real_A.to(device)
                real_B = real_B.to(device)
                break

            fake_B = generator(real_A)
            result = torch.cat([real_A[:3], real_B[:3], fake_B[:3]], dim=3)

            result = result.detach().cpu().numpy().transpose(0, 2, 3, 1)
            result = result * 127.5 + 127.5
            
            result = np.vstack(result)
            result = (result * 255).astype(np.uint8)
            
        cv2.imwrite(os.path.join(results_save_path, 'epoch'+str(epoch+1).zfill(3)+'.png'), result)

        generator.train()

100%|██████████| 187/187 [01:04<00:00,  2.91it/s]


Epoch [1/50] Loss D: 0.5027278661727905, Loss G: 16.82379722595215


100%|██████████| 187/187 [01:03<00:00,  2.94it/s]


Epoch [2/50] Loss D: 0.558770477771759, Loss G: 15.72962760925293


100%|██████████| 187/187 [01:03<00:00,  2.92it/s]


Epoch [3/50] Loss D: 0.48727524280548096, Loss G: 17.980648040771484


100%|██████████| 187/187 [01:03<00:00,  2.92it/s]


Epoch [4/50] Loss D: 0.6101694107055664, Loss G: 12.626778602600098


100%|██████████| 187/187 [01:03<00:00,  2.95it/s]


Epoch [5/50] Loss D: 0.44258081912994385, Loss G: 15.544577598571777


100%|██████████| 187/187 [01:04<00:00,  2.92it/s]


Epoch [6/50] Loss D: 0.663121223449707, Loss G: 13.911996841430664


100%|██████████| 187/187 [01:03<00:00,  2.93it/s]


Epoch [7/50] Loss D: 0.5498199462890625, Loss G: 16.43813705444336


100%|██████████| 187/187 [01:03<00:00,  2.96it/s]


Epoch [8/50] Loss D: 0.49890756607055664, Loss G: 13.836112022399902


100%|██████████| 187/187 [01:03<00:00,  2.94it/s]


Epoch [9/50] Loss D: 0.46903884410858154, Loss G: 16.059768676757812


100%|██████████| 187/187 [01:03<00:00,  2.96it/s]


Epoch [10/50] Loss D: 0.5181180238723755, Loss G: 12.167506217956543


100%|██████████| 187/187 [01:03<00:00,  2.95it/s]


Epoch [11/50] Loss D: 0.6204447746276855, Loss G: 13.727514266967773


100%|██████████| 187/187 [01:03<00:00,  2.94it/s]


Epoch [12/50] Loss D: 0.7356151342391968, Loss G: 10.575288772583008


100%|██████████| 187/187 [01:04<00:00,  2.92it/s]


Epoch [13/50] Loss D: 0.3466085195541382, Loss G: 13.238426208496094


100%|██████████| 187/187 [01:05<00:00,  2.88it/s]


Epoch [14/50] Loss D: 0.4446236193180084, Loss G: 14.471049308776855


100%|██████████| 187/187 [01:04<00:00,  2.91it/s]


Epoch [15/50] Loss D: 0.3817664384841919, Loss G: 12.214873313903809


100%|██████████| 187/187 [01:03<00:00,  2.92it/s]


Epoch [16/50] Loss D: 0.3523368835449219, Loss G: 12.835640907287598


100%|██████████| 187/187 [01:03<00:00,  2.93it/s]


Epoch [17/50] Loss D: 0.5337226390838623, Loss G: 12.865813255310059


100%|██████████| 187/187 [01:03<00:00,  2.94it/s]


Epoch [18/50] Loss D: 0.8136641383171082, Loss G: 16.73992347717285


100%|██████████| 187/187 [01:02<00:00,  2.97it/s]


Epoch [19/50] Loss D: 0.34980088472366333, Loss G: 15.575443267822266


100%|██████████| 187/187 [01:03<00:00,  2.95it/s]


Epoch [20/50] Loss D: 0.3340662121772766, Loss G: 12.588224411010742


100%|██████████| 187/187 [01:03<00:00,  2.96it/s]


Epoch [21/50] Loss D: 0.1475202590227127, Loss G: 15.593640327453613


100%|██████████| 187/187 [01:03<00:00,  2.94it/s]


Epoch [22/50] Loss D: 0.3692028820514679, Loss G: 15.470340728759766


100%|██████████| 187/187 [01:03<00:00,  2.95it/s]


Epoch [23/50] Loss D: 0.4652760624885559, Loss G: 16.208431243896484


100%|██████████| 187/187 [01:03<00:00,  2.97it/s]


Epoch [24/50] Loss D: 0.3089368939399719, Loss G: 16.088897705078125


100%|██████████| 187/187 [01:03<00:00,  2.96it/s]


Epoch [25/50] Loss D: 0.15466326475143433, Loss G: 15.318866729736328


100%|██████████| 187/187 [01:03<00:00,  2.96it/s]


Epoch [26/50] Loss D: 0.17686261236667633, Loss G: 14.008050918579102


100%|██████████| 187/187 [01:03<00:00,  2.94it/s]


Epoch [27/50] Loss D: 0.13070616126060486, Loss G: 11.792806625366211


100%|██████████| 187/187 [01:03<00:00,  2.95it/s]


Epoch [28/50] Loss D: 0.09050644934177399, Loss G: 13.262195587158203


100%|██████████| 187/187 [01:02<00:00,  2.98it/s]


Epoch [29/50] Loss D: 0.16504643857479095, Loss G: 16.701887130737305


100%|██████████| 187/187 [01:03<00:00,  2.96it/s]


Epoch [30/50] Loss D: 0.17733721435070038, Loss G: 11.337803840637207


100%|██████████| 187/187 [01:03<00:00,  2.94it/s]


Epoch [31/50] Loss D: 0.1864265501499176, Loss G: 12.862166404724121


100%|██████████| 187/187 [01:03<00:00,  2.96it/s]


Epoch [32/50] Loss D: 0.1430969536304474, Loss G: 13.396514892578125


100%|██████████| 187/187 [01:02<00:00,  2.97it/s]


Epoch [33/50] Loss D: 0.17126424610614777, Loss G: 14.274295806884766


100%|██████████| 187/187 [01:03<00:00,  2.96it/s]


Epoch [34/50] Loss D: 0.24070236086845398, Loss G: 14.478164672851562


100%|██████████| 187/187 [01:03<00:00,  2.96it/s]


Epoch [35/50] Loss D: 0.17396512627601624, Loss G: 12.673310279846191


100%|██████████| 187/187 [01:03<00:00,  2.94it/s]


Epoch [36/50] Loss D: 0.46584630012512207, Loss G: 14.145442962646484


100%|██████████| 187/187 [01:03<00:00,  2.95it/s]


Epoch [37/50] Loss D: 0.13379889726638794, Loss G: 13.102850914001465


100%|██████████| 187/187 [01:03<00:00,  2.96it/s]


Epoch [38/50] Loss D: 2.015998601913452, Loss G: 15.213249206542969


100%|██████████| 187/187 [01:02<00:00,  2.97it/s]


Epoch [39/50] Loss D: 0.4630644619464874, Loss G: 10.872984886169434


100%|██████████| 187/187 [01:03<00:00,  2.97it/s]


Epoch [40/50] Loss D: 0.4034480154514313, Loss G: 11.869609832763672


100%|██████████| 187/187 [01:03<00:00,  2.96it/s]


Epoch [41/50] Loss D: 0.24420166015625, Loss G: 12.363496780395508


100%|██████████| 187/187 [01:03<00:00,  2.96it/s]


Epoch [42/50] Loss D: 0.31651368737220764, Loss G: 13.009367942810059


100%|██████████| 187/187 [01:03<00:00,  2.95it/s]


Epoch [43/50] Loss D: 0.2097504585981369, Loss G: 14.992843627929688


100%|██████████| 187/187 [01:02<00:00,  2.97it/s]


Epoch [44/50] Loss D: 0.22647199034690857, Loss G: 13.713573455810547


100%|██████████| 187/187 [01:03<00:00,  2.96it/s]


Epoch [45/50] Loss D: 0.36940205097198486, Loss G: 12.299605369567871


100%|██████████| 187/187 [01:03<00:00,  2.97it/s]


Epoch [46/50] Loss D: 0.7054151296615601, Loss G: 12.172661781311035


100%|██████████| 187/187 [01:03<00:00,  2.95it/s]


Epoch [47/50] Loss D: 0.17636147141456604, Loss G: 12.13811206817627


100%|██████████| 187/187 [01:02<00:00,  2.97it/s]


Epoch [48/50] Loss D: 0.13828159868717194, Loss G: 12.507722854614258


100%|██████████| 187/187 [01:03<00:00,  2.96it/s]


Epoch [49/50] Loss D: 0.433326780796051, Loss G: 10.497750282287598


100%|██████████| 187/187 [01:03<00:00,  2.97it/s]


Epoch [50/50] Loss D: 0.17312343418598175, Loss G: 12.490008354187012
