In [8]:
import torch
from torchvision import models
import torch.nn.functional as F
from net.unet_model import UNet 
import numpy as np

In [42]:
from torch.nn import functional as F
import tqdm
import torchvision
import torchvision.utils as vutils

In [9]:
device = 'cuda:2' if torch.cuda.is_available() else 'cpu'

### Test Sample

In [10]:
input =  torch.from_numpy(np.random.rand(1,3,255,255).astype(np.float32)).to(device)


In [11]:
from torch.utils.data import Dataset ,DataLoader

### Pandas load train dataframe


In [12]:
import pandas as pd
import matplotlib.pylab as plt

In [38]:
df_blackberry  = pd.read_pickle('blackberry.pkl')
df_iphone  = pd.read_pickle('iphone.pkl')
df_sony  = pd.read_pickle('sony.pkl')


df  = pd.concat([df_blackberry,df_iphone,df_sony],sort=True,ignore_index=True)
df_full = pd.read_pickle('image_full.pkl')

In [41]:
df.head()

Unnamed: 0,high,low
0,/home/jimmyyoung/dataset/dped/blackberry/train...,/home/jimmyyoung/dataset/dped/blackberry/train...
1,/home/jimmyyoung/dataset/dped/blackberry/train...,/home/jimmyyoung/dataset/dped/blackberry/train...
2,/home/jimmyyoung/dataset/dped/blackberry/train...,/home/jimmyyoung/dataset/dped/blackberry/train...
3,/home/jimmyyoung/dataset/dped/blackberry/train...,/home/jimmyyoung/dataset/dped/blackberry/train...
4,/home/jimmyyoung/dataset/dped/blackberry/train...,/home/jimmyyoung/dataset/dped/blackberry/train...


In [None]:
from PIL import Image
class ImageEnhanceDataset(Dataset):
    def __init__(self,df,transform=None):
        self.df = df
        self.transform = transform
    def __len__(self):
        return len(self.df)
    def __getitem__(self,index):
        high_image_dir = self.df.iloc[index]['high']
        low_image_dir = self.df.iloc[index]['low']
        
        high_image = Image.open(high_image_dir).convert('RGB')
        low_image = Image.open(low_image_dir).convert('RGB')
        
        high_image , low_image = ToTensor_center(high_image),ToTensor_Dark(low_image)
        

        return {'high':high_image,'low':low_image}
    
class ImageFulleDataset(Dataset):
    def __init__(self,df,transform=None):
        self.df = df
        self.transform = transform
    def __len__(self):
        return len(self.df)
    def __getitem__(self,index):
        idx = np.random.randint(0,len(self.df))
        image_dir = self.df.iloc[idx]['dir']
        
        image = Image.open(image_dir).convert('RGB')

        
        

        return {'label':ToTensor_Full(image),'input':ToTensor_Full_input(image)}

### Dataset

In [None]:
dataset = ImageEnhanceDataset(df)
dataset_full = ImageFulleDataset(df_full)

### Transform  and Loss

In [None]:
ToTensor_Dark = torchvision.transforms.Compose([torchvision.transforms.CenterCrop(96),
                                                torchvision.transforms.ColorJitter(brightness=0.3, contrast=0.2, saturation=0.3, hue=0.05),
                                                torchvision.transforms.Lambda(lambda crops: torchvision.transforms.functional.adjust_brightness(img=crops,brightness_factor=np.random.uniform(0.5,1))),
                                           torchvision.transforms.Lambda(lambda crops: torchvision.transforms.functional.adjust_gamma(img=crops,gain=1,gamma=np.random.uniform(1,2))),
                            torchvision.transforms.Lambda(lambda crops: torchvision.transforms.functional.adjust_contrast(img=crops,contrast_factor=np.random.uniform(0.2,0.8))),
                                           torchvision.transforms.ToTensor()])

ToTensor_center = torchvision.transforms.Compose([torchvision.transforms.CenterCrop(96),
                                           torchvision.transforms.ToTensor()])

ToTensor_Test = torchvision.transforms.Compose([torchvision.transforms.CenterCrop((800,1184)),
                                           torchvision.transforms.ToTensor()])

ToTensor = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])

ToTensor_Full_input = torchvision.transforms.Compose([torchvision.transforms.CenterCrop((512,1024)),
                                                torchvision.transforms.ColorJitter(brightness=0.5, contrast=0.2, saturation=0.3, hue=0.05),
                                                torchvision.transforms.Lambda(lambda crops: torchvision.transforms.functional.adjust_gamma(img=crops,gain=1,gamma=np.random.uniform(0.7,2.5))),
                                            torchvision.transforms.Resize((128,256)),
                                           torchvision.transforms.ToTensor()])

ToTensor_Full = torchvision.transforms.Compose([torchvision.transforms.CenterCrop((512,1024)),
                                            torchvision.transforms.Resize((128,256)),
                                           torchvision.transforms.ToTensor()])

### Optimizaer Loss

In [None]:
L1loss = torch.nn.L1Loss(reduction='mean')
MSEloss_sum = torch.nn.MSELoss(reduction='mean')

optimizer = torch.optim.Adam(model.parameters())

In [None]:
model = UNet(n_channels=3,n_classes=3)
model.to(device)

In [None]:
from tensorboardX import SummaryWriter
writer = SummaryWriter()


### Dataloader 

In [None]:
dataloader = DataLoader(dataset, batch_size=16,
                        shuffle=True, num_workers=4)


### Train Loop

In [None]:
num_image = len(df)
num_batch = 10
itercount = 0
for b in range(num_batch):
    for i,data in enumerate(dataloader):
        
        sample = data['low'].to(device)
        label  = data['high'].to(device)
        full = dataset_full[i]
        full_label = full['label'].unsqueeze(0).to(device)
        full_input = full['input'].unsqueeze(0).to(device)
        
        out = model(sample)
        out_full = model(full_input)
        

        pixle_loss_patch = L1loss(out,label)
        pixle_loss_full = L1loss(out_full,full_label)

        loss =  pixle_loss_patch+pixle_loss
        
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        itercount = itercount+1
        
        
        if i%100==0:
            x = torch.cat((sample,out,label),dim=0)
            x = vutils.make_grid(x, normalize=True, scale_each=True)
            writer.add_image('Sample : Out : Ground', x, itercount)

            
            y = torch.cat((full_input,out_full,full_label),dim=0)
            y = vutils.make_grid(y, normalize=True, scale_each=True)
            writer.add_image('Full', y, itercount)
            
            writer.add_scalar('data/pixel loss', pixle_loss.item(),itercount)
            writer.add_scalar('data/pixel loss_patch', pixle_loss_patch.item(),itercount)
            
            im = Image.open('4.jpg').convert('RGB')
            im = ToTensor_Test(im).unsqueeze(0).to(device)

            sample = model(im)
            x = vutils.make_grid(sample, normalize=True, scale_each=True)
            writer.add_image('Test', x, itercount)

In [None]:
torch.save(model.state_dict(),'unet3.pt')