In [None]:
import os
import cv2
import numpy as np
import pandas as pd
import torch
import glob as gl
import random
from PIL import Image
from tqdm import tqdm
from torchvision import transforms
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.transforms import InterpolationMode

class AnimeDataSet(Dataset):
    def __init__(self, root='', mode='',trans=None,trans_gray=None):
        super().__init__()
        self.transform = transforms.Compose(trans)
        self.transform_gray=transforms.Compose(trans_gray)
        self.source_path = os.path.join(root, "source/*")
        self.style_path = os.path.join(root, f"{mode}/style/*")
        self.smooth_path = os.path.join(root, f"{mode}/smooth/*")

        self.list_style = gl.glob(self.style_path)
        self.list_smooth = gl.glob(self.smooth_path)
        self.list_source = gl.glob(self.source_path)

    def __getitem__(self, index):
        data = {}
        style_path = random.choice(self.list_style)
        smooth_path = random.choice(self.list_smooth)
        img_path = random.choice(self.list_source)
        style = Image.open(style_path).convert('RGB')
        style_gray=Image.open(style_path).convert('L')
        smooth_gray = Image.open(smooth_path).convert('L')
        img = Image.open(img_path).convert('RGB')
        img = self.transform(img)
        img_A = self.transform(style)
        img_B = self.transform_gray(style_gray)
        img_C = self.transform_gray(smooth_gray)
        img_B=img_B.squeeze(0)
        img_C = img_C.squeeze(0)
        img_B = np.stack([img_B, img_B, img_B], axis=0)
        img_C = np.stack([img_C, img_C, img_C], axis=0)
        data.update({'source': img, 'style': img_A, 'style_gray':img_B,'smooth_gray':img_C})
        return data

    def __len__(self):
        return max(len(self.list_style), len(self.list_smooth),len(self.list_source))





In [None]:
import torch
import cv2
import os
import numpy as np
from tqdm import tqdm


def gram(input):
    b, c, w, h = input.size()
    x = input.view(b * c, w * h)
    G = torch.mm(x, x.T)
    return G.div(b * c * w * h)
def rgb_to_yuv(image,x):
    image = (image + 1.0) / 2.0
    yuv_img = torch.tensordot(
        image,
        x,
        dims=([image.ndim - 3], [0]))
    return yuv_img


def divisible(dim):
    width, height = dim
    return width - (width % 32), height - (height % 32)


def resize_image(image, width=None, height=None, inter=cv2.INTER_AREA):
    dim = None
    h, w = image.shape[:2]

    if width and height:
        return cv2.resize(image, divisible((width, height)),  interpolation=inter)

    if width is None and height is None:
        return cv2.resize(image, divisible((w, h)),  interpolation=inter)

    if width is None:
        r = height / float(h)
        dim = (int(w * r), height)

    else:
        r = width / float(w)
        dim = (width, int(h * r))

    return cv2.resize(image, divisible(dim), interpolation=inter)

import time as t
import os
import random
import numpy as np
import torch
from torch.autograd import Variable
#初始化
def initialize_weights(net):
    for m in net.modules():
        try:
            if isinstance(m, nn.Conv2d):
                m.weight.data.normal_(0, 0.02)
                m.bias.data.zero_()
            elif isinstance(m, nn.ConvTranspose2d):
                m.weight.data.normal_(0, 0.02)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.02)
                m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
        except Exception as e:
            # print(f'SKip layer {m}, {e}')
            pass

#时间转化
def time_change(time):
    new_time = t.localtime(time)
    new_time = t.strftime("%Hh%Mm%Ss", new_time)
    return new_time
#归一化
def  denorm(x):
    x=(x* 0.5+ 0.5)*255.0
    return x.cpu().detach().numpy().transpose(1,2,0)
def RGB2BGR(x):
    return cv2.cvtColor(x, cv2.COLOR_RGB2BGR)
#创建文件目录
def check_folder(log_dir):
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    return log_dir
#图片池

class ImagePool():
    def __init__(self, pool_size):
        self.pool_size = pool_size
        if self.pool_size > 0:
            self.num_imgs = 0
            self.images = []

    def query(self, images):
        if self.pool_size == 0:
            return images
        return_images = []
        for image in images.data:
            image = torch.unsqueeze(image, 0)
            if self.num_imgs < self.pool_size:
                self.num_imgs = self.num_imgs + 1
                self.images.append(image)
                return_images.append(image)
            else:
                p = random.uniform(0, 1)
                if p > 0.5:
                    random_id = random.randint(0, self.pool_size-1)
                    tmp = self.images[random_id].clone()
                    self.images[random_id] = image
                    return_images.append(tmp)
                else:
                    return_images.append(image)
        return_images = Variable(torch.cat(return_images, 0))
        return return_images
