In [None]:
import random
import torch
import glob as gl
import os
import random
import numpy as np
from PIL import Image
from tqdm import tqdm
from torchvision import transforms
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision.transforms import InterpolationMode
import cv2
import matplotlib.pyplot as plt
#定义数据集
class ImageDataset(Dataset):
    def __init__(self,root='',trans=None,mode=None):
        super().__init__()
        self.transform = transforms.Compose(trans)

        self.TA_path = os.path.join(root,"trainA/*")
        self.TB_path = os.path.join(root,"trainB/*")
        self.EA_path = os.path.join(root, "edgeA/*")
        self.EB_path = os.path.join(root, "edgeB/*")
        if mode == 'A':
            self.source_path = os.path.join(root, "sourceA/*")
        else:
            self.source_path=os.path.join(root,'sourceB/*')
        self.list_TA = gl.glob(self.TA_path)
        self.list_TB = gl.glob(self.TB_path)
        self.list_EA = gl.glob(self.EA_path)
        self.list_EB = gl.glob(self.EB_path)
        self.list_source = gl.glob(self.source_path)

    def  __getitem__ (self, index):
        data = {}
        imgTA_path = random.choice(self.list_TA)
        imgTB_path = random.choice(self.list_TB)
        imgEA_path = random.choice(self.list_EA)
        imgEB_path = random.choice(self.list_EB)
        img_path = random.choice(self.list_source)
        imgTA = Image.open(imgTA_path).convert('RGB')
        imgTB = Image.open(imgEB_path).convert('RGB')
        imgEA= Image.open(imgEA_path).convert('RGB')
        imgEB = Image.open(imgEB_path).convert('RGB')
        img = Image.open(img_path).convert('RGB')
        img=self.transform(img)
        img_A = self.transform(imgTA)
        img_B = self.transform(imgTB)
        img_C = self.transform(imgEA)
        img_D = self.transform(imgEB)
        data.update({'source':img,'TA':img_A,'TB':img_B,'EA':img_C,'EB':img_D})
        return data
    def __len__(self):
        return max(len(self.list_TA),len(self.list_TB),len(self.list_EA),len(self.list_EB),len(self.list_source))

In [None]:
import time as t
import os
import random
import numpy as np
import torch
from torch.autograd import Variable
#初始化
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
        if hasattr(m.bias, 'data'):
            m.bias.data.fill_(0)
    elif classname.find('BatchNorm2d') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)
#时间转化
def time_change(time):
    new_time = t.localtime(time)
    new_time = t.strftime("%Hh%Mm%Ss", new_time)
    return new_time
#归一化
def  denorm(x):
    x=(x* 0.5+ 0.5)*255.0
    return x.cpu().detach().numpy().transpose(1,2,0)
def RGB2BGR(x):
    return cv2.cvtColor(x, cv2.COLOR_RGB2BGR)
#创建文件目录
def check_folder(log_dir):
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    return log_dir
#图片池

class ImagePool():
    def __init__(self, pool_size):
        self.pool_size = pool_size
        if self.pool_size > 0:
            self.num_imgs = 0
            self.images = []

    def query(self, images):
        if self.pool_size == 0:
            return images
        return_images = []
        for image in images.data:
            image = torch.unsqueeze(image, 0)
            if self.num_imgs < self.pool_size:
                self.num_imgs = self.num_imgs + 1
                self.images.append(image)
                return_images.append(image)
            else:
                p = random.uniform(0, 1)
                if p > 0.5:
                    random_id = random.randint(0, self.pool_size-1)
                    tmp = self.images[random_id].clone()
                    self.images[random_id] = image
                    return_images.append(tmp)
                else:
                    return_images.append(image)
        return_images = Variable(torch.cat(return_images, 0))
        return return_images

