In [8]:
import argparse
import os

import torch
import torch.backends.cudnn
from torch.nn import DataParallel
from torch.optim.rmsprop import RMSprop
from torch.utils.data import DataLoader
from tqdm import trange, tqdm

from model import hg1, hg2, hg8
from datasets.mpii import Mpii
from train import do_training_epoch, do_validation_epoch
from utils.logger import Logger
from utils.misc import save_checkpoint, adjust_learning_rate

def main():
    
    epochs=200
    arch='hg1'
    checkpoint='checkpoint'
    image_path = 'images'
    input_shape=(256, 256)
    schedule= [60, 90]
    gamma=0.1
    lr = 1e-4
    workers=4
    train_batch=6
    test_batch=6
    snapshot=0
    
    
    
    # Select the hardware device to use for inference.
    if torch.cuda.is_available():
        device = torch.device('cuda', torch.cuda.current_device())
        torch.backends.cudnn.benchmark = True
    else:
        device = torch.device('cpu')

    # Disable gradient calculations by default.
    torch.set_grad_enabled(False)

    # create checkpoint dir
    os.makedirs(checkpoint, exist_ok=True)

    if arch == 'hg1':
        model = hg1(pretrained=False)
    elif arch == 'hg2':
        model = hg2(pretrained=False)
    elif arch == 'hg8':
        model = hg8(pretrained=False)
    else:
        raise Exception('unrecognised model architecture: ' + args.arch)

    model = DataParallel(model).to(device)

    optimizer = RMSprop(model.parameters(), lr=1e-4)

    best_acc = 0

#     # optionally resume from a checkpoint
#     if args.resume:
#         assert os.path.isfile(args.resume)
#         print("=> loading checkpoint '{}'".format(args.resume))
#         checkpoint = torch.load(args.resume)
#         args.start_epoch = checkpoint['epoch']
#         best_acc = checkpoint['best_acc']
#         model.load_state_dict(checkpoint['state_dict'])
#         optimizer.load_state_dict(checkpoint['optimizer'])
#         print("=> loaded checkpoint '{}' (epoch {})"
#               .format(args.resume, checkpoint['epoch']))
#         logger = Logger(os.path.join(args.checkpoint, 'log.txt'), resume=True)
#     else:
#         logger = Logger(os.path.join(args.checkpoint, 'log.txt'))
#         logger.set_names(['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Train Acc', 'Val Acc'])

    # create data loader
    train_dataset = Mpii(image_path, is_train=True, inp_res=input_shape)
    train_loader = DataLoader(
        train_dataset,
        batch_size=train_batch, shuffle=True,
        num_workers=workers, pin_memory=True
    )

    val_dataset = Mpii(image_path, is_train=False, inp_res=input_shape)
    val_loader = DataLoader(
        val_dataset,
        batch_size=test_batch, shuffle=False,
        num_workers=workers, pin_memory=True
    )

    # train and eval
    for epoch in trange(epochs, desc='Overall', ascii=True):
        lr = adjust_learning_rate(optimizer, epoch, lr, schedule, gamma)

        # train for one epoch
        train_loss, train_acc = do_training_epoch(train_loader, model, device, Mpii.DATA_INFO,
                                                  optimizer,
                                                  acc_joints=Mpii.ACC_JOINTS)

        # evaluate on validation set
        valid_loss, valid_acc, predictions = do_validation_epoch(train_loader, model, device,
                                                                 Mpii.DATA_INFO, False,
                                                                 acc_joints=Mpii.ACC_JOINTS)

        # print metrics
        tqdm.write(f'[{epoch + 1:3d}/{epochs:3d}] lr={lr:0.2e} '
                   f'train_loss={train_loss:0.4f} train_acc={100 * train_acc:0.2f} '
                   f'valid_loss={valid_loss:0.4f} valid_acc={100 * valid_acc:0.2f}')

        # append logger file
        logger.append([epoch + 1, lr, train_loss, valid_loss, train_acc, valid_acc])

        # remember best acc and save checkpoint
        is_best = valid_acc > best_acc
        best_acc = max(valid_acc, best_acc)
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': arch,
            'state_dict': model.state_dict(),
            'best_acc': best_acc,
            'optimizer' : optimizer.state_dict(),
        }, predictions, is_best, checkpoint=checkpoint, snapshot=snapshot)

    logger.close()



In [9]:
main()

Overall:   0%|                                                   | 0/200 [00:00<?, ?it/s]
Train:   0%|                                                    | 0/3708 [00:00<?, ?it/s][A
Train:   0%|                         | 0/3708 [00:09<?, ?it/s, Loss: 0.0667, Acc:   0.00][A
Train:   0%|              | 1/3708 [00:09<10:08:32,  9.85s/it, Loss: 0.0667, Acc:   0.00][A
Train:   0%|              | 1/3708 [00:10<10:08:32,  9.85s/it, Loss: 0.0654, Acc:   0.00][A
Train:   0%|               | 2/3708 [00:10<4:16:26,  4.15s/it, Loss: 0.0654, Acc:   0.00][A
Train:   0%|               | 2/3708 [00:10<4:16:26,  4.15s/it, Loss: 0.0568, Acc:   0.00][A
Train:   0%|               | 3/3708 [00:10<2:23:53,  2.33s/it, Loss: 0.0568, Acc:   0.00][A
Train:   0%|               | 3/3708 [00:10<2:23:53,  2.33s/it, Loss: 0.0468, Acc:   0.00][A
Train:   0%|               | 4/3708 [00:10<1:31:04,  1.48s/it, Loss: 0.0468, Acc:   0.00][A
Train:   0%|               | 4/3708 [00:10<1:31:04,  1.48s/it, Loss: 0.04

KeyboardInterrupt: 