def compute_data_mean(data_folder):
    if not os.path.exists(data_folder):
        raise FileNotFoundError(f'Folder {data_folder} does not exits')

    image_files = os.listdir(data_folder)
    total = np.zeros(3)

    print(f"Compute mean (R, G, B) from {len(image_files)} images")

    for img_file in tqdm(image_files):
        path = os.path.join(data_folder, img_file)
        image = cv2.imread(path)
        total += image.mean(axis=(0, 1))

    channel_mean = total / len(image_files)
    mean = np.mean(channel_mean)

    return mean - channel_mean[...,::-1]  # Convert to BGR for training

In [None]:

import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
import torch.nn as nn
import torch.nn.functional as F

class DownConv(nn.Module):

    def __init__(self, channels, bias=False):
        super(DownConv, self).__init__()

        self.conv1 = SeparableConv2D(channels, channels, stride=2, bias=bias)
        self.conv2 = SeparableConv2D(channels, channels, stride=1, bias=bias)

    def forward(self, x):
        out1 = self.conv1(x)
        out2 = F.interpolate(x, scale_factor=0.5, mode='bilinear')
        out2 = self.conv2(out2)

        return out1 + out2


class UpConv(nn.Module):
    def __init__(self, channels, bias=False):
        super(UpConv, self).__init__()

        self.conv = SeparableConv2D(channels, channels, stride=1, bias=bias)

    def forward(self, x):
        out = F.interpolate(x, scale_factor=2.0, mode='bilinear')
        out = self.conv(out)

        return out


class SeparableConv2D(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, bias=False):
        super(SeparableConv2D, self).__init__()
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=3,
            stride=stride, padding=1, groups=in_channels, bias=bias)
        self.pointwise = nn.Conv2d(in_channels, out_channels,
            kernel_size=1, stride=1, bias=bias)
        # self.pad =
        self.ins_norm1 = nn.InstanceNorm2d(in_channels)
        self.activation1 = nn.LeakyReLU(0.2, True)
        self.ins_norm2 = nn.InstanceNorm2d(out_channels)
        self.activation2 = nn.LeakyReLU(0.2, True)

        initialize_weights(self)

    def forward(self, x):
        out = self.depthwise(x)
        out = self.ins_norm1(out)
        out = self.activation1(out)

        out = self.pointwise(out)
        out = self.ins_norm2(out)

        return self.activation2(out)


