In [None]:
from __future__ import print_function
# import argparse
import os
import random
from os.path import join
import torch
import torch.optim as optim
from torch.utils import data as Data
import shutil
import numpy as np

from learning.datasets import BDIDerivedDataset
from learning.models import BDI_3D_Conv, BDI_3D_Conv_Simple
from learning.loops import train, validate, test
import config
import learning.dataset_config as learning_config
# from tensorboard_logger import Logger

In [None]:
python train.py --datasets BDI_DERIVED --output  --batch 50 --epochs 20 --lr 1e-4^C


dataroot= config.ALIGNED_FACES_FOLDER
datasets= "BDI_DERIVED"
workers= 4
batch= 75
epochs= 20
cpu= False
# seed= 1
seed = random.randint(1, 10000)
seed
checkpoint='None'
output= '/timo/datasets/Dua/3Dcnn/output'
print_freq= 10
visualize=False
lr=1e-4
submedian= False
overwrite=False
augment= False


# parser = argparse.ArgumentParser(description='BDI Training Script')
# parser.add_argument('--dataroot', default=config.ALIGNED_FACES_FOLDER, type=str, help='path to dataset')
# parser.add_argument('--datasets', type=str, help='datasets used for training and validation')
# parser.add_argument('--workers', '-j', default=4, type=int, help='number of data loading workers')
# parser.add_argument('--batch', type=int, default=75, help='input batch size')
# parser.add_argument('--epochs', default=25, type=int, help='number of epochs to run')
# parser.add_argument('--cpu', action='store_true', help='run without cuda')
# parser.add_argument('--seed', type=int, help='manual seed')
# parser.add_argument('--checkpoint', type=str, help='location of the checkpoint to load')
# parser.add_argument('--output', default='/home/mohammad/output/emotion', type=str,
#                     help='folder to output model checkpoints')

# parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
# parser.add_argument('--visualize', action='store_true', help='evaluate model on validation set')
# parser.add_argument('--lr', default=1e-3, type=float, help='learning rate')
# parser.add_argument('--submedian', action='store_true', help='run on natural data')
# parser.add_argument('--overwrite', action='store_true', help='overwrite the prediction data or not')
# parser.add_argument('--augment', action='store_true', help='do data augmentation or not')

#parser.set_defaults(augment=True)

# # parse arguments
# args = parser.parse_args()
# if args.seed is None:
#     args.seed = random.randint(1, 10000)

if visualize:
    # evaluation model
    output = checkpoint[0:-4]

# # print arguments
# print("Summary of Arguments:")
# for key, val in vars(args).items():
#     print("{:10} {}".format(key, val))


# handle random seed
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(int(seed))
if not cpu:
    torch.cuda.manual_seed_all(int(seed))

# create output folder
if os.path.exists(output):
    if not visualize and checkpoint == None:
        if len([f for f in os.listdir(output) if "model" in f]) > 0:
            raise(RuntimeError("Output folder {} already exist.".format(output)))
        else:
            shutil.rmtree(output)
            os.makedirs(output)
else:
    os.makedirs(output)

# create tensorboard, args and code copy
# if not visualize:
# #     Logger_train = Logger(logdir=join(args.output, "train"), flush_secs=2)
# #     Logger_val = Logger(logdir=join(args.output, "val"), flush_secs=2)
#     f_log = open(os.path.join(output, 'args.txt'), 'w')
#     for key, val in vars(args).items():
#         print("{} {}".format(key, val), file=f_log)
#     f_log.close()
#     if checkpoint is None:
#         shutil.copytree('.', os.path.join(output, 'src'))

In [None]:
# specify train and validate data folders

train_exp, validate_exp = learning_config.get_train_val_folders(name=datasets)

# #TODO
# if args.visualize:
#     validate_exp += train_exp[0:2]
#     validate_exp += ["exp0718-2-01", "exp0718-2-02", "exp0719-2-01", "exp0719-2-02"]

# build datasets
train_dataset = BDIDerivedDataset(
    folders=[dataroot.format(exp=exp) for exp in train_exp],
    submedian=submedian,
    flip=augment
)
validate_dataset = BDIDerivedDataset(
    folders=[dataroot.format(exp=exp) for exp in validate_exp],
    submedian=submedian,
    flip=False,
    return_idx=visualize
)

In [None]:
# build data loaders
train_loader = Data.DataLoader(
    dataset=train_dataset,
    batch_size=batch,
    shuffle=True,
    num_workers=workers,
    pin_memory=False
)
validate_loader = Data.DataLoader(
    dataset=validate_dataset,
    batch_size=batch,
    shuffle=False,
    num_workers=workers,
    pin_memory=False
)

# print(len(train_loader))
# print(len(validate_loader))
#    
# for idx, data in enumerate(train_loader):
#     (rf, labels) = data
# for idx, data in enumerate(validate_loader):
#     (rf, labels) = data
# #     print(labels)
# #     print(np.array(rf).shape)
# exit()


# build model and optimizer
model = BDI_3D_Conv_Simple()
# model = BDI_3D_Conv()
# Device configuration
device = torch.device('cuda:0' if not cpu else 'cpu')
model.to(device)


print('# of params:', str(sum([p.numel() for p in model.parameters()])))

optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = torch.nn.CrossEntropyLoss().cuda()

# optionally load model from a checkpoint
if checkpoint:
    if os.path.isfile(checkpoint):
        model = torch.load(checkpoint)
        if not visualize:
            start_epoch = int(checkpoint.split('.')[-2].split('_')[-1]) + 1
    else:
        raise(RuntimeError("no checkpoint found at '{}'".format(checkpoint)))
else:
    start_epoch = 1

lr = lr

if visualize:
    # evaluation model
    if not checkpoint:
        raise(RuntimeWarning("visualizing a random initialized model"))
    if overwrite:
        shutil.rmtree(output)
    test(model, validate_loader, 0, args, criterion=criterion)
else:
    # train and validate loop
    for epoch in range(start_epoch, epochs + 1):
        train_loss = train(model, train_loader, optimizer, epoch, args, criterion=criterion)
        val_loss = validate(model, validate_loader, epoch, args, criterion=criterion)

        # do tensorboard and checkpointing:
#         Logger_train.log_value('loss', train_loss, epoch)
#         Logger_val.log_value('loss', val_loss[0], epoch)
#         Logger_val.log_value('accuracy', val_loss[1], epoch)
#         Logger_train.log_value('lr', lr, epoch)
        torch.save(model, '{}/model_epoch_{}.pth'.format(output, epoch))