In [None]:
from Libs.myFontLib import *
from Libs.myFontData import *
from EfficientNet.utils import *
from StyleGAN.network import *
from Libs.mypSp import *

import torch
import gc
import pickle
import time
import os

from torch.utils.tensorboard import SummaryWriter

%matplotlib inline
fontsInfoFile = "fontsInfo.pkl"

In [None]:
batchSize = 2
compatibleDict = []
with open("checker.pkl", "br") as f:
    compatibleDict = pickle.load(f)
fixedDataset = []
with open("fixedDataset.pkl", "br") as f:
    fixedDataset = pickle.load(f)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("使用デバイス：", device)

In [None]:
trainDataset = FontGeneratorDataset(FontTools(), compatibleDict, [5, 11], useTensor=True, startInd=10)
validDataset = FontGeneratorDataset(FontTools(), compatibleDict, [5, 11], useTensor=True, startInd=0,\
      indN=10, isForValid=fixedDataset)

trainDataLoader = torch.utils.data.dataloader.DataLoader(trainDataset,\
     batch_sampler=MyPSPBatchSampler(batchSize, trainDataset, japaneseRate=0.5))
validDataLoader = torch.utils.data.dataloader.DataLoader(validDataset, batch_size=batchSize)
myPSP = MyPSP()
myPSP.chara_encoder.init_original_layer()
myPSP.style_encoder.init_original_layer()

In [None]:
def getUpdatedParams(myPSP):
    GENERATOR_NAME = "style_gen"
    ENCODER_CONV_NAME = "encode_convs"
    ENCODER_MAP_NAME =  "map2styles"
    paramUpdatedGen = []
    paramUpdatedEncMain = []
    paramUpdatedEncConvMap = []
    for name, param in myPSP.named_parameters():
        param.requires_grad = True
        if GENERATOR_NAME in  name:
            paramUpdatedGen.append(param)
        else:
            if(ENCODER_CONV_NAME in name or ENCODER_MAP_NAME in name):
                paramUpdatedEncConvMap.append(param)
            else:
                paramUpdatedEncMain.append(param)
    return [paramUpdatedGen, paramUpdatedEncMain, paramUpdatedEncConvMap]

In [None]:
def trainModel(myPSP, dataLoaders, epochN, writer: SummaryWriter, checkpointFile = "out.cpt"):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("使用デバイス：", device)

    params = getUpdatedParams(myPSP)
    optimizer = torch.optim.Adam([
        {"params": params[0], "lr": 1e-3}, 
        {"params": params[1], "lr": 1e-3}, 
        {"params": params[2], "lr": 1e-3}
    ], betas=(0.0, 0.99), eps = 1e-8)

    myPSP.to(device)
    myPSP.train()
    torch.backends.cudnn.benchmark = True

    trainLogs = []
    validLogs = []
    logs = [trainLogs, validLogs]
    start = 0
    if os.path.exists(checkpointFile):


        checkpoint = torch.load(checkpointFile)
        start = checkpoint["epoch"]+1
        myPSP.load_state_dict(checkpoint["modelStateDict"])
        optimizer.load_state_dict(checkpoint["optStateDict"])
        logs = checkpoint["logs"]



    # epochのループ
    for epoch in range(start, epochN):
        # print(torch.cuda.memory_summary(device=None, abbreviated=False))
        epochStartTime = time.time()
        epochLoss = 0
        print('-------------')
        print('Epoch {}/{}'.format(epoch, epochN))


        for phase in ["train", "val"]:
            iteration = 0
            dataLoader = None
            if phase== "train":
                if epoch == 0:
                    continue
                myPSP.train()
                dataLoader = dataLoaders[0]
            else:
                myPSP.eval()
                dataLoader = dataLoaders[1]
            print('---------')
            print("({})".format(phase))

            for data in dataLoader:
                beforeCharacter = data[0][0]
                afterCharacter = data[0][1]
                teachers = data[1]
                alpha = torch.ones((1, 1))
                beforeCharacter =  beforeCharacter.to(device, torch.float32)
                afterCharacter =  afterCharacter.to(device, torch.float32)
                teachers =  teachers.to(device, torch.float32)
                alpha = alpha.to(device, torch.float32)

                with torch.set_grad_enabled(phase == "train"):
                    fakes = myPSP(beforeCharacter, teachers, alpha)
                    loss = MyPSPLoss(onSharp=0.3)(fakes, afterCharacter)

                    if iteration == 0:
                        for i in range(2):
                            writer.add_images("{}/{}".format(phase, i), torch.stack([beforeCharacter[i],\
                             afterCharacter[i], fakes[i], teachers[i, 0, 1, :]]), global_step=epoch)


                    if phase == "train":
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()

                    iteration += 1
                    epochLoss += loss
                    
                    del beforeCharacter, afterCharacter, teachers, alpha, data, fakes, loss
                    torch.cuda.empty_cache()
                    gc.collect()

            # epochのphaseごとのloss
            loss =  epochLoss.item() / iteration
            if(phase == "train"):
                logs[0].append(loss)
                writer.add_scalar("loss/train", loss, global_step=epoch)
            else:
                logs[1].append(loss)
                writer.add_scalar("loss/valid", loss, global_step=epoch)
            epochFinishTime = time.time()
            print('-----')
            print('epoch {} || Epoch_Loss:{:.4f} '.format(
                epoch, loss))
            print('timer:  {:.4f} sec.'.format(epochFinishTime - epochStartTime))

        checkpoint = {"epoch": epoch, 
            "modelStateDict": myPSP.state_dict(), 
            "optStateDict": optimizer.state_dict(), 
            "logs": logs}
        torch.save(checkpoint, checkpointFile)
    return logs



In [None]:
writer = SummaryWriter(log_dir="./logs2")
logs = trainModel(myPSP, [trainDataLoader, validDataLoader], 1000, writer)