In [1]:
from __future__ import print_function
import zipfile
import os
import pdb
import torch
import h5py
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision import datasets, transforms, utils

import numpy as np
import torch.nn as nn
import torch

output_height=64
output_width=64

class TransposeDepthInput(object):
    def __call__(self, depth):
        depth = depth.transpose((2, 0, 1))
        depth = torch.from_numpy(depth)
        depth = depth.view(1, depth.shape[0], depth.shape[1], depth.shape[2])
        depth = nn.functional.interpolate(depth, size=(output_height, output_width), mode='bilinear', align_corners=False)
        depth = torch.log(depth[0])
        return depth

rgb_data_transforms = transforms.Compose([
    transforms.Resize((output_height, output_width)),    # Different for Input Image & Depth Image
    transforms.ToTensor(),
])

depth_data_transforms = transforms.Compose([
    TransposeDepthInput(),
])

input_for_plot_transforms = transforms.Compose([
    transforms.Resize((output_height, output_width)),    # Different for Input Image & Depth Image
    transforms.ToTensor(),
])

class NYUDataset(Dataset):
    def __init__(self, filename, type, rgb_transform = None, depth_transform = None):
        f = h5py.File(filename, 'r')
        if type == "training":
            self.images = f['images'][0:1024]
            self.depths = f['depths'][0:1024]
        elif type == "validation":
            self.images = f['images'][1024:1248]
            self.depths = f['depths'][1024:1248]
        elif type == "test":
            self.images = f['images'][1248:]
            self.depths = f['depths'][1248:]
        self.rgb_transform = rgb_transform
        self.depth_transform = depth_transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        image = image.transpose((2, 1, 0))
        image = Image.fromarray(image)
        if self.rgb_transform:
            image = self.rgb_transform(image)
        depth = self.depths[idx]
        depth = np.reshape(depth, (1, depth.shape[0], depth.shape[1]))
        depth = depth.transpose((2, 1, 0))
        if self.depth_transform:
            depth = self.depth_transform(depth)
        sample = {'image': image, 'depth': depth}
        return sample

  from ._conv import register_converters as _register_converters


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class UNetConvBlock(nn.Module):
    def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu):
        super(UNetConvBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_size, out_size, kernel_size, padding=1)
        self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, padding=1)
        self.conv1 = nn.Conv2d(in_size, out_size, kernel_size, padding=1)
        self.activation = activation

    def forward(self, x):
        out = self.activation(self.conv(x))
        out = self.activation(self.conv2(out))
        return out

class UNetUpBlock(nn.Module):
    def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu, space_dropout=False):
        super(UNetUpBlock, self).__init__()
        self.up = nn.ConvTranspose2d(in_size, out_size, 2, stride=2)
        self.conv = nn.Conv2d(in_size, out_size, kernel_size, padding=1)
        self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, padding=1)
        self.activation = activation

    def forward(self, x, bridge):
        up = self.up(x)
        out = torch.cat([up, bridge], 1)
        out = self.activation(self.conv(out))
        out = self.activation(self.conv2(out))
        return out

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        self.activation = F.relu
        
        self.pool1 = nn.MaxPool2d(2)
        self.pool2 = nn.MaxPool2d(2)
        self.pool3 = nn.MaxPool2d(2)
        self.pool4 = nn.MaxPool2d(2)

        self.conv_block1_64 = UNetConvBlock(3, 64)
        self.conv_block64_128 = UNetConvBlock(64, 128)
        self.conv_block128_256 = UNetConvBlock(128, 256)
        self.conv_block256_512 = UNetConvBlock(256, 512)
        self.conv_block512_1024 = UNetConvBlock(512, 1024)

        self.up_block1024_512 = UNetUpBlock(1024, 512)
        self.up_block512_256 = UNetUpBlock(512, 256)
        self.up_block256_128 = UNetUpBlock(256, 128)
        self.up_block128_64 = UNetUpBlock(128, 64)

        self.last = nn.Conv2d(64, 1, 1)

    def forward(self, x):
        block1 = self.conv_block1_64(x)
        pool1 = self.pool1(block1)

        block2 = self.conv_block64_128(pool1)
        pool2 = self.pool2(block2)

        block3 = self.conv_block128_256(pool2)
        pool3 = self.pool3(block3)

        block4 = self.conv_block256_512(pool3)
        pool4 = self.pool4(block4)

        block5 = self.conv_block512_1024(pool4)

        up1 = self.up_block1024_512(block5, block4)

        up2 = self.up_block512_256(up1, block3)

        up3 = self.up_block256_128(up2, block2)

        up4 = self.up_block128_64(up3, block1)

        return self.last(up4)