class ConvBlock(nn.Module):
    def __init__(self, channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False):
        super(ConvBlock, self).__init__()

        self.conv = nn.Conv2d(channels, out_channels,
            kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
        self.ins_norm = nn.InstanceNorm2d(out_channels)
        self.activation = nn.LeakyReLU(0.2, True)

        initialize_weights(self)

    def forward(self, x):
        out = self.conv(x)
        out = self.ins_norm(out)
        out = self.activation(out)

        return out


class InvertedResBlock(nn.Module):
    def __init__(self, channels=256, out_channels=256, expand_ratio=2, bias=False):
        super(InvertedResBlock, self).__init__()
        bottleneck_dim = round(expand_ratio * channels)
        self.conv_block = ConvBlock(channels, bottleneck_dim, kernel_size=1, stride=1, padding=0, bias=bias)
        self.depthwise_conv = nn.Conv2d(bottleneck_dim, bottleneck_dim,
            kernel_size=3, groups=bottleneck_dim, stride=1, padding=1, bias=bias)
        self.conv = nn.Conv2d(bottleneck_dim, out_channels,
            kernel_size=1, stride=1, bias=bias)

        self.ins_norm1 = nn.InstanceNorm2d(out_channels)
        self.ins_norm2 = nn.InstanceNorm2d(out_channels)
        self.activation = nn.LeakyReLU(0.2, True)

        initialize_weights(self)

    def forward(self, x):
        out = self.conv_block(x)
        out = self.depthwise_conv(out)
        out = self.ins_norm1(out)
        out = self.activation(out)
        out = self.conv(out)
        out = self.ins_norm2(out)

        return out + x

class Generator(nn.Module):
    def __init__(self, dataset=''):
        super(Generator, self).__init__()
        self.name = f'generator_{dataset}'
        bias = False

        self.encode_blocks = nn.Sequential(
            ConvBlock(3, 64, bias=bias),
            ConvBlock(64, 128, bias=bias),
            DownConv(128, bias=bias),
            ConvBlock(128, 128, bias=bias),
            SeparableConv2D(128, 256, bias=bias),
            DownConv(256, bias=bias),
            ConvBlock(256, 256, bias=bias),
        )

        self.res_blocks = nn.Sequential(
            InvertedResBlock(256, 256, bias=bias),
            InvertedResBlock(256, 256, bias=bias),
            InvertedResBlock(256, 256, bias=bias),
            InvertedResBlock(256, 256, bias=bias),
            InvertedResBlock(256, 256, bias=bias),
            InvertedResBlock(256, 256, bias=bias),
            InvertedResBlock(256, 256, bias=bias),
            InvertedResBlock(256, 256, bias=bias),
        )

        self.decode_blocks = nn.Sequential(
            ConvBlock(256, 128, bias=bias),
            UpConv(128, bias=bias),
            SeparableConv2D(128, 128, bias=bias),
            ConvBlock(128, 128, bias=bias),
            UpConv(128, bias=bias),
            ConvBlock(128, 64, bias=bias),
            ConvBlock(64, 64, bias=bias),
            nn.Conv2d(64, 3, kernel_size=1, stride=1, padding=0, bias=bias),
            nn.Tanh(),
        )

        initialize_weights(self)

    def forward(self, x):
        out = self.encode_blocks(x)
        out = self.res_blocks(out)
        img = self.decode_blocks(out)

        return img


class Discriminator(nn.Module):
    def __init__(self,  args):
        super(Discriminator, self).__init__()
        self.name = f'discriminator_{args.dataset}'
        self.bias = False
        channels = 32

        layers = [
            nn.Conv2d(3, channels, kernel_size=3, stride=1, padding=1, bias=self.bias),
            nn.LeakyReLU(0.2, True)
        ]

        for i in range(args.d_layers):
            layers += [
                nn.Conv2d(channels, channels * 2, kernel_size=3, stride=2, padding=1, bias=self.bias),
                nn.LeakyReLU(0.2, True),
                nn.Conv2d(channels * 2, channels * 4, kernel_size=3, stride=1, padding=1, bias=self.bias),
                nn.InstanceNorm2d(channels * 4),
                nn.LeakyReLU(0.2, True),
            ]
            channels *= 4

        layers += [
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=self.bias),
            nn.InstanceNorm2d(channels),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(channels, 1, kernel_size=3, stride=1, padding=1, bias=self.bias),
        ]

        if args.use_sn:
            for i in range(len(layers)):
                if isinstance(layers[i], nn.Conv2d):
                    layers[i] = spectral_norm(layers[i])

        self.discriminate = nn.Sequential(*layers)

        initialize_weights(self)

    def forward(self, img):
        return self.discriminate(img)

In [None]:
from numpy.lib.arraysetops import isin
import torchvision.models as models
import torch.nn as nn
import torch
import torch.nn as nn
import torch
from PIL import Image
import numpy as np


#VGG19
class VGG19(nn.Module):
    def __init__(self,vgg_mean,vgg_std):
        super(VGG19, self).__init__()
        self.vgg19 = self.get_vgg19().eval()
        self.mean = vgg_mean.view(-1, 1 ,1)
        self.std = vgg_std.view(-1, 1, 1)

    def forward(self, x):
        return self.vgg19(self.normalize_vgg(x))


    @staticmethod
    def get_vgg19(last_layer='conv4_4'):
        vgg = models.vgg19(pretrained=torch.cuda.is_available()).features
        model_list = []

        i = 0
        j = 1
        for layer in vgg.children():
            if isinstance(layer, nn.MaxPool2d):
                i = 0
                j += 1

            elif isinstance(layer, nn.Conv2d):
                i += 1

            name = f'conv{j}_{i}'

            if name == last_layer:
                model_list.append(layer)
                break

            model_list.append(layer)


        model = nn.Sequential(*model_list)
        return model
    def normalize_vgg(self, image):
        image = (image + 1.0) / 2.0
        return (image - self.mean) / self.std

In [None]:
import torch

from torch import nn
from torch import optim
import itertools
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.nn.functional import interpolate
from torchvision.transforms import InterpolationMode
import cv2
import numpy as np


import time as t
from tqdm import tqdm