In [None]:
import torch
from torch import nn
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from torch.utils.data  import DataLoader
from tqdm import tqdm
import numpy  as np
#定义残差块
class ResBlock(nn.Module):
          def __init__(self,channels,use_bias=False):
                  super().__init__()
                  Res_block=[]
                  Res_block+=[nn.ReflectionPad2d(1),
                              nn.Conv2d(channels,channels,3,1,0,bias=use_bias),
                              nn.InstanceNorm2d(channels),nn.PReLU(num_parameters=1)]
                      
                  Res_block+=[nn.ReflectionPad2d(1),
                              nn.Conv2d(channels,channels,3,1,0,bias=use_bias),
                              nn.InstanceNorm2d(channels)]
                  self.Res_block=nn.Sequential(*Res_block)
          def forward(self,x):
                    return x+self.Res_block(x)
#定义生成器-Encoder
class GenEncoder(nn.Module):
          def __init__(self,hw=64,n_block=None,norm=nn.InstanceNorm2d,use_bias=False):
                  super().__init__()
                  #平面卷积
                  model=[]
                  model+=[nn.ReflectionPad2d(3),
                    nn.Conv2d(3,hw,7,1,0,bias=use_bias)]
                  #下采样
                  down=2
                  for i in range(down):
                      mult=2**i
                      model+=[nn.Conv2d(hw*mult,hw*mult*2,kernel_size=3,stride=2,padding=1,bias=use_bias),
                                  norm(hw*mult*2),nn.ReLU(True)] 
                  #残差块
                  res=hw*4           
                  for j in range(n_block):
                      model+=[ResBlock(res)]
                  self.model=nn.Sequential(*model) 
          def forward(self,x):
                    return self.model(x)   #1,256,64,64        
