In [11]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
flag = 0
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))
        flag = 1
        break
    if flag == 1:
        break
    

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/cyclegandataset/vangogh2photo/val/testB/2014-08-16 00_41_44.jpg


In [3]:
"""
@Author : Keep_Trying_Go
@Major  : Computer Science and Technology
@Hobby  : Computer Vision
@Time   : 2023/4/28 17:14
"""

import torch
import torchvision
from torchvision import transforms

class Block(torch.nn.Module):
    def __init__(self,in_channels,out_channels,stride):
        super(Block, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels

        self.conv = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=(4,4),
                            stride=stride,padding=1,bias=True,padding_mode='reflect'),
            torch.nn.BatchNorm2d(num_features=out_channels),
            torch.nn.LeakyReLU(negative_slope=0.2,inplace=True)
        )
    def forward(self,x):
        out = self.conv(x)
        return out

class Discriminator(torch.nn.Module):
    def __init__(self, in_channels=3,features=(64,128,256,512)):
        super(Discriminator, self).__init__()
        self.features = features
        self.initial = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=in_channels, out_channels=features[0], kernel_size=(4, 4),
                            stride=(2,2), padding=1, bias=True, padding_mode='reflect'),
            torch.nn.BatchNorm2d(num_features=features[0]),
            torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
        )
        layers = []
        in_channels=features[0]
        for feature in features[1:]:
            layers.append(
                Block(in_channels,feature,stride=1 if feature == features[-1] else 2)
            )
            in_channels=feature
        layers.append(torch.nn.Conv2d(in_channels=in_channels,out_channels=1,kernel_size=(4,4),
                                      stride=(1,1),padding=1,padding_mode='reflect'))
        #将值归一化到[0-1]
        layers.append(torch.nn.Sigmoid())
        #对layers进行解序列
        self.model = torch.nn.Sequential(
            *layers
        )
    def forward(self,x):
        x = self.initial(x)
        out= self.model(x)
        return out

if __name__ == '__main__':
    x = torch.randn(size = (5,3,256,256),device='cpu')
    model = Discriminator(in_channels=3)
    preds = model(x)
    print(preds.shape)

torch.Size([5, 1, 30, 30])


In [4]:
"""
@Author : Keep_Trying_Go
@Major  : Computer Science and Technology
@Hobby  : Computer Vision
@Time   : 2023/4/28 17:14
"""

import torch
import torchvision
from torchinfo import summary
from torchvision import transforms

class ConvBlock(torch.nn.Module):
    def __init__(self,in_channels,out_channels,down=True,use_act=True,**kwargs):
        super(ConvBlock, self).__init__()
        self.conv = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=in_channels,out_channels=out_channels,padding_mode='reflect',**kwargs)
            if down
            else torch.nn.ConvTranspose2d(in_channels=in_channels,out_channels=out_channels,**kwargs),
            torch.nn.BatchNorm2d(num_features=out_channels),
            torch.nn.ReLU(inplace=True) if use_act else torch.nn.Identity()
        )
    def forward(self,x):
        return self.conv(x)

class ResidualBlock(torch.nn.Module):
    def __init__(self,channels):
        super(ResidualBlock, self).__init__()
        self.block = torch.nn.Sequential(
            ConvBlock(channels,channels,kernel_size = 3,padding = 1),
            ConvBlock(channels,channels,use_act=False,kernel_size = 3,padding = 1),
        )
    def forward(self,x):
        return x + self.block(x)