class AnimeGANs(object):
    def __init__(self,args):
        #定义配置
        self.cpu_count=args.cpu_count
        self.test_dir=args.test_dir
        self.istrain=args.istrain
        self.istest=args.istest
        self.init_train=args.init_train
        self.retrain=args.retrain
        self.epochs=args.epochs
        self.init_epochs=args.init_epochs
        self.data_dir=args.data_dir
        self.dataset=args.dataset
        self.result_dir=args.result_dir
        self.save_interval=args.save_interval
        self.checkpoint_dir=args.checkpoint_dir
        self.save_image_dir=args.save_image_dir
        self.G_lr=args.lr_g
        self.D_lr = args.lr_d
        self.decay_g=args.decay_g
        self.decay_d=args.decay_d
        self.init_lr=args.init_lr
        self.batch_size=args.batch_size
        self.d_noise=args.d_noise
        self.device=args.device
        self.vgg_mean = torch.tensor([0.485, 0.456, 0.406]).float().to(self.device)
        self.vgg_std = torch.tensor([0.229, 0.224, 0.225]).float().to(self.device)
        
        #定义模型
        self.G=Generator(dataset=args.dataset).to(self.device)
        self.D=Discriminator(args).to(self.device)
        self.vgg19=VGG19(self.vgg_mean,self.vgg_std).to(self.device)
