In [1]:
""" To use Google Drive with Colab, 
1. set use_google_drive to True, and
2. specify a directory in Google Drive (Modify as in your Google Drive)
(You will need to authorize manually.)
"""
use_google_drive = True
workdir = '/content/drive/My Drive/Colab/MegaDepth/'


import os

try:
    if use_google_drive:
        from google.colab import drive
        drive.mount('/content/drive')
        # Create target directory & all intermediate directories if don't exists
        if not os.path.exists(workdir):
            os.makedirs(workdir)
            print('## Directory: ' , workdir ,  ' was created.') 
        os.chdir(workdir)
        print('## Current working directory: ', os.getcwd())
except:
    print('Run the code without using Google Drive.')
        
try:    
    print('## Check the uptime. (Google Colab reboots every 12 hours)')
    !cat /proc/uptime | awk '{print "Uptime is " $1 /60 /60 " hours (" $1 " sec)"}'
    print('## Check the GPU info')
    !nvidia-smi
    print('## Check the OS') 
    !cat /etc/issue
    print('## Check the Python version') 
    !python --version
    print('## Check the memory')
    !free -h
    print('## Check the disk')
    !df -h
except:
    print('Run the code assuming the environment is not Google Colab.')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
## Current working directory:  /content/drive/My Drive/Colab/MegaDepth
## Check the uptime. (Google Colab reboots every 12 hours)
Uptime is 4.03395 hours (14522.23 sec)
## Check the GPU info
Sun Apr 21 10:03:57 2019       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 418.56       Driver Version: 410.79       CUDA Version: 10.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   53C    P8    16W /  70W |      0MiB / 15079MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
        

In [0]:
""" .base_model """

import os
import torch

class BaseModel():
    def name(self):
        return 'BaseModel'

    def initialize(self, opt):
        self.opt = opt
        self.gpu_ids = opt.gpu_ids
        self.isTrain = opt.isTrain
        self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
        self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)

    def set_input(self, input):
        self.input = input

    def forward(self):
        pass

    # used in test time, no backprop
    def test(self):
        pass

    def get_image_paths(self):
        pass

    def optimize_parameters(self):
        pass

    def get_current_visuals(self):
        return self.input

    def get_current_errors(self):
        return {}

    def save(self, label):
        pass

    # helper saving function that can be used by subclasses
    def save_network(self, network, network_label, epoch_label, gpu_ids):
        save_filename = '_%s_net_%s.pth' % (epoch_label, network_label)
        save_path = os.path.join(self.save_dir, save_filename)
        torch.save(network.cpu().state_dict(), save_path)
        if len(gpu_ids) and torch.cuda.is_available():
            network.cuda(device_id=gpu_ids[0])

    # helper loading function that can be used by subclasses
    def load_network(self, network, network_label, epoch_label):
        save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
        save_path = os.path.join(self.save_dir, save_filename)
        print(save_path)
        model = torch.load(save_path)
        return model
        # network.load_state_dict(torch.load(save_path))

    def update_learning_rate():
        pass


In [0]:
""" pytorch_DIW_scratch """


import torch
import torch.nn as nn
from torch.autograd import Variable
from functools import reduce

class LambdaBase(nn.Sequential):
    def __init__(self, fn, *args):
        super(LambdaBase, self).__init__(*args)
        self.lambda_func = fn

    def forward_prepare(self, input):
        output = []
        for module in self._modules.values():
            output.append(module(input))
        return output if output else input

class Lambda(LambdaBase):
    def forward(self, input):
        return self.lambda_func(self.forward_prepare(input))

class LambdaMap(LambdaBase):
    def forward(self, input):
        return list(map(self.lambda_func,self.forward_prepare(input)))

class LambdaReduce(LambdaBase):
    def forward(self, input):
        return reduce(self.lambda_func,self.forward_prepare(input))


