In [1]:
import torch
import torch.nn as nn
from torchvision import models
import torch.nn.functional as F
from models import MobileNetSkipConcat
import numpy as np

In [2]:
from torch.nn import functional as F
import tqdm
import torchvision
import torchvision.utils as vutils

In [3]:
device = 'cuda:3' if torch.cuda.is_available() else 'cpu'

### Test Sample

In [4]:
input =  torch.from_numpy(np.random.rand(1,3,255,255).astype(np.float32)).to(device)


In [5]:
from torch.utils.data import Dataset ,DataLoader

### Pandas load train dataframe


In [6]:
import pandas as pd
import matplotlib.pylab as plt

In [7]:
df_blackberry  = pd.read_pickle('blackberry.pkl')
df_iphone  = pd.read_pickle('iphone.pkl')
df_sony  = pd.read_pickle('sony.pkl')


df  = pd.concat([df_blackberry,df_iphone,df_sony],sort=True,ignore_index=True)
df_full = pd.read_pickle('image_full.pkl')

In [8]:
df.head()

Unnamed: 0,high,low
0,/home/jimmyyoung/dataset/dped/blackberry/train...,/home/jimmyyoung/dataset/dped/blackberry/train...
1,/home/jimmyyoung/dataset/dped/blackberry/train...,/home/jimmyyoung/dataset/dped/blackberry/train...
2,/home/jimmyyoung/dataset/dped/blackberry/train...,/home/jimmyyoung/dataset/dped/blackberry/train...
3,/home/jimmyyoung/dataset/dped/blackberry/train...,/home/jimmyyoung/dataset/dped/blackberry/train...
4,/home/jimmyyoung/dataset/dped/blackberry/train...,/home/jimmyyoung/dataset/dped/blackberry/train...


In [9]:
from PIL import Image
class ImageEnhanceDataset(Dataset):
    def __init__(self,df,transform=None):
        self.df = df
        self.transform = transform
    def __len__(self):
        return len(self.df)
    def __getitem__(self,index):
        high_image_dir = self.df.iloc[index]['high']
        low_image_dir = self.df.iloc[index]['low']
        
        high_image = Image.open(high_image_dir).convert('RGB')
        low_image = Image.open(low_image_dir).convert('RGB')
        
        high_image , low_image = ToTensor_center(high_image),ToTensor_Dark(low_image)
        

        return {'high':high_image,'low':low_image}
    
class ImageFulleDataset(Dataset):
    def __init__(self,df,transform=None):
        self.df = df
        self.transform = transform
    def __len__(self):
        return len(self.df)
    def __getitem__(self,index):
        idx = np.random.randint(0,len(self.df))
        image_dir = self.df.iloc[idx]['dir']
        
        image = Image.open(image_dir).convert('RGB')

        
        

        return {'label':ToTensor_Full(image),'input':ToTensor_Full_input(image)}

### Dataset

In [10]:
dataset = ImageEnhanceDataset(df)
dataset_full = ImageFulleDataset(df_full)

### Transform  and Loss

In [11]:
ToTensor_Dark = torchvision.transforms.Compose([torchvision.transforms.CenterCrop(96),
                                                torchvision.transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.01),
                                                torchvision.transforms.Lambda(lambda crops: torchvision.transforms.functional.adjust_brightness(img=crops,brightness_factor=np.random.uniform(0.5,1))),
                                           torchvision.transforms.Lambda(lambda crops: torchvision.transforms.functional.adjust_gamma(img=crops,gain=1,gamma=np.random.uniform(1,1.5))),
                            torchvision.transforms.Lambda(lambda crops: torchvision.transforms.functional.adjust_contrast(img=crops,contrast_factor=np.random.uniform(0.2,0.5))),
                                           torchvision.transforms.ToTensor()])

ToTensor_center = torchvision.transforms.Compose([torchvision.transforms.CenterCrop(96),
                                           torchvision.transforms.ToTensor()])

ToTensor_Test = torchvision.transforms.Compose([torchvision.transforms.CenterCrop((800,1184)),
                                           torchvision.transforms.ToTensor()])

ToTensor = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])

ToTensor_Full_input = torchvision.transforms.Compose([torchvision.transforms.CenterCrop((512,1024)),
                                                torchvision.transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.01),
                                                torchvision.transforms.Lambda(lambda crops: torchvision.transforms.functional.adjust_gamma(img=crops,gain=1,gamma=np.random.uniform(0.7,2.0))),
                                            torchvision.transforms.Resize((128,256)),
                                           torchvision.transforms.ToTensor()])

