In [2]:
import rawpy
import imageio
from pathlib import Path
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets, models
import numpy as np
from torchvision import models
from matplotlib import pyplot as plt
import torch
from torchsummary import summary
from collections import defaultdict
import torch.nn.functional as F
import torch
from torchsummary import summary
import torch
import torch.nn as nn
from tqdm import tqdm
import random
import pickle
from torch.utils.tensorboard import SummaryWriter

In [3]:
with open('/mnt/d/dng/raw_list.pkl','rb') as f:
    raw_list = pickle.load(f)
    raw_array = np.stack(raw_list)
    del raw_list
    raw_array[raw_array>255] = 255
    raw_array = raw_array.astype(np.float32)/255.0
with open('/mnt/d/dng/rgb_list.pkl','rb') as f:
    rgb_list = pickle.load(f)
    rgb_array = np.stack(rgb_list)
    del rgb_list
    rgb_array = rgb_array.astype(np.float32)/255.0

In [20]:
class SimDataset(Dataset):
    def __init__(self, raw_array, rgb_array,count, transform=None):
        self.size = 128
        self.stride = 2
        self.count = count
        self.raw_array = raw_array
        self.rgb_array = rgb_array
        self.transform = transform
        self.source_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.479], [0.212]) # dng
        ])
        self.target_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # imagenet
        ])
    def __len__(self):
        return self.count
    
    def __getitem__(self, idx):  
        stride = self.stride
        sample_num = idx%len(self.raw_array)
        raw_image = self.raw_array[sample_num]
        rgb = self.rgb_array[sample_num]
        x = random.randint(0,int((raw_image.shape[0]-train_set.size-1)/stride))
        y = random.randint(0,int((raw_image.shape[1]-train_set.size-1)/stride))
        raw_crop = raw_image[stride*x:stride*x+self.size,stride*y:stride*y+self.size]
        rgb_crop = rgb[stride*x:stride*x+self.size,stride*y:stride*y+self.size]
        return [self.source_transform(raw_crop), self.target_transform(rgb_crop)]


In [21]:
train_set = SimDataset(raw_array[2:],rgb_array[2:], 512)
val_set = SimDataset(raw_array[:2],rgb_array[:2], 128)

In [27]:
batch_size = 32

train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2);
val_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=False, num_workers=2);

In [7]:
inv_normalize = transforms.Normalize(
    mean=[-0.485/0.229, -0.456/0.224, -0.406/0.255],
    std=[1/0.229, 1/0.224, 1/0.255]
)


In [None]:
raw, rgb = train_set[0]
plt.imshow(raw.permute(1,2,0))
plt.show()
plt.imshow(rgb.permute(1,2,0))
plt.show()

In [None]:
plt.imshow(reverse_transform(rgb))

In [23]:
def convrelu(in_channels, out_channels, kernel, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel, padding=padding),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
    )