#         self.vgg19.load_state_dict(torch.load('/kaggle/input/prevgg19/vgg19.pth'))
        #定义优化器
        if self.init_train:
            self.G_optim = optim.Adam(self.G.parameters(), lr=self.init_lr, betas=(0.5, 0.999))
        else:
            self.G_optim=optim.Adam(self.G.parameters(),lr=self.G_lr,betas=(0.5, 0.999))
        self.D_optim=optim.Adam(self.D.parameters(), lr=self.D_lr, betas=(0.5, 0.999))
        #定义损失函数
        self.huber = nn.SmoothL1Loss().to(self.device)
        self.content_loss = nn.L1Loss().to(self.device)
        self.gram_loss = nn.L1Loss().to(self.device)
        self.color_loss = nn.L1Loss().to(self.device)
        self.gan_loss = args.gan_loss
        self.wadvg = args.wadvg
        self.wadvd = args.wadvd
        self.wcon = args.wcon
        self.wgra = args.wgra
        self.wcol = args.wcol
        self.adv_type = args.gan_loss
        #noise
        self.gaussian_mean=torch.tensor(0.0)
        self.gaussian_std=torch.tensor(0.1)
        self._rgb_to_yuv_kernel = torch.tensor([
    [0.299, -0.14714119, 0.61497538],
    [0.587, -0.28886916, -0.51496512],
    [0.114, 0.43601035, -0.10001026]
]).float().to(self.device)
    #高斯噪声
    def gaussian_noise(self):
        return torch.normal(self.gaussian_mean, self.gaussian_std)
    #数据加载。。。。
    def load_data(self):
        trans = [transforms.Resize(286, InterpolationMode.BICUBIC),
                     transforms.CenterCrop(256),
                     transforms.RandomHorizontalFlip(0.5),
                     transforms.ToTensor(),
                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
        trans_gray = [transforms.Resize(286, InterpolationMode.BICUBIC),
                 transforms.CenterCrop(256),
                 transforms.RandomHorizontalFlip(0.5),
                 transforms.ToTensor(),
                 transforms.Normalize((0.5), (0.5))]
        data_loader = DataLoader(AnimeDataSet(root=self.data_dir,mode=self.dataset,trans=trans,trans_gray=trans_gray),
                                  batch_size=self.batch_size, shuffle=True,num_workers=self.cpu_count,
                                  pin_memory=True,drop_last=True)
        return data_loader
    #输入数据
    def set_inputs(self,input):
        self.source_img=input['source'].to(self.device)
        self.style_img = input['style'].to(self.device)
        self.stg_img = input['style_gray'].to(self.device)
        self.smg_img = input['smooth_gray'].to(self.device)
    #训练
    def train(self):
        print("training on 4090")
        print('=================train start!=========================')
        data_loader=self.load_data()#读取图像
        start_time = t.time()
        for step in tqdm(range(1,41)):
            if step>(self.epochs//2):
                    self.G_optim.param_groups[0]['lr']-=self.G_lr/(self.epochs//2)
                    self.D_optim.param_groups[0]['lr']-=self.D_lr/(self.epochs//2)
            for i,data in enumerate(data_loader):
                count=len(data_loader)
                self.G.train()
                self.set_inputs(data)
                # train G with  content loss only
                if  not self.init_train:
                    #========================D=================
                    self.D.train()
                    self.D_optim.zero_grad()
                    fake_img=self.G(self.source_img).detach()
                    if self.d_noise:
                        fake_img += self.gaussian_noise()
                        self.style_img += self.gaussian_noise()
                        self.stg_img += self.gaussian_noise()
                        self.smg_img += self.gaussian_noise()
                    fake_d = self.D(fake_img)
                    real_anime_d = self.D(self.style_img)
                    real_anime_gray_d = self.D(self.stg_img)
                    real_anime_smg_gray_d = self.D(self.smg_img)
                    #loss
                    real_anime_d=torch.mean(torch.square(real_anime_d-1.0))
                    fake_d=torch.mean(torch.square(fake_d))#lsgan
                    real_anime_gray_d = torch.mean(torch.square(real_anime_gray_d))
                    real_anime_smg_gray_d= torch.mean(torch.square(real_anime_smg_gray_d))
                    loss_d=self.wadvd*(real_anime_d+fake_d+real_anime_gray_d+0.2*real_anime_smg_gray_d)
                    loss_d.backward()
                    self.D_optim.step()
                    #=========================train g=================
                    self.G_optim.zero_grad()
                    fake_img=self.G(self.source_img)
                    fake_d=self.D(fake_img)
                    fake_teture=self.vgg19(fake_img)
                    real_teture=self.vgg19(self.source_img)
                    styg_teture=self.vgg19(self.stg_img)
                    adv_loss=torch.mean(torch.square(fake_d - 1.0))
                    con_loss=self.content_loss(fake_teture,real_teture)
                    gram_loss=self.gram_loss(gram(styg_teture),gram(fake_teture))
                    source_yuv=rgb_to_yuv(self.source_img,self._rgb_to_yuv_kernel )
                    fake_yuv=rgb_to_yuv(fake_img,self._rgb_to_yuv_kernel )
                    col_loss=(self.color_loss(source_yuv[:,:,:,0],fake_yuv[:,:,:,0])+self.huber(source_yuv[:,:,:,1],fake_yuv[:,:,:,1])+\
                             self.huber(source_yuv[:,:,:,2],fake_yuv[:,:,:,2]))
                    loss_G=adv_loss*self.wadvg+con_loss*self.wcon+gram_loss*self.wgra+col_loss*self.wcol
                    loss_G.backward()
                    self.G_optim.step()
                else:
                    self.G_optim.zero_grad()
                    fake_img = self.G(self.source_img)
                    real_con = self.vgg19(self.source_img)
                    fake_con = self.vgg19(fake_img)
                    loss_con = self.content_loss(fake_con, real_con)
                    loss_con.backward()
                    self.G_optim.step()
                t_end=t.time()
                if self.init_train:
                    print(
                        f"epoch:[{step}/{self.epochs}],iter:[{i+1}/{count}],loss_G:{loss_con},G_lr:{self.G_optim.param_groups[0]['lr']},time:{time_change(t_end -start_time)}")
                else:
                    print(
                        f"epoch[{step}/{self.epochs}],iter[{i+1}/{count}],loss_G:{loss_G},loss_D:{loss_d},G_lr:{self.G_optim.param_groups[0]['lr']},D_lr:{self.D_optim.param_groups[0]['lr']},time:{time_change(t_end - start_time)}")
            if step%self.save_interval==0:
                train_sample_num = 5
                style= np.zeros((256 * 3, 0, 3))
                self.G.eval(), self.D.eval()
                data_loader=self.load_data()
                for _ in range(train_sample_num):
                    for i, data in tqdm(enumerate(data_loader)):
                        break
                    real_img = data['source'].to(self.device)
                    style_img = data['style'].to(self.device)
                    fake_img=self.G(real_img)  # 生成假图
                    style = np.concatenate((style, np.concatenate((RGB2BGR(denorm(real_img[0])),
                                                                     RGB2BGR(denorm(style_img[0])),
                                                                     RGB2BGR(denorm(fake_img[0]))), 0)), 1)
                cv2.imwrite(os.path.join(self.result_dir, self.dataset, 'img', 'anime_%06d.png' % step), style)
                print("测试图像生成成功！")
                self.save_model()
                # 保存模型

    def save_model(self):
        params = {}
        params["G"] = self.G.state_dict()
        params["D"] = self.D.state_dict()
        torch.save(params, os.path.join(self.result_dir, self.dataset,self.checkpoint_dir,f'checkpoints_{self.dataset}.pth'))
        print("保存模型成功！")
        # 加载模型

    def load_model(self):
        params = torch.load(os.path.join(self.test_dir, f'checkpoints_{self.dataset}.pth'))
        self.G.load_state_dict(params['G'])
        self.D.load_state_dict(params['D'])
        print("加载模型成功！")
#保存生成图像
    def test(self):
        self.load_model()
        data_loader=self.load_data()
        self.G.eval()
        for i ,data in tqdm(enumerate(data_loader)):
            real_img=data['source'].to(self.device)
            fake_img=self.G(real_img)# 生成假图
            fake_img=RGB2BGR(denorm(fake_img[0]))
            cv2.imwrite(os.path.join(self.result_dir, self.dataset, 'test/img','style_%06d.png' % i), fake_img)
            if i==3000:
                break
        print("测试图像生成成功！")


In [None]:

import argparse
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='Shinkai')
    parser.add_argument('--data_dir', type=str, default='/kaggle/input/dataset/datasets')
    parser.add_argument('--result_dir', type=str, default='results')
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--cpu_count', type=int, default=1)
    parser.add_argument('--init_epochs', type=int, default=5)
    parser.add_argument('--batch_size', type=int, default=4)
    parser.add_argument('--test_dir', type=str, default='/kaggle/input/skinkai50')
    parser.add_argument('--checkpoint_dir', type=str, default='checkpoints')
    parser.add_argument('--save_image_dir', type=str, default='img')
    parser.add_argument('--gan_loss', type=str, default='lsgan', help='lsgan / hinge / bce')
    parser.add_argument('--retrain', type=bool, default=True)
    parser.add_argument('--istrain', type=bool, default=True)
    parser.add_argument('--init_train', type=bool, default=False)
    parser.add_argument('--istest', type=bool, default=False)
    parser.add_argument('--use_sn', type=bool, default=True)
    parser.add_argument('--save_interval', type=int, default=1)
    parser.add_argument('--lr_g', type=float, default=0.00008)
    parser.add_argument('--lr_d', type=float, default=0.00016)
    parser.add_argument('--decay_g', type=float, default=4.800000000000004e-05)
    parser.add_argument('--decay_d', type=float, default=9.600000000000008e-05)
    parser.add_argument('--init_lr', type=float, default=0.0001)
    parser.add_argument('--wadvg', type=float, default=10.0, help='Adversarial loss weight for G')
    parser.add_argument('--wadvd', type=float, default=10.0, help='Adversarial loss weight for D')
    parser.add_argument('--wcon', type=float, default=1.5, help='Content loss weight')
    parser.add_argument('--wgra', type=float, default=3.0, help='Gram loss weight')
    parser.add_argument('--wcol', type=float, default=30.0, help='Color loss weight')
    parser.add_argument('--d_layers', type=int, default=3, help='Discriminator conv layers')
    parser.add_argument('--d_noise', type=bool, default=True)
    parser.add_argument('--device',type=str,default='cuda',choices=['cuda','cpu'])
    return check_args(parser.parse_args(args=[]))

def check_args(args):
    check_folder(os.path.join(args.result_dir, args.dataset, 'checkpoints'))
    check_folder(os.path.join(args.result_dir, args.dataset, 'img'))
    check_folder(os.path.join(args.result_dir, args.dataset, 'test'))
    check_folder(os.path.join(args.result_dir, args.dataset, 'test', 'img'))
    return args


def main():
   args=parse_args()
   gan=AnimeGANs(args)
   if args.istrain:
       if args.retrain:
            gan.load_model()
       print(f"training on {args.device}")
       gan.train()
       print("train haved finished")
   if args.istest:
       gan.test()
       print("test haved finished")
if __name__=="__main__":
    main()






In [None]:
import os
import zipfile
def file2zip(packagePath, zipPath):
    zip = zipfile.ZipFile(zipPath, 'w', zipfile.ZIP_DEFLATED)
    for path, dirNames, fileNames in os.walk(packagePath):
        fpath = path.replace(packagePath, '')
        for name in fileNames:
            fullName = os.path.join(path, name)
            name = fpath + '\\' + name
            zip.write(fullName, name)
    zip.close()
if __name__ == "__main__":
    # 文件夹路径
    packagePath = './results'
    zipPath = './model.zip'
    if os.path.exists(zipPath):
        os.remove(zipPath)
    file2zip(packagePath, zipPath)
    print("打包完成")