#定义解码器
class GenDecoder(nn.Module):
          def __init__(self,hw=64,out_channels=3,n_block=None,norm=nn.InstanceNorm2d,use_bias=False):
                  super().__init__()
                  #残差块
                  model=[]
                  res=hw*4
                  for i in range(n_block):
                    model+=[ResBlock(res)]
                  # frist upsample
                  mult=2**(n_block//2)
                  model += [nn.Upsample(scale_factor=2, mode='bilinear'),
                    nn.Conv2d(int(hw * mult), int(hw * mult / 2), kernel_size=3, stride=1, padding=1),
                    nn.Conv2d(int(hw * mult / 2), int(hw * mult / 2), kernel_size=3, stride=1, padding=1),
                              norm(int(hw * mult / 2)),
                              nn.ReLU(True)]
                  # second upsampling
                  model += [nn.Upsample(scale_factor=2, mode='bilinear'),
                    nn.Conv2d(int(hw * mult / 2), int(hw * mult / 4), kernel_size=3, stride=1, padding=1),
                    nn.Conv2d(int(hw * mult / 4), int(hw * mult / 4), kernel_size=3, stride=1, padding=1),
                              norm(int(hw * mult / 4)),
                              nn.ReLU(True)]
                  #addtional layer
                  model+=[nn.Conv2d(int(hw*mult/4),int(hw*mult/8),kernel_size=3,stride=1,padding=1),
                           nn.Conv2d(int(hw*mult/8),int(hw*mult/8),kernel_size=3,stride=1,padding=1),
                           norm(int(hw*mult/8)),nn.ReLU(True)]    
                  model+=[nn.Conv2d(int(hw*mult/8),int(hw*mult/16),kernel_size=3,stride=1,padding=1),
                              nn.Conv2d(int(hw*mult/16),int(hw*mult/16),kernel_size=3,stride=1,padding=1),
                              norm(int(hw*mult/16)),nn.ReLU(True)]  
                  model+=[nn.Conv2d(int(hw*mult/16),out_channels,7,1,3),nn.Tanh()]  
                  self.model=nn.Sequential(*model)
          def forward(self,x):
                    return self.model(x)                  
#定义生成器               
class Generator(nn.Module):
          def __init__(self,n_domian=None,E_block=5,D_block=4):
                  super(Generator,self).__init__()
                  #编码器
                  self.Encoder=[GenEncoder(n_block=E_block)]
                  self.Encoder=nn.Sequential(*self.Encoder)
                  #解码器
                  self.Decoder1=[GenDecoder(n_block=D_block)]
                  self.Decoder2=[GenDecoder(n_block=D_block)]
                  self.Decoder1=nn.Sequential(*self.Decoder1)
                  self.Decoder2=nn.Sequential(*self.Decoder2)
          def encoder(self,x):
                    return self.Encoder(x)  # type: ignore
          def decoders(self,x,n):
                    if n==0:
                       return self.Decoder1(x)
                    else :
                        return self.Decoder2(x)
          def forward(self,x,n):
                encode=self.encoder(x)
                return self.decoders(encode,n)
#patchgan  
class PatchGAN_D(nn.Module):
          def __init__(self,hw=64,out_channels=3,norm=nn.InstanceNorm2d,use_bias=False):
                  super(PatchGAN_D,self).__init__()
                  model=[]
                  #平面卷积
                  model+=[nn.Conv2d(out_channels,hw,kernel_size=3,stride=1,padding=1,bias=True),
                          nn.LeakyReLU(0.2,True)
                  ]
                  #下采样
                  model+=[nn.Conv2d(hw,hw*2,kernel_size=3,stride=2,padding=1,bias=True),
                              nn.LeakyReLU(0.2,True),nn.Conv2d(hw*2,hw*4,kernel_size=3,stride=1,padding=1,bias=True),
                              norm(hw*4),nn.LeakyReLU(0.2,True),
                              nn.Conv2d(hw*4,hw*4,kernel_size=3,stride=2,padding=1,bias=True),
                              nn.LeakyReLU(0.2,True),nn.Conv2d(hw*4,hw*8,kernel_size=3,stride=1,padding=1,bias=True),
                              norm(hw*8),nn.LeakyReLU(0.2,True),
                              nn.Conv2d(hw*8,1,kernel_size=3,stride=1,padding=1),nn.Sigmoid()
                              ]
                  self.model=nn.Sequential(*model)
          def forward(self,x):
                    return self.model(x)


#判别器结构patchGAN
class Discrimintor1(nn.Module):
          def __init__(self):
                  super(Discrimintor1,self).__init__()
                  self.model=[PatchGAN_D()]
                  self.model=nn.Sequential(*self.model)
                 
          def forward(self,x):
                return self.model(x)
          def init_train(self,grad=None):
              for param in self.model.parameters():
                  param.requires_grad = grad
class Discrimintor2(nn.Module):
          def __init__(self):
                  super(Discrimintor2,self).__init__()
                  self.model=[PatchGAN_D()]
                  self.model=nn.Sequential(*self.model)
          def forward(self,x):
                return self.model(x)
          def init_train(self,grad=None):
              for param in self.model.parameters():
                  param.requires_grad = grad                
#aux classfier                  
class aux_classfier(nn.Module):
    def __init__(self,dim=32,num_class=None,use_bias=False):
        super().__init__()
        #平面卷积
        layer=[]
        layer+=[nn.Conv2d(3,dim,kernel_size=4,stride=2,padding=1),
                nn.LeakyReLU(0.01,True)]
        for i in range(4):
            mult=2**i
            layer+=[nn.Conv2d(dim*mult,dim*mult*2,kernel_size=4,stride=2,padding=1),
                    nn.LeakyReLU(0.01,True)]
        # j=int(256/np.power(2,4))
        layer+=[nn.Conv2d(512,num_class,kernel_size=8,bias=use_bias)]
        self.layer=nn.Sequential(*layer)
    def forward(self,x):
        out=self.layer(x)
        return out.view(out.size(0),out.size(1))
    def init_train(self,grad=None):
            for param in self.layer.parameters():
                param.requires_grad = grad

#VGG19
class VGG19(nn.Module):
    def __init__(self,batch_norm=False,num_classes=1000):
        super(VGG19, self).__init__()
        self.cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512,
                    'M']
        self.batch_norm=batch_norm
        self.num_clases = num_classes
        self.features=self.make_layers(self.cfg,self.batch_norm)
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, num_classes),
        )
    def make_layers(self, cfg, batch_norm=False):
        layers = []
        in_channels = 3
        for v in cfg:
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
                if batch_norm:
                    layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
                else:
                    layers += [conv2d, nn.ReLU(inplace=True)]
                in_channels = v
        return nn.Sequential(*layers)
    def forward(self,x):
        module_list = list(self.features.modules())
        for l in module_list[1:27]:  # conv4_4
            x = l(x)
        return x

