In [1]:
import numpy as np
import random
import time
import cv2
import os

from torch.utils.data import Dataset, DataLoader
from torchsummary import summary
from torch import Tensor
from torch.autograd import Variable
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
import torch

%run Models.ipynb
%run LossFunctions.ipynb
%run DataLoader.ipynb

Random Seed:  5845


# Start Training

### Step 0: 宣告訓練參數

In [2]:
## 
Workers = 4
BatchSize = 32
LR = 0.0001
Beta1 = 0.9
NGpu = 1
Model_path = "/home/mj/HardDisk/Github/Image_Compressor/Model/"

### Step 1: 讀取訓練資料（受損 / 原圖）

In [3]:
train_path = "/home/mj/HardDisk/Github/Image_Compressor/Dataset/Training_Data"  ## 受損影像路徑
# valid_path = "/home/mj/HardDisk/Github/Image_Compressor/Dataset/Validation_Data"  ## 原圖影像路徑

train_data_loader = DataLoader(Image(train_path), batch_size=BatchSize, shuffle=True, num_workers=Workers, pin_memory=True)
# valid_data_loader = DataLoader(Image(valid_path), batch_size=BatchSize, shuffle=True, num_workers=Workers, pin_memory=True)

### Step 2: 讀取 Models

In [4]:
print("[*] Loading Models")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
netG = Generator(NGpu).to(device)
netG.apply(weights_init)
print(netG)

netD = Discriminator(NGpu).to(device)
netD.apply(weights_init)
print(netD)