In [3]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, utils
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

from torch.autograd import Variable
from logger import Logger
import pdb
import os
import re
import numpy as np

# Training settings
# parser = argparse.ArgumentParser(description='PyTorch depth map prediction example')
# parser.add_argument('model_folder', type=str, default='trial', metavar='F',
#                      help='In which folder do you want to save the model')
# parser.add_argument('--data', type=str, default='data', metavar='D',
#                      help="folder where data is located. train_data.zip and test_data.zip need to be found in the folder")
# parser.add_argument('--batch-size', type = int, default = 32, metavar = 'N',
#                      help='input batch size for training (default: 8)')
# parser.add_argument('--epochs', type=int, default = 10, metavar='N',
#                       help='number of epochs to train (default: 10)')
# parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
#                      help='learning rate (default: 0.01)')
# parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
#                      help='SGD momentum (default: 0.5)')
# parser.add_argument('--seed', type=int, default=1, metavar='S',
#                      help='random seed (default: 1)')
# parser.add_argument('--log-interval', type=int, default=10, metavar='N',
#                      help='how many batches to wait before logging training status')
# parser.add_argument('--suffix', type=str, default='', metavar='D',
#                      help='suffix for the filename of models and output files')
# args = parser.parse_args()

data = 'data'
batch_size = 8
epochs = 20
lr = 0.0001
momentum = 0.5
seed = 1
log_interval = 10
suffix = ''
model_folder = 'local-unet'

#torch.manual_seed(seed)

### Data Initialization and Loading
# from data import initialize_data, rgb_data_transforms, depth_data_transforms, output_height, output_width
#initialize_data(args.data) # extracts the zip files, makes a validation set

# from data import NYUDataset, rgb_data_transforms, depth_data_transforms, input_for_plot_transforms, output_height, output_width

# train_rgb_loader = torch.utils.data.DataLoader(datasets.ImageFolder(args.data + '/train_images/rgb/', transform = rgb_data_transforms), batch_size=args.batch_size, shuffle=False, num_workers=1)
# train_depth_loader = torch.utils.data.DataLoader(datasets.ImageFolder(args.data + '/train_images/depth/', transform = depth_data_transforms), batch_size=args.batch_size, shuffle=False, num_workers=1)
# val_rgb_loader = torch.utils.data.DataLoader(datasets.ImageFolder(args.data + '/val_images/rgb/', transform = rgb_data_transforms), batch_size=args.batch_size, shuffle=False, num_workers=1)
# val_depth_loader = torch.utils.data.DataLoader(datasets.ImageFolder(args.data + '/val_images/depth/', transform = depth_data_transforms), batch_size=args.batch_size, shuffle=False, num_workers=1)

train_loader = torch.utils.data.DataLoader(NYUDataset( 'nyu_depth_v2_labeled.mat', 
                                                       'training', 
                                                        rgb_transform = rgb_data_transforms, 
                                                        depth_transform = depth_data_transforms), 
                                                        batch_size = batch_size, 
                                                        shuffle = False, num_workers = 0)

val_loader = torch.utils.data.DataLoader(NYUDataset( 'nyu_depth_v2_labeled.mat',
                                                     'validation', 
                                                     rgb_transform = rgb_data_transforms, 
                                                     depth_transform = depth_data_transforms), 
                                                     batch_size = batch_size, 
                                                     shuffle = False, num_workers = 0)

from model import UNet
model = UNet()
# model.cuda()

def rel_error(output, target):
    target = target + 0.000001
    target = log10(target)
    output = output + 0.000001
    output = log10(output)
    return F.mse_loss(output, target)
    #diff = (output-target)/target
    #diff = torch.abs(diff)
    #return diff.mean()

def custom_loss_function(output, target):
    # di = torch.log(target) - torch.log(output)
    di = target - output
    n = (output_height * output_width)
    di2 = torch.pow(di, 2)
    fisrt_term = torch.sum(di2,(1,2,3))/n
    second_term = 0.5*torch.pow(torch.sum(di,(1,2,3)), 2)/ (n**2)
    loss = fisrt_term - second_term
    return loss.sum()

loss_function = custom_loss_function
#loss_function = F.mse_loss
#loss_function = F.smooth_l1_loss
#loss_function = rel_error
optimizer = optim.Adam(model.parameters(), amsgrad=True, lr=0.0001)
#optimizer = optim.SGD(model.parameters(), lr = 0.0001, momentum=0.99)
#optimizer = optim.Adamax(model.parameters())
dtype=torch.cuda.FloatTensor
logger = Logger('./logs/' + model_folder)

