In [2]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

# import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
"""
@Author : Keep_Trying_Go
@Major  : Computer Science and Technology
@Hobby  : Computer Vision
@Time   : 2023/5/16 15:23
"""

import os
import torch
import albumentations
from torchvision import transforms
from albumentations.pytorch import ToTensorV2

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DATASET_DIR = "data/img_align_celeba"
TRAIN_DIR = "data/celeba/train"
VAL_DIR = "data/celeba/val"
WIDTH,HEIGHT = 128,128
LEARNING_RATIO = 1e-4
BATCH_SIZE_T = 16
BATCH_SIZE_V = 16
SHUFFLE = True
NUM_WORKERS = 0
SAVE_MODELS = "models"
LOAD_MODELS = "models"
SAVE_IMAGES = "images"
NUM_EPOCHS = 1
DOWNSCALE = 4
NEW_SIZE_W,NEW_SIZE_H = WIDTH // DOWNSCALE,HEIGHT // DOWNSCALE
LOSS_RATIO = 1e-3
BEAT1 = 0.9
BEAT2 = 0.9
EPSILON = 1e-8

#对32 x 32和128 x 128的图像进行增强
Transform = albumentations.Compose([
    albumentations.HorizontalFlip(p = 0.5),
    albumentations.VerticalFlip(p = 0.5),
    albumentations.RandomBrightness(limit=0.2),
    albumentations.HueSaturationValue(hue_shift_limit=20,sat_shift_limit=30),
    albumentations.RandomContrast(limit=0.2),
    albumentations.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5]),
    ToTensorV2()
])
#对128 x 128的图像进行裁剪到32 x 32
down_Transform = albumentations.Compose([
    albumentations.Resize(height=NEW_SIZE_W,width=NEW_SIZE_H)
])


In [4]:
"""
@Author : Keep_Trying_Go
@Major  : Computer Science and Technology
@Hobby  : Computer Vision
@Time   : 2023/5/16 15:53
"""

import os
import torch
import numpy as np
from PIL import Image
from torch.utils.data import DataLoader,Dataset

class MyDataset(Dataset):
    def __init__(self,root_dir):
        super(MyDataset, self).__init__()
        self.root_dir = root_dir
        self.imgs_list = os.listdir(root_dir)

    def __len__(self):
        return len(self.imgs_list)

    def __getitem__(self, index):
        img_name = self.imgs_list[index]
        img_path = os.path.join(self.root_dir,img_name)
        image = np.array(Image.open(img_path).convert("RGB"))
        labels = Transform(image=image)["image"]
        features = Transform(image=down_Transform(image=image)["image"])["image"]
        return labels,features

if __name__ == '__main__':
    mydataT = MyDataset("/kaggle/input/celebadataset/celeba/train")
    print(mydataT.__len__())
    print("labels.shape: {}--feautres.shape: {}".format(np.shape(mydataT[0][0]),np.shape(mydataT[0][1])))
    mydataV = MyDataset("/kaggle/input/celebadataset/celeba/val")
    print(mydataV.__len__())
    print("labels.shape: {}--feautres.shape: {}".format(np.shape(mydataV[0][0]), np.shape(mydataV[0][1])))

10319
labels.shape: torch.Size([3, 128, 128])--feautres.shape: torch.Size([3, 32, 32])
1147
labels.shape: torch.Size([3, 128, 128])--feautres.shape: torch.Size([3, 32, 32])


In [10]:
"""
@Author : Keep_Trying_Go
@Major  : Computer Science and Technology
@Hobby  : Computer Vision
@Time   : 2023/5/16 14:30
"""

import torch
from torchinfo import summary

class ConvBlock(torch.nn.Module):
    def __init__(self,in_channels,out_channels):
        super(ConvBlock, self).__init__()
        self.conv1 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels = in_channels,out_channels=out_channels,kernel_size=(3,3),
                            stride=(1,1),padding=(1,1),bias=False),
