In [1]:
from Libs.myFontLib import *
from Libs.myFontData import *
from Libs.myTrain import *
from EfficientNet.model import EfficientNetEncoder
from EfficientNet.utils import *
from Libs.myLoss import *
%matplotlib inline
fontsInfoFile = "fontsInfo.pkl"
import torch
import gc
import pickle
import time
import os
import torchvision.transforms as transforms
import torch.nn.functional as F
import shutil
from torchinfo import summary
from StyleGAN.network import *
from Libs.mypSp import *
from torch.utils.tensorboard import SummaryWriter

In [9]:
forCharaTrain = True
forStyleTrain = False
modelLevel = 3

batchSize = 16 
workers = 2

d_dropout_limit = 0.75
d_dropout = 0.925

useKanji = False

In [10]:

compatibleDict = []
with open("checker.pkl", "br") as f:
    compatibleDict = pickle.load(f)
fixedDataset = []
with open("fixedDataset.pkl", "br") as f:
    fixedDataset = pickle.load(f)
styleDict = []
with open("styleChecker.pkl", "br") as f:
    styleDict = pickle.load(f)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("使用デバイス：", device)

使用デバイス： cuda:0


In [11]:
trainDataset = FontGeneratorDataset(FontTools(useKanji=useKanji), compatibleDict, [3, 3], styleDict, useTensor=True, startInd=10, augmentationP = [0.3, 0.3, 0], originalAugmentationP = [0.02, 0.05, 0.02, 0.04, 0.02, 0.05])
validDataset = FontGeneratorDataset(FontTools(useKanji = useKanji), compatibleDict, [5, 5], styleDict, useTensor=True, startInd=0,\
      indN=10, isForValid=fixedDataset)

trainDataLoader = torch.utils.data.dataloader.DataLoader(trainDataset,\
     batch_sampler=MyPSPBatchSampler(batchSize, trainDataset, japaneseRate=0.7), num_workers=workers, pin_memory=True)
validDataLoader = torch.utils.data.dataloader.DataLoader(validDataset, batch_size=batchSize, num_workers=workers, pin_memory=True)
charaList = []
with open("Libs/difficult_list2.txt", "r", encoding="utf-8") as f:
    line = f.readline()
    while line:
        charaList.append(line.strip())
        line = f.readline()
charaTrainDataset = MyPSPCharaDataset(charaList)
charaDataLoader = torch.utils.data.dataloader.DataLoader(charaTrainDataset, batch_size = batchSize, shuffle = True, num_workers=workers, pin_memory=True)
myPSP = MyPSP(ver=4, dropout_p=0.0, useBNform2s=True, useBin = True)
myPSP.chara_encoder.init_original_layer()
myPSP.style_encoder.init_original_layer()
myPSP.set_for_chara_training(False)
gen_settings = get_setting_json()
discriminator = Discriminator4(dropout_p=d_dropout)
charaDiscriminator = CharaDiscriminator(ver = 4)
styleDiscriminator = StyleDiscriminator()

Loaded pretrained weights for efficientnet-b0
Loaded pretrained weights for efficientnet-b0
Loaded pretrained weights for efficientnet-b3