class Generator(torch.nn.Module):
    def __init__(self,img_channels,num_features = 64,num_residual=9):
        super(Generator, self).__init__()
        self.initial = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=img_channels,out_channels=num_features,kernel_size=(7,7),stride=(1,1),padding=3,padding_mode='reflect'),
            torch.nn.ReLU(inplace=True)
        )
        self.down_blocks = torch.nn.ModuleList(
            [
                ConvBlock(in_channels=num_features,out_channels=num_features*2,kernel_size=(3,3),stride=(2,2),padding=1),
                ConvBlock(in_channels=num_features*2, out_channels=num_features * 4, kernel_size=(3, 3), stride=(2, 2), padding=1)
            ]
        )
        self.residual_blocks = torch.nn.Sequential(
            *[ResidualBlock(num_features*4) for _ in range(num_residual)]
        )
        self.up_blocks = torch.nn.ModuleList(
            [
                ConvBlock(in_channels=num_features*4,out_channels=num_features*2,down=False,kernel_size=3,stride = 2,padding=1,output_padding=1),
                ConvBlock(in_channels=num_features * 2, out_channels=num_features * 1, down=False, kernel_size=3, stride=2, padding=1,output_padding=1)
            ]
        )
        self.last = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=num_features, out_channels=img_channels, kernel_size=(7, 7), stride=(1, 1),
                            padding=3, padding_mode='reflect'),
            torch.nn.Tanh()
        )
    def forward(self,x):
        x = self.initial(x)
        for layer in self.down_blocks:
            x = layer(x)
        x = self.residual_blocks(x)
        for layer in self.up_blocks:
            x = layer(x)
        out = self.last(x)
        return out

if __name__ == '__main__':
    img_channels = 3
    img_size = 256
    x = torch.randn(size = (2,img_channels,img_size,img_size))
    model = Generator(img_channels,9)
    summary(model,input_size=(2,img_channels,img_size,img_size))
#     img = model(x)
#     print(img.shape)

In [5]:
"""
@Author : Keep_Trying_Go
@Major  : Computer Science and Technology
@Hobby  : Computer Vision
@Time   : 2023/4/29 11:36
"""
import torch
import albumentations #深度学习增强库
from torchvision import transforms
from albumentations.pytorch import ToTensorV2

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TRAIN_DIR = "/kaggle/input/cyclegandataset/vangogh2photo/train"
VAL_DIR = "/kaggle/input/cyclegandataset/vangogh2photo/val"

BATCH_SIZE = 1
LEARNING_RATE = 2e-4
LAMBDA_IDENTITY = 0.0
LAMBDA_CYCLE = 10
NUM_WORKERS = 0
NUM_EPOCHS = 20
LOAD_MODEL = False
SAVE_MODEL = True
CHECKPOINT_GEN_H="gen.pth.tar"
CHECKPOINT_GEN_Z = "genz.pth.tar"
CHECKPOINT_CRITICH_H="critich.pth.tar"
CHECKPOINT_CRITICH_Z="criticz.pth.tar"

transform = albumentations.Compose(
    [
        albumentations.Resize(width=256,height=256),
        albumentations.HorizontalFlip(p = 0.5),
        albumentations.ColorJitter(p=0.1),#颜色抖动
        albumentations.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5]),
        albumentations.pytorch.ToTensorV2()
    ],
    additional_targets={"image0":"image"},
)


In [6]:
"""
@Author : Keep_Trying_Go
@Major  : Computer Science and Technology
@Hobby  : Computer Vision
@Time   : 2023/4/29 11:27
"""

import copy
import os
import random

import torch
import numpy as np

def save_checkpoint(model,optimizer,filename = "my_checkpoint.pth.tar",epochs = 0):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict":model.state_dict(),
        "optimizer":optimizer.state_dict(),
    }
    torch.save(checkpoint,filename + str(epochs))