pytorch_DIW_scratch = nn.Sequential( # Sequential,
	nn.Conv2d(3,128,(7, 7),(1, 1),(3, 3)),
	nn.BatchNorm2d(128),
	nn.ReLU(),
	nn.Sequential( # Sequential,
		LambdaMap(lambda x: x, # ConcatTable,
			nn.Sequential( # Sequential,
				nn.MaxPool2d((2, 2),(2, 2)),
				LambdaReduce(lambda x,y,dim=1: torch.cat((x,y),dim), # Concat,
					nn.Sequential( # Sequential,
						nn.Conv2d(128,32,(1, 1)),
						nn.BatchNorm2d(32,1e-05,0.1,False),
						nn.ReLU(),
					),
					nn.Sequential( # Sequential,
						nn.Conv2d(128,32,(1, 1)),
						nn.BatchNorm2d(32,1e-05,0.1,False),
						nn.ReLU(),
						nn.Conv2d(32,32,(3, 3),(1, 1),(1, 1)),
						nn.BatchNorm2d(32,1e-05,0.1,False),
						nn.ReLU(),
					),
					nn.Sequential( # Sequential,
						nn.Conv2d(128,32,(1, 1)),
						nn.BatchNorm2d(32,1e-05,0.1,False),
						nn.ReLU(),
						nn.Conv2d(32,32,(5, 5),(1, 1),(2, 2)),
						nn.BatchNorm2d(32,1e-05,0.1,False),
						nn.ReLU(),
					),
					nn.Sequential( # Sequential,
						nn.Conv2d(128,32,(1, 1)),
						nn.BatchNorm2d(32,1e-05,0.1,False),
						nn.ReLU(),
						nn.Conv2d(32,32,(7, 7),(1, 1),(3, 3)),
						nn.BatchNorm2d(32,1e-05,0.1,False),
						nn.ReLU(),
					),
				),
				LambdaReduce(lambda x,y,dim=1: torch.cat((x,y),dim), # Concat,
					nn.Sequential( # Sequential,
						nn.Conv2d(128,32,(1, 1)),
						nn.BatchNorm2d(32,1e-05,0.1,False),
						nn.ReLU(),
					),
					nn.Sequential( # Sequential,
						nn.Conv2d(128,32,(1, 1)),
						nn.BatchNorm2d(32,1e-05,0.1,False),
						nn.ReLU(),
						nn.Conv2d(32,32,(3, 3),(1, 1),(1, 1)),
						nn.BatchNorm2d(32,1e-05,0.1,False),
						nn.ReLU(),
					),
					nn.Sequential( # Sequential,
						nn.Conv2d(128,32,(1, 1)),
						nn.BatchNorm2d(32,1e-05,0.1,False),
						nn.ReLU(),
						nn.Conv2d(32,32,(5, 5),(1, 1),(2, 2)),
						nn.BatchNorm2d(32,1e-05,0.1,False),
						nn.ReLU(),
					),
					nn.Sequential( # Sequential,
						nn.Conv2d(128,32,(1, 1)),
						nn.BatchNorm2d(32,1e-05,0.1,False),
						nn.ReLU(),
						nn.Conv2d(32,32,(7, 7),(1, 1),(3, 3)),
						nn.BatchNorm2d(32,1e-05,0.1,False),
						nn.ReLU(),
					),
				),
				nn.Sequential( # Sequential,
					LambdaMap(lambda x: x, # ConcatTable,
						nn.Sequential( # Sequential,
							nn.MaxPool2d((2, 2),(2, 2)),
							LambdaReduce(lambda x,y,dim=1: torch.cat((x,y),dim), # Concat,
								nn.Sequential( # Sequential,
									nn.Conv2d(128,32,(1, 1)),
									nn.BatchNorm2d(32,1e-05,0.1,False),
									nn.ReLU(),
								),
								nn.Sequential( # Sequential,
									nn.Conv2d(128,32,(1, 1)),
									nn.BatchNorm2d(32,1e-05,0.1,False),
									nn.ReLU(),
									nn.Conv2d(32,32,(3, 3),(1, 1),(1, 1)),
									nn.BatchNorm2d(32,1e-05,0.1,False),
									nn.ReLU(),
								),
								nn.Sequential( # Sequential,
									nn.Conv2d(128,32,(1, 1)),
									nn.BatchNorm2d(32,1e-05,0.1,False),
									nn.ReLU(),
									nn.Conv2d(32,32,(5, 5),(1, 1),(2, 2)),
									nn.BatchNorm2d(32,1e-05,0.1,False),
									nn.ReLU(),
								),
								nn.Sequential( # Sequential,
									nn.Conv2d(128,32,(1, 1)),
									nn.BatchNorm2d(32,1e-05,0.1,False),
									nn.ReLU(),
									nn.Conv2d(32,32,(7, 7),(1, 1),(3, 3)),
									nn.BatchNorm2d(32,1e-05,0.1,False),
									nn.ReLU(),
								),
							),
							LambdaReduce(lambda x,y,dim=1: torch.cat((x,y),dim), # Concat,
								nn.Sequential( # Sequential,
									nn.Conv2d(128,64,(1, 1)),
									nn.BatchNorm2d(64,1e-05,0.1,False),
									nn.ReLU(),
								),
								nn.Sequential( # Sequential,
									nn.Conv2d(128,32,(1, 1)),
									nn.BatchNorm2d(32,1e-05,0.1,False),
									nn.ReLU(),
									nn.Conv2d(32,64,(3, 3),(1, 1),(1, 1)),
									nn.BatchNorm2d(64,1e-05,0.1,False),
									nn.ReLU(),
								),
								nn.Sequential( # Sequential,
									nn.Conv2d(128,32,(1, 1)),
									nn.BatchNorm2d(32,1e-05,0.1,False),
									nn.ReLU(),
									nn.Conv2d(32,64,(5, 5),(1, 1),(2, 2)),
									nn.BatchNorm2d(64,1e-05,0.1,False),
									nn.ReLU(),
								),
								nn.Sequential( # Sequential,
									nn.Conv2d(128,32,(1, 1)),
									nn.BatchNorm2d(32,1e-05,0.1,False),
									nn.ReLU(),
									nn.Conv2d(32,64,(7, 7),(1, 1),(3, 3)),
									nn.BatchNorm2d(64,1e-05,0.1,False),
									nn.ReLU(),
								),
							),
							nn.Sequential( # Sequential,
								LambdaMap(lambda x: x, # ConcatTable,
									nn.Sequential( # Sequential,
										LambdaReduce(lambda x,y,dim=1: torch.cat((x,y),dim), # Concat,
											nn.Sequential( # Sequential,
												nn.Conv2d(256,64,(1, 1)),
												nn.BatchNorm2d(64,1e-05,0.1,False),
												nn.ReLU(),
											),
											nn.Sequential( # Sequential,
												nn.Conv2d(256,32,(1, 1)),
												nn.BatchNorm2d(32,1e-05,0.1,False),
												nn.ReLU(),
												nn.Conv2d(32,64,(3, 3),(1, 1),(1, 1)),
												nn.BatchNorm2d(64,1e-05,0.1,False),
												nn.ReLU(),
											),
											nn.Sequential( # Sequential,
												nn.Conv2d(256,32,(1, 1)),
												nn.BatchNorm2d(32,1e-05,0.1,False),
												nn.ReLU(),
												nn.Conv2d(32,64,(5, 5),(1, 1),(2, 2)),
												nn.BatchNorm2d(64,1e-05,0.1,False),
												nn.ReLU(),
											),
											nn.Sequential( # Sequential,
												nn.Conv2d(256,32,(1, 1)),
												nn.BatchNorm2d(32,1e-05,0.1,False),
												nn.ReLU(),
												nn.Conv2d(32,64,(7, 7),(1, 1),(3, 3)),
												nn.BatchNorm2d(64,1e-05,0.1,False),
												nn.ReLU(),
											),
										),
										LambdaReduce(lambda x,y,dim=1: torch.cat((x,y),dim), # Concat,
											nn.Sequential( # Sequential,
												nn.Conv2d(256,64,(1, 1)),
												nn.BatchNorm2d(64,1e-05,0.1,False),
												nn.ReLU(),
											),
											nn.Sequential( # Sequential,
												nn.Conv2d(256,64,(1, 1)),
												nn.BatchNorm2d(64,1e-05,0.1,False),
												nn.ReLU(),
												nn.Conv2d(64,64,(3, 3),(1, 1),(1, 1)),
												nn.BatchNorm2d(64,1e-05,0.1,False),
												nn.ReLU(),
											),
											nn.Sequential( # Sequential,
												nn.Conv2d(256,64,(1, 1)),
												nn.BatchNorm2d(64,1e-05,0.1,False),
												nn.ReLU(),
												nn.Conv2d(64,64,(7, 7),(1, 1),(3, 3)),
												nn.BatchNorm2d(64,1e-05,0.1,False),
												nn.ReLU(),
											),
											nn.Sequential( # Sequential,
												nn.Conv2d(256,64,(1, 1)),
												nn.BatchNorm2d(64,1e-05,0.1,False),
												nn.ReLU(),
												nn.Conv2d(64,64,(11, 11),(1, 1),(5, 5)),
												nn.BatchNorm2d(64,1e-05,0.1,False),
												nn.ReLU(),
											),
										),
									),
									nn.Sequential( # Sequential,
										nn.AvgPool2d((2, 2),(2, 2)),
										LambdaReduce(lambda x,y,dim=1: torch.cat((x,y),dim), # Concat,
											nn.Sequential( # Sequential,
												nn.Conv2d(256,64,(1, 1)),
												nn.BatchNorm2d(64,1e-05,0.1,False),
												nn.ReLU(),
											),
											nn.Sequential( # Sequential,
												nn.Conv2d(256,32,(1, 1)),
												nn.BatchNorm2d(32,1e-05,0.1,False),
												nn.ReLU(),
												nn.Conv2d(32,64,(3, 3),(1, 1),(1, 1)),
												nn.BatchNorm2d(64,1e-05,0.1,False),
												nn.ReLU(),
											),
											nn.Sequential( # Sequential,
												nn.Conv2d(256,32,(1, 1)),
												nn.BatchNorm2d(32,1e-05,0.1,False),
												nn.ReLU(),
												nn.Conv2d(32,64,(5, 5),(1, 1),(2, 2)),
												nn.BatchNorm2d(64,1e-05,0.1,False),
												nn.ReLU(),
											),
											nn.Sequential( # Sequential,
												nn.Conv2d(256,32,(1, 1)),
												nn.BatchNorm2d(32,1e-05,0.1,False),
												nn.ReLU(),
												nn.Conv2d(32,64,(7, 7),(1, 1),(3, 3)),
												nn.BatchNorm2d(64,1e-05,0.1,False),
												nn.ReLU(),
											),
										),
										LambdaReduce(lambda x,y,dim=1: torch.cat((x,y),dim), # Concat,
											nn.Sequential( # Sequential,
												nn.Conv2d(256,64,(1, 1)),
												nn.BatchNorm2d(64,1e-05,0.1,False),
												nn.ReLU(),
											),
											nn.Sequential( # Sequential,
												nn.Conv2d(256,32,(1, 1)),
												nn.BatchNorm2d(32,1e-05,0.1,False),
												nn.ReLU(),
												nn.Conv2d(32,64,(3, 3),(1, 1),(1, 1)),
												nn.BatchNorm2d(64,1e-05,0.1,False),
												nn.ReLU(),
											),
											nn.Sequential( # Sequential,
												nn.Conv2d(256,32,(1, 1)),
												nn.BatchNorm2d(32,1e-05,0.1,False),
												nn.ReLU(),
												nn.Conv2d(32,64,(5, 5),(1, 1),(2, 2)),
												nn.BatchNorm2d(64,1e-05,0.1,False),
												nn.ReLU(),
											),
											nn.Sequential( # Sequential,
												nn.Conv2d(256,32,(1, 1)),
												nn.BatchNorm2d(32,1e-05,0.1,False),
												nn.ReLU(),
												nn.Conv2d(32,64,(7, 7),(1, 1),(3, 3)),
												nn.BatchNorm2d(64,1e-05,0.1,False),
												nn.ReLU(),
											),
										),
										nn.Sequential( # Sequential,
											LambdaMap(lambda x: x, # ConcatTable,
												nn.Sequential( # Sequential,
													LambdaReduce(lambda x,y,dim=1: torch.cat((x,y),dim), # Concat,
														nn.Sequential( # Sequential,
															nn.Conv2d(256,64,(1, 1)),
															nn.BatchNorm2d(64,1e-05,0.1,False),
															nn.ReLU(),
														),
														nn.Sequential( # Sequential,
															nn.Conv2d(256,32,(1, 1)),
															nn.BatchNorm2d(32,1e-05,0.1,False),
															nn.ReLU(),
															nn.Conv2d(32,64,(3, 3),(1, 1),(1, 1)),
															nn.BatchNorm2d(64,1e-05,0.1,False),
															nn.ReLU(),
														),
														nn.Sequential( # Sequential,
															nn.Conv2d(256,32,(1, 1)),
															nn.BatchNorm2d(32,1e-05,0.1,False),
															nn.ReLU(),
															nn.Conv2d(32,64,(5, 5),(1, 1),(2, 2)),
															nn.BatchNorm2d(64,1e-05,0.1,False),
															nn.ReLU(),
														),
														nn.Sequential( # Sequential,
															nn.Conv2d(256,32,(1, 1)),
															nn.BatchNorm2d(32,1e-05,0.1,False),
															nn.ReLU(),
															nn.Conv2d(32,64,(7, 7),(1, 1),(3, 3)),
															nn.BatchNorm2d(64,1e-05,0.1,False),
															nn.ReLU(),
														),
													),
													LambdaReduce(lambda x,y,dim=1: torch.cat((x,y),dim), # Concat,
														nn.Sequential( # Sequential,
															nn.Conv2d(256,64,(1, 1)),
															nn.BatchNorm2d(64,1e-05,0.1,False),
															nn.ReLU(),
														),
														nn.Sequential( # Sequential,
															nn.Conv2d(256,32,(1, 1)),
															nn.BatchNorm2d(32,1e-05,0.1,False),
															nn.ReLU(),
															nn.Conv2d(32,64,(3, 3),(1, 1),(1, 1)),
															nn.BatchNorm2d(64,1e-05,0.1,False),
															nn.ReLU(),
														),
														nn.Sequential( # Sequential,
															nn.Conv2d(256,32,(1, 1)),
															nn.BatchNorm2d(32,1e-05,0.1,False),
															nn.ReLU(),
															nn.Conv2d(32,64,(5, 5),(1, 1),(2, 2)),
															nn.BatchNorm2d(64,1e-05,0.1,False),
															nn.ReLU(),
														),
														nn.Sequential( # Sequential,
															nn.Conv2d(256,32,(1, 1)),
															nn.BatchNorm2d(32,1e-05,0.1,False),
															nn.ReLU(),
															nn.Conv2d(32,64,(7, 7),(1, 1),(3, 3)),
															nn.BatchNorm2d(64,1e-05,0.1,False),
															nn.ReLU(),
														),
													),
												),
												nn.Sequential( # Sequential,
													nn.AvgPool2d((2, 2),(2, 2)),
													LambdaReduce(lambda x,y,dim=1: torch.cat((x,y),dim), # Concat,
														nn.Sequential( # Sequential,
															nn.Conv2d(256,64,(1, 1)),
															nn.BatchNorm2d(64,1e-05,0.1,False),
															nn.ReLU(),
														),
														nn.Sequential( # Sequential,
															nn.Conv2d(256,32,(1, 1)),
															nn.BatchNorm2d(32,1e-05,0.1,False),
															nn.ReLU(),
															nn.Conv2d(32,64,(3, 3),(1, 1),(1, 1)),
															nn.BatchNorm2d(64,1e-05,0.1,False),
															nn.ReLU(),
														),
														nn.Sequential( # Sequential,
															nn.Conv2d(256,32,(1, 1)),
															nn.BatchNorm2d(32,1e-05,0.1,False),
															nn.ReLU(),
															nn.Conv2d(32,64,(5, 5),(1, 1),(2, 2)),
															nn.BatchNorm2d(64,1e-05,0.1,False),
															nn.ReLU(),
														),
														nn.Sequential( # Sequential,
															nn.Conv2d(256,32,(1, 1)),
															nn.BatchNorm2d(32,1e-05,0.1,False),
															nn.ReLU(),
															nn.Conv2d(32,64,(7, 7),(1, 1),(3, 3)),
															nn.BatchNorm2d(64,1e-05,0.1,False),
															nn.ReLU(),
														),
													),
													LambdaReduce(lambda x,y,dim=1: torch.cat((x,y),dim), # Concat,
														nn.Sequential( # Sequential,
															nn.Conv2d(256,64,(1, 1)),
															nn.BatchNorm2d(64,1e-05,0.1,False),
															nn.ReLU(),
														),
														nn.Sequential( # Sequential,
															nn.Conv2d(256,32,(1, 1)),
															nn.BatchNorm2d(32,1e-05,0.1,False),
															nn.ReLU(),
															nn.Conv2d(32,64,(3, 3),(1, 1),(1, 1)),
															nn.BatchNorm2d(64,1e-05,0.1,False),
															nn.ReLU(),
														),
														nn.Sequential( # Sequential,
															nn.Conv2d(256,32,(1, 1)),
															nn.BatchNorm2d(32,1e-05,0.1,False),
															nn.ReLU(),
															nn.Conv2d(32,64,(5, 5),(1, 1),(2, 2)),
															nn.BatchNorm2d(64,1e-05,0.1,False),
															nn.ReLU(),
														),
														nn.Sequential( # Sequential,
															nn.Conv2d(256,32,(1, 1)),
															nn.BatchNorm2d(32,1e-05,0.1,False),
															nn.ReLU(),
															nn.Conv2d(32,64,(7, 7),(1, 1),(3, 3)),
															nn.BatchNorm2d(64,1e-05,0.1,False),
															nn.ReLU(),
														),
													),
													LambdaReduce(lambda x,y,dim=1: torch.cat((x,y),dim), # Concat,
														nn.Sequential( # Sequential,
															nn.Conv2d(256,64,(1, 1)),
															nn.BatchNorm2d(64,1e-05,0.1,False),
															nn.ReLU(),
														),
														nn.Sequential( # Sequential,
															nn.Conv2d(256,32,(1, 1)),
															nn.BatchNorm2d(32,1e-05,0.1,False),
															nn.ReLU(),
															nn.Conv2d(32,64,(3, 3),(1, 1),(1, 1)),
															nn.BatchNorm2d(64,1e-05,0.1,False),
															nn.ReLU(),
														),
														nn.Sequential( # Sequential,
															nn.Conv2d(256,32,(1, 1)),
															nn.BatchNorm2d(32,1e-05,0.1,False),
															nn.ReLU(),
															nn.Conv2d(32,64,(5, 5),(1, 1),(2, 2)),
															nn.BatchNorm2d(64,1e-05,0.1,False),
															nn.ReLU(),
														),
														nn.Sequential( # Sequential,
															nn.Conv2d(256,32,(1, 1)),
															nn.BatchNorm2d(32,1e-05,0.1,False),
															nn.ReLU(),
															nn.Conv2d(32,64,(7, 7),(1, 1),(3, 3)),
															nn.BatchNorm2d(64,1e-05,0.1,False),
															nn.ReLU(),
														),
													),
													nn.UpsamplingNearest2d(scale_factor=2),
												),
											),
											LambdaReduce(lambda x,y: x+y), # CAddTable,
										),
										LambdaReduce(lambda x,y,dim=1: torch.cat((x,y),dim), # Concat,
											nn.Sequential( # Sequential,
												nn.Conv2d(256,64,(1, 1)),
												nn.BatchNorm2d(64,1e-05,0.1,False),
												nn.ReLU(),
											),
											nn.Sequential( # Sequential,
												nn.Conv2d(256,32,(1, 1)),
												nn.BatchNorm2d(32,1e-05,0.1,False),
												nn.ReLU(),
												nn.Conv2d(32,64,(3, 3),(1, 1),(1, 1)),
												nn.BatchNorm2d(64,1e-05,0.1,False),
												nn.ReLU(),
											),
											nn.Sequential( # Sequential,
												nn.Conv2d(256,32,(1, 1)),
												nn.BatchNorm2d(32,1e-05,0.1,False),
												nn.ReLU(),
												nn.Conv2d(32,64,(5, 5),(1, 1),(2, 2)),
												nn.BatchNorm2d(64,1e-05,0.1,False),
												nn.ReLU(),
											),
											nn.Sequential( # Sequential,
												nn.Conv2d(256,32,(1, 1)),
												nn.BatchNorm2d(32,1e-05,0.1,False),
												nn.ReLU(),
												nn.Conv2d(32,64,(7, 7),(1, 1),(3, 3)),
												nn.BatchNorm2d(64,1e-05,0.1,False),
												nn.ReLU(),
											),
										),
										LambdaReduce(lambda x,y,dim=1: torch.cat((x,y),dim), # Concat,
											nn.Sequential( # Sequential,
												nn.Conv2d(256,64,(1, 1)),
												nn.BatchNorm2d(64,1e-05,0.1,False),
												nn.ReLU(),
											),
											nn.Sequential( # Sequential,
												nn.Conv2d(256,64,(1, 1)),
												nn.BatchNorm2d(64,1e-05,0.1,False),
												nn.ReLU(),
												nn.Conv2d(64,64,(3, 3),(1, 1),(1, 1)),
												nn.BatchNorm2d(64,1e-05,0.1,False),
												nn.ReLU(),
											),
											nn.Sequential( # Sequential,
												nn.Conv2d(256,64,(1, 1)),
												nn.BatchNorm2d(64,1e-05,0.1,False),
												nn.ReLU(),
												nn.Conv2d(64,64,(7, 7),(1, 1),(3, 3)),
												nn.BatchNorm2d(64,1e-05,0.1,False),
												nn.ReLU(),
											),
											nn.Sequential( # Sequential,
												nn.Conv2d(256,64,(1, 1)),
												nn.BatchNorm2d(64,1e-05,0.1,False),
												nn.ReLU(),
												nn.Conv2d(64,64,(11, 11),(1, 1),(5, 5)),
												nn.BatchNorm2d(64,1e-05,0.1,False),
												nn.ReLU(),
											),
										),
										nn.UpsamplingNearest2d(scale_factor=2),
									),
								),
								LambdaReduce(lambda x,y: x+y), # CAddTable,
							),
							LambdaReduce(lambda x,y,dim=1: torch.cat((x,y),dim), # Concat,
								nn.Sequential( # Sequential,
									nn.Conv2d(256,64,(1, 1)),
									nn.BatchNorm2d(64,1e-05,0.1,False),
									nn.ReLU(),
								),
								nn.Sequential( # Sequential,
									nn.Conv2d(256,32,(1, 1)),
									nn.BatchNorm2d(32,1e-05,0.1,False),
									nn.ReLU(),
									nn.Conv2d(32,64,(3, 3),(1, 1),(1, 1)),
									nn.BatchNorm2d(64,1e-05,0.1,False),
									nn.ReLU(),
								),
								nn.Sequential( # Sequential,
									nn.Conv2d(256,32,(1, 1)),
									nn.BatchNorm2d(32,1e-05,0.1,False),
									nn.ReLU(),
									nn.Conv2d(32,64,(5, 5),(1, 1),(2, 2)),
									nn.BatchNorm2d(64,1e-05,0.1,False),
									nn.ReLU(),
								),
								nn.Sequential( # Sequential,
									nn.Conv2d(256,32,(1, 1)),
									nn.BatchNorm2d(32,1e-05,0.1,False),
									nn.ReLU(),
									nn.Conv2d(32,64,(7, 7),(1, 1),(3, 3)),
									nn.BatchNorm2d(64,1e-05,0.1,False),
									nn.ReLU(),
								),
							),
							LambdaReduce(lambda x,y,dim=1: torch.cat((x,y),dim), # Concat,
								nn.Sequential( # Sequential,
									nn.Conv2d(256,32,(1, 1)),
									nn.BatchNorm2d(32,1e-05,0.1,False),
									nn.ReLU(),
								),
								nn.Sequential( # Sequential,
									nn.Conv2d(256,32,(1, 1)),
									nn.BatchNorm2d(32,1e-05,0.1,False),
									nn.ReLU(),
									nn.Conv2d(32,32,(3, 3),(1, 1),(1, 1)),
									nn.BatchNorm2d(32,1e-05,0.1,False),
									nn.ReLU(),
								),
								nn.Sequential( # Sequential,
									nn.Conv2d(256,32,(1, 1)),
									nn.BatchNorm2d(32,1e-05,0.1,False),
									nn.ReLU(),
									nn.Conv2d(32,32,(5, 5),(1, 1),(2, 2)),
									nn.BatchNorm2d(32,1e-05,0.1,False),
									nn.ReLU(),
								),
								nn.Sequential( # Sequential,
									nn.Conv2d(256,32,(1, 1)),
									nn.BatchNorm2d(32,1e-05,0.1,False),
									nn.ReLU(),
									nn.Conv2d(32,32,(7, 7),(1, 1),(3, 3)),
									nn.BatchNorm2d(32,1e-05,0.1,False),
									nn.ReLU(),
								),
							),
							nn.UpsamplingNearest2d(scale_factor=2),
						),
						nn.Sequential( # Sequential,
							LambdaReduce(lambda x,y,dim=1: torch.cat((x,y),dim), # Concat,
								nn.Sequential( # Sequential,
									nn.Conv2d(128,32,(1, 1)),
									nn.BatchNorm2d(32,1e-05,0.1,False),
									nn.ReLU(),
								),
								nn.Sequential( # Sequential,
									nn.Conv2d(128,32,(1, 1)),
									nn.BatchNorm2d(32,1e-05,0.1,False),
									nn.ReLU(),
									nn.Conv2d(32,32,(3, 3),(1, 1),(1, 1)),
									nn.BatchNorm2d(32,1e-05,0.1,False),
									nn.ReLU(),
								),
								nn.Sequential( # Sequential,
									nn.Conv2d(128,32,(1, 1)),
									nn.BatchNorm2d(32,1e-05,0.1,False),
									nn.ReLU(),
									nn.Conv2d(32,32,(5, 5),(1, 1),(2, 2)),
									nn.BatchNorm2d(32,1e-05,0.1,False),
									nn.ReLU(),
								),
								nn.Sequential( # Sequential,
									nn.Conv2d(128,32,(1, 1)),
									nn.BatchNorm2d(32,1e-05,0.1,False),
									nn.ReLU(),
									nn.Conv2d(32,32,(7, 7),(1, 1),(3, 3)),
									nn.BatchNorm2d(32,1e-05,0.1,False),
									nn.ReLU(),
								),
							),
							LambdaReduce(lambda x,y,dim=1: torch.cat((x,y),dim), # Concat,
								nn.Sequential( # Sequential,
									nn.Conv2d(128,32,(1, 1)),
									nn.BatchNorm2d(32,1e-05,0.1,False),
									nn.ReLU(),
								),
								nn.Sequential( # Sequential,
									nn.Conv2d(128,64,(1, 1)),
									nn.BatchNorm2d(64,1e-05,0.1,False),
									nn.ReLU(),
									nn.Conv2d(64,32,(3, 3),(1, 1),(1, 1)),
									nn.BatchNorm2d(32,1e-05,0.1,False),
									nn.ReLU(),
								),
								nn.Sequential( # Sequential,
									nn.Conv2d(128,64,(1, 1)),
									nn.BatchNorm2d(64,1e-05,0.1,False),
									nn.ReLU(),
									nn.Conv2d(64,32,(7, 7),(1, 1),(3, 3)),
									nn.BatchNorm2d(32,1e-05,0.1,False),
									nn.ReLU(),
								),
								nn.Sequential( # Sequential,
									nn.Conv2d(128,64,(1, 1)),
									nn.BatchNorm2d(64,1e-05,0.1,False),
									nn.ReLU(),
									nn.Conv2d(64,32,(11, 11),(1, 1),(5, 5)),
									nn.BatchNorm2d(32,1e-05,0.1,False),
									nn.ReLU(),
								),
							),
						),
					),
					LambdaReduce(lambda x,y: x+y), # CAddTable,
				),
				LambdaReduce(lambda x,y,dim=1: torch.cat((x,y),dim), # Concat,
					nn.Sequential( # Sequential,
						nn.Conv2d(128,32,(1, 1)),
						nn.BatchNorm2d(32,1e-05,0.1,False),
						nn.ReLU(),
					),
					nn.Sequential( # Sequential,
						nn.Conv2d(128,64,(1, 1)),
						nn.BatchNorm2d(64,1e-05,0.1,False),
						nn.ReLU(),
						nn.Conv2d(64,32,(3, 3),(1, 1),(1, 1)),
						nn.BatchNorm2d(32,1e-05,0.1,False),
						nn.ReLU(),
					),
					nn.Sequential( # Sequential,
						nn.Conv2d(128,64,(1, 1)),
						nn.BatchNorm2d(64,1e-05,0.1,False),
						nn.ReLU(),
						nn.Conv2d(64,32,(5, 5),(1, 1),(2, 2)),
						nn.BatchNorm2d(32,1e-05,0.1,False),
						nn.ReLU(),
					),
					nn.Sequential( # Sequential,
						nn.Conv2d(128,64,(1, 1)),
						nn.BatchNorm2d(64,1e-05,0.1,False),
						nn.ReLU(),
						nn.Conv2d(64,32,(7, 7),(1, 1),(3, 3)),
						nn.BatchNorm2d(32,1e-05,0.1,False),
						nn.ReLU(),
					),
				),
				LambdaReduce(lambda x,y,dim=1: torch.cat((x,y),dim), # Concat,
					nn.Sequential( # Sequential,
						nn.Conv2d(128,16,(1, 1)),
						nn.BatchNorm2d(16,1e-05,0.1,False),
						nn.ReLU(),
					),
					nn.Sequential( # Sequential,
						nn.Conv2d(128,32,(1, 1)),
						nn.BatchNorm2d(32,1e-05,0.1,False),
						nn.ReLU(),
						nn.Conv2d(32,16,(3, 3),(1, 1),(1, 1)),
						nn.BatchNorm2d(16,1e-05,0.1,False),
						nn.ReLU(),
					),
					nn.Sequential( # Sequential,
						nn.Conv2d(128,32,(1, 1)),
						nn.BatchNorm2d(32,1e-05,0.1,False),
						nn.ReLU(),
						nn.Conv2d(32,16,(7, 7),(1, 1),(3, 3)),
						nn.BatchNorm2d(16,1e-05,0.1,False),
						nn.ReLU(),
					),
					nn.Sequential( # Sequential,
						nn.Conv2d(128,32,(1, 1)),
						nn.BatchNorm2d(32,1e-05,0.1,False),
						nn.ReLU(),
						nn.Conv2d(32,16,(11, 11),(1, 1),(5, 5)),
						nn.BatchNorm2d(16,1e-05,0.1,False),
						nn.ReLU(),
					),
				),
				nn.UpsamplingNearest2d(scale_factor=2),
			),
			nn.Sequential( # Sequential,
				LambdaReduce(lambda x,y,dim=1: torch.cat((x,y),dim), # Concat,
					nn.Sequential( # Sequential,
						nn.Conv2d(128,16,(1, 1)),
						nn.BatchNorm2d(16,1e-05,0.1,False),
						nn.ReLU(),
					),
					nn.Sequential( # Sequential,
						nn.Conv2d(128,64,(1, 1)),
						nn.BatchNorm2d(64,1e-05,0.1,False),
						nn.ReLU(),
						nn.Conv2d(64,16,(3, 3),(1, 1),(1, 1)),
						nn.BatchNorm2d(16,1e-05,0.1,False),
						nn.ReLU(),
					),
					nn.Sequential( # Sequential,
						nn.Conv2d(128,64,(1, 1)),
						nn.BatchNorm2d(64,1e-05,0.1,False),
						nn.ReLU(),
						nn.Conv2d(64,16,(7, 7),(1, 1),(3, 3)),
						nn.BatchNorm2d(16,1e-05,0.1,False),
						nn.ReLU(),
					),
					nn.Sequential( # Sequential,
						nn.Conv2d(128,64,(1, 1)),
						nn.BatchNorm2d(64,1e-05,0.1,False),
						nn.ReLU(),
						nn.Conv2d(64,16,(11, 11),(1, 1),(5, 5)),
						nn.BatchNorm2d(16,1e-05,0.1,False),
						nn.ReLU(),
					),
				),
			),
		),
		LambdaReduce(lambda x,y: x+y), # CAddTable,
	),
	nn.Conv2d(64,1,(3, 3),(1, 1),(1, 1)),
)

