#### 导入库

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
import torchvision.models as models
from Net import *
from Loss import *
from DataLoader import *
from torch.utils.data import DataLoader
import time

#### 基本参数

In [2]:
size=256
train_batch_size = 4
start_epochs = 0
learning_rate = 0.0002
# 总共训练200个epoch
num_epochs = 5
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
save_point = 5
# 1:分解模型
model_choose = 4

#### 损失函数

In [3]:
consLoss = nn.MSELoss()
recLoss = nn.MSELoss()
colorLoss = nn.MSELoss()
hazeLoss = nn.MSELoss()
# structure-aware TV loss
smoothLoss = TVLoss()

#### 数据缓存入内存，加快读入速度

In [4]:
class DataPrefetcher():

    def __init__(self, loader):
        self.loader = iter(loader)
        self.preload()

    def preload(self):
        try:
            self.batch = next(self.loader)
        except StopIteration:
            self.batch = None
            return

    def next(self):
        batch = self.batch
        self.preload()
        return batch

#### 模型1_分解模型

In [5]:
def train_1(start_epoch):
    print("模型导入中")
    model = Retinex_Decomposition_net().to(device)
    if start_epoch != 0:
        model_path = './checkpoints/Retinex_Decomposition_net/epoch_' + str(start_epoch) + '.pth'
        model.load_state_dict(torch.load(model_path))
    print("模型导入完成")
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    total_loss = 0
    for epoch in range(start_epoch+1, num_epochs+1):
        print("epoch: ", epoch)
        L_no_light_path = r"/data/underwater/UIALN/Synthetic_dataset/dataset_no_AL"
        L_light_path = r"/data/underwater/UIALN/Synthetic_dataset/dataset_with_AL/train"
        dataset = retinex_decomposition_data(L_no_light_path, L_light_path)
        train_loader = DataLoader(dataset, batch_size=train_batch_size, shuffle=True, num_workers=0)
        start_time = time.time()
        prefetcher = DataPrefetcher(train_loader)
        batch = prefetcher.next()
        i = 0
        epoch_loss = 0
        while batch is not None:
            i += 1
            L_no_light = batch[0].to(device)
            L_light = batch[1].to(device)
            L_no_light_hat = model(L_no_light)
            # 每个batch中的第一个是I_no_light_hat，第二个是R_no_light_hat，它们的shape都是[batch_size, 1, 256, 256]，batch不改变
            I_no_light_hat, R_no_light_hat = torch.split(L_no_light_hat, 1, dim=1)
            L_light_hat = model(L_light)
            I_light_hat, R_light_hat = torch.split(L_light_hat, 1, dim=1)
            loss_1 = consLoss(R_light_hat, R_no_light_hat)
            loss_2_1 = recLoss(I_light_hat*R_light_hat, L_light)
            loss_2_2 = recLoss(I_no_light_hat*R_no_light_hat, L_no_light)
            loss_3 = smoothLoss(I_light_hat, R_light_hat)
            loss_4 = smoothLoss(I_no_light_hat, R_no_light_hat)
            loss = loss_1 + loss_2_1 + loss_2_2 + loss_3 + loss_4
            epoch_loss += loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            batch = prefetcher.next()
        if (epoch + 1) % save_point == 0:
            state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch}
            torch.save(state, './checkpoints/Retinex_Decomposition_net/epoch_' + str(epoch) + '.pth')
        time_epoch = time.time() - start_time
        epoch_loss = epoch_loss*1.0/i
        total_loss += epoch_loss
        print("==>No: {} epoch, time: {:.2f}, loss: {:.5f}".format(epoch, time_epoch / 60, epoch_loss))
        with open("output.txt", "a") as f:
            f.write("==>No: {} epoch, time: {:.2f}, loss: {:.5f}\n".format(epoch, time_epoch / 60, epoch_loss))
    print("total_loss:",total_loss*1.0/num_epochs-start_epochs)

#### 模型2_光照校正模型