class ResNetUNet(nn.Module):

    def __init__(self, n_class):
        super().__init__()
        
        self.base_model = models.resnet18(pretrained=True)
        
        self.base_layers = list(self.base_model.children())                
        
        self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2)
        self.layer0[0] = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.layer0_1x1 = convrelu(64, 64, 1, 0)
        self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4)        
        self.layer1_1x1 = convrelu(64, 64, 1, 0)       
        self.layer2 = self.base_layers[5]  # size=(N, 128, x.H/8, x.W/8)        
        self.layer2_1x1 = convrelu(128, 128, 1, 0)  
        self.layer3 = self.base_layers[6]  # size=(N, 256, x.H/16, x.W/16)        
        self.layer3_1x1 = convrelu(256, 256, 1, 0)  
        self.layer4 = self.base_layers[7]  # size=(N, 512, x.H/32, x.W/32)
        self.layer4_1x1 = convrelu(512, 512, 1, 0)  
        
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        
        self.conv_up3 = convrelu(256 + 512, 512, 3, 1)
        self.conv_up2 = convrelu(128 + 512, 256, 3, 1)
        self.conv_up1 = convrelu(64 + 256, 256, 3, 1)
        self.conv_up0 = convrelu(64 + 256, 128, 3, 1)
        
        self.conv_original_size0 = convrelu(1, 64, 3, 1)
        self.conv_original_size1 = convrelu(64, 64, 3, 1)
        self.conv_original_size2 = convrelu(64 + 128, 64, 3, 1)
        
        self.conv_last = nn.Conv2d(64, n_class, 1)
        
    def forward(self, input):
        x_original = self.conv_original_size0(input)
        x_original = self.conv_original_size1(x_original)
        
        layer0 = self.layer0(input)            
        layer1 = self.layer1(layer0)
        layer2 = self.layer2(layer1)
        layer3 = self.layer3(layer2)        
        layer4 = self.layer4(layer3)
        
        layer4 = self.layer4_1x1(layer4)
        x = self.upsample(layer4)
        
        layer3 = self.layer3_1x1(layer3)
#         x = self.upsample(layer3)
#         print("x shape:",x.shape)
#         print("layer3 shape:",layer3.shape)
        x = torch.cat([x, layer3], dim=1)
        x = self.conv_up3(x)
 
        x = self.upsample(x)
        layer2 = self.layer2_1x1(layer2)
        x = torch.cat([x, layer2], dim=1)
        x = self.conv_up2(x)

        x = self.upsample(x)
        layer1 = self.layer1_1x1(layer1)
        x = torch.cat([x, layer1], dim=1)
        x = self.conv_up1(x)

        x = self.upsample(x)
        layer0 = self.layer0_1x1(layer0)
        x = torch.cat([x, layer0], dim=1)
        x = self.conv_up0(x)
        
        x = self.upsample(x)
        x = torch.cat([x, x_original], dim=1)
        x = self.conv_original_size2(x)        
        
        out = self.conv_last(x)        
        
        return out



In [24]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ResNetUNet(3)
model = model.to(device)


In [25]:
inputs, labels = next(iter(train_dataloader))
out = model(inputs.to(device))

In [None]:
summary(model,input_size=(1, 64, 64))

In [26]:
import torch
import torch.optim as optim
from torch.optim import lr_scheduler
import time
import copy

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

num_class = 3

model = ResNetUNet(num_class).to(device)
optimizer = optim.AdamW(model.parameters(),lr=1e-4)

scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.8) 


cuda:0


In [None]:
num_epochs = 200
writer = SummaryWriter(comment="size 128 batch 32")
mse_loss = nn.MSELoss()

for epoch in range(num_epochs):
    print('Epoch {}/{}'.format(epoch, num_epochs - 1))


    model.train()
    scheduler.step()
    losses = []
    progress = tqdm(train_dataloader)
    for inputs, labels in progress:
        inputs = inputs.to(device)
        labels = labels.to(device)             

        outputs = model(inputs)

        loss = mse_loss(outputs,labels)
        losses.append(loss)

        # Backpropagation
        optimizer.zero_grad()

        loss.backward()
        optimizer.step()
        progress.set_description(str(loss.item()))

    writer.add_scalar('loss train', sum(losses)/len(losses), epoch)
    writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)

    losses = []
    model.eval()
    progress = tqdm(val_dataloader)
    for inputs, labels in progress:
        inputs = inputs.to(device)
        labels = labels.to(device)             

        with torch.no_grad():
            outputs = model(inputs)
            loss = mse_loss(outputs,labels)
        losses.append(loss)

        progress.set_description(str(loss.item()))

    writer.add_scalar('loss val', sum(losses)/len(losses), epoch)


    print("loss:", sum(losses)/len(losses))
    print(optimizer.param_groups[0]['lr'])





In [29]:
torch.save(model,'resnet18_128_4l_03xx.pt')