In [0]:
""" .HG_model """

import numpy as np
import torch
import os
from torch.autograd import Variable
# from .base_model import BaseModel
import sys
# import pytorch_DIW_scratch

class HGModel(BaseModel):
    def name(self):
        return 'HGModel'

    def __init__(self, opt):
        BaseModel.initialize(self, opt)

        print("===========================================LOADING Hourglass NETWORK====================================================")
#         model = pytorch_DIW_scratch.pytorch_DIW_scratch
        model = pytorch_DIW_scratch
        model= torch.nn.parallel.DataParallel(model, device_ids = [0]) # model= torch.nn.parallel.DataParallel(model, device_ids = [0,1])
        model_parameters = self.load_network(model, 'G', 'best_generalization') # model_parameters = self.load_network(model, 'G', 'best_vanila')
        model.load_state_dict(model_parameters)
        self.netG = model.cuda()


    def batch_classify(self, z_A_arr, z_B_arr, ground_truth ):
        threashold = 1.1
        depth_ratio = torch.div(z_A_arr, z_B_arr)

        depth_ratio = depth_ratio.cpu()

        estimated_labels = torch.zeros(depth_ratio.size(0))

        estimated_labels[depth_ratio > (threashold)] = 1
        estimated_labels[depth_ratio < (1/threashold)] = -1

        diff = estimated_labels - ground_truth
        diff[diff != 0] = 1

        # error 
        inequal_error_count = diff[ground_truth != 0]
        inequal_error_count =  torch.sum(inequal_error_count)

        error_count = torch.sum(diff) #diff[diff !=0]
        # error_count = error_count.size(0)

        equal_error_count = error_count - inequal_error_count


        # total 
        total_count = depth_ratio.size(0)
        ground_truth[ground_truth !=0 ] = 1

        inequal_count_total = torch.sum(ground_truth)
        equal_total_count = total_count - inequal_count_total


        error_list = [equal_error_count, inequal_error_count, error_count]
        count_list = [equal_total_count, inequal_count_total, total_count]

        return error_list, count_list 


    def computeSDR(self, prediction_d, targets):
        #  for each image 
        total_error = [0,0,0]
        total_samples = [0,0,0]

        for i in range(0, prediction_d.size(0)):

            if targets['has_SfM_feature'][i] == False:
                continue
            
            x_A_arr = targets["sdr_xA"][i].squeeze(0)
            x_B_arr = targets["sdr_xB"][i].squeeze(0)
            y_A_arr = targets["sdr_yA"][i].squeeze(0)
            y_B_arr = targets["sdr_yB"][i].squeeze(0)

            predict_depth = torch.exp(prediction_d[i,:,:])
            predict_depth = predict_depth.squeeze(0)
            ground_truth = targets["sdr_gt"][i]

            # print(x_A_arr.size())
            # print(y_A_arr.size())

            z_A_arr = torch.gather( torch.index_select(predict_depth, 1 ,x_A_arr.cuda()) , 0, y_A_arr.view(1, -1).cuda())# predict_depth:index(2, x_A_arr):gather(1, y_A_arr:view(1, -1))
            z_B_arr = torch.gather( torch.index_select(predict_depth, 1 ,x_B_arr.cuda()) , 0, y_B_arr.view(1, -1).cuda())

            z_A_arr = z_A_arr.squeeze(0)
            z_B_arr = z_B_arr.squeeze(0)

            error_list, count_list  = self.batch_classify(z_A_arr, z_B_arr,ground_truth)

            for j in range(0,3):
                total_error[j] += error_list[j]
                total_samples[j] += count_list[j]

        return  total_error, total_samples


    def evaluate_SDR(self, input_, targets):
        input_images = Variable(input_.cuda() )
        prediction_d = self.netG.forward(input_images) 

        total_error, total_samples = self.computeSDR(prediction_d.data, targets)

        return total_error, total_samples

    def rmse_Loss(self, log_prediction_d, mask, log_gt):
        N = torch.sum(mask)
        log_d_diff = log_prediction_d - log_gt
        log_d_diff = torch.mul(log_d_diff, mask)
        s1 = torch.sum( torch.pow(log_d_diff,2) )/N 

        s2 = torch.pow(torch.sum(log_d_diff),2)/(N*N)  
        data_loss = s1 - s2

        data_loss = torch.sqrt(data_loss)

        return data_loss

    def evaluate_RMSE(self, input_images, prediction_d, targets):
        count = 0            
        total_loss = Variable(torch.cuda.FloatTensor(1))
        total_loss[0] = 0
        mask_0 = Variable(targets['mask_0'].cuda(), requires_grad = False)
        d_gt_0 = torch.log(Variable(targets['gt_0'].cuda(), requires_grad = False))

        for i in range(0, mask_0.size(0)):
 
            total_loss +=  self.rmse_Loss(prediction_d[i,:,:], mask_0[i,:,:], d_gt_0[i,:,:])
            count += 1

        return total_loss.data[0], count


    def evaluate_sc_inv(self, input_, targets):
        input_images = Variable(input_.cuda() )
        prediction_d = self.netG.forward(input_images) 
        rmse_loss , count= self.evaluate_RMSE(input_images, prediction_d, targets)

        return rmse_loss, count


    def switch_to_train(self):
        self.netG.train()

    def switch_to_eval(self):
        self.netG.eval()



