In [1]:
import sys
sys.path.append("..")
import os
import torch
import torch.nn.functional as F
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter


In [None]:
import utils.dataset as myDataset
import utils.loss as myLoss
import model.model as myModel


In [None]:

arg_batchSize = 8
arg_nEpoch = 50
arg_pretrainedModel = None
arg_workers = 8
arg_dataset = "../data/"
arg_split = "train"
arg_outName = "facedet.pt"


In [None]:

dataset = myDataset.FaceDataset()
dataloader = torch.utils.data.DataLoader(dataset, shuffle = True, batch_size = arg_batchSize, \
                                         num_workers = arg_workers, drop_last = False)

In [None]:
# dataplotter = visualizer.DataPlotter()
writer = SummaryWriter('../log/scene')

print('length of dataset: %s' % (len(dataloader)))
batch_num = int(len(dataloader) / arg_batchSize + 1)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


# train process
model = myModel.FaceKeypointModel()

if arg_pretrainedModel != 'None':
    model.load_state_dict(torch.load('../model/' + arg_pretrainedModel))
    print('Use model from ../model/' + arg_pretrainedModel)
else:
    print('Use new model')

if not os.path.exists('../model/pretrainedModel'):
    os.makedirs('../model/pretrainedModel')

model.to(device)
optimizer = torch.optim.Adam(model.parameters(),lr=0.0001,betas=(0.9, 0.999))

for epoch in tqdm(range(arg_nEpoch)):
    for i, data in tqdm(enumerate(dataloader)):
        image, anno = data
        
        loss = myLoss.calLoss()
        loss.backward()
        optimizer.step()

        print('[ epoch: %d/%d  batch: %d/%d ]  loss: %f' % (epoch, arg_nEpoch, i + 1, batch_num, loss.item()))

        writer.add_scalar('training loss', loss.item(), epoch*len(dataloader) + i)

    if epoch % 30 == 29:
        torch.save(model.state_dict(), '../model/pretrainedModel/epo' + str(epoch) + arg_outName)
        print('Model saved at ../model/pretrainedModel/epo' + str(epoch) + arg_outName)

torch.save(model.state_dict(), '../model/pretrainedModel/final_' + arg_outName)
print('Model saved at ../model/pretrainedModel/final_' + arg_outName)