def display_images(images):
    grid = utils.make_grid(images)
    plt.imshow(grid.cpu().detach().numpy().transpose((1, 2, 0)))
    plt.show();

def format_data_for_display(tensor):
    maxVal = tensor.max()
    minVal = abs(tensor.min())
    maxVal = max(maxVal,minVal)
    output_data = tensor / maxVal
    output_data = output_data / 2
    output_data = output_data + 0.5
    return output_data

def plot_grid(fig, plot_input, output, actual_output, row_no):
        grid = ImageGrid(fig, 141, nrows_ncols=(row_no, 4), axes_pad=0.05, label_mode="1")
        for i in range(row_no):
                for j in range(3):
                        if(j == 0):
                                grid[i*4+j].imshow(np.transpose(plot_input[i], (1, 2, 0)), interpolation="nearest")
                        if(j == 1):
                                grid[i*4+j].imshow(np.transpose(output[i][0].detach().cpu().numpy(), (0, 1)), interpolation="nearest")
                        if(j == 2):
                                grid[i*4+j].imshow(np.transpose(actual_output[i][0].detach().cpu().numpy(), (0, 1)), interpolation="nearest")

def train_Unet(epoch):
    model.train()
    for batch_idx, data in enumerate(train_loader):
        rgb, depth = data['image'], data['depth']
        optimizer.zero_grad()
        output = model(rgb)
        target = depth[:,0,:,:].view(list(depth.shape)[0], 1, output_height, output_width)
        #print("target")
        #print(target)
        #print("output")
        #print(output)
        loss = loss_function(output, target)
        loss.backward()
        optimizer.step()
        F = plt.figure(1, (30, 60))
        F.subplots_adjust(left=0.05, right=0.95)
        plot_grid(F, rgb, target, output, batch_size)
        plt.savefig("plots/train_" + model_folder + "_" + str(epoch) + "_" + str(batch_idx) + ".jpg")
        plt.show()

        if batch_idx % log_interval == 0:
            training_tag = "training loss epoch:" + str(epoch)
            logger.scalar_summary(training_tag, loss.item(), batch_idx)

            for tag, value in model.named_parameters():
                tag = tag.replace('.', '/') + ":" + str(epoch)
                #logger.histo_summary(tag, value.data.cpu().numpy(), batch_idx)
                #logger.histo_summary(tag + '/grad', value.grad.data.cpu().detach().numpy(), batch_idx)

            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(rgb), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
#         batch_idx = batch_idx + 1
        if batch_idx == 0: break

def validate_Unet():
    print('validating unet')
    model.eval()
    validation_loss = 0
    with torch.no_grad():
        for batch_idx, data in enumerate(val_loader):
            rgb, depth = data['image'], data['depth']
            output = model(rgb)
            target = depth[:,0,:,:].view(list(depth.shape)[0], 1, output_height, output_width)
            validation_loss += rel_error(output, target)
#           if batch_idx == 2: break
            rel_loss = rel_error(output, target)
            rms_loss = F.mse_loss(output, target)
        validation_loss /= batch_idx
        rel_loss /= batch_idx
        rms_loss /= batch_idx
        logger.scalar_summary("validation loss", validation_loss, epoch)
        print('\nValidation set: Average loss: {:.6f} {:.6f} {:.6f}\n'.format(validation_loss, rel_loss, rms_loss))

folder_name = "models/" + model_folder
if not os.path.exists(folder_name): os.mkdir(folder_name)

for epoch in range(1, epochs + 1):
    print("********* Training the Unet Model **************")
    train_Unet(epoch)
    if epoch % 25== 0:
        model_file = folder_name + "/" + 'model_' + str(epoch) + '.pth'
        torch.save(model.state_dict(), model_file)
#    validate_Unet()

********* Training the Unet Model **************
********* Training the Unet Model **************
********* Training the Unet Model **************
********* Training the Unet Model **************
********* Training the Unet Model **************
********* Training the Unet Model **************
********* Training the Unet Model **************
********* Training the Unet Model **************
********* Training the Unet Model **************
********* Training the Unet Model **************
********* Training the Unet Model **************
********* Training the Unet Model **************
********* Training the Unet Model **************
********* Training the Unet Model **************
********* Training the Unet Model **************
********* Training the Unet Model **************
********* Training the Unet Model **************
********* Training the Unet Model **************
********* Training the Unet Model **************
********* Training the Unet Model **************
