In [1]:
import torch
import torch.nn.functional as F
import torch.optim as optim

from data_loader import *
from model_hourglass import StackedHourGlass
from utils import to_cuda

In [2]:
# Path to data

img_path = 'E:\\Datasets\\3DFaces\\300W-3D-ALL\\images'
mat_path = 'E:\\Datasets\\3DFaces\\300W-3D-ALL\\3d-scans'

In [3]:
# Calling data loaders

trainset = FacesWith3DCoords(
    images_dir=img_path, mats_dir=mat_path, transform=True
)

trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=4, shuffle=True, num_workers=2
)

model = StackedHourGlass(nChannels=224, nStack=2, nModules=2, numReductions=4, nOutputs=200)
model.cuda()
model.train()

# Definition of loss and optimizer
criterion = torch.nn.CrossEntropyLoss()

# optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9)
optimizer = optim.RMSprop(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.9)

In [None]:
def train():

    for epoch in range(0, 100):  # loop over the dataset multiple times
        print("=== Epoch", epoch, "===")
        scheduler.step()

        running_loss, epoch_avg = 0.0, 0.0

        for i, data in enumerate(trainloader, 1):
            
            # Load input from dataloader
            images, volumes, _ = data

            images = to_cuda(images, True)
            volumes = to_cuda(volumes, True)

            # Zero the parameter gradients
            optimizer.zero_grad()

            out_volumes = F.sigmoid(model(images))

            loss = F.binary_cross_entropy(out_volumes, volumes)

            loss.backward()

            torch.nn.utils.clip_grad_value_(model.parameters(), 5)
            optimizer.step()

            # Print some stats
            running_loss += loss.item()
            epoch_avg += loss.item()

            if i % 1 == 0:
                print('[%2d, %5d/%5d] loss: %.8f lr %.8f' % (epoch, i, len(trainloader), running_loss / 1, scheduler.get_lr()[0]))
                running_loss = 0.0

        print("EPOCH AVG", epoch_avg / len(trainloader))
        
        # Save model each 5 epochs
        if epoch % 5 == 0:
            torch.save(model.state_dict(), "trained_models/hourglass_%d_epochs" % epoch)

train()