In [22]:
import torch
import torchvision.transforms as trForms
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import os
import glob
import random
import cv2
import torch.nn as nn
import torch.nn.functional as F

## Dataset

In [23]:
class ImageDataset(Dataset):
    def __init__(self, root = "", transform = None, model = "train"):
        super(ImageDataset, self).__init__()
        self.transform = trForms.Compose(transform)
        # 两组图片 每个sample是一个pair apple-orange
        self.pathA = os.path.join(root, model,"A/*")   # 所有的训练苹果图片路径 root/train/A/*
        self.pathB = os.path.join(root, model, "B/*")  # 所有orange
        self.listA = glob.glob(self.pathA)
        self.listB = glob.glob(self.pathB)


    def __getitem__(self, index):
        im_pathA = self.listA[index % len(self.listA)]
        im_pathB = random.choice(self.listB)

        im_A = Image.open(im_pathA)  # 从路径中读取图片
        im_B = Image.open(im_pathB)

        im_A = self.transform(im_A)
        im_B = self.transform(im_B)
        return {"A": im_A, "B":im_B}

    def __len__(self):
        return max(len(self.listA), len(self.listB))

In [24]:
root = "/data/zhuowei/datasets/cyclegan/datasets/apple2orange"

图片resize, 插值方式为BILINEAR

In [25]:
transforms = [trForms.Resize(256, Image.BILINEAR), trForms.ToTensor()]

In [26]:
dataloader = DataLoader(ImageDataset(root=root, transform=transforms, model="train"), shuffle=True, batch_size=1, num_workers=1)

In [27]:
for batch in dataloader:
    print(batch["A"].shape)
    break

torch.Size([1, 3, 256, 256])


# Model: CycleGAN

主干网络采用ResNet

In [28]:
class resBlock(nn.Module):
    def __init__(self, in_channel):
        super(resBlock, self).__init__()
        conv_block = [
            nn.ReflectionPad2d(1),  # 对输入扩边，上下左右扩展一行一列,填充内容来自输入   避免卷积之后图像尺寸损失
            nn.Conv2d(in_channels=in_channel, out_channels= in_channel, kernel_size=3),
            nn.InstanceNorm2d(in_channel),
            nn.ReLU(inplace = True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_channels=in_channel, out_channels= in_channel, kernel_size=3),
            nn.InstanceNorm2d(in_channel)
        ]

        self.conv_block = nn.Sequential(*conv_block) # 串联算子
    
    def forward(self,x):
        return x + self.conv_block(x)  # 跳连结构

## Generator

In [29]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        # 第一个卷积采用一个7*7的大卷积核
        net = [
            nn.ReflectionPad2d(3),  # 图像先pad3后在经过7*7的卷积核，卷积结果的shape不变
            nn.Conv2d(3,64,7),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        ] # 输出64 * 256 * 256

        # downsampling
        in_channel = 64
        out_channel = in_channel * 2
        for _ in range(2):
            net += [
                nn.Conv2d(in_channels=in_channel, out_channels = out_channel, kernel_size=3, stride=2, padding=1),
                nn.InstanceNorm2d(64),
                nn.ReLU(inplace=True)
            ]
            in_channel = out_channel
            out_channel = in_channel * 2
        
        # resBlock:
        for _ in range(9):
            net += [resBlock(in_channel)]


        # upsampling:
        out_channel = in_channel // 2
        for _ in range(2):
            net += [nn.ConvTranspose2d(in_channel, out_channel, kernel_size=3, stride=2, padding=1, output_padding=1),
                    nn.InstanceNorm2d(out_channel),
                    nn.ReLU(inplace = True)
                    ]
            in_channel = out_channel
            out_channel = in_channel //2


        # output:
        net += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(in_channel, 3, 7),
            nn.Tanh()
        ]
        self.model = nn.Sequential(*net)

    def forward(self, x):
        return self.model(x)

## Discriminator

In [30]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        model = [nn.Conv2d(3,64,4,stride = 2,padding = 1),
                 nn.LeakyReLU(.2, inplace=True)
                ]        
        model += [nn.Conv2d(64,128,4,stride = 2,padding = 1),
                 nn.InstanceNorm2d(128),
                 nn.LeakyReLU(.2, inplace=True)
                ]
        model += [nn.Conv2d(128,256,4,stride = 2,padding = 1),
                 nn.InstanceNorm2d(256),
                 nn.LeakyReLU(.2, inplace=True)
                ]

        model += [nn.Conv2d(256,512,4,stride = 2,padding = 1),
                 nn.InstanceNorm2d(512),
                 nn.LeakyReLU(.2, inplace=True)
                ]
        model += [nn.Conv2d(512,1,4, padding=1)]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        x = self.model(x)

        return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0],-1)

In [31]:
G = Generator()
D = Discriminator()

In [32]:
input_tensor = torch.ones((1,3,256,256), dtype = torch.float)
out = G(input_tensor)

In [33]:
print(out.shape)

torch.Size([1, 3, 256, 256])


In [34]:
out = D(input_tensor)
print(out.shape)