In [6]:
def train_2(start_epoch):
    print("模型导入")
    # 前置模型
    model_1 = Retinex_Decomposition_net().to(device)
    model1_path = './save_model/Retinex_Light_Correction_net.pth'
    model_1.load_state_dict(torch.load(model1_path)['model'])
    # 后置模型
    model_2 = Illumination_Correction().to(device)
    if start_epoch != 0:
        model2_path = './checkpoints/Illumination_Correction/epoch_' + str(start_epoch) + '.pth'
        model_2.load_state_dict(torch.load(model2_path))
    print("模型导入完成")
    model_1.eval()
    model_2.train()
    optimizer = torch.optim.Adam(model_2.parameters(), lr=learning_rate)
    total_loss = 0
    for epoch in range(start_epoch+1, num_epochs+1):
        print("epoch: ", epoch)
        L_no_light_path = r"/data/underwater/UIALN/Synthetic_dataset/dataset_no_AL"
        L_light_path = r"/data/underwater/UIALN/Synthetic_dataset/dataset_with_AL/train"
        dataset = retinex_decomposition_data(L_no_light_path, L_light_path)
        train_loader = DataLoader(dataset, batch_size=train_batch_size, shuffle=True, num_workers=0)
        start_time = time.time()
        prefetcher = DataPrefetcher(train_loader)
        batch = prefetcher.next()
        i = 0
        epoch_loss = 0
        while batch is not None:
            i+=1
            L_no_light = batch[0].to(device)
            L_light = batch[1].to(device)
            temp = model_1(L_light)
            I_light, R_light = torch.split(temp, 1, dim=1)
            temp = model_1(L_no_light)
            I_no_light, R_no_light = torch.split(temp, 1, dim=1)
            I_delight_hat = model_2(torch.cat((I_light, R_light), dim=1))
            # 感觉论文这里有点问题，之后问一下
            loss_1 = recLoss(I_delight_hat*R_light, L_no_light)
            loss_2 = consLoss(R_light, R_no_light)
            loss = loss_1 + loss_2
            epoch_loss += loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            batch = prefetcher.next()
        if (epoch + 1) % save_point == 0:
            state = {'model': model_2.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch}
            torch.save(state, './checkpoints/Illumination_Correction/epoch_' + str(epoch) + '.pth')
        time_epoch = time.time() - start_time
        epoch_loss = epoch_loss*1.0/i
        total_loss += epoch_loss
        print("==>No: {} epoch, time: {:.2f}, loss: {:.5f}".format(epoch, time_epoch / 60, epoch_loss))
        with open("output.txt", "a") as f:
            f.write("==>No: {} epoch, time: {:.2f}, loss: {:.5f}\n".format(epoch, time_epoch / 60, epoch_loss))
    print("total_loss:",total_loss*1.0/num_epochs-start_epochs)


#### 模型3_AL区域自导向色彩恢复模块

In [7]:
def train_3(start_epoch):
    print("模型导入")
    # 前置双模型
    model_1 = Retinex_Decomposition_net().to(device)
    model1_path = './save_model/Retinex_Light_Correction_net.pth'
    model_1.load_state_dict(torch.load(model1_path)['model'])
    model_2 = Illumination_Correction().to(device)
    model2_path = './save_model/Illumination_Correction_net.pth'
    model_2.load_state_dict(torch.load(model2_path)['model'])
    # 后置模型
    model_3 = AL_Area_Selfguidance_Color_Correction().to(device)
    if start_epoch != 0:
        model3_path = './checkpoints/AL_Area_Selfguidance_Color_Correction/epoch_' + str(start_epoch) + '.pth'
        model_3.load_state_dict(torch.load(model3_path))
    print("模型导入完成")
    model_1.eval()
    model_2.eval()
    model_3.train()
    optimizer = torch.optim.Adam(model_3.parameters(), lr=learning_rate)
    total_loss = 0
    for epoch in range(start_epoch+1, num_epochs+1):
        print("epoch: ", epoch)
        # ABcc_path = r"./dataset/UIALN_datasest/train_data/dataset_with_AL/train"
        # gt_path = r"./dataset/UIALN_datasest/train_data/labels/raw"
        ABcc_path = r"/data/underwater/UIEB-EUVP-LSUI2/train/input"
        gt_path = r"/data/underwater/UIEB-EUVP-LSUI2/train/target"
        dataset = AL_data(ABcc_path, gt_path, size=size)
        train_loader = DataLoader(dataset, batch_size=train_batch_size, shuffle=True, num_workers=0)
        start_time = time.time()
        prefetcher = DataPrefetcher(train_loader)
        batch = prefetcher.next()
        i = 0
        epoch_loss = 0
        while batch is not None:
            i+=1
            ABcc = batch[0].to(device)
            gt = batch[1].to(device)
            L = batch[2].to(device)
            temp = model_1(L)
            I_light, R_light = torch.split(temp, 1, dim=1)
            I_delight = model_2(temp)
            M_image = I_light - I_delight
            ABcc_hat = model_3(M_image, ABcc)
            loss = colorLoss(ABcc_hat, gt)
            epoch_loss += loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            batch = prefetcher.next()
        if (epoch + 1) % save_point == 0:
            state = {'model': model_3.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch}
            torch.save(state, './checkpoints/AL_Area_Selfguidance_Color_Correction/epoch_' + str(epoch) + '.pth')
        time_epoch = time.time() - start_time
        epoch_loss = epoch_loss*1.0/i
        total_loss += epoch_loss
        print("==>No: {} epoch, time: {:.2f}, loss: {:.5f}".format(epoch, time_epoch / 60, epoch_loss))
        with open("output.txt", "a") as f:
            f.write("==>No: {} epoch, time: {:.2f}, loss: {:.5f}\n".format(epoch, time_epoch / 60, epoch_loss))
    print("total_loss:",total_loss*1.0/num_epochs-start_epochs)