In [None]:

import torch
from torch import optim
import itertools
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.nn.functional import interpolate
import cv2
import numpy as np
class MSCartoonGAN(object):
    def __init__(self,args):
        #定义配置
        self.device=args.device
        self.dir=args.dir
        self.dataset=args.dataset
        self.data_path=args.data_path
        self.test_path=args.test_path
        self.isTrain=args.isTrain
        self.isTest = args.isTest
        self.train_init=args.train_init
        self.n_domains=args.n_domain
        #定义模型参数
        self.input_c=args.input_c
        self.hw=args.hw
        self.b1=args.b1
        self.b2=args.b2
        self.lr=args.lr
        self.init_lr=args.init_lr
        self.batch_size=args.batch_size
        self.save_pred=args.save_pred
        self.iter=args.iter
        self.weight_content=args.weight_content
        self.weight_classifer=args.weight_classifer
        self.weight_decay=args.weight_decay
        self.mode=args.mode
        #定义模型
        self.G=Generator(args.n_domain).to(self.device)
        self.D1=Discrimintor1().to(self.device)
        self.D2=Discrimintor2().to(self.device)
        self.classifier=aux_classfier(num_class=args.n_class).to(self.device)
        self.vgg19=VGG19().to(self.device)
        #模型初始化
        self.G.apply(weights_init)
        self.D1.apply(weights_init)
        self.D2.apply(weights_init)
        self.classifier.apply(weights_init)
        self.vgg19.load_state_dict(torch.load('/kaggle/input/vgg19-model/vgg19.pth'))
        #优化器策略
        if self.train_init:
           self.G_optim=optim.Adam(self.G.parameters(),lr=self.init_lr,betas=(self.b1,self.b2))
        else:
           self.G_optim=optim.Adam(self.G.parameters(),lr=self.lr,betas=(self.b1,self.b2))
        self.D1_optim=optim.Adam(self.D1.parameters(),lr=self.lr,betas=(self.b1,self.b2))
        self.D2_optim=optim.Adam(self.D2.parameters(),lr=self.lr,betas=(self.b1,self.b2))
        self.classifier_optim=optim.Adam(self.classifier.parameters(),lr=self.lr,betas=(self.b1,self.b2))
        #定义损失函数
        self.lcon=nn.L1Loss().to(self.device)
        self.ladv=nn.BCELoss().to(self.device)
        self.lsty=nn.CrossEntropyLoss().to(self.device)
        self.lambda_con = self.weight_content
        self.lambda_cla = self.weight_classifer
        # 定义标签
        self.real_img = torch.cuda.FloatTensor(self.batch_size, self.input_c, self.hw, self.hw)
        self.real_A = torch.cuda.FloatTensor(self.batch_size, self.input_c, self.hw, self.hw)
        self.real_B = torch.cuda.FloatTensor(self.batch_size, self.input_c, self.hw, self.hw)
        # edge
        self.edge_A = torch.cuda.FloatTensor(self.batch_size, self.input_c, self.hw, self.hw)
        self.edge_B = torch.cuda.FloatTensor(self.batch_size, self.input_c, self.hw, self.hw)
        #图片池
        self.fake_pools = [ImagePool(args.pool_size) for _ in range(self.n_domains)]
    #加载数据集
    def load_data(self,mode=None):
        trans=[transforms.Resize(286,InterpolationMode.BICUBIC),
               transforms.CenterCrop(256),
               transforms.RandomHorizontalFlip(0.5),
               transforms.ToTensor(),
               transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]
        data_loader=DataLoader(ImageDataset(root=self.data_path,trans=trans,mode=self.mode),batch_size=self.batch_size,drop_last=True,shuffle=True)
        return data_loader
    #设置输入  
    def set_input(self, input):  # input is a dictionary recording images
        input_real = input['source'].to(self.device)
        self.real_img.resize_(input_real.size()).copy_(input_real).to(self.device)
        input_A = input['TA'].to(self.device)
        self.real_A.resize_(input_A.size()).copy_(input_A).to(self.device)
        input_B = input['TB'].to(self.device)
        self.real_B.resize_(input_B.size()).copy_(input_B).to(self.device)
        # for images without edges
        edge_EA = input['EA'].to(self.device)
        self.edge_A.resize_(edge_EA.size()).copy_(edge_EA).to(self.device)
        edge_EB = input['EB'].to(self.device)
        self.edge_B.resize_(edge_EB.size()).copy_(edge_EB).to(self.device)
    #生成假图.to(self.device)
    def Gen_pool(self):
        encode=self.G.encoder(self.real_img)
        self.fake_a=self.G.decoders(encode,0).clone()
        self.fake_b=self.G.decoders(encode,1).clone()
    def Gen_pool_test(self,input):
        encode=self.G.encoder(input)
        fake_a=self.G.decoders(encode,0).clone()
        fake_b=self.G.decoders(encode,1).clone()    
        return fake_a,fake_b
    #训练
    def train(self):
        #输入设置
        start_time=t.time()
        data_loader=self.load_data()
        print("----------------------train start!----------------")
        print("------------init paser!-----------------------")
        for step in tqdm(range(1,self.iter+1)):
            self.G.train()
            for i ,data in enumerate(data_loader):
                count=len(data_loader)
                self.set_input(data)
                if not self.train_init:
                    self.D1.train()
                    self.D2.train()
                    self.classifier.train()
                    self.D1.init_train(True)
                    self.D2.init_train(True)
                    self.classifier.init_train(True)
                    #训练判别器
                    self.D1_optim.zero_grad()
                    self.D2_optim.zero_grad()
                    self.classifier_optim.zero_grad()
                    self.Gen_pool()  # 生成假图
                    #对抗损失
                    fake_a=self.fake_pools[0].query(self.fake_a)
                    fake_b=self.fake_pools[1].query(self.fake_b)
                    #real
                    real_a=self.D1(self.real_A)
                    real_a_logit=self.ladv(real_a,torch.ones_like(real_a))
                    real_b = self.D2(self.real_B)
                    real_b_logit = self.ladv(real_b, torch.ones_like(real_b))
                    #edge
                    edge_a = self.D1(self.edge_A)
                    edge_a_logit = self.ladv(edge_a, torch.zeros_like(edge_a))
                    edge_b = self.D2(self.edge_B)
                    edge_b_logit = self.ladv(real_b, torch.zeros_like(edge_b))
                    # fake
                    fake_a = self.D1(fake_a.detach())
                    fake_a_logit = self.ladv(fake_a, torch.zeros_like(fake_a))
                    fake_b = self.D2(fake_b)
                    fake_b_logit = self.ladv(fake_b, torch.zeros_like(fake_b))
                    #class loss
                    class_real_a=self.classifier(self.real_A)
                    label_real_a=torch.cuda.FloatTensor(class_real_a.size()[0]).fill_(0).long()
                    class_a_logit=self.lsty(class_real_a,label_real_a)
                    class_real_b=self.classifier(self.real_B)
                    label_real_b=torch.cuda.FloatTensor(class_real_b.size()[0]).fill_(1).long()
                    class_b_logit = self.lsty(class_real_b,label_real_b)
                    class_edge_a= self.classifier(self.edge_A)
                    label_a=torch.cuda.FloatTensor(class_edge_a.size()[0]).fill_(2).long()
                    class_edge_logitA = self.lsty(class_edge_a,label_a)
                    class_edge_b = self.classifier(self.edge_B)
                    label_b = torch.cuda.FloatTensor(class_edge_b.size()[0]).fill_(2).long()
                    class_edge_logitB = self.lsty(class_edge_b, label_b)
                    #各个风格的loss
                    loss_D1=(real_a_logit+edge_a_logit+fake_a_logit)/3+(class_edge_logitA+class_a_logit)/2*self.lambda_cla
                    loss_D2=(real_b_logit+edge_b_logit+fake_b_logit)/3+(class_edge_logitB+class_b_logit)/2*self.lambda_cla
                    #传播更新
                    loss_D1.backward()
                    loss_D2.backward()
                    self.D1_optim.step()
                    self.D2_optim.step()
                    #训练生成器
                    self.G_optim.zero_grad()
                    self.Gen_pool()#产生假图
                    #adv loss
                    fake_g_a=self.D1(self.fake_a)
                    label_fake_a=torch.cuda.FloatTensor(fake_g_a.size()).fill_(1)
                    fake_a_g_logit=self.ladv(fake_g_a,label_fake_a)
                    fake_g_b = self.D2(self.fake_b)
                    label_fake_b=torch.cuda.FloatTensor(fake_g_b.size()).fill_(1)
                    fake_b_g_logit = self.ladv(fake_g_b,label_fake_b)
                    #class loss
                    class_ga_logit=self.classifier(self.fake_a)
                    label_class_a=torch.cuda.FloatTensor(class_ga_logit.size()[0]).fill_(0).long()
                    class_ga_logit=self.lsty(class_ga_logit,label_class_a)
                    class_gb_logit = self.classifier(self.fake_b)
                    label_class_b=torch.cuda.FloatTensor(class_gb_logit.size()[0]).fill_(1).long()
                    class_gb_logit = self.lsty(class_gb_logit,label_class_b)
                #content loss
                #256*256
                self.D1.init_train(grad=False)
                self.D2.init_train(grad=False)
                self.classifier.init_train(grad=False)
                self.Gen_pool()  # 产生假图
                real_con1=self.vgg19(self.real_img)
                fake_con1_a=self.vgg19(self.fake_a)
                loss_con1_a=self.lcon(fake_con1_a,real_con1.detach())
                fake_con1_b = self.vgg19(self.fake_b)
                loss_con1_b= self.lcon(fake_con1_b,real_con1.detach())
                #128*128
                real_con2=interpolate(self.real_img,scale_factor=0.5,mode='bilinear')
                real_con2=interpolate(real_con2,scale_factor=2,mode='bilinear')
                fake_con2_a=interpolate(self.fake_a,scale_factor=0.5,mode='bilinear')
                fake_con2_a =interpolate(fake_con2_a, scale_factor=2, mode='bilinear')
                loss_con2_a = self.lcon(fake_con2_a,real_con2.detach())
                fake_con2_b = interpolate(self.fake_b, scale_factor=0.5, mode='bilinear')
                fake_con2_b = interpolate(fake_con2_b, scale_factor=2, mode='bilinear')
                loss_con2_b = self.lcon(fake_con2_b, real_con2.detach())
                #64*64
                real_con3 = interpolate(self.real_img, scale_factor=0.25, mode='bilinear')
                real_con3 = interpolate(real_con3, scale_factor=4, mode='bilinear')
                fake_con3_a = interpolate(self.fake_a, scale_factor=0.25, mode='bilinear')
                fake_con3_a = interpolate(fake_con3_a, scale_factor=4, mode='bilinear')
                loss_con3_a = self.lcon(fake_con3_a, real_con3.detach())
                fake_con3_b = interpolate(self.fake_b, scale_factor=0.25, mode='bilinear')
                fake_con3_b = interpolate(fake_con3_b, scale_factor=4, mode='bilinear')
                loss_con3_b = self.lcon(fake_con3_a, real_con3.detach())
                if not self.train_init:
                    loss_G=fake_a_g_logit+fake_b_g_logit+(class_ga_logit+class_gb_logit)/2*self.lambda_cla+\
                               ((loss_con1_a+loss_con1_b)+(loss_con2_a+loss_con2_b)+(loss_con3_a+loss_con3_b))/3*self.lambda_con
                else:
                     loss_G=(loss_con1_a+loss_con1_b+loss_con2_a+loss_con2_b+loss_con3_a+loss_con3_b)/3*self.lambda_con
                loss_G.backward()
                self.G_optim.step()
                end_time=t.time()
                if self.train_init:
                   print(f"epoch:[{step}/{self.iter}],iter:[{i+1}/{count}],loss_G:{loss_G},G_lr:{self.G_optim.param_groups[0]['lr']},time:{time_change(end_time-start_time)}")
                else:
                   print(f"epoch[{step}/{self.iter}],iter[{i+1}/{count}],loss_G:{loss_G},loss_D1:{loss_D1},loss_D2:{loss_D2},time:{time_change(end_time-start_time)}")
            if step%self.save_pred==0:
                train_sample_num=10
                data_loader=self.load_data()
                style1 = np.zeros((self.hw * 3, 0, 3))
                style2 = np.zeros((self.hw * 3, 0, 3))
                self.G.eval(),self.D1.eval(),self.D2.eval(),self.classifier.eval()
                for _ in range(5):
                    for i ,data in tqdm(enumerate(data_loader)):
                        break
                    real_img=data['source'].to(self.device)    
                    real_A=data['TA'].to(self.device) 
                    real_B=data['TB'].to(self.device) 
                    fake_a,fake_b=self.Gen_pool_test(real_img)# 生成假图
                    style1 = np.concatenate((style1, np.concatenate((RGB2BGR(denorm(real_img[0])),
                                                                     RGB2BGR(denorm(real_A[0])),
                                                                     RGB2BGR(denorm(fake_a[0]))),0)),1)
                    style2 = np.concatenate((style2, np.concatenate((RGB2BGR(denorm(real_img[0])),
                                                                     RGB2BGR(denorm(real_B[0])),
                                                                     RGB2BGR(denorm(fake_b[0]))),0)),1)
                cv2.imwrite(os.path.join(self.dir, self.dataset, 'img', 'style1_%06d.png' % step), style1)
                cv2.imwrite(os.path.join(self.dir, self.dataset, 'img', 'style2_%06d.png' % step), style2)
                print("测试图像生成成功！")    
                self.save_model()
    #保存模型
    def save_model(self):
        params={}
        params["G"] = self.G.state_dict()
        params["D1"] = self.D1.state_dict()
        params["D2"] = self.D2.state_dict()
        params["classifer"] = self.classifier.state_dict()
        torch.save(params, os.path.join(self.dir, self.dataset + 'cartoonmodel.pt'))
        print("保存模型成功！")
    #加载模型
    def load_model(self):
        params = torch.load(os.path.join(self.test_path, 'cartooncartoonmodel.pt'))
        self.G.load_state_dict(params['G'])
