In [1]:
import sys
sys.path.append("..")
import os
import torch
import torch.nn.functional as F
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter


In [2]:
import utils.dataset as myDataset
import utils.loss as myLoss
import model.model as myModel


In [3]:
arg_batchSize = 32
arg_nEpoch = 10
arg_pretrainedModel = None
# arg_pretrainedModel = "../model/pretrainedModel/final_facedet.pt"
arg_workers = 12
arg_dataset = "../data/"
arg_split = "train"
arg_outName = "facedet.pt"


In [4]:
dataset = myDataset.FaceDataset(datapath = arg_dataset, split = arg_split)
dataloader = torch.utils.data.DataLoader(dataset, shuffle = True, batch_size = arg_batchSize, \
                                         num_workers = arg_workers, drop_last = False)


7050it [01:17, 91.06it/s] 


In [8]:
writer = SummaryWriter("../log/scene")

print("length of dataset: %s" % (len(dataloader)))
batch_num = len(dataloader)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = myModel.FaceKeypointModel()
model.apply(myModel.weights_init)

if arg_pretrainedModel != None:
    model.load_state_dict(torch.load("../model/" + arg_pretrainedModel))
    print("Use model from ../model/" + arg_pretrainedModel)
else:
    print("Use new model")

if not os.path.exists("../model/pretrainedModel"):
    os.makedirs("../model/pretrainedModel")

model.cuda()
# model.train()
optimizer = torch.optim.Adam(model.parameters(),lr=0.0001,betas=(0.9, 0.999))


length of dataset: 221
Use new model


In [9]:
for epoch in tqdm(range(arg_nEpoch)):
    for i, data in tqdm(enumerate(dataloader)):

        image, anno, gtmap= data
        image, anno, gtmap= image.to(device,  dtype=torch.float), anno.to(device), gtmap.to(device, dtype=torch.float)
        image = image/255.0
        
        heatMap = model(image)
        
        loss = myLoss.calLossMSE(heatMap, anno, gtmap)
        
        loss.backward()
        optimizer.step()

        writer.add_scalar("training loss", loss.item(), epoch*len(dataloader) + i)

    print("[ epoch: %d/%d  batch: %d/%d ]  loss: %f" % (epoch, arg_nEpoch, i + 1, batch_num, loss.item()))
    if epoch % 5 == 4:
        torch.save(model.state_dict(), "../model/pretrainedModel/epo" + str(epoch) + arg_outName)
        print("Model saved at ../model/pretrainedModel/epo" + str(epoch) + arg_outName)


221it [01:53,  1.95it/s]
23it [00:12,  1.84it/s]
 10%|█         | 1/10 [02:07<19:06, 127.35s/it]


KeyboardInterrupt: 

In [None]:
torch.save(model.state_dict(), "../model/pretrainedModel/final_test_" + arg_outName)
print("Model saved at ../model/pretrainedModel/final_" + arg_outName)