In [0]:
""" models.models """


def create_model(opt):
    model = None
#     from .HG_model import HGModel
    model = HGModel(opt)
    print("model [%s] was created" % (model.name()))
    return model


In [0]:
""" data.data_loader """

def CreateDataLoader(_root, _list_dir, _input_height, _input_width, is_flip = True, shuffle =  True):
    data_loader = None
    from data.aligned_data_loader import AlignedDataLoader
    data_loader = AlignedDataLoader(_root, _list_dir, _input_height, _input_width, is_flip, shuffle)
    return data_loader

In [12]:
import torch
import sys
from torch.autograd import Variable
import numpy as np
# from options.train_options import TrainOptions
# opt = TrainOptions().parse()  # set CUDA_VISIBLE_DEVICES before import torch
# from options.test_options import TestOptions
# opt = TestOptions().parse()  # set CUDA_VISIBLE_DEVICES before import torch
import easydict
opt = easydict.EasyDict({
    'gpu_ids': '0',
    'isTrain': False,
    'checkpoints_dir': './checkpoints/',
    'name': 'test_local'
})

# from data.data_loader import CreateDataLoader
# from models.models import create_model
from skimage import io
from skimage.transform import resize


img_path = 'demo.jpg'

model = create_model(opt)

input_height = 384
input_width  = 512