#         self.D.load_state_dict(params['D'])
        self.classifier.load_state_dict(params['classifer'])
        print("加载模型成功！")
    #保存生成图像
    def test(self):
        self.load_model()
        data_loader=self.load_data()
        self.G.eval()
        for i ,data in tqdm(enumerate(data_loader)):
            real_img=data['source'].to(self.device)    
            fake_a,fake_b=self.Gen_pool_test(real_img)# 生成假图
            fake_a=RGB2BGR(denorm(fake_a[0]))
            fake_b=RGB2BGR(denorm(fake_b[0]))
            cv2.imwrite(os.path.join(self.dir, self.dataset, 'test/testA','style1_%06d.png' % i), fake_a)
            cv2.imwrite(os.path.join(self.dir, self.dataset, 'test/testB','style2_%06d.png' % i), fake_b)
            if i==100:
                break
        print("测试图像生成成功！")    

In [None]:
import argparse
def parse_args():
    desc='pytorch of MS-CartoonGAN'
    parser=argparse.ArgumentParser(desc)
    parser.add_argument('--device',type=str,default='cuda',choices=['cuda','cpu'])
    parser.add_argument('--input_c', type=int, default=3)
    parser.add_argument('--n_domain', type=int, default=2)
    parser.add_argument('--n_class', type=int, default=3)
    parser.add_argument('--pool_size', type=int, default=50)
    parser.add_argument('--init_lr',type=float,default=0.0002)
    parser.add_argument('--lr', type=float, default=0.00001, help='learning rate for ADAM')
    parser.add_argument('--hw', type=int, default=256)
    parser.add_argument('--mode', type=str, default='A')
    parser.add_argument('--dir',type=str,default='result')
    parser.add_argument('--dataset', type=str, default='cartoon')
    parser.add_argument('--data_path',type=str,default='/kaggle/input/mscartoon/Cartoon/cartoon')
    parser.add_argument('--test_path',type=str,default='/kaggle/input/model20d')
    parser.add_argument('--isTrain',type=bool,default=True)
    parser.add_argument('--retrain', type=bool, default=True)
    parser.add_argument('--isTest', type=bool, default=False)
    parser.add_argument('--train_init', type=bool, default=False)
    parser.add_argument('--b1',type=int,default=0.5)
    parser.add_argument('--b2', type=int, default=0.999)
    parser.add_argument('--batch_size',type=int,default=4)
    parser.add_argument('--save_pred',type=int,default=1)
    parser.add_argument('--iter',type=int,default=25)
    parser.add_argument('--weight_content',type=float,default=0.2)
    parser.add_argument('--weight_decay', type=float, default=0.0001)
    parser.add_argument('--weight_classifer', type=float, default=0.5)
    return check_args(parser.parse_args(args=[]))