In [18]:
def trainModel(myPSP, D, charaDis, styleDis, dataLoaders, epochN, writer: SummaryWriter, forCharaTraining = False, forStyleTraining = False,
     inheritOnlyModel = False,  checkpointFile = "out.cpt", checkpointFormat = "cpts/output{}.cpt", useFakeBackLog = False,
      lookIntermidiate = False, charaDisCheckpointFile = "", dCheck = "", nowDropout = 0.0, changeDropout = False, checkGradNow = False):
    trainCharaAndCharaDis = False
    trainCharaDis = forCharaTraining
    emergencySave = False # バランスが乱れた際に緊急セーブをしたか
    forUnderTraining = forCharaTraining or forStyleTraining
    torch.cuda.empty_cache()
    gc.collect()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("使用デバイス：", device)

    dropoutChangeCount = 0
    firstTeacherSize = 4 # 最初に読み込むテンソルが大きいとエラー落ちするため，制限

    useWSGradient = True
    useDforG = True
    noiseP = 0.0

    trainRate = 5
    trainRateC = 0 # 0


    SquareLossFactor = 10
    fakeRawFactor = 0.002
    charaDisFactor = 1000 # 100
    styleLossFactor = 5
    DforGFactor = 40
    d_optimFun = torch.optim.AdamW
    torch.backends.cudnn.benchmark = True
    useFakeBackLog = (not forUnderTraining) and useFakeBackLog

    modelsList = [myPSP, D, styleDis, charaDis]
    optimizer_d_lr = 3e-5
    optimizersList = getOptimizers(modelsList, forCharaTraining, trainCharaDis, d_optimFun, optimizer_d_lr)
    optimizer, optimizer_d, optimizer_styleDis, optimizer_charaDis = optimizersList

    charaDisLoss = torch.nn.MSELoss()
    myPSP.to(device)

    myPSP.train()
    if(trainCharaDis or not forUnderTraining):
        charaDis.to(device)
        charaDis.train()
    styleDis.to(device)
    styleDis.train()
    if(not forUnderTraining):
        D.to(device)
        D.train()


    start = loadCheckpoints(checkpointFile, modelsList, optimizersList, \
        dCheck, charaDisCheckpointFile,\
        inheritOnlyModel, forUnderTraining, trainCharaDis)
    
    genIntermidiateList = getGenIntermidiateLayers(myPSP)
    disIntermidiateList = None
    if not forUnderTraining: 
        disIntermidiateList = getDisIntermidiatelayers(D)
    train_d_correct = 0

    # epochのループ
    for epoch in range(start, epochN):
        trainD = True
        epochStartTime = time.time()
        epochGLoss = 0
        epochDLoss = 0
        print('-------------')
        print('Epoch {}/{}'.format(epoch, epochN))

        usingCharaDataLoader = False

        for phase in ["train", "val"]:
            epochCharaLoss = 0
            GLossDict = initGLossDict()
            epochGLoss = 0
            epochDLoss = 0
            epochtrainGDAcc = 0
            epochtrainGn = 0
            epochDLossList = np.zeros(3)
            dataLoader = None
            if phase== "train":
                if epoch == 0:
                    continue
                myPSP.train()
                # D.train()
                dataLoader = dataLoaders[0]
                if(forCharaTraining and random.random() > 0.2 and not trainCharaAndCharaDis):
                    dataLoader = dataLoaders[2]
                    usingCharaDataLoader = True
                else:
                    usingCharaDataLoader = False
            else:
                myPSP.eval()
                # D.eval()
                dataLoader = dataLoaders[1]
                usingCharaDataLoader = False
            print('---------')
            print("({})".format(phase))

            # Discriminatorの正解率
            discriminator_problems_n = discriminator_problems_n_b = 0 # 入力された回数
            discriminator_correct_n =  discriminator_correct_n_b =  0# そのうちの正解数
            TCorrectN = 0
            if(epoch % 1 == 0 and useFakeBackLog and phase == "train"):
                # FakesBackLogでDiscriminatorを再訓練
                epochDLoss, fakesBackLog, nowFakesBackPath, scores = trainWithBackLog(phase, device, D, optimizer_d, noiseP, useWSGradient)
                discriminator_problems_n_b, discriminator_correct_n_b = scores
            gc.collect()
            torch.cuda.empty_cache()
            
            iteration = 0
            d_iteration = 0
            for data in dataLoader:
                minibatch_size = data[0][0].size()[0]
                beforeCharacter = None
                if usingCharaDataLoader:
                    beforeCharacter = data
                else: 
                    beforeCharacter =  data[0][0]
                alpha = torch.ones((1, 1))
                beforeCharacter =  beforeCharacter.to(device, torch.float32, non_blocking=True)
                alpha = alpha.to(device, torch.float32, non_blocking=True)
                afterCharacter = beforeCharacter
                teachers = None
                styleLabel  = None
                # label_real = (torch.ones((minibatch_size, )) + 0.6 * (torch.rand((minibatch_size, )) - 0.5)).to(device)
                # label_fake = (torch.zeros((minibatch_size, )) + 0.3 * torch.rand((minibatch_size, )) ).to(device)
                if not forCharaTraining or trainCharaAndCharaDis:
                    afterCharacter = data[0][1]
                    teachers = data[1][:, :, 1] # teachers = data[1][:, :, 1]
                    styleLabel = data[2]
                    if(iteration <= 2):
                        teachers = teachers[:, :firstTeacherSize] #最初に読み込むサイズを制限
                    afterCharacter =  afterCharacter.to(device, torch.float32, non_blocking=True)
                    
                    teachers =  teachers.to(device, torch.float32, non_blocking=True)
                    styleLabel = styleLabel.to(device, torch.float32, non_blocking=True)

                with torch.set_grad_enabled(phase == "train"):
                    # Generator Loss

                    # 中間層出力用の関数
                    Dhandles, Ghandles =  getIntermidiateHandlers(genIntermidiateList, disIntermidiateList, writer, iteration, epoch, \
                            lookIntermidiate, checkGradNow,  forUnderTraining)
                    
                    factors = [SquareLossFactor, fakeRawFactor, styleLossFactor, charaDisFactor]
                    iterGLoss, featureT, fakes = forwardG(myPSP, styleDis, charaDis, charaDisLoss, beforeCharacter, teachers, afterCharacter,\
                        alpha, styleLabel, GLossDict, factors, \
                        forCharaTraining, forStyleTraining)
                    torch.cuda.empty_cache()
                    gc.collect()
                    
                    if(lookIntermidiate):
                        for handle in Ghandles:
                            handle.remove()

                    if(not forUnderTraining and useDforG):
                        fakes = transforms.Normalize(FontGeneratorDataset.IMAGE_MEAN, FontGeneratorDataset.IMAGE_VAR)(fakes)
                        beforeCharacterN, fakesN, teachersN = MyPSPAugmentation.getNoisedImages([beforeCharacter, fakes, teachers], noiseP,device)
                        d_fake = D(fakesN, teachersN, alpha)
                        if(lookIntermidiate):
                            for handle in Dhandles:
                                handle.remove()
                        iterGLoss += DforGFactor * g_wgan_loss(d_fake)
                        del beforeCharacterN, d_fake, fakesN, teachersN
                    epochGLoss += iterGLoss.item()
                    
                    if phase == "train":
                        iterGLoss.backward()
                        del iterGLoss
                        if(iteration == 0 and (epoch % 10 == 0 or checkGradNow)):
                            writeGeneratorGradients(myPSP, writer, epoch, styleDis)
                        optimizer.step()
                        optimizer_styleDis.step()
                        optimizer.zero_grad()
                        optimizer_styleDis.zero_grad()
                        charaDis.zero_grad()
                        D.zero_grad()

                    if(not forUnderTraining):
                        fakes = fakes.detach() # Disctriminatorで使う
                    if iteration <= 2 and not forStyleTraining:
                        for i in range(2):
                            image = None
                            if forCharaTraining:
                                image = torch.stack([beforeCharacter[i*3], fakes[i*3], beforeCharacter[i*3+1], fakes[i*3+1], 
                                                beforeCharacter[i*3+2], fakes[i*3+2]])
                            else:
                                image = torch.stack([beforeCharacter[i],\
                                    afterCharacter[i], 1-(fakes * FontGeneratorDataset.IMAGE_VAR + FontGeneratorDataset.IMAGE_MEAN )[i],
                                    teachers[i, 0, :], teachers[i, 1, :]])
                            writer.add_images("{}/{}".format(phase, i+iteration*2), image, global_step=epoch)
                            del image
                    
                    torch.cuda.empty_cache()
                    gc.collect()

                    
                    # 以下D
                    if(trainCharaDis):
                        featureO = charaDis(afterCharacter)
                        c_loss = charaDisLoss(featureO, featureT)
                        epochCharaLoss += c_loss.item()
                        if phase == "train":
                            c_loss.backward()
                            optimizer_charaDis.step()
                            optimizer_charaDis.zero_grad()
                        del featureT, featureO, c_loss
                    
                    if(not forUnderTraining and trainD and (iteration % trainRate == trainRateC or phase == "val")):
                        fakesN, afterCharacterN, teachersN =\
                             MyPSPAugmentation.getNoisedImages([fakes, afterCharacter, teachers], noiseP, device)
                        if(phase == "train"):
                            # ここでメモリをよく使うため，minibatchを小さくしておく
                            if(useWSGradient):
                                if(minibatch_size > 2):
                                    minibatch_size = 2
                                else:
                                    minibatch_size = 1
                                if(teachersN.shape[1] >= 4):
                                    teachersN = teachersN[:, : 3]
                            else:
                                minibatch_size = 4
                            # beforeCharacterN = beforeCharacterN[:minibatch_size]
                            fakesN = fakesN[:minibatch_size]
                            afterCharacterN = afterCharacterN[:minibatch_size]
                            teachersN = teachersN[:minibatch_size]

                        d_loss, discCorrectN, lossList, tcorrect, fcorrect = d_wgan_loss2(D, None, afterCharacterN,\
                             fakesN, teachersN, alpha, phase, useGradient=useWSGradient, useBefore=False)
                        TCorrectN += tcorrect
                        epochDLossList += lossList
                        epochDLoss += d_loss.item()
                        discriminator_problems_n += minibatch_size*2
                        discriminator_correct_n += discCorrectN
                        del  afterCharacterN, teachersN,  fakesN

                        if phase == "train":
                            d_loss.backward()
                            del d_loss
                            if(d_iteration == 0 and (epoch % 10 == 0 or checkGradNow)):
                                writeDiscriminatorGradients(D, writer, epoch)
                            optimizer_d.step()
                            optimizer_d.zero_grad()               
                        d_iteration += 1      

                    iteration += 1
                    print("\riter {:4}/{}".format(iteration, len(dataLoader)), end="")
                    # add to fakesBackLog
                    if(epoch % 10 == 0 and phase == "train" and iteration == 1 and useFakeBackLog):
                        beforeCharacter = beforeCharacter.cpu()
                        afterCharacter = afterCharacter.cpu()
                        fakes = fakes.cpu()
                        teachers = teachers.cpu()
                        fakesBackLog.append([beforeCharacter, afterCharacter, fakes, teachers])
                    
                    del beforeCharacter, afterCharacter, teachers, alpha, data, fakes


            # epochのphaseごとのloss
            if(d_iteration == 0):
                d_loss = np.nan
            else:
                d_loss =  epochDLoss / d_iteration
            g_loss =  epochGLoss / iteration
            c_loss = epochCharaLoss / iteration
            GLossDict = {key: GLossDict[key] / iteration for key in GLossDict}   

            discriminator_ns = [[discriminator_correct_n , discriminator_problems_n],
                                [discriminator_correct_n_b , discriminator_problems_n_b ]]
            d_correct_rate = printResults(d_loss, discriminator_ns, g_loss, GLossDict, c_loss, epochDLossList, TCorrectN,  epochStartTime, \
                                epoch, forUnderTraining, trainD
                                )
            outputWriter(writer,d_loss, d_correct_rate, g_loss, GLossDict, c_loss, phase, epoch, \
                forUnderTraining, trainD)
            if( phase == "train"): # 後でdropout率を更新するときに使う
                train_d_correct = d_correct_rate
        
        if(os.path.exists(checkpointFile)):
            shutil.copy(checkpointFile, "cpts/before.cpt")
        if epoch % 20 == 0:
            checkpointFile = checkpointFormat.format(epoch)
        checkpoint = {"epoch": epoch, 
            "modelStateDict": myPSP.state_dict(), 
            "discriminatorStateDict": D.state_dict(),
            "charaDiscriminatorStateDict": charaDis.state_dict(),
            "styleDiscriminatorStateDict": styleDis.state_dict(),
            "optStateDict": optimizer.state_dict(), 
            "optDStateDict": optimizer_d.state_dict(),
            "optCDStateDict": optimizer_charaDis.state_dict() if trainCharaDis  else None,
            "optSDStateDict": optimizer_styleDis.state_dict()
            }
        torch.save(checkpoint, checkpointFile)

        D, optimizer_d,  disIntermidiateList, dropoutChangeCount, nowDropout, trainRate, trainRateC, emergencySave = \
            updateDropout(D, optimizer_d, disIntermidiateList, writer, checkpoint, checkpointFile, train_d_correct,
                            nowDropout, dropoutChangeCount, epoch,\
                            trainRate,  trainRateC, device, d_optimFun, optimizer_d_lr,\
                            trainD, emergencySave, forUnderTraining, changeDropout)
        
        if(epoch % 3 == 0 and useFakeBackLog):
            torch.save({FAKES_BACK_LOG_KEY: fakesBackLog}, nowFakesBackPath)
    return 



In [19]:
myPSP.set_level(modelLevel)
myPSP.set_for_chara_training(forCharaTrain)
myPSP.set_for_style_training(forStyleTrain)

In [None]:
writer = SummaryWriter(log_dir="./logs1")
logs = trainModel(myPSP,discriminator, charaDiscriminator, styleDiscriminator, [trainDataLoader, validDataLoader, charaDataLoader], 100000, \
     writer, checkpointFile="cpts/output0.cpt", useFakeBackLog=not (forStyleTrain or forCharaTrain), forCharaTraining = forCharaTrain, forStyleTraining = forStyleTrain, lookIntermidiate=False, \
     nowDropout=d_dropout, changeDropout=True, checkGradNow=False)