#             torch.nn.BatchNorm2d(num_features=out_channels),
            torch.nn.LeakyReLU(negative_slope=0.01)
        )
        self.conv2 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=(3, 3),
                            stride=(2, 2), padding=(1, 1), bias=False),
#             torch.nn.BatchNorm2d(num_features=out_channels),
            torch.nn.LeakyReLU(negative_slope=0.01)
        )
    def forward(self,x):
        x = self.conv1(x)
        x = self.conv2(x)
        return x

class Discriminator(torch.nn.Module):
    def __init__(self,in_channels,out_channels,data_format = "channels_first"):
        super(Discriminator, self).__init__()
        if data_format == "channels_first":
            self._input_shape = [-1,3,128,128]
            self.bn_axis = 1
        else:
            assert data_format == "channels_last"
            self._input_shape=[-1,128,128,3]
            self.bn_axis = 3

        self.conv1 = ConvBlock(in_channels=in_channels,out_channels = 64)
        self.conv2 = ConvBlock(in_channels=64,out_channels=128)
        self.conv3 = ConvBlock(in_channels=128,out_channels=256)
        self.conv4 = ConvBlock(in_channels=256,out_channels=512)
        # self.fc1 = torch.nn.Linear(in_features=512 * 8 * 8,out_features=256)
        # self.fc2 = torch.nn.Linear(in_features=256,out_features=out_channels)
        self.conv5 = torch.nn.Conv2d(in_channels=512, out_channels=out_channels, kernel_size=(3, 3),
                            stride=(2, 2), padding=(1, 1), bias=False)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self,x):
        # x = torch.reshape(x,self._input_shape)

        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)

        # x = x.view(-1,512 * 8 * 8)
        # x = self.fc1(x)
        # x = self.fc2(x)
        x = self.conv5(x)
        x = self.sigmoid(x)

        return x

if __name__ == '__main__':
    x = torch.randn(size = (1,3,128,128))
    model = Discriminator(in_channels=3,out_channels=1)
    summary(model,input_size=(1,3,128,128))
    print(model(x).shape)

torch.Size([1, 1, 4, 4])


In [11]:
"""
@Author : Keep_Trying_Go
@Major  : Computer Science and Technology
@Hobby  : Computer Vision
@Time   : 2023/5/16 13:27
"""

import torch
import numpy as np
from torchinfo import summary

#残差模块
class _IdentityBlock(torch.nn.Module):
    #data_format判断输入图像的通道C位置[w,h,c]或者[c,w,h]
    def __init__(self,in_channels,out_channels,stride = (1,1),data_format = "channels_first"):
        super(_IdentityBlock, self).__init__()
        self.bn_axis = 1 if data_format == "channels_first" else 3
        self.conv = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=(3,3),
                            stride=stride,padding=(1,1),bias=False),
#             torch.nn.BatchNorm2d(num_features=out_channels),
            torch.nn.PReLU(num_parameters=1),
            torch.nn.Conv2d(in_channels = out_channels,out_channels = out_channels,kernel_size=(3,3),
                            stride=(1,1),padding=(1,1),bias=False),
#             torch.nn.BatchNorm2d(num_features=out_channels)
        )
    def forward(self,x):
        x_resnet = self.conv(x)
        out = x + x_resnet
        return out

def phaseShift(inputs,scale,shape_1,shape_2):
    x = torch.reshape(inputs,shape_1)
    x = torch.reshape(x,[0,1,3,2,4])
    return torch.reshape(x,shape_2)