torch.Size([1, 1])


# Train

In [35]:
from utils import tensor2image, LambdaLR, weights_init_normal, ReplayBuffer
import itertools
import tensorboardX

In [36]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [37]:
batch_size = 1
size = 256
lr = 0.0002
n_epochs = 200
epoch = 0
decay_epoch = 100

# network
netG_A2B = Generator().to(device)
netG_B2A = Generator().to(device)
netD_A = Discriminator().to(device)
netD_B = Discriminator().to(device)

# loss
loss_GAN = torch.nn.MSELoss()
loss_Cycle = torch.nn.L1Loss()
loss_identity = torch.nn.L1Loss()  # 真实数据和生成数据相似程度

In [38]:
# Optimizer & LR
opt_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()), lr = lr, betas=(0.5,0.9999))  # 两个生成器的参数连接 同时优化

optD_A = torch.optim.Adam(netD_A.parameters(), lr = lr, betas = (0.5,0.9999))
optD_B = torch.optim.Adam(netD_B.parameters(), lr = lr, betas = (0.5, 0.9999))

In [39]:
# 自定义学习率衰减
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(opt_G, 
                                                   lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).step)

lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optD_A, 
                                                   lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).step)

lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optD_B, 
                                                   lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).step)

In [40]:
data_root = root
input_A = torch.ones([1,3,size,size], dtype = torch.float).to(device)
input_B = torch.ones([1,3,size,size], dtype = torch.float).to(device)
label_real = torch.ones([1], dtype = torch.float, requires_grad = False).to(device)
label_fake = torch.zeros([1], dtype = torch.float, requires_grad = False).to(device)

fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

log_path = "logs"
write_log = tensorboardX.SummaryWriter(log_path)  # 写入对应路径

transform_ = [
    trForms.Resize(int(256 * 1.12), Image.BICUBIC),  #尺寸放大
    trForms.RandomCrop(256),  # 随机剪裁到256
    trForms.RandomHorizontalFlip(),
    trForms.ToTensor(),
    trForms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
]

In [41]:
dataloader = DataLoader(ImageDataset(root, transform_, model ="train"), batch_size = batch_size, shuffle = True, num_workers = 8)

step = 0

In [21]:

for epoch in range(n_epochs):
    for i,batch in enumerate(dataloader):
        real_A = torch.tensor(input_A.copy_(batch["A"]), dtype = torch.float).to(device)
        real_B = torch.tensor(input_B.copy_(batch["B"]), dtype = torch.float).to(device)

        opt_G.zero_grad()
        same_B = netG_A2B(real_B)  # 生成器A2B生成B
        loss_identity_B = loss_identity(same_B, real_B) * 5.0 # 生成的B和真实B的差别 B是橘子

        same_A = netG_B2A(real_A)
        loss_identity_A = loss_identity(same_A, real_A) * 5.0

        fake_B = netG_A2B(real_A)   # 苹果生成假橘子
        pred_fake = netD_B(fake_B)   # 判别器给生成器生成的B打分
        loss_GAN_A2B = loss_GAN(pred_fake, label_real)        


        fake_A = netG_B2A(real_B)
        pred_fake = netD_A(fake_A)
        loss_GAN_B2A = loss_GAN(pred_fake, label_real)

        # cycle loss
        recover_A = netG_B2A(fake_B)
        loss_Cycle_ABA = loss_Cycle(recover_A, real_A)  * 10.0

        recover_B = netG_A2B(fake_A)
        loss_Cycle_BAB = loss_Cycle(recover_B, real_B)  * 10.0

        loss_G = loss_identity_B + loss_identity_A + \
                        loss_GAN_A2B +  loss_GAN_B2A + loss_Cycle_ABA + loss_Cycle_BAB

        loss_G.backward()

        opt_G.step()

        ############################### 判别器 #################################

        optD_A.zero_grad()

        pred_real = netD_A(real_A)
        loss_D_real = loss_GAN(pred_real, label_real)
        fake_A = fake_A_buffer.push_and_pop(fake_A)   # 放入队列中，再从队列中随机选一个当做fake_A
        pred_fake = netD_A(fake_A.detach())  # fake_A由生成器产生，避免更新判别器的时候对生成器参数更新，加入detach  梯度截断

        loss_D_fake = loss_GAN(pred_fake, label_fake)  
        loss_D_A = (loss_D_real + loss_D_fake) * 0.5
        

        loss_D_A.backward()
        optD_A.step()

        ## B--->
        optD_B.zero_grad()
        pred_real = netD_B(real_B)
        loss_D_real = loss_GAN(pred_real, label_real)

        fake_B = fake_B_buffer.push_and_pop(fake_B)
        pred_fake = netD_B(fake_B.detach())
        loss_D_fake = loss_GAN(pred_fake, label_fake)

        loss_D_B = (loss_D_real + loss_D_fake) * 0.5
        loss_D_B.backward()
        optD_B.step()

        print("epoch: {}, loss G: {}, loss_G_identity: {}, loss_G_GAN: {}, loss_G_cycle: {}, loss_D_A: {}, loss_D_B: {} "\
                .format(epoch,
                        loss_G, 
                        loss_identity_A+loss_identity_B, 
                        loss_GAN_A2B+loss_GAN_B2A, 
                        loss_Cycle_ABA + loss_Cycle_BAB,
                        loss_D_A,
                        loss_D_B))
        write_log.add_scalar("loss_G", loss_G, global_step=step + 1)
        write_log.add_scalar("loss_G_identity", loss_identity_A+loss_identity_B, global_step=step + 1)
        write_log.add_scalar("loss_G_GAN",  loss_GAN_A2B+loss_GAN_B2A, global_step=step + 1)
        write_log.add_scalar("loss_G_cycle", loss_Cycle_ABA + loss_Cycle_BAB, global_step=step + 1)
        write_log.add_scalar("loss_D_A", loss_D_A, global_step=step + 1)
        write_log.add_scalar("loss_D_B", loss_D_B, global_step=step + 1)

        step += 1
    
    # 更新学习率
    lr_scheduler_G.step()
    lr_scheduler_D_A.step()
    lr_scheduler_D_B.step()

    torch.save(netG_A2B.state_dict(), "/data/zhuowei_common/models/cyclegan/netG_A2B.pth")
    torch.save(netG_B2A.state_dict(), "/data/zhuowei_common/models/cyclegan/netG_B2A.pth")
    torch.save(netD_A.state_dict(), "/data/zhuowei_common/models/cyclegan/netD_A.pth")
    torch.save(netD_B.state_dict(), "/data/zhuowei_common/models/cyclegan/netD_B.pth")

