# Model

In [48]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        # 4 input channels
        self.conv11 = nn.Conv2d(4, 32, 3, padding=1)
        self.conv12 = nn.Conv2d(32, 32, 3, padding=1)
        
        self.conv21 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv22 = nn.Conv2d(64, 64, 3, padding=1)
        
        self.conv31 = nn.Conv2d(64, 128, 3, padding=1)
        self.conv32 = nn.Conv2d(128, 128, 3, padding=1)
        
        self.conv41 = nn.Conv2d(128, 256, 3, padding=1)
        self.conv42 = nn.Conv2d(256, 256, 3, padding=1)
        
        self.conv51 = nn.Conv2d(256, 512, 3, padding=1)
        self.conv52 = nn.Conv2d(512, 512, 3, padding=1)
        
        # in_channels, out_channels, kernel-size(need check...)
        self.deconv5 = nn.ConvTranspose2d(512, 256, (2, 2), (2,2))
        
        self.conv61 = nn.Conv2d(512, 256, 3, padding=1)
        self.conv62 = nn.Conv2d(256, 256, 3, padding=1)
        
        self.deconv6 = nn.ConvTranspose2d(256, 128, (2,2), (2,2))
        
        self.conv71 = nn.Conv2d(256, 128, 3, padding=1)
        self.conv72 = nn.Conv2d(128, 128, 3, padding=1)
        
        self.deconv7 = nn.ConvTranspose2d(128, 64, (2,2), (2,2))
        
        self.conv81 = nn.Conv2d(128, 64, 3, padding=1)
        self.conv82 = nn.Conv2d(64, 64, 3, padding=1)
        
        self.deconv8 = nn.ConvTranspose2d(64, 32, (2,2), (2,2))
        
        self.conv91 = nn.Conv2d(64, 32, 3, padding=1)
        self.conv92 = nn.Conv2d(32, 32, 3, padding=1)
        
        self.conv10 = nn.Conv2d(32, 12, 3, padding=1)
        
        self.pixelshuffle = nn.PixelShuffle(2)
        
    
    def forward(self, x):
        x = F.leaky_relu(self.conv11(x))
        x_conv1 = F.leaky_relu(self.conv12(x))
        x = F.max_pool2d(x_conv1, (2, 2))
        
        x = F.leaky_relu(self.conv21(x))
        x_conv2 = F.leaky_relu(self.conv22(x))
        x = F.max_pool2d(x_conv2, (2, 2))
        
        x = F.leaky_relu(self.conv31(x))
        x_conv3 = F.leaky_relu(self.conv32(x))
        x = F.max_pool2d(x_conv3, (2, 2))
        
        x = F.leaky_relu(self.conv41(x))
        x_conv4 = F.leaky_relu(self.conv42(x))
        x = F.max_pool2d(x_conv4, (2, 2))
        
        x = F.leaky_relu(self.conv51(x))
        x = F.leaky_relu(self.conv52(x))
        
        deconv = self.deconv5(x)
        #print x_conv4.size(), deconv.size()
        up6 = self.concat(deconv, x_conv4, 256)
        #print "afterconcat:", up6.size()
        
        x = F.leaky_relu(self.conv61(up6))
        x = F.leaky_relu(self.conv62(x))
        
        deconv = self.deconv6(x)
        up7 = self.concat(deconv, x_conv3, 128)
        
        x = F.leaky_relu(self.conv71(up7))
        x = F.leaky_relu(self.conv72(x))
        
        deconv = self.deconv7(x)
        up8 = self.concat(deconv, x_conv2, 64)
        
        x = F.leaky_relu(self.conv81(up8))
        x = F.leaky_relu(self.conv82(x))
        
        deconv = self.deconv8(x)
        up9 = self.concat(deconv, x_conv1, 32)
        
        x = F.leaky_relu(self.conv91(up9))
        x = F.leaky_relu(self.conv92(x))
        
        x = self.conv10(x)
        
        out = self.pixelshuffle(x)
        
        return out
        
        
        
    def concat(self, deconv, x2, output_channels):
        deconv_out = torch.cat((deconv, x2), 1)
        return deconv_out
    
    

net = ConvNet()
print(net)