ToTensor_Full = torchvision.transforms.Compose([torchvision.transforms.CenterCrop((512,1024)),
                                            torchvision.transforms.Resize((128,256)),
                                           torchvision.transforms.ToTensor()])

In [12]:
model = MobileNetSkipConcat(3,pretrained=False)
# model.load_state_dict(torch.load('unet_fast.pt'))
model.to(device)

MobileNetSkipConcat(
  (conv0): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU6(inplace)
  )
  (conv1): Sequential(
    (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU6(inplace)
    (3): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU6(inplace)
  )
  (conv2): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=64, bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU6(inplace)
    (3): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (4): BatchNorm2d(128,

In [22]:
vgg =torchvision.models.vgg11(pretrained=True)
vgg = vgg.features.to(device)

### Optimizaer Loss

In [13]:

L1loss = torch.nn.L1Loss(reduction='mean')
MSEloss = torch.nn.MSELoss(reduction='mean')

optimizer = torch.optim.Adam(model.parameters())

class ST_loss(nn.Module):
    def __init__(self, Lambda=0.3):
        super(ST_loss, self).__init__()
        self.conv_x = nn.Conv2d(3, 1, kernel_size=3, stride=1, padding=1)
        self.conv_y = nn.Conv2d(3, 1, kernel_size=3, stride=1, padding=1)
        conv_x = np.array([[[-1.0,0,1], [-2,0,2], [-1,0,1]],[[-1.0,0,1], [-2,0,2], [-1,0,1]],[[-1.0,0,1], [-2,0,2], [-1,0,1]]])
        conv_y = np.array([[[-1.0,-2,-1], [0,0,0], [1,2,1]],[[-1.0,-2,-1], [0,0,0], [1,2,1]],[[-1.0,-2,-1], [0,0,0], [1,2,1]]])
        self.conv_x.weight = nn.Parameter(torch.from_numpy(conv_x).float().unsqueeze(0))
        self.conv_y.weight = nn.Parameter(torch.from_numpy(conv_y).float().unsqueeze(0))
        self.Lambda = Lambda
    def forward(self, input):
        grd_x = self.conv_x(input)
        grd_y = self.conv_y(input)
        out = torch.sqrt(grd_x**2 + grd_y**2) / 2
        h, w = out.shape[2], out.shape[3]
        out[:, :, [0, h-1], :] = 0
        out[:, :, :, [0, w-1]] = 0
        return out
st_loss = ST_loss().to(device)

In [28]:
from tensorboardX import SummaryWriter
writer = SummaryWriter()


### Dataloader 

In [16]:
dataloader = DataLoader(dataset, batch_size=16,
                        shuffle=True, num_workers=4)


### Train Loop

In [30]:
num_image = len(df)
num_batch = 5
itercount = 0
for b in range(num_batch):
    for i,data in enumerate(dataloader):
        
        sample = data['low'].to(device)
        label  = data['high'].to(device)
        full = dataset_full[i]
        full_label = full['label'].unsqueeze(0).to(device)
        full_input = full['input'].unsqueeze(0).to(device)
        
        out = model(sample)
        out_full = model(full_input)
        

        pixle_loss_patch = MSEloss(out,label)
        pixle_loss_full = MSEloss(out_full,full_label)
        loss_st = MSEloss(st_loss(out),st_loss(label))
        loss_vgg = MSEloss(vgg(out),vgg(label))
        loss =  0.1*pixle_loss_patch+pixle_loss_full+0.1*loss_st+loss_vgg
        
        
        
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        itercount = itercount+1
        
        
        if i%1000==0:
            x = torch.cat((sample,out,label),dim=0)
            x = vutils.make_grid(x, normalize=True, scale_each=True)
            writer.add_image('Sample : Out : Ground', x, itercount)

            
            y = torch.cat((full_input,out_full,full_label),dim=0)
            y = vutils.make_grid(y, normalize=True, scale_each=True)
            writer.add_image('Full', y, itercount)
            
            writer.add_scalar('data/pixel loss full', pixle_loss_full.item(),itercount)
            writer.add_scalar('data/pixel loss_patch', pixle_loss_patch.item(),itercount)
            writer.add_scalar('data/loss_st', loss_st.item(),itercount)
            writer.add_scalar('data/loss_vgg', loss_vgg.item(),itercount)
            
            im = Image.open('4.jpg').convert('RGB')
            im = ToTensor_Test(im).unsqueeze(0).to(device)

            sample = model(im)
            x = vutils.make_grid(sample, normalize=True, scale_each=True)
            writer.add_image('Test', x, itercount)

In [31]:
torch.save(model.state_dict(),'fast_depth.pt')