In [8]:
def train_4(start_epoch):
    print("模型导入")
    # 前置模型
    model_1 = Retinex_Decomposition_net().to(device)
    model1_path = './save_model/Retinex_Light_Correction_net.pth'
    model_1.load_state_dict(torch.load(model1_path)['model'])
    model_2 = Illumination_Correction().to(device)
    model2_path = './save_model/Illumination_Correction_net.pth'
    model_2.load_state_dict(torch.load(model2_path)['model'])
    model_3 = AL_Area_Selfguidance_Color_Correction().to(device)
    model3_path = './save_model/AL_Area_Selfguidance_Color_Correction_net.pth'
    model_3.load_state_dict(torch.load(model3_path)['model'])
    
    # 后置模型
    model_4 = Detail_Enhancement().to(device)
    model_fusion = Channels_Fusion().to(device)
    if start_epoch != 0:
        model4_path = './checkpoints/Detail_Enhancement/epoch_' + str(start_epoch) + '.pth'
        model_4.load_state_dict(torch.load(model4_path))
        model_fusion_path = './checkpoints/Channels_Fusion/epoch_' + str(start_epoch) + '.pth'
        model_fusion.load_state_dict(torch.load(model_fusion_path))
    print("模型导入完成")
    model_1.eval()
    model_2.eval()
    model_3.eval()
    model_4.train()
    model_fusion.train()
    
    optimizer_4 = torch.optim.Adam(model_4.parameters(), lr=learning_rate)
    optimizer_fusion = torch.optim.Adam(model_fusion.parameters(), lr=learning_rate)
    total_loss = 0
    for epoch in range(start_epoch+1, num_epochs+1):
        print("epoch: ", epoch)
        # ABcc_path = r"./dataset/UIALN_datasest/train_data/dataset_with_AL/train"
        # gt_path = r"./dataset/UIALN_datasest/train_data/labels/raw"
        ABcc_path = r"/data/underwater/UIEB-EUVP-LSUI2/train/input"
        gt_path = r"/data/underwater/UIEB-EUVP-LSUI2/train/target"
        dataset = Detail_Enhancement_data(ABcc_path, gt_path, size=size)
        train_loader = DataLoader(dataset, batch_size=train_batch_size, shuffle=True, num_workers=0)
        start_time = time.time()
        prefetcher = DataPrefetcher(train_loader)
        batch = prefetcher.next()
        i = 0
        epoch_loss = 0
        while batch is not None:
            i+=1
            ABcc = batch[0].to(device)
            L = batch[1].to(device)
            gt_L_tensor = batch[2].to(device)
            gt = batch[3].to(device)
            
            temp = model_1(L)
            I_light, R_light = torch.split(temp, 1, dim=1)
            I_delight = model_2(temp)
            M_image = I_light - I_delight
            ABcc = model_3(M_image, ABcc)
            L_delight = I_delight * R_light
            
            L_en_hat = model_4(L_delight)   # enhanced L
            LAB_hat = torch.cat((L_en_hat, ABcc), dim=1)
            LAB_hat = model_fusion(LAB_hat)
            
            loss_haze = hazeLoss(gt_L_tensor, L_en_hat)
            loss_recons = recLoss(gt, LAB_hat)
            final_loss = loss_haze + loss_recons
            epoch_loss += final_loss
            
            optimizer_fusion.zero_grad()
            optimizer_4.zero_grad()
            final_loss.backward()
            optimizer_fusion.step()
            
            
            # final_loss.backward()
            optimizer_4.step()                
            
            batch = prefetcher.next()
        if epoch % save_point == 0:
            state = {'model': model_4.state_dict(), 'optimizer': optimizer_4.state_dict(), 'epoch': epoch}
            torch.save(state, './checkpoints/Detail_Enhancement/epoch_' + str(epoch) + '.pth')
            state = {'model': model_fusion.state_dict(), 'optimizer': optimizer_fusion.state_dict(), 'epoch': epoch}
            torch.save(state, './checkpoints/Channels_Fusion/epoch_' + str(epoch) + '.pth')
        time_epoch = time.time() - start_time
        epoch_loss = epoch_loss*1.0/i
        total_loss += epoch_loss
        print("==>No: {} epoch, time: {:.2f}, loss: {:.5f}".format(epoch, time_epoch / 60, epoch_loss))
        with open("output.txt", "a") as f:
            f.write("==>No: {} epoch, time: {:.2f}, loss: {:.5f}\n".format(epoch, time_epoch / 60, epoch_loss))
    print("total_loss:",total_loss*1.0/num_epochs-start_epochs)

#### 主函数-判定训练哪个模型

In [9]:
if __name__ == '__main__':
    print(torch.cuda.is_available())
    if model_choose == 1:
        train_1(start_epochs)
    elif model_choose == 2:
        train_2(start_epochs)
    elif model_choose == 3:
        train_3(start_epochs)
    elif model_choose == 4:
        train_4(start_epochs)
    else:
        print("model_choose error")

True
模型导入
模型导入完成
epoch:  1
==>No: 1 epoch, time: 4.79, loss: 0.05011
epoch:  2
==>No: 2 epoch, time: 4.81, loss: 0.03047
epoch:  3
==>No: 3 epoch, time: 4.85, loss: 0.02686
epoch:  4
==>No: 4 epoch, time: 4.82, loss: 0.02421
epoch:  5


RuntimeError: Parent directory ./checkpoints/Detail_Enhancement does not exist.