3855 
epoch: 4, loss G: 12.825312614440918, loss_G_identity: 2.023472309112549, loss_G_GAN: 0.28446170687675476, loss_G_cycle: 10.517378807067871, loss_D_A: 0.41550201177597046, loss_D_B: 0.16072432696819305 
epoch: 4, loss G: 9.2338228225708, loss_G_identity: 2.1844429969787598, loss_G_GAN: 1.141271948814392, loss_G_cycle: 5.908107757568359, loss_D_A: 0.22912465035915375, loss_D_B: 0.22713422775268555 
epoch: 4, loss G: 8.66160774230957, loss_G_identity: 2.4106836318969727, loss_G_GAN: 0.6155174374580383, loss_G_cycle: 5.635406494140625, loss_D_A: 0.3030778467655182, loss_D_B: 0.06250499188899994 
epoch: 4, loss G: 8.801414489746094, loss_G_identity: 1.7809386253356934, loss_G_GAN: 0.8843106031417847, loss_G_cycle: 6.136165142059326, loss_D_A: 0.12473195046186447, loss_D_B: 0.1601119041442871 
epoch: 4, loss G: 9.53542423248291, loss_G_identity: 1.619598627090454, loss_G_GAN: 1.3150423765182495, loss_G_cycle: 6.600783348083496, loss_D_A: 0.2558755874633789, loss_D_B: 0.151617363095283

KeyboardInterrupt: 

# Test

In [None]:
import torchvision.transforms as trForms
from torchvision.utils import save_image

In [None]:
# 加载模型
netG_A2B = Generator().to(device)
netG_B2A = Generator().to(device)

netG_A2B.load_state_dict(torch.load("/data/zhuowei_common/models/cyclegan/netG_A2B.pth"))
netG_B2A.load_state_dict(torch.load("/data/zhuowei_common/models/cyclegan/netG_B2A.pth"))


In [None]:
netG_A2B.eval()
netG_B2A.eval()

size = 256

input_A = torch.ones([1,3,size, size], dtype=torch.float).to(device)

input_B = torch.ones([1,3,size, size], dtype=torch.float).to(device)

transforms_ = [
    trForms.ToTensor(),
    trForms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
]


data_root = "/data/zhuowei/datasets/cyclegan/datasets/apple2orange"

dataloader = DataLoader(ImageDataset(root = data_root,transform = transforms_, model = "test"), batch_size = 1, shuffle = False, num_workers = 8)

if not os.path.exists("/data/zhuowei_common/output/cyclegan/A"):
    os.mkdir("/data/zhuowei_common/output/cyclegan/A")
if not os.path.exists("/data/zhuowei_common/output/cyclegan/B"):
    os.mkdir("/data/zhuowei_common/output/cyclegan/B")

for i, batch in enumerate(dataloader):
    real_A = torch.tensor(input_A.copy_(batch["A"]), dtype = torch.float).to(device)
    real_B = torch.tensor(input_B.copy_(batch["A"]), dtype = torch.float).to(device)

    fake_B = 0.5 * (netG_A2B(real_A).data + 1.0)

    fake_A = 0.5 * (netG_B2A(real_B).data + 1.0)

    save_image(fake_A, "/data/zhuowei_common/output/cyclegan/A/{}.png".format(i))
    save_image(fake_B, "/data/zhuowei_common/output/cyclegan/B/{}.png".format(i))
    print(i)