#使用PixelShuffle进行上采样
def PixelShuffle(inputs,scale = 2):
    """
    :param inputs: 进行上采样的输入
    :param scale: 上采样的倍率
    :return:
    """
    size = np.shape(inputs)
    batch_size = size[0]
    h = size[1]
    w = size[2]
    c = size[-1]
    #进行上采样之后需要进行通道数1/4
    channel_target = c // (scale * scale)
    #获得上采样因子
    channel_factor = c // channel_target
    shape_1 = [batch_size,h,w,channel_factor // scale,channel_target//scale]
    shape_2 = [batch_size,h * scale,w * scale]
    #reshape and transpose for periods shuffle for each channel
    input_split = torch.split(inputs,channel_target,dim=3)
    output = torch.cat([phaseShift(x,scale,shape_1,shape_2) for x in input_split],dim=3)
    return output

# 生成器
class Generator(torch.nn.Module):
    def __init__(self,upscale = 2,data_format = "channels_last"):
        super(Generator, self).__init__()
        self.upscale = 2
        if data_format == "channels_first":
            self._input_shape = [-1,3,32,32]
            self.bn_axis = 1
        else:
            assert data_format == "channels_last"
            self._input_shape=[-1,32,32,3]
            self.bn_axis = 3

        self.initial_conv = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(9, 9), stride=(1, 1),
                            padding=(4,4),bias=False),
            torch.nn.PReLU(num_parameters=1)
        )

        #使用残差模块
        self.identityBlocks = [_IdentityBlock(in_channels=64,out_channels=64) for _ in range(16)]
        self.Blocks = torch.nn.Sequential(
            *self.identityBlocks
        )

        self.conv2 = torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(1, 1),
                        padding=(1, 1))

        # 由于进行了一次上采样，通道数减为原来的1/4，所以输入通道数为256=>64
        self.upconv1 = torch.nn.Conv2d(in_channels=64,out_channels=256,kernel_size=(3,3),stride=(1,1),
                        padding=(1,1))
        self.prelu1 = torch.nn.PReLU(num_parameters=1)
        #由于进行了一次上采样，通道数减为原来的1/4，所以输入通道数为256=>64
        self.upconv2 = torch.nn.Conv2d(in_channels=64, out_channels=256, kernel_size=(3, 3), stride=(1, 1),
                        padding=(1, 1))
        self.prelu2 = torch.nn.PReLU(num_parameters=1)
        self.conv3 = torch.nn.Conv2d(in_channels=64, out_channels=3, kernel_size=(9, 9), stride=(1, 1),
                        padding=(4, 4), bias=False)
    def forward(self,x):
        # x = torch.reshape(x,self._input_shape)
        x = self.initial_conv(x)

        x_resnet = self.Blocks(x)
        x_resnet = self.conv2(x_resnet)
        x = x + x_resnet

        #进行第一次上采样
        x = self.upconv1(x)
        x = torch.nn.PixelShuffle(self.upscale)(x)
        x = self.prelu1(x)

        #进行第二次上采样
        x = self.upconv2(x)
        x = torch.nn.PixelShuffle(self.upscale)(x)
        x = self.prelu2(x)

        x = self.conv3(x)
        x = torch.tanh(x)

        return x


if __name__ == '__main__':
    x = torch.randn(size = (1,3,32,32))
    gen = Generator(upscale=2,data_format="channels_first")
    summary(gen,input_size=(1,3,32,32))
    print(gen(x).shape)



torch.Size([1, 3, 128, 128])


In [12]:
"""
@Author : Keep_Trying_Go
@Major  : Computer Science and Technology
@Hobby  : Computer Vision
@Time   : 2023/5/16 15:26
"""
import os
import torch
import numpy as np
import matplotlib.pyplot as plt

#保存模型
def save_model(model,optimizer,epoch):
    """
    :param model:
    :param epoch:
    :return:
    """
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, os.path.join(str(epoch)+'gen.tar'))