def check_args(args):
    check_folder(os.path.join(args.dir, args.dataset, 'model'))
    check_folder(os.path.join(args.dir, args.dataset, 'img'))
    check_folder(os.path.join(args.dir, args.dataset, 'test'))
    check_folder(os.path.join(args.dir, args.dataset, 'test','testA'))
    check_folder(os.path.join(args.dir, args.dataset, 'test','testB'))
    return args
def main():
   args=parse_args()
   gan=MSCartoonGAN(args)
   if args.isTrain:
       if args.retrain:
            gan.load_model()
       print(f"training on {args.device}")
       gan.train()
       print("train haved finished")
   if args.isTest:
       gan.test()
       print("test haved finished")
if __name__=="__main__":
    main()

In [None]:
import os
import zipfile
def file2zip(packagePath, zipPath):
    zip = zipfile.ZipFile(zipPath, 'w', zipfile.ZIP_DEFLATED)
    for path, dirNames, fileNames in os.walk(packagePath):
        fpath = path.replace(packagePath, '')
        for name in fileNames:
            fullName = os.path.join(path, name)
            name = fpath + '\\' + name
            zip.write(fullName, name)
    zip.close()
if __name__ == "__main__":
    # 文件夹路径
    packagePath = './result/cartoon/test/testA'
    zipPath = './model40AB.zip'
    if os.path.exists(zipPath):
        os.remove(zipPath)
    file2zip(packagePath, zipPath)
    print("打包完成")