# [See HTML version to see outputs.](http://rawgit.com/MatthewKleinsmith/portraitseg/master/portraitseg/tests/visually_test_Flickr_Dataset_and_dataloader.html)
# Use this notebook to reproduce.

In [None]:
%matplotlib inline

import os
os.chdir("../../")

from portraitseg.common import *

class Net(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 1, kernel_size=3, padding=1)
    
    def forward(self, x):
        return self.conv1(x)


SEED = 0
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)

DATA_DIR = "../data/"
FLICKR_DIR = DATA_DIR + "portraits/flickr/"

TESTING_PIPES = True

# Hyperparameters
BATCH_SIZE = 1
LR = 1e-2
NB_EPOCHS = 2
AUGMENT = False
net = Net().cuda()
loss_fn = nn.MSELoss()
optimizer = optim.SGD(net.parameters(), lr=LR)

# Get DataLoaders
trn_loader, val_loader = get_train_valid_loader(FLICKR_DIR,
                                                batch_size=BATCH_SIZE,
                                                augment=AUGMENT,
                                                random_seed=SEED,
                                                valid_size=0.2,
                                                show_sample=True,
                                                num_workers=1,
                                                pin_memory=True)

# Train
val_outputs = []
for epoch in range(NB_EPOCHS):

    start = time()
    running_loss = 0.0
    print("\n[Epoch, batches]")
    for i, sample_batch in enumerate(trn_loader, 0):

        portraits, masks = sample_batch
        portraits, masks = Variable(portraits).cuda(), Variable(masks).cuda()
        
        outputs = net(portraits)
        loss = loss_fn(outputs, masks)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        running_loss += loss.data[0]
        if i % 100 == 99:
            print("[%d, %4d] loss: %.3f" % (epoch+1, i+1, running_loss/100))
            running_loss = 0.0
            if TESTING_PIPES:
                break
            
    # Test on validation set
    running_loss_val = 0.0
    for i, sample_batch in enumerate(val_loader, 0):
        portraits, masks = sample_batch
        portraits, masks = Variable(portraits).cuda(), Variable(masks).cuda()
        outputs = net(portraits)
        loss = loss_fn(outputs, masks)
        running_loss_val += loss.data[0]
        if i == 1 and epoch == 0:
            portraits_v = portraits
            masks_v = masks
    print("Validation loss: %.3f" % (running_loss_val/len(val_loader)))
    print("Epoch duration: %.2f seconds" % (time() - start))
    val_outputs.append(net(portraits_v))
    show_input_output_target(portraits_v,
                             val_outputs,
                             masks_v,
                             denormalizer)

print("Training complete.")