def load_checkpoin(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    for param_group in optimizer.param_group:
        param_group["lr"] = lr

def generate_and_save_images(save_dir,features,gen,epoch):
    """
    :param save_dir:
    :param gen:
    :param epoch:
    :return:
    """
    predictions = gen(features)
    pass

def draw(gen_loss,disc_loss):
    """
    :param gen_loss:
    :param disc_loss:
    :return:
    """
    plt.plot(range(1, len(gen_loss) + 1), gen_loss, label='genLoss')
    plt.plot(range(1, len(disc_loss) + 1), disc_loss, label='discLoss')
    plt.legend()
    plt.title('GEN_DISC-LOSS')
    plt.savefig('logs/figure.png')

def save_images(model,epoch,step,val_loader):
    imgs = model(val_loader).detach().cpu().numpy()
    fig = plt.figure(figsize=(2,2))
    imgs = np.squeeze(imgs)
    for i in range(4):
        plt.subplot(2,2,i + 1)
        plt.imshow((imgs[i] + 1) / 2)
        plt.axis("off")
    plt.savefig(os.path.join(config.SAVE_IMAGES,str(epoch)+"_"+str(step)+'.png'))

In [8]:
import gc
gc.collect()

0

In [13]:
"""
@Author : Keep_Trying_Go
@Major  : Computer Science and Technology
@Hobby  : Computer Vision
@Time   : 2023/5/16 16:23
"""

import os
import torch
import numpy as np
from tqdm import tqdm
from torchinfo import summary
import torchvision.models.vgg
from torch.utils.data import DataLoader,Dataset


def vgg19():
    """
    vgg19主要是用来在这里提取特征，进生成器生成的图片输入到vgg19中和将标签图片输入到vgg19中，
    将两个样本最后输出的特征之间的距离作为最后计算损失值
    :return:
    """
    vgg19 = torchvision.models.vgg.vgg19(pretrained = True,progress = True)
    # vgg19.classifier = torch.nn.Sequential()
    # vgg19.avgpool = torch.nn.Sequential()
    #去掉最后的全连接层和avg_pooling层
    vgg19 = torch.nn.Sequential(*(list(vgg19.children())[:-2])[0][:36])
    # summary(vgg19,input_size=(1,3,128,128))
    # print(vgg19)
    return vgg19

#生成器损失
def create_g_loss(d_output,g_output,labels,loss_model,loss_fn_gen):
    """
    :param d_output: 判别器输出结果
    :param g_output: 生成器生成的图片
    :param labels: 标签值
    :param loss_model: 损失模型
    :return:
    """
    gene_ce_loss = loss_fn_gen(torch.ones_like(d_output),d_output)
    # print('labels.shape: {}'.format(np.shape(labels)))
    # print('g_output.shape: {}'.format(np.shape(g_output)))
    vgg_loss = torch.mean(torch.square(loss_model(labels) - loss_model(g_output)))
    g_loss = vgg_loss + LOSS_RATIO * gene_ce_loss
    return g_loss

#判别器损失值
def create_d_loss(disc_real_output,disc_fake_output,loss_fn_real,loss_fn_fake):
    """
    :param disc_real_output: 标签值输入到判别器的输出结果
    :param disc_fake_output: 生成器生成的图片输入到判别器中的输出结果
    :return:
    """
    disc_real_loss = loss_fn_real(torch.ones_like(disc_real_output),disc_real_output)
    disc_fake_loss = loss_fn_fake(torch.zeros_like(disc_fake_output),disc_fake_output)
    disc_loss =  disc_fake_loss + disc_real_loss
    return disc_loss

def train_step(features,labels,loss_model,gen,disc,opt_gen,opt_disc,loss_fn_gen,loss_fn_real,loss_fn_fake):
    """
    :param features: 低分辨率图像
    :param labels: 高分辨率图像
    :param loss_model: 使用VGG19计算输入图像的特征
    :param gen:生成器
    :param disc:判别器
    :param opt_gen:生成器优化器
    :param opt_disc:判别器优化器
    :return:
    """
    fake_img = gen(features)
    real_disc = disc(labels)
    fake_disc = disc(fake_img)
    g_loss = create_g_loss(fake_disc,fake_img,labels,loss_model,loss_fn_gen)
    #注意fake_disc.detach()
    d_loss = create_d_loss(real_disc,fake_disc.detach(),loss_fn_real,loss_fn_fake)

    opt_gen.zero_grad()
    g_loss.requires_grad_(True)
    g_loss.backward(retain_graph=True)
    opt_gen.step()

    opt_disc.zero_grad()
    d_loss.requires_grad_(True)
    d_loss.backward(retain_graph=True)
    opt_disc.step()
    return g_loss,d_loss



def train():
    #加载模型
    gen = Generator().to(DEVICE)
    disc = Discriminator(in_channels=3,out_channels=1).to(DEVICE)
    #加载数据集
    trainDataset = MyDataset(root_dir="/kaggle/input/celebadataset/celeba/train")
    valDataset = MyDataset(root_dir="/kaggle/input/celebadataset/celeba/val")
    trainLoader = DataLoader(
        dataset=trainDataset,
        batch_size=BATCH_SIZE_T,
        shuffle=True,
        num_workers=0
    )
    valLoader = DataLoader(
        dataset=valDataset,
        batch_size=BATCH_SIZE_V,
        shuffle=True,
        num_workers=0
    )
    #定义优化器
    opt_gen = torch.optim.Adam(params=gen.parameters(),lr = LEARNING_RATIO,
                               betas=(BEAT1,BEAT2),eps=EPSILON)
    opt_disc = torch.optim.Adam(params=disc.parameters(),lr = LEARNING_RATIO,
                                betas=(BEAT1,BEAT2),eps=EPSILON)

    #定义损失函数
    loss_fn_gen = torch.nn.BCELoss()
    loss_fn_real = torch.nn.BCELoss()
    loss_fn_fake = torch.nn.BCELoss()

    loss_model = vgg19()
    loss_model = loss_model.to(DEVICE)

    gen_loss = []
    disc_loss = []
    for epoch in range(NUM_EPOCHS):
        all_g_cost,all_d_cost = 0,0
        loop = tqdm(trainLoader,leave=True)
        loop.set_description(desc="training: ")
        gen.train()
        disc.train()
        for step,data in enumerate(loop):
            imgs,labels = data
            imgs,labels = imgs.to(DEVICE),labels.to(DEVICE)

            g_loss,d_loss = train_step(imgs,labels,loss_model,gen,disc,opt_gen,opt_disc,
                                       loss_fn_gen,loss_fn_real,loss_fn_fake)

            all_g_cost += g_loss
            all_d_cost += d_loss
            loop.set_description(desc="training: ")

            if step % 50 == 0 and step > 0:
                loop.set_postfix(epoch = epoch,g_loss = g_loss,d_loss = d_loss)
                print("--------------------------------------g_loss: {:.6f}--------------------------------------".format(g_loss))
                print("--------------------------------------d_loss: {:.6f}--------------------------------------".format(d_loss))
        gen_loss.append(all_g_cost / len(trainLoader))
        disc_loss.append(all_d_cost / len(trainLoader))

        gen.eval()
        disc.eval()
        with torch.no_grad():
            loop = tqdm(valLoader,leave=True)
            loop.set_description(desc="valing: ")
            for step,data in enumerate(loop):
                imgs,labels = data
                imgs,labels = imgs.to(DEVICE),labels.to(DEVICE)

                if step % 100 == 0 and step > 0:
                    utils.save_images(gen,epoch,step,imgs)
        if epoch % 10 == 0:
            save_model(gen, opt_gen, epoch)
    draw(gen_loss,disc_loss)


if __name__ == '__main__':
    # vgg19()
    train()

# import tensorflow as tf
#
# vgg19 = tf.keras.applications.vgg19.VGG19(include_top = False,weights='imagenet',input_shape=(128,128,3))
# vgg19.trainable = False
# for l in vgg19.layers:
#     l.trainable = False
#
# loss_models = tf.keras.Model(inputs = vgg19.input,outputs = vgg19.get_layer("block5_conv4").output)
# loss_models.trainable = False
# loss_models.summary()

training: :   0%|          | 0/645 [01:19<?, ?it/s]


RuntimeError: The size of tensor a (2) must match the size of tensor b (32) at non-singleton dimension 3

In [None]:
import torch, gc

gc.collect()
torch.cuda.empty_cache()