def test_simple(model):
    total_loss =0 
    toal_count = 0
    print("============================= TEST ============================")
    model.switch_to_eval()

    img = np.float32(io.imread(img_path))/255.0
    img = resize(img, (input_height, input_width), order = 1)
    input_img =  torch.from_numpy( np.transpose(img, (2,0,1)) ).contiguous().float()
    input_img = input_img.unsqueeze(0)

    input_images = Variable(input_img.cuda() )
    pred_log_depth = model.netG.forward(input_images) 
    pred_log_depth = torch.squeeze(pred_log_depth)

    pred_depth = torch.exp(pred_log_depth)

    # visualize prediction using inverse depth, so that we don't need sky segmentation (if you want to use RGB map for visualization, \
    # you have to run semantic segmentation to mask the sky first since the depth of sky is random from CNN)
    pred_inv_depth = 1/pred_depth
    pred_inv_depth = pred_inv_depth.data.cpu().numpy()
    # you might also use percentile for better visualization
    pred_inv_depth = pred_inv_depth/np.amax(pred_inv_depth)

    io.imsave('demo.png', pred_inv_depth)
    # print(pred_inv_depth.shape)
#     sys.exit()



test_simple(model)
print("We are done")


./checkpoints/test_local/best_generalization_net_G.pth
model [HGModel] was created


  warn("The default mode, 'constant', will be changed to 'reflect' in "
  warn("Anti-aliasing will be enabled by default in skimage 0.15 to "


We are done


  .format(dtypeobj_in, dtypeobj_out))