def load_checkpoin(checkpoint_file,model,optimizer,lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file,map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    for param_group in optimizer.param_group:
        param_group["lr"] = lr

def seed_everthing(seed = 42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic=True
    torch.backends.cudnn.benchmark=False



if __name__ == '__main__':
    pass


In [7]:
"""
@Author : Keep_Trying_Go
@Major  : Computer Science and Technology
@Hobby  : Computer Vision
@Time   : 2023/4/28 17:14
"""

import os
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
from torch.utils.data import DataLoader,Dataset

class vanGoPhotoDataset(Dataset):
    def __init__(self,root_vango,root_photo,transform=None):
        super(vanGoPhotoDataset, self).__init__()
        self.root_vango = root_vango
        self.root_photo = root_photo
        self.transform = transform

        self.vango_Images = os.listdir(self.root_vango)
        self.photo_Images = os.listdir(self.root_photo)

        self.length_dataset = max(len(self.vango_Images),len(self.photo_Images))
        self.vango_len = len(self.vango_Images)
        self.photo_len = len(self.photo_Images)

    def __len__(self):
        return self.length_dataset

    def __getitem__(self, index):
        vango_img = self.vango_Images[index % self.vango_len]
        photo_img = self.photo_Images[index % self.photo_len]

        vango_path = os.path.join(self.root_vango,vango_img)
        photo_path = os.path.join(self.root_photo,photo_img)

        vango_img = np.array(Image.open(vango_path).convert("RGB"))
        photo_img = np.array(Image.open(photo_path).convert("RGB"))

        if self.transform:
            argumentation = self.transform(image = vango_img,image0 = photo_img)
            vango_img = argumentation["image"]
            photo_img = argumentation["image0"]
        return vango_img,photo_img




In [None]:
"""
@Author : Keep_Trying_Go
@Major  : Computer Science and Technology
@Hobby  : Computer Vision
@Time   : 2023/4/28 17:14
"""
import os
import sys
import torch
import numpy as np
from PIL import Image
from tqdm import tqdm
from torchvision import transforms

from torchvision.utils import save_image
from torch.utils.data import DataLoader,Dataset

def train_fn(disc_H,disc_Z,gen_H,gen_Z,loader,opt_disc,opt_gen,L1,mse,d_scale,g_scale,epoch):
    loop = tqdm(loader,leave=True)
    for idx ,(vango,photo) in enumerate(loop):
        vango = vango.to(DEVICE)
        photo = photo.to(DEVICE)
        
        #train discriminator
        with torch.cuda.amp.autocast():
            fake_photo = gen_H(vango)
            D_H_real = disc_H(photo)
            D_H_fake = disc_H(fake_photo.detach())
            D_H_real_loss = mse(D_H_real,torch.ones_like(D_H_real))
            D_H_fake_loss = mse(D_H_fake, torch.zeros_like(D_H_fake))
            D_H_loss = D_H_fake_loss + D_H_real_loss

            fake_vango = gen_Z(photo)
            D_Z_real = disc_Z(vango)
            D_Z_fake = disc_Z(fake_vango.detach())
            D_Z_real_loss = mse(D_Z_real, torch.ones_like(D_Z_real))
            D_Z_fake_loss = mse(D_Z_fake, torch.zeros_like(D_Z_fake))
            D_Z_loss = D_Z_fake_loss + D_Z_real_loss

            D_loss = D_H_loss + D_Z_loss

        opt_disc.zero_grad()
        d_scale.scale(D_loss).backward()
        d_scale.step(opt_disc)
        d_scale.update()

    #train Generator H and Z
        with torch.cuda.amp.autocast():
            #adversarial loss for both generator
            D_H_fake = disc_H(fake_photo)
            D_Z_fake = disc_Z(fake_vango)
            loss_G_Z = mse(D_H_fake,torch.ones_like(D_H_fake))
            loss_G_H = mse(D_H_fake,torch.ones_like(D_Z_fake))

            #cycle loss
            cycle_vango = gen_Z(fake_photo)
            cycle_photo = gen_H(fake_vango)
            cycle_vango_loss = L1(vango,cycle_vango)
            cycle_photo_loss = L1(photo,cycle_photo)

            # identity loss
            identity_vango = gen_Z(vango)
            identity_photo = gen_H(photo)
            identity_vango_loss = L1(vango,identity_vango)
            identity_photo_loss = L1(photo,identity_photo)

            G_loss = (
                loss_G_H + loss_G_Z
                + cycle_vango_loss * LAMBDA_CYCLE
                +cycle_photo_loss * LAMBDA_CYCLE
                +identity_vango_loss * LAMBDA_IDENTITY
                +identity_photo_loss * LAMBDA_IDENTITY
            )
        opt_gen.zero_grad()
        g_scale.scale(G_loss).backward()
        g_scale.step(opt_gen)
        g_scale.update()


        with torch.no_grad():
            if idx % 1000 == 0:
                save_image(fake_photo*0.5 + 0.5,f"VangoTophoto{epoch}_{idx}.png")
                save_image(fake_vango*0.5 + 0.5,f"phtotToVango{epoch}_{idx}.png")
                print("******************************************************************\n")
                print("--------------------G_loss : {:.6}-------------------".format(G_loss))
                print("--------------------D_loss : {:.6}-------------------".format(D_loss))


def main_():
    disc_H = Discriminator(in_channels=3).to(DEVICE)
    disc_Z = Discriminator(in_channels=3).to(DEVICE)
    gen_Z = Generator(img_channels=3,num_residual=9).to(DEVICE)
    gen_H = Generator(img_channels=3,num_residual=9).to(DEVICE)
    opt_disc = torch.optim.Adam(
        list(disc_H.parameters()) + list(disc_Z.parameters()),
        lr = LEARNING_RATE,
        betas=(0.5,0.999)
    )
    opt_gen = torch.optim.Adam(
        list(gen_H.parameters()) + list(gen_Z.parameters()),
        lr = LEARNING_RATE,
        betas=(0.5,0.999)
    )
    L1 = torch.nn.L1Loss()
    mse = torch.nn.MSELoss()

    #导入预训练模型
    if LOAD_MODEL:
        load_checkpoin(
            CHECKPOINT_GEN_H,gen_H,opt_gen,LEARNING_RATE
        )
        load_checkpoin(
            CHECKPOINT_GEN_Z, gen_Z, opt_gen, LEARNING_RATE
        )
        load_checkpoin(
            CHECKPOINT_CRITICH_H, disc_H, opt_disc, LEARNING_RATE
        )
        load_checkpoin(
            CHECKPOINT_CRITICH_Z, disc_Z, opt_disc, LEARNING_RATE
        )

    dataset = vanGoPhotoDataset(
        root_vango=TRAIN_DIR + "/trainA",root_photo=TRAIN_DIR + "/trainB",
        transform=transform
    )

    loader = DataLoader(
        dataset=dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=False
    )
    g_scale = torch.cuda.amp.GradScaler()
    d_scale = torch.cuda.amp.GradScaler()

    for epoch in range(NUM_EPOCHS):
        train_fn(disc_H,disc_Z,gen_H,gen_Z,loader,opt_disc,opt_gen,L1,mse,d_scale,g_scale,epoch)
        if SAVE_MODEL:
            save_checkpoint(gen_H,opt_gen,filename=CHECKPOINT_GEN_H,epochs = epoch)
            save_checkpoint(gen_Z, opt_gen, filename=CHECKPOINT_GEN_Z,epochs = epoch)
            save_checkpoint(disc_H, opt_disc, filename=CHECKPOINT_CRITICH_H,epochs = epoch)
            save_checkpoint(disc_Z, opt_disc, filename=CHECKPOINT_CRITICH_Z,epochs = epoch)


if __name__ == '__main__':
    main_()


  0%|          | 1/6287 [00:01<2:18:45,  1.32s/it]

******************************************************************

--------------------G_loss : 10.689-------------------
--------------------D_loss : 1.03261-------------------


 16%|█▌        | 1001/6287 [10:03<53:54,  1.63it/s] 

******************************************************************

--------------------G_loss : 5.08701-------------------
--------------------D_loss : 0.519297-------------------


 32%|███▏      | 2001/6287 [20:04<43:29,  1.64it/s]

******************************************************************

--------------------G_loss : 4.16226-------------------
--------------------D_loss : 0.574354-------------------


 48%|████▊     | 3001/6287 [30:05<33:29,  1.64it/s]

******************************************************************

--------------------G_loss : 5.24547-------------------
--------------------D_loss : 0.542539-------------------


 64%|██████▎   | 4001/6287 [40:06<23:27,  1.62it/s]

******************************************************************

--------------------G_loss : 4.1881-------------------
--------------------D_loss : 0.542041-------------------


 80%|███████▉  | 5001/6287 [50:08<13:07,  1.63it/s]

******************************************************************

--------------------G_loss : 3.9214-------------------
--------------------D_loss : 0.983858-------------------


 95%|█████████▌| 6001/6287 [1:00:11<02:55,  1.63it/s]

******************************************************************

--------------------G_loss : 3.8233-------------------
--------------------D_loss : 0.565725-------------------


100%|██████████| 6287/6287 [1:03:04<00:00,  1.66it/s]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


  0%|          | 1/6287 [00:00<1:06:37,  1.57it/s]

******************************************************************

--------------------G_loss : 4.02927-------------------
--------------------D_loss : 0.335928-------------------


 16%|█▌        | 1001/6287 [10:00<54:02,  1.63it/s] 

******************************************************************

--------------------G_loss : 3.92392-------------------
--------------------D_loss : 0.28532-------------------


 32%|███▏      | 2001/6287 [19:59<43:49,  1.63it/s]

******************************************************************

--------------------G_loss : 3.64139-------------------
--------------------D_loss : 0.616068-------------------


 48%|████▊     | 3001/6287 [29:59<33:19,  1.64it/s]

******************************************************************

--------------------G_loss : 3.28053-------------------
--------------------D_loss : 0.457228-------------------


 64%|██████▎   | 4001/6287 [39:58<23:06,  1.65it/s]

******************************************************************

--------------------G_loss : 3.3704-------------------
--------------------D_loss : 0.517752-------------------


 80%|███████▉  | 5001/6287 [49:56<13:02,  1.64it/s]

******************************************************************

--------------------G_loss : 3.269-------------------
--------------------D_loss : 0.516287-------------------


 95%|█████████▌| 6001/6287 [59:54<02:54,  1.63it/s]

******************************************************************

--------------------G_loss : 3.66621-------------------
--------------------D_loss : 0.587637-------------------


100%|██████████| 6287/6287 [1:02:45<00:00,  1.67it/s]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


  0%|          | 1/6287 [00:00<1:05:54,  1.59it/s]

******************************************************************

--------------------G_loss : 3.09499-------------------
--------------------D_loss : 0.304768-------------------


 16%|█▌        | 1001/6287 [09:57<53:23,  1.65it/s] 

******************************************************************

--------------------G_loss : 3.15497-------------------
--------------------D_loss : 0.446501-------------------


 32%|███▏      | 2001/6287 [19:54<43:31,  1.64it/s]

******************************************************************

--------------------G_loss : 3.14788-------------------
--------------------D_loss : 0.694607-------------------


 48%|████▊     | 3001/6287 [29:52<33:16,  1.65it/s]

******************************************************************

--------------------G_loss : 2.73008-------------------
--------------------D_loss : 0.470462-------------------


 64%|██████▎   | 4001/6287 [39:50<23:07,  1.65it/s]

******************************************************************

--------------------G_loss : 3.28358-------------------
--------------------D_loss : 0.495972-------------------


 80%|███████▉  | 5001/6287 [49:46<13:00,  1.65it/s]

******************************************************************

--------------------G_loss : 2.98879-------------------
--------------------D_loss : 0.530085-------------------


 95%|█████████▌| 6001/6287 [59:43<02:54,  1.64it/s]

******************************************************************

--------------------G_loss : 3.22349-------------------
--------------------D_loss : 0.52813-------------------


100%|██████████| 6287/6287 [1:02:34<00:00,  1.67it/s]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


  0%|          | 1/6287 [00:00<1:06:22,  1.58it/s]

******************************************************************

--------------------G_loss : 3.01564-------------------
--------------------D_loss : 0.511454-------------------


 16%|█▌        | 1001/6287 [09:56<53:31,  1.65it/s] 

******************************************************************

--------------------G_loss : 3.02639-------------------
--------------------D_loss : 0.226997-------------------


 32%|███▏      | 2001/6287 [19:53<43:28,  1.64it/s]

******************************************************************

--------------------G_loss : 3.07585-------------------
--------------------D_loss : 0.298132-------------------


 48%|████▊     | 3001/6287 [29:49<33:08,  1.65it/s]

******************************************************************

--------------------G_loss : 3.00439-------------------
--------------------D_loss : 0.502537-------------------


 64%|██████▎   | 4001/6287 [39:45<23:02,  1.65it/s]

******************************************************************

--------------------G_loss : 2.89359-------------------
--------------------D_loss : 0.658576-------------------


 80%|███████▉  | 5001/6287 [49:41<12:59,  1.65it/s]

******************************************************************

--------------------G_loss : 2.7869-------------------
--------------------D_loss : 0.472824-------------------


 92%|█████████▏| 5758/6287 [57:13<05:14,  1.68it/s]

In [2]:
# import shutil
# import os
 
# if __name__ == '__main__':
#     path = '/kaggle/working'
#     if os.path.exists(path):
#         shutil.rmtree(path)
#         print('删除完成')
#     else:
#         print('原本为空')