[*] Loading Models
Generator(
  (Conv1): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (Conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (Conv3): Conv2d(64, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (ReLU): ReLU()
  (BN): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (Tanh): Tanh()
  (RUnit1): Residual_Unit_GN(
    (Conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (Conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (BN1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (BN2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (ReLU): ReLU()
  )
  (RUnit2): Residual_Unit_GN(
    (Conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (Conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (BN1): BatchNorm2d(64, eps=1e-05, 

### Step 3: 初始化參數配置

In [5]:
print("[*] Initialized Parameters")
optimizer_G = torch.optim.Adam(netG.parameters(), lr=LR, betas=(Beta1, 0.999))
optimizer_D = torch.optim.Adam(netG.parameters(), lr=LR, betas=(Beta1, 0.999))
MSELoss = nn.MSELoss(reduction='sum').to(device)

[*] Initialized Parameters


### Step 4: 先稍微訓練 Generative Network

In [None]:
print("[*] Start Training Generative Network")
num_epochs = 10
for epoch in range(num_epochs):
    for i, data in enumerate(train_data_loader, 0):
        # 初始所有變數的梯度
        dcp, ori = data['Dcp'].to(device, dtype=torch.float, non_blocking=True), data['Ori'].to(device, dtype=torch.float, non_blocking=True)
        
        optimizer_G.zero_grad()
        gen = netG(dcp)*255
    
        target_loss = MSELoss(dcp, ori)
        loss_MSE = MSELoss(gen, ori)
        loss_MSE.backward()
        optimizer_G.step()

        print("Epoch: {}/{} Batch: {}/{} G Loss: {} Target Loss: {}".format(epoch+1, num_epochs, i+1, len(train_data_loader), loss_MSE.item(), target_loss.item()), end="\r")
        
        ## Store after every epoch
        if (loss_MSE <= 2000):
            state = {
                'epoch': epoch+1,
                'netG': netG.state_dict(),
                'netD': netD.state_dict(),
                'optimizer_G': optimizer_G.state_dict(),
                'optimizer_D': optimizer_D.state_dict(),
            }
            save_checkpoint(state, epoch+1, i, Model_path)

[*] Start Training Generative Network
Epoch: 2/10 Batch: 59409/63876 G Loss: 320432.8125 Target Loss: 278304.000

### Step 5: 加入 Discriminiative Network 一起訓練， 儲存權重

In [None]:
# print("[*] Training Stacked Network")
# for epoch in range(num_epochs):
#     for i, data in enumerate(train_data_loader, 0):
#         # 取出影像 及 生成影像實際狀態
#         dcp, ori = data['Dcp'].to(device, dtype=torch.float, non_blocking=True), data['Ori'].to(device, dtype=torch.float, non_blocking=True)
        
#         #####################
#         ## 訓練 Descriminator
#         ####################
#         # 初始所有變數的梯度
#         optimizer_D.zero_grad()
#         gen = netG(dcp)
        
#         # 獲取 對抗損失
#         loss_ADV_G, loss_ADV_D = ADVLoss(netD(ori), netD(gen.detach()))
        
#         # 調整 Descriminator 參數
#         loss_ADV_D.backward(retain_graph=True)
#         optimizer_D.step()
        
        
#         #################
#         ## 訓練 Generator
#         ################
#         # 初始所有變數的梯度
#         optimizer_G.zero_grad()
        
#         # 獲取 Generator 的綜合損失
#         loss_MSE = MSELoss(gen, ori)
#         loss_Mix = loss_MSE + loss_ADV_G
        
#         # 調整 Generator 參數
#         loss_Mix.backward()
#         optimizer_G.step()
        
        
#         print("Epoch: {}/{} Batch: {}/{} G Loss: {} D Loss: {}".format(epoch+1, num_epochs, i+1, len(train_data_loader), loss_Mix.item(), loss_ADV_D.item()), end="\r")


# # Step : 儲存 Final 權重

In [None]:
class Test(Dataset):

    def __init__(self, root_dir):
        self.dcp_dir = os.path.join(root_dir, "JPG")
        self.ori_dir = os.path.join(root_dir, "Original")
        self.file_names = os.listdir(self.dcp_dir)[:1]
        
    def __len__(self):
        return len(self.file_names)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        file_name = self.file_names[idx].split(".")[0]
        img_dcp_loc = os.path.join(self.dcp_dir, "{}.jpg".format(file_name))
        img_ori_loc = os.path.join(self.ori_dir, "{}.png".format(file_name))
        img_dcp = cv2.imread(img_dcp_loc, 0).reshape(1,64,64)
        img_ori = cv2.imread(img_ori_loc, 0).reshape(1,64,64)
        sample = {'Dcp': img_dcp, 'Ori': img_ori}
        
        return sample

In [None]:
valid_path = "/home/mj/HardDisk/Github/Image_Compressor/Dataset/Validation_Data"  ## 原圖影像路徑
valid_data_loader = DataLoader(Test(valid_path), batch_size=1, shuffle=True, num_workers=1, pin_memory=True)


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
netG = Generator(1).to(device)
state = torch.load('/home/mj/HardDisk/Github/Image_Compressor/Model/Epoch_1_10000.pth')
netG.load_state_dict(state['netG'])

MSELoss = nn.MSELoss(reduction='sum').to(device)

for i, data in enumerate(valid_data_loader, 0):
    dcp, ori = data['Dcp'].to(device, dtype=torch.float, non_blocking=True), data['Ori'].to(device, dtype=torch.float, non_blocking=True)
    
    gen = netG(dcp) * 255
    loss_MSE = MSELoss(gen, ori)
    print(loss_MSE)



# # cv2.imshow("Img", dep_img)
# # cv2.waitKey(0)
# # cv2.destroyAllWindows()

In [None]:
# state = torch.load('/home/mj/HardDisk/Github/Image_Compressor/Model/Epoch_2.pth')
# print(state)

In [None]:
# dcp = cv2.imread("/home/mj/HardDisk/Github/Image_Compressor/Dataset/Validation_Data/Decompressed/00000001.pgm", 0)
# ori = cv2.imread("/home/mj/HardDisk/Github/Image_Compressor/Dataset/Validation_Data/Original_pgm/00000001.pgm", 0)

# print(dcp)
# print(ori)

# # cv2.imshow("DEP", dcp)
# # cv2.imshow("ORI", ori)

# # cv2.waitKey(0)
# # cv2.destroyAllWindows()