ConvNet(
  (conv11): Conv2d(4, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv12): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv21): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv22): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv31): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv32): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv41): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv42): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv51): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv52): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (deconv5): ConvTranspose2d(512, 256, kernel_size=(2, 2), stride=(2, 2))
  (conv61): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv62): Conv2d(256, 256, kernel_size=(3, 3)

A helper function， to calculate padding.

In [41]:
def padding(i=32, p=2, k=3, s=1, d=1):
        """
            o = output
            p = padding
            k = kernel_size
            s = stride
            d = dilation
        """
        o = (i + 2*p - k - (k-1)*(d-1))/s + 1
        return o

In [43]:
padding(p=1,i=512)

512.0

In [6]:
from __future__ import division
import os, time, scipy.io
import numpy as np
import rawpy
import glob

input_dir = './dataset/Sony/short/'
gt_dir = './dataset/Sony/long/'
checkpoint_dir = './result_Sony/'
result_dir = './result_Sony/'

# get train IDs
train_fns = glob.glob(gt_dir + '0*.ARW')
train_ids = [int(os.path.basename(train_fn)[0:5]) for train_fn in train_fns]

ps = 512  # patch size for training
save_freq = 500

DEBUG = 0
if DEBUG == 1:
    save_freq = 2
    train_ids = train_ids[0:5]

In [7]:
def pack_raw(raw):
    # pack Bayer image to 4 channels
    im = raw.raw_image_visible.astype(np.float32)
    im = np.maximum(im - 512, 0) / (16383 - 512)  # subtract the black level

    im = np.expand_dims(im, axis=2)
    img_shape = im.shape
    H = img_shape[0]
    W = img_shape[1]

    out = np.concatenate((im[0:H:2, 0:W:2, :],
                          im[0:H:2, 1:W:2, :],
                          im[1:H:2, 1:W:2, :],
                          im[1:H:2, 0:W:2, :]), axis=2)
    return out

In [8]:
# Raw data takes long time to load. Keep them in memory after loaded.
gt_images = [None] * 6000
input_images = {}
input_images['300'] = [None] * len(train_ids)
input_images['250'] = [None] * len(train_ids)
input_images['100'] = [None] * len(train_ids)

In [9]:
g_loss = np.zeros((5000, 1))

allfolders = glob.glob('./result/*0')
lastepoch = 0
for folder in allfolders:
    lastepoch = np.maximum(lastepoch, int(folder[-4:]))

learning_rate = 1e-4

In [11]:
import torch.optim as optim

criterion = nn.L1Loss()
optimizer = optim.Adam(net.parameters(), lr=1e-4)

In [53]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
net.to(device)

cuda:0


ConvNet(
  (conv11): Conv2d(4, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv12): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv21): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv22): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv31): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv32): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv41): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv42): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv51): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv52): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (deconv5): ConvTranspose2d(512, 256, kernel_size=(2, 2), stride=(2, 2))
  (conv61): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv62): Conv2d(256, 256, kernel_size=(3, 3)

# Training

In [68]:
for epoch in range(lastepoch, 4001):
    if os.path.isdir("result/%04d" % epoch):
        continue
    cnt = 0
    if epoch > 2000:
        learning_rate = 1e-5
        optimizer = optim.Adam(net.parameters(), lr=learning_rate)
        
    net.to(device)
    num_ids = len(train_ids)
    running_loss = 0.0
    for ind in np.random.permutation(num_ids):
        # get the path from image id
        train_id = train_ids[ind]
        in_files = glob.glob(input_dir + '%05d_00*.ARW' % train_id)
        in_path = in_files[np.random.random_integers(0, len(in_files) - 1)]
        in_fn = os.path.basename(in_path)

        gt_files = glob.glob(gt_dir + '%05d_00*.ARW' % train_id)
        gt_path = gt_files[0]
        gt_fn = os.path.basename(gt_path)
        in_exposure = float(in_fn[9:-5])
        gt_exposure = float(gt_fn[9:-5])
        ratio = min(gt_exposure / in_exposure, 300)

        st = time.time()
        cnt += 1

        if input_images[str(ratio)[0:3]][ind] is None:
            raw = rawpy.imread(in_path)
            input_images[str(ratio)[0:3]][ind] = np.expand_dims(pack_raw(raw), axis=0) * ratio

            gt_raw = rawpy.imread(gt_path)
            im = gt_raw.postprocess(use_camera_wb=True, half_size=False, no_auto_bright=True, output_bps=16)
            gt_images[ind] = np.expand_dims(np.float32(im / 65535.0), axis=0)

        # crop
        H = input_images[str(ratio)[0:3]][ind].shape[1]
        W = input_images[str(ratio)[0:3]][ind].shape[2]

        xx = np.random.randint(0, W - ps)
        yy = np.random.randint(0, H - ps)
        input_patch = input_images[str(ratio)[0:3]][ind][:, yy:yy + ps, xx:xx + ps, :]
        gt_patch = gt_images[ind][:, yy * 2:yy * 2 + ps * 2, xx * 2:xx * 2 + ps * 2, :]

        if np.random.randint(2, size=1)[0] == 1:  # random flip
            input_patch = np.flip(input_patch, axis=1)
            gt_patch = np.flip(gt_patch, axis=1)
        if np.random.randint(2, size=1)[0] == 1:
            input_patch = np.flip(input_patch, axis=2)
            gt_patch = np.flip(gt_patch, axis=2)
        if np.random.randint(2, size=1)[0] == 1:  # random transpose
            input_patch = np.transpose(input_patch, (0, 2, 1, 3))
            gt_patch = np.transpose(gt_patch, (0, 2, 1, 3))

        input_patch = np.minimum(input_patch, 1.0)

        input_patch = torch.tensor(input_patch.copy())
        input_patch = input_patch.permute(0, 3, 1, 2)
        gt_patch = torch.tensor(gt_patch.copy())
        gt_patch = gt_patch.permute(0, 3, 1, 2)
        
        # zero the parameter gradients
        optimizer.zero_grad()

        input_patch = input_patch.to(device)
        gt_patch = gt_patch.to(device)
        
        # forward + backward + optimize
        output = net(input_patch)
        loss = criterion(output, gt_patch)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        if cnt == (num_ids - 1):
            print("epoch: %d, loss: %.4f"%(epoch, running_loss/num_ids))

        if epoch % save_freq == 0:
            if not os.path.isdir(result_dir + '%04d' % epoch):
                os.makedirs(result_dir + '%04d' % epoch)

            temp = np.concatenate((gt_patch.cpu().detach().numpy()[0, :, :, :], output.cpu().detach().numpy()[0, :, :, :]), axis=1)
            scipy.misc.toimage(temp * 255, high=255, low=0, cmin=0, cmax=255).save(
                result_dir + '%04d/%05d_00_train_%d.jpg' % (epoch, train_id, ratio))
            
    torch.save(net.cpu(), checkpoint_dir + 'sid_torchversion_model.ckpt')

  app.launch_new_instance()
`toimage` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use Pillow's ``Image.fromarray`` directly instead.
  app.launch_new_instance()
  app.launch_new_instance()


epoch: 0, loss: 0.2199
epoch: 1, loss: 0.2278
epoch: 2, loss: 0.2241
epoch: 3, loss: 0.2384
epoch: 4, loss: 0.2349
epoch: 5, loss: 0.2233
epoch: 6, loss: 0.2535
epoch: 7, loss: 0.2321
epoch: 8, loss: 0.2356
epoch: 9, loss: 0.2328
epoch: 10, loss: 0.2291
epoch: 11, loss: 0.2321
epoch: 12, loss: 0.2288
epoch: 13, loss: 0.2325
epoch: 14, loss: 0.2354
epoch: 15, loss: 0.2361
epoch: 16, loss: 0.2324
epoch: 17, loss: 0.2333
epoch: 18, loss: 0.2288
epoch: 19, loss: 0.2388
epoch: 20, loss: 0.2277
epoch: 21, loss: 0.2460
epoch: 22, loss: 0.2327
epoch: 23, loss: 0.2294
epoch: 24, loss: 0.2254
epoch: 25, loss: 0.2281
epoch: 26, loss: 0.2269
epoch: 27, loss: 0.2388
epoch: 28, loss: 0.2289
epoch: 29, loss: 0.2182
epoch: 30, loss: 0.2265
epoch: 31, loss: 0.2294
epoch: 32, loss: 0.2307
epoch: 33, loss: 0.2251
epoch: 34, loss: 0.2328
epoch: 35, loss: 0.2495
epoch: 36, loss: 0.2393
epoch: 37, loss: 0.2312
epoch: 38, loss: 0.2275
epoch: 39, loss: 0.2255
epoch: 40, loss: 0.2355
epoch: 41, loss: 0.2349
ep

epoch: 333, loss: 0.2255
epoch: 334, loss: 0.2298
epoch: 335, loss: 0.2353
epoch: 336, loss: 0.2369
epoch: 337, loss: 0.2384
epoch: 338, loss: 0.2389
epoch: 339, loss: 0.2388
epoch: 340, loss: 0.2401
epoch: 341, loss: 0.2279
epoch: 342, loss: 0.2254
epoch: 343, loss: 0.2404
epoch: 344, loss: 0.2191
epoch: 345, loss: 0.2390
epoch: 346, loss: 0.2293
epoch: 347, loss: 0.2282
epoch: 348, loss: 0.2294
epoch: 349, loss: 0.2376
epoch: 350, loss: 0.2315
epoch: 351, loss: 0.2364
epoch: 352, loss: 0.2365
epoch: 353, loss: 0.2247
epoch: 354, loss: 0.2312
epoch: 355, loss: 0.2351
epoch: 356, loss: 0.2350
epoch: 357, loss: 0.2339
epoch: 358, loss: 0.2231
epoch: 359, loss: 0.2256
epoch: 360, loss: 0.2403
epoch: 361, loss: 0.2257
epoch: 362, loss: 0.2372
epoch: 363, loss: 0.2357
epoch: 364, loss: 0.2392
epoch: 365, loss: 0.2283
epoch: 366, loss: 0.2350
epoch: 367, loss: 0.2346
epoch: 368, loss: 0.2354
epoch: 369, loss: 0.2240
epoch: 370, loss: 0.2250
epoch: 371, loss: 0.2339
epoch: 372, loss: 0.2346


epoch: 661, loss: 0.2263
epoch: 662, loss: 0.2424
epoch: 663, loss: 0.2259
epoch: 664, loss: 0.2323
epoch: 665, loss: 0.2272
epoch: 666, loss: 0.2364
epoch: 667, loss: 0.2399
epoch: 668, loss: 0.2306
epoch: 669, loss: 0.2235
epoch: 670, loss: 0.2338
epoch: 671, loss: 0.2241
epoch: 672, loss: 0.2315
epoch: 673, loss: 0.2362
epoch: 674, loss: 0.2298
epoch: 675, loss: 0.2271
epoch: 676, loss: 0.2350
epoch: 677, loss: 0.2368
epoch: 678, loss: 0.2316
epoch: 679, loss: 0.2232
epoch: 680, loss: 0.2356
epoch: 681, loss: 0.2239
epoch: 682, loss: 0.2321
epoch: 683, loss: 0.2310
epoch: 684, loss: 0.2312
epoch: 685, loss: 0.2381
epoch: 686, loss: 0.2287
epoch: 687, loss: 0.2272
epoch: 688, loss: 0.2292
epoch: 689, loss: 0.2311
epoch: 690, loss: 0.2489
epoch: 691, loss: 0.2349
epoch: 692, loss: 0.2400
epoch: 693, loss: 0.2281
epoch: 694, loss: 0.2385
epoch: 695, loss: 0.2336
epoch: 696, loss: 0.2226
epoch: 697, loss: 0.2338
epoch: 698, loss: 0.2361
epoch: 699, loss: 0.2354
epoch: 700, loss: 0.2356


epoch: 989, loss: 0.2388
epoch: 990, loss: 0.2324
epoch: 991, loss: 0.2285
epoch: 992, loss: 0.2228
epoch: 993, loss: 0.2280
epoch: 994, loss: 0.2340
epoch: 995, loss: 0.2316
epoch: 996, loss: 0.2421
epoch: 997, loss: 0.2284
epoch: 998, loss: 0.2317
epoch: 999, loss: 0.2357
epoch: 1000, loss: 0.2331
epoch: 1001, loss: 0.2377
epoch: 1002, loss: 0.2287
epoch: 1003, loss: 0.2385
epoch: 1004, loss: 0.2246
epoch: 1005, loss: 0.2388
epoch: 1006, loss: 0.2260
epoch: 1007, loss: 0.2317
epoch: 1008, loss: 0.2303
epoch: 1009, loss: 0.2201
epoch: 1010, loss: 0.2449
epoch: 1011, loss: 0.2290
epoch: 1012, loss: 0.2377
epoch: 1013, loss: 0.2322
epoch: 1014, loss: 0.2323
epoch: 1015, loss: 0.2261
epoch: 1016, loss: 0.2399
epoch: 1017, loss: 0.2351
epoch: 1018, loss: 0.2268
epoch: 1019, loss: 0.2315
epoch: 1020, loss: 0.2233
epoch: 1021, loss: 0.2363
epoch: 1022, loss: 0.2452
epoch: 1023, loss: 0.2281
epoch: 1024, loss: 0.2310
epoch: 1025, loss: 0.2334
epoch: 1026, loss: 0.2391
epoch: 1027, loss: 0.23

epoch: 1305, loss: 0.2332
epoch: 1306, loss: 0.2509
epoch: 1307, loss: 0.2338
epoch: 1308, loss: 0.2346
epoch: 1309, loss: 0.2249
epoch: 1310, loss: 0.2397
epoch: 1311, loss: 0.2458
epoch: 1312, loss: 0.2287
epoch: 1313, loss: 0.2401
epoch: 1314, loss: 0.2247
epoch: 1315, loss: 0.2379
epoch: 1316, loss: 0.2285
epoch: 1317, loss: 0.2413
epoch: 1318, loss: 0.2358
epoch: 1319, loss: 0.2425
epoch: 1320, loss: 0.2305
epoch: 1321, loss: 0.2298
epoch: 1322, loss: 0.2270
epoch: 1323, loss: 0.2435
epoch: 1324, loss: 0.2250
epoch: 1325, loss: 0.2347
epoch: 1326, loss: 0.2302
epoch: 1327, loss: 0.2278
epoch: 1328, loss: 0.2352
epoch: 1329, loss: 0.2269
epoch: 1330, loss: 0.2306
epoch: 1331, loss: 0.2270
epoch: 1332, loss: 0.2211
epoch: 1333, loss: 0.2339
epoch: 1334, loss: 0.2383
epoch: 1335, loss: 0.2439
epoch: 1336, loss: 0.2297
epoch: 1337, loss: 0.2447
epoch: 1338, loss: 0.2397
epoch: 1339, loss: 0.2217
epoch: 1340, loss: 0.2372
epoch: 1341, loss: 0.2283
epoch: 1342, loss: 0.2310
epoch: 1343,

epoch: 1621, loss: 0.2385
epoch: 1622, loss: 0.2286
epoch: 1623, loss: 0.2478
epoch: 1624, loss: 0.2446
epoch: 1625, loss: 0.2354
epoch: 1626, loss: 0.2310
epoch: 1627, loss: 0.2380
epoch: 1628, loss: 0.2319
epoch: 1629, loss: 0.2348
epoch: 1630, loss: 0.2300
epoch: 1631, loss: 0.2355
epoch: 1632, loss: 0.2359
epoch: 1633, loss: 0.2330
epoch: 1634, loss: 0.2260
epoch: 1635, loss: 0.2356
epoch: 1636, loss: 0.2278
epoch: 1637, loss: 0.2244
epoch: 1638, loss: 0.2412
epoch: 1639, loss: 0.2314
epoch: 1640, loss: 0.2328
epoch: 1641, loss: 0.2366
epoch: 1642, loss: 0.2415
epoch: 1643, loss: 0.2262
epoch: 1644, loss: 0.2329
epoch: 1645, loss: 0.2206
epoch: 1646, loss: 0.2389
epoch: 1647, loss: 0.2384
epoch: 1648, loss: 0.2333
epoch: 1649, loss: 0.2313
epoch: 1650, loss: 0.2282
epoch: 1651, loss: 0.2275
epoch: 1652, loss: 0.2407
epoch: 1653, loss: 0.2374
epoch: 1654, loss: 0.2398
epoch: 1655, loss: 0.2279
epoch: 1656, loss: 0.2346
epoch: 1657, loss: 0.2284
epoch: 1658, loss: 0.2339
epoch: 1659,

epoch: 1937, loss: 0.2409
epoch: 1938, loss: 0.2360
epoch: 1939, loss: 0.2316
epoch: 1940, loss: 0.2332
epoch: 1941, loss: 0.2300
epoch: 1942, loss: 0.2274
epoch: 1943, loss: 0.2359
epoch: 1944, loss: 0.2304
epoch: 1945, loss: 0.2333
epoch: 1946, loss: 0.2360
epoch: 1947, loss: 0.2395
epoch: 1948, loss: 0.2321
epoch: 1949, loss: 0.2321
epoch: 1950, loss: 0.2279
epoch: 1951, loss: 0.2430
epoch: 1952, loss: 0.2332
epoch: 1953, loss: 0.2260
epoch: 1954, loss: 0.2265
epoch: 1955, loss: 0.2305
epoch: 1956, loss: 0.2374
epoch: 1957, loss: 0.2415
epoch: 1958, loss: 0.2340
epoch: 1959, loss: 0.2354
epoch: 1960, loss: 0.2261
epoch: 1961, loss: 0.2208
epoch: 1962, loss: 0.2318
epoch: 1963, loss: 0.2256
epoch: 1964, loss: 0.2272
epoch: 1965, loss: 0.2344
epoch: 1966, loss: 0.2310
epoch: 1967, loss: 0.2306
epoch: 1968, loss: 0.2341
epoch: 1969, loss: 0.2365
epoch: 1970, loss: 0.2351
epoch: 1971, loss: 0.2256
epoch: 1972, loss: 0.2341
epoch: 1973, loss: 0.2333
epoch: 1974, loss: 0.2328
epoch: 1975,

epoch: 2253, loss: 0.0473
epoch: 2254, loss: 0.0450
epoch: 2255, loss: 0.0451
epoch: 2256, loss: 0.0457
epoch: 2257, loss: 0.0470
epoch: 2258, loss: 0.0469
epoch: 2259, loss: 0.0450
epoch: 2260, loss: 0.0458
epoch: 2261, loss: 0.0449
epoch: 2262, loss: 0.0452
epoch: 2263, loss: 0.0463
epoch: 2264, loss: 0.0446
epoch: 2265, loss: 0.0451
epoch: 2266, loss: 0.0470
epoch: 2267, loss: 0.0449
epoch: 2268, loss: 0.0456
epoch: 2269, loss: 0.0461
epoch: 2270, loss: 0.0451
epoch: 2271, loss: 0.0449
epoch: 2272, loss: 0.0453
epoch: 2273, loss: 0.0437
epoch: 2274, loss: 0.0457
epoch: 2275, loss: 0.0439
epoch: 2276, loss: 0.0455
epoch: 2277, loss: 0.0461
epoch: 2278, loss: 0.0449
epoch: 2279, loss: 0.0446
epoch: 2280, loss: 0.0441
epoch: 2281, loss: 0.0444
epoch: 2282, loss: 0.0456
epoch: 2283, loss: 0.0450
epoch: 2284, loss: 0.0455
epoch: 2285, loss: 0.0452
epoch: 2286, loss: 0.0444
epoch: 2287, loss: 0.0436
epoch: 2288, loss: 0.0455
epoch: 2289, loss: 0.0458
epoch: 2290, loss: 0.0439
epoch: 2291,

epoch: 2569, loss: 0.0403
epoch: 2570, loss: 0.0407
epoch: 2571, loss: 0.0399
epoch: 2572, loss: 0.0413
epoch: 2573, loss: 0.0407
epoch: 2574, loss: 0.0390
epoch: 2575, loss: 0.0408
epoch: 2576, loss: 0.0397
epoch: 2577, loss: 0.0405
epoch: 2578, loss: 0.0399
epoch: 2579, loss: 0.0408
epoch: 2580, loss: 0.0404
epoch: 2581, loss: 0.0405
epoch: 2582, loss: 0.0402
epoch: 2583, loss: 0.0401
epoch: 2584, loss: 0.0421
epoch: 2585, loss: 0.0410
epoch: 2586, loss: 0.0404
epoch: 2587, loss: 0.0395
epoch: 2588, loss: 0.0400
epoch: 2589, loss: 0.0404
epoch: 2590, loss: 0.0394
epoch: 2591, loss: 0.0408
epoch: 2592, loss: 0.0415
epoch: 2593, loss: 0.0403
epoch: 2594, loss: 0.0414
epoch: 2595, loss: 0.0397
epoch: 2596, loss: 0.0406
epoch: 2597, loss: 0.0406
epoch: 2598, loss: 0.0411
epoch: 2599, loss: 0.0402
epoch: 2600, loss: 0.0412
epoch: 2601, loss: 0.0403
epoch: 2602, loss: 0.0385
epoch: 2603, loss: 0.0404
epoch: 2604, loss: 0.0395
epoch: 2605, loss: 0.0401
epoch: 2606, loss: 0.0404
epoch: 2607,

epoch: 2885, loss: 0.0361
epoch: 2886, loss: 0.0360
epoch: 2887, loss: 0.0374
epoch: 2888, loss: 0.0366
epoch: 2889, loss: 0.0355
epoch: 2890, loss: 0.0352
epoch: 2891, loss: 0.0363
epoch: 2892, loss: 0.0352
epoch: 2893, loss: 0.0361
epoch: 2894, loss: 0.0358
epoch: 2895, loss: 0.0356
epoch: 2896, loss: 0.0351
epoch: 2897, loss: 0.0358
epoch: 2898, loss: 0.0353
epoch: 2899, loss: 0.0356
epoch: 2900, loss: 0.0367
epoch: 2901, loss: 0.0374
epoch: 2902, loss: 0.0354
epoch: 2903, loss: 0.0348
epoch: 2904, loss: 0.0366
epoch: 2905, loss: 0.0361
epoch: 2906, loss: 0.0351
epoch: 2907, loss: 0.0353
epoch: 2908, loss: 0.0361
epoch: 2909, loss: 0.0344
epoch: 2910, loss: 0.0352
epoch: 2911, loss: 0.0350
epoch: 2912, loss: 0.0348
epoch: 2913, loss: 0.0355
epoch: 2914, loss: 0.0351
epoch: 2915, loss: 0.0349
epoch: 2916, loss: 0.0349
epoch: 2917, loss: 0.0366
epoch: 2918, loss: 0.0356
epoch: 2919, loss: 0.0365
epoch: 2920, loss: 0.0363
epoch: 2921, loss: 0.0369
epoch: 2922, loss: 0.0350
epoch: 2923,

epoch: 3201, loss: 0.0338
epoch: 3202, loss: 0.0337
epoch: 3203, loss: 0.0337
epoch: 3204, loss: 0.0338
epoch: 3205, loss: 0.0335
epoch: 3206, loss: 0.0340
epoch: 3207, loss: 0.0343
epoch: 3208, loss: 0.0335
epoch: 3209, loss: 0.0343
epoch: 3210, loss: 0.0344
epoch: 3211, loss: 0.0336
epoch: 3212, loss: 0.0326
epoch: 3213, loss: 0.0336
epoch: 3214, loss: 0.0334
epoch: 3215, loss: 0.0338
epoch: 3216, loss: 0.0337
epoch: 3217, loss: 0.0337
epoch: 3218, loss: 0.0338
epoch: 3219, loss: 0.0334
epoch: 3220, loss: 0.0348
epoch: 3221, loss: 0.0329
epoch: 3222, loss: 0.0344
epoch: 3223, loss: 0.0341
epoch: 3224, loss: 0.0331
epoch: 3225, loss: 0.0345
epoch: 3226, loss: 0.0331
epoch: 3227, loss: 0.0336
epoch: 3228, loss: 0.0346
epoch: 3229, loss: 0.0339
epoch: 3230, loss: 0.0341
epoch: 3231, loss: 0.0334
epoch: 3232, loss: 0.0343
epoch: 3233, loss: 0.0339
epoch: 3234, loss: 0.0342
epoch: 3235, loss: 0.0340
epoch: 3236, loss: 0.0347
epoch: 3237, loss: 0.0341
epoch: 3238, loss: 0.0339
epoch: 3239,

epoch: 3517, loss: 0.0314
epoch: 3518, loss: 0.0324
epoch: 3519, loss: 0.0319
epoch: 3520, loss: 0.0329
epoch: 3521, loss: 0.0320
epoch: 3522, loss: 0.0317
epoch: 3523, loss: 0.0330
epoch: 3524, loss: 0.0325
epoch: 3525, loss: 0.0324
epoch: 3526, loss: 0.0326
epoch: 3527, loss: 0.0327
epoch: 3528, loss: 0.0314
epoch: 3529, loss: 0.0317
epoch: 3530, loss: 0.0328
epoch: 3531, loss: 0.0310
epoch: 3532, loss: 0.0329
epoch: 3533, loss: 0.0316
epoch: 3534, loss: 0.0317
epoch: 3535, loss: 0.0319
epoch: 3536, loss: 0.0325
epoch: 3537, loss: 0.0313
epoch: 3538, loss: 0.0320
epoch: 3539, loss: 0.0331
epoch: 3540, loss: 0.0329
epoch: 3541, loss: 0.0315
epoch: 3542, loss: 0.0315
epoch: 3543, loss: 0.0323
epoch: 3544, loss: 0.0324
epoch: 3545, loss: 0.0332
epoch: 3546, loss: 0.0314
epoch: 3547, loss: 0.0334
epoch: 3548, loss: 0.0327
epoch: 3549, loss: 0.0315
epoch: 3550, loss: 0.0316
epoch: 3551, loss: 0.0328
epoch: 3552, loss: 0.0318
epoch: 3553, loss: 0.0333
epoch: 3554, loss: 0.0313
epoch: 3555,

epoch: 3833, loss: 0.0311
epoch: 3834, loss: 0.0311
epoch: 3835, loss: 0.0308
epoch: 3836, loss: 0.0316
epoch: 3837, loss: 0.0309
epoch: 3838, loss: 0.0298
epoch: 3839, loss: 0.0310
epoch: 3840, loss: 0.0308
epoch: 3841, loss: 0.0312
epoch: 3842, loss: 0.0305
epoch: 3843, loss: 0.0307
epoch: 3844, loss: 0.0309
epoch: 3845, loss: 0.0307
epoch: 3846, loss: 0.0308
epoch: 3847, loss: 0.0293
epoch: 3848, loss: 0.0308
epoch: 3849, loss: 0.0313
epoch: 3850, loss: 0.0311
epoch: 3851, loss: 0.0313
epoch: 3852, loss: 0.0314
epoch: 3853, loss: 0.0306
epoch: 3854, loss: 0.0307
epoch: 3855, loss: 0.0301
epoch: 3856, loss: 0.0315
epoch: 3857, loss: 0.0307
epoch: 3858, loss: 0.0304
epoch: 3859, loss: 0.0304
epoch: 3860, loss: 0.0310
epoch: 3861, loss: 0.0316
epoch: 3862, loss: 0.0309
epoch: 3863, loss: 0.0315
epoch: 3864, loss: 0.0310
epoch: 3865, loss: 0.0309
epoch: 3866, loss: 0.0308
epoch: 3867, loss: 0.0310
epoch: 3868, loss: 0.0313
epoch: 3869, loss: 0.0307
epoch: 3870, loss: 0.0300
epoch: 3871,

# Testing

In [70]:
checkpoint_dir = './checkpoint/SID/Sony/'
result_dir = './result_Sony/'

# get test IDs
test_fns = glob.glob(gt_dir + '/1*.ARW')
test_ids = [int(os.path.basename(test_fn)[0:5]) for test_fn in test_fns]

ckpt = torch.load(checkpoint_dir + 'sid_torchversion_model.ckpt')
ckpt.eval()
ckpt.to(device)

if ckpt:
    print('successfully loaded model.')

if not os.path.isdir(result_dir + 'final/'):
    os.makedirs(result_dir + 'final/')

successfully loaded model.


In [73]:
for test_id in test_ids:
    # test the first image in each sequence
    in_files = glob.glob(input_dir + '%05d_00*.ARW' % test_id)
    for k in range(len(in_files)):
        in_path = in_files[k]
        in_fn = os.path.basename(in_path)
        print(in_fn)
        gt_files = glob.glob(gt_dir + '%05d_00*.ARW' % test_id)
        gt_path = gt_files[0]
        gt_fn = os.path.basename(gt_path)
        in_exposure = float(in_fn[9:-5])
        gt_exposure = float(gt_fn[9:-5])
        ratio = min(gt_exposure / in_exposure, 300)

        raw = rawpy.imread(in_path)
        input_full = np.expand_dims(pack_raw(raw), axis=0) * ratio

        im = raw.postprocess(use_camera_wb=True, half_size=False, no_auto_bright=True, output_bps=16)
        # scale_full = np.expand_dims(np.float32(im/65535.0),axis = 0)*ratio
        scale_full = np.expand_dims(np.float32(im / 65535.0), axis=0)

        gt_raw = rawpy.imread(gt_path)
        im = gt_raw.postprocess(use_camera_wb=True, half_size=False, no_auto_bright=True, output_bps=16)
        gt_full = np.expand_dims(np.float32(im / 65535.0), axis=0)

        input_full = np.minimum(input_full, 1.0)
        
        #------- Generate Img ---------#
        input_full = torch.tensor(input_full.copy())
        input_full = input_full.permute(0, 3, 1, 2)
        input_full = input_full.to(device)
        
        output = ckpt(input_full)
        output = output.cpu().detach().numpy()
        #----------- Done -------------# 
        
        output = np.minimum(np.maximum(output, 0), 1)

        output = output[0, :, :, :]
        gt_full = gt_full[0, :, :, :]
        scale_full = scale_full[0, :, :, :]
        scale_full = scale_full * np.mean(gt_full) / np.mean(
            scale_full)  # scale the low-light image to the same mean of the groundtruth

        scipy.misc.toimage(output * 255, high=255, low=0, cmin=0, cmax=255).save(
            result_dir + 'final/%5d_00_%d_out.png' % (test_id, ratio))
        scipy.misc.toimage(scale_full * 255, high=255, low=0, cmin=0, cmax=255).save(
            result_dir + 'final/%5d_00_%d_scale.png' % (test_id, ratio))
        scipy.misc.toimage(gt_full * 255, high=255, low=0, cmin=0, cmax=255).save(
            result_dir + 'final/%5d_00_%d_gt.png' % (test_id, ratio))

10178_00_0.1s.ARW


`toimage` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use Pillow's ``Image.fromarray`` directly instead.
`toimage` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use Pillow's ``Image.fromarray`` directly instead.
`toimage` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use Pillow's ``Image.fromarray`` directly instead.


10178_00_0.033s.ARW
10178_00_0.04s.ARW
10082_00_0.1s.ARW
10185_00_0.04s.ARW
10185_00_0.1s.ARW
10185_00_0.033s.ARW
10035_00_0.1s.ARW
10035_00_0.04s.ARW
10170_00_0.1s.ARW
10055_00_0.04s.ARW
10055_00_0.1s.ARW
10034_00_0.1s.ARW
10034_00_0.04s.ARW
10139_00_0.1s.ARW
10093_00_0.1s.ARW
10103_00_0.1s.ARW
10226_00_0.033s.ARW
10226_00_0.1s.ARW
10226_00_0.04s.ARW
10077_00_0.1s.ARW
10003_00_0.04s.ARW
10003_00_0.1s.ARW
10213_00_0.04s.ARW
10213_00_0.033s.ARW
10213_00_0.1s.ARW
10217_00_0.04s.ARW
10217_00_0.1s.ARW
10217_00_0.033s.ARW
10203_00_0.04s.ARW
10203_00_0.033s.ARW
10203_00_0.1s.ARW
10016_00_0.1s.ARW
10016_00_0.04s.ARW
10006_00_0.1s.ARW
10006_00_0.04s.ARW
10193_00_0.1s.ARW
10193_00_0.04s.ARW
10193_00_0.033s.ARW
10228_00_0.04s.ARW
10228_00_0.033s.ARW
10228_00_0.1s.ARW
10074_00_0.1s.ARW
10162_00_0.1s.ARW
10040_00_0.1s.ARW
10040_00_0.04s.ARW
10111_00_0.1s.ARW
10101_00_0.1s.ARW
10032_00_0.04s.ARW
10032_00_0.1s.ARW
10187_00_0.04s.ARW
10187_00_0.033s.ARW
10187_00_0.1s.ARW
10176_00_0.1s.ARW
10140_00_0.