In [1]:
import torch
import cv2
import glob
import os
import random
from torch.utils.data import DataLoader
data_path = '/raid/chensq/Cam_IR/e2e_opt/data/Urban100'

class Dataset_test(torch.utils.data.Dataset):
    def __init__(self, data_path):
        self.gt_path = sorted(glob.glob(data_path + '/*.png'))
        self.device = torch.device('cuda')

    def __getitem__(self, index):
        gt = cv2.imread(self.gt_path[index])
        gt = cv2.cvtColor(gt, cv2.COLOR_BGR2RGB)
        
        # crop
        h, w, _ = gt.shape
        patch_size = 256
        rand_h = random.randint(0, h-patch_size)
        rand_w = random.randint(0, w-patch_size)
        patch = gt[rand_h:rand_h+patch_size, rand_w:rand_w+patch_size, :]

        # degrad
        inp = patch[0::2, 0::2, :]

        # to tensor
        gt = torch.from_numpy(patch).permute(2, 0, 1).float().div(255.).to(self.device)
        inp = torch.from_numpy(inp).permute(2, 0, 1).float().div(255.).to(self.device)

        return {'gt': gt, 'inp': inp}
    
    def __len__(self):
        return len(self.gt_path)
    
DD = Dataset_test(data_path=data_path)
dloader = DataLoader(DD, batch_size=4, shuffle=True, num_workers=1)
for i, train_data in enumerate(dloader):
    gt = train_data['gt']
    inp = train_data['inp']
    print(gt.shape)
    print(inp.shape)

torch.Size([4, 3, 256, 256])
torch.Size([4, 3, 128, 128])
torch.Size([4, 3, 256, 256])
torch.Size([4, 3, 128, 128])
torch.Size([4, 3, 256, 256])
torch.Size([4, 3, 128, 128])
torch.Size([4, 3, 256, 256])
torch.Size([4, 3, 128, 128])
torch.Size([4, 3, 256, 256])
torch.Size([4, 3, 128, 128])


In [2]:
import torch.nn as nn

class my_net(nn.Module):
    def __init__(self, in_chs, out_chs, base_chs):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=in_chs, out_channels=base_chs, kernel_size=7, padding=3)
        self.conv2 = nn.Conv2d(in_channels=base_chs, out_channels=base_chs, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(in_channels=base_chs, out_channels=base_chs, kernel_size=3, padding=1)
        # self.transconv = nn.ConvTranspose2d()
        self.pixelshuffle = nn.PixelShuffle(upscale_factor=2)
        self.conv_out = nn.Conv2d(in_channels=base_chs//4, out_channels=out_chs, kernel_size=3, padding=1)
        
    def forward(self, data):
        x1 = self.conv1(data)
        x2 = self.conv2(x1)
        x3 = self.conv3(x2)
        x4 = self.pixelshuffle(x3)
        x_out = self.conv_out(x4)
        return x_out
    
net = my_net(3, 3, 96)
inp = torch.randn(4, 3, 128, 128)
out = net(inp)

print(inp.shape)
print(out.shape)

torch.Size([4, 3, 128, 128])
torch.Size([4, 3, 256, 256])


In [4]:
import torch
import cv2
import glob
import os
import random
from torch.utils.data import DataLoader
import torch.nn as nn
data_path = '/raid/chensq/Cam_IR/e2e_opt/data/Urban100'
device = torch.device('cuda:0')

class Dataset_test(torch.utils.data.Dataset):
    def __init__(self, data_path):
        self.gt_path = sorted(glob.glob(data_path + '/*.png'))
        self.device = torch.device('cuda:0')

    def __getitem__(self, index):
        gt = cv2.imread(self.gt_path[index])
        gt = cv2.cvtColor(gt, cv2.COLOR_BGR2RGB)
        
        # crop
        h, w, _ = gt.shape
        patch_size = 256
        rand_h = random.randint(0, h-patch_size)
        rand_w = random.randint(0, w-patch_size)
        patch = gt[rand_h:rand_h+patch_size, rand_w:rand_w+patch_size, :]

        # degrad
        inp = patch[0::2, 0::2, :]

        # to tensor
        gt = torch.from_numpy(patch).permute(2, 0, 1).float().div(255.).to(self.device)
        inp = torch.from_numpy(inp).permute(2, 0, 1).float().div(255.).to(self.device)

        return {'gt': gt, 'inp': inp}
    
    def __len__(self):
        return len(self.gt_path)
    
DD = Dataset_test(data_path=data_path)
dloader = DataLoader(DD, batch_size=4, shuffle=True, num_workers=0)

class my_net(nn.Module):
    def __init__(self, in_chs, out_chs, base_chs):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=in_chs, out_channels=base_chs, kernel_size=7, padding=3)
        self.conv2 = nn.Conv2d(in_channels=base_chs, out_channels=base_chs, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(in_channels=base_chs, out_channels=base_chs, kernel_size=3, padding=1)
        # self.transconv = nn.ConvTranspose2d()
        self.pixelshuffle = nn.PixelShuffle(upscale_factor=2)
        self.conv_out = nn.Conv2d(in_channels=base_chs//4, out_channels=out_chs, kernel_size=3, padding=1)
        
    def forward(self, data):
        x1 = self.conv1(data)
        x2 = self.conv2(x1)
        x3 = self.conv3(x2)
        x4 = self.pixelshuffle(x3)
        x_out = self.conv_out(x4)
        return x_out

net_su = my_net(3, 3, 96).to(device=device)
optimizer = torch.optim.Adam(net_su.parameters(), lr=1e-4)
loss = torch.nn.MSELoss().to(device=device)
num_epochs = 100
for epoch in range(num_epochs):
    for iter, data in enumerate(dloader):
        gt = train_data['gt']
        inp = train_data['inp']
        out = net_su(inp)

        loss_val = loss(gt, out)
        loss_val.backward()

        optimizer.step()
        optimizer.zero_grad()

        with torch.no_grad():
            print('epochs: {}, iters: {}, loss: {}'.format(epoch, iter, loss_val.cpu().item()))

epochs: 0, iters: 0, loss: 0.29264217615127563
epochs: 0, iters: 1, loss: 0.26108479499816895
epochs: 0, iters: 2, loss: 0.23161998391151428
epochs: 0, iters: 3, loss: 0.20329022407531738
epochs: 0, iters: 4, loss: 0.1755642294883728
epochs: 1, iters: 0, loss: 0.14819592237472534
epochs: 1, iters: 1, loss: 0.12129150331020355
epochs: 1, iters: 2, loss: 0.09542578458786011
epochs: 1, iters: 3, loss: 0.07169298827648163
epochs: 1, iters: 4, loss: 0.05174825340509415
epochs: 2, iters: 0, loss: 0.03773608058691025
epochs: 2, iters: 1, loss: 0.03183145821094513
epochs: 2, iters: 2, loss: 0.03487172722816467
epochs: 2, iters: 3, loss: 0.04420260712504387
epochs: 2, iters: 4, loss: 0.05359891057014465
epochs: 3, iters: 0, loss: 0.05770944058895111
epochs: 3, iters: 1, loss: 0.05547101050615311
epochs: 3, iters: 2, loss: 0.04901361092925072
epochs: 3, iters: 3, loss: 0.04137710854411125
epochs: 3, iters: 4, loss: 0.03493116423487663
epochs: 4, iters: 0, loss: 0.030894950032234192
epochs: 4, it