In [None]:
%matplotlib inline

import random
import numpy as np
import matplotlib.pyplot as plt
import os
import cv2
import torch
import dense_correspondence_manipulation.utils.utils as utils
utils.add_dense_correspondence_to_python_path()

import dense_correspondence
from dense_correspondence.evaluation.evaluation import *
import dense_correspondence.correspondence_tools.correspondence_plotter as correspondence_plotter
from dense_correspondence.dataset.dense_correspondence_dataset_masked import ImageType
from dense_correspondence.dataset.dynamic_time_contrast_dataset import DynamicTimeContrastDataset

## Load the configuration for training

In [None]:
# LOAD DATASET
dataset_config_filename = os.path.join(utils.getDenseCorrespondenceSourceDir(), 'config', 'dense_correspondence',
                                       'dataset', 'composite',
                                       'dynamic.yaml')
dataset_config = utils.getDictFromYamlFilename(dataset_config_filename)
dataset = DynamicTimeContrastDataset(config=dataset_config)

# LOAD DESCRIPTOR NETWORK
eval_config_filename = os.path.join(utils.getDenseCorrespondenceSourceDir(), 'config', 
                               'dense_correspondence', 'evaluation', 'evaluation.yaml')
eval_config = utils.getDictFromYamlFilename(eval_config_filename)

utils.set_cuda_visible_devices([0])
dce = DenseCorrespondenceEvaluation(eval_config)
network_name = "sugar_closer_3"
dcn = dce.load_network_from_config(network_name)
dcn.cuda().eval()
print "loaded"

In [None]:
# # Test that I can put an image through DCN from TCN dataloader
# metadata, images = dataset[0]
# img_anchor = images[0]
# res_a = dcn.forward_single_image_tensor(img_anchor).data.cpu().numpy()
# from dense_correspondence.evaluation.plotting import normalize_descriptor
# res_a = normalize_descriptor(res_a, dcn.descriptor_image_stats["mask_image"])
# import matplotlib.pyplot as plt
# plt.imshow(res_a)
# plt.show()

In [None]:
# stacked_images = torch.cat([x.unsqueeze(0) for x in images]).cuda() #N, C, H, W
# stacked_descriptor_images = dcn.forward(stacked_images).detach() #N, D, H, W
# #plt.imshow(stacked_descriptor_images[0].data.permute(1,2,0).cpu().numpy())

# # Format the input
# print stacked_images.shape
# print stacked_descriptor_images.shape

# # Stack carefully
# rgb_and_descriptor_images = []
# for i in range(stacked_images.shape[0]):
#     rgb_and_descriptor_images.append(torch.cat([stacked_images[i], stacked_descriptor_images[i]]).unsqueeze(0))

# stacked_all_images = torch.cat(rgb_and_descriptor_images) # N, C+D, H, W
# print stacked_all_images.shape
# print stacked_all_images.requires_grad
# print stacked_images.requires_grad

## Train the time contrast network

In [None]:
import torch.nn as nn
import torch.nn.functional as F



class TimeEmbeddingNetwork(nn.Module):
    def __init__(self, D):
        super(TimeEmbeddingNetwork, self).__init__()
        self.conv1 = nn.Conv2d(6, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.conv3 = nn.Conv2d(16, 3, 5)
        self.fc1 = nn.Linear(12768, 12768)
        self.fc2 = nn.Linear(12768, 12768)
        self.fc3 = nn.Linear(12768, D)
        
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(-1,12768)
        drop = nn.Dropout(0.1)
        x = F.relu(drop(self.fc1(x)))
        x = F.relu(drop(self.fc2(x)))
        x = self.fc3(x)
        return x
    
net = TimeEmbeddingNetwork(D=32).cuda()
# out = net(stacked_all_images)
# print out.shape

In [None]:
def triplet_time_contrastive_loss(three_embeddings):
    if True: # normalize
        three_embeddings = [x/torch.norm(x) for x in three_embeddings]
    positive = (three_embeddings[0] - three_embeddings[1]).pow(2).sum()
    negative = (three_embeddings[0] - three_embeddings[2]).pow(2).sum()
    alpha = 0.2
    margin = negative - positive + alpha
    return torch.clamp(margin, min=0)
    
# triplet_time_contrastive_loss(out)

In [None]:
import tensorboard_logger

def setup_tensorboard():
    tensorboard_log_dir = "./tensorboard_log_dir/"+utils.get_current_YYYY_MM_DD_hh_mm_ss()
    logging.info("setting up tensorboard_logger")
    cmd = "tensorboard --logdir=%s" %(tensorboard_log_dir)
    tb_logger = tensorboard_logger.Logger(tensorboard_log_dir)
    logging.info("tensorboard logger started")
    return tb_logger
    
tb_logger = setup_tensorboard()

In [None]:
batch_size = 1
trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=6)

import torch.optim as optim

criterion = triplet_time_contrastive_loss
optimizer = optim.SGD(net.parameters(), lr=0.0001, momentum=0.9)

print("Starting training...")
import time
start = time.time()
overall_iter = 0
for epoch in range(10):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        overall_iter += 1
        # get the inputs
        metadata, images = data

        stacked_images = torch.cat([x for x in images]).cuda() #N, C, H, W
        stacked_descriptor_images = dcn.forward(stacked_images).detach() #N, D, H, W
        rgb_and_descriptor_images = []
        for i in range(stacked_images.shape[0]):
            rgb_and_descriptor_images.append(torch.cat([stacked_images[i], stacked_descriptor_images[i]]).unsqueeze(0))

        stacked_all_images = torch.cat(rgb_and_descriptor_images) # N, C+D, H, W
        
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        out = net(stacked_all_images)
        loss = criterion(out)
        loss.backward()
        optimizer.step()
        
        tb_logger.log_value("loss", loss.item(), overall_iter)

        # print statistics
        running_loss += loss.item()
        if False:    # print every 1000 mini-batches
            print('[%d, %5d] loss: %.8f' %
                  (epoch + 1, i + 1, running_loss / 10))
            running_loss = 0.0

print('Finished Training with', str(epoch*i), 'steps')
print("In time", str(time.time() - start))
torch.save(net.state_dict(), "./new_net.pth")