In [43]:
import os
from tqdm import tqdm
data_dir = 'dataset/val/image_gray'

In [44]:
import torch.nn as nn
import torch

def _weights_init(m):
    if isinstance(m, nn.Conv2d):
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)
    elif isinstance(m, nn.ConvTranspose2d):
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)
    elif isinstance(m, nn.BatchNorm2d):
        m.weight.data.fill_(1)
        m.bias.data.zero_()
    elif isinstance(m, nn.Linear):
        m.weight.data.normal_(0, 0.01)
        m.bias.data.zero_()

class Block(nn.Module):
    def __init__(self,in_channel,out_channel):
        super(Block,self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channel,out_channel,5,1,2),
            nn.ReLU(inplace=True),
        )
        self.apply(_weights_init)
    def forward(self,x):
        x = self.block(x)
        return x

class DCNN(nn.Module):
    def __init__(self,block_num):
        super(DCNN,self).__init__()
        self.input = nn.Sequential(
            nn.Conv2d(1,64,5,1,2),
            nn.ReLU(inplace=True),
        )
        self.blocks = nn.ModuleList()
        for _ in range(block_num-1):
            self.blocks.append(Block(64,64))
        self.out = nn.Sequential(
            nn.Conv2d(64,1,5,1,2),
        )
        self.apply(_weights_init)
    def forward(self,x):
        x = self.input(x)
        for block in self.blocks:
            x = block(x)
        x = self.out(x)
        return x

class RS(nn.Module):
    def __init__(self,block_num):
        super(RS,self).__init__()
        self.IRS = DCNN(16)
        self.ISMP = DCNN(6)
        self.input = nn.Sequential(
            nn.Conv2d(3,64,5,1,2),
            nn.ReLU(inplace=True),
        )
        self.blocks = nn.ModuleList()
        for _ in range(block_num-1):
            self.blocks.append(Block(64,64))
        self.out = nn.Sequential(
            nn.Conv2d(64,1,5,1,2),
        )
        self.apply(_weights_init)

    def initModel(self,device):
        state_dict = torch.load('./model/IRS_best_psnr.pkl',map_location=device)
        self.IRS.load_state_dict(state_dict['model'],strict=False)
        print('IRS model load done, epoch:%d, best_loss:%f, best psnr:%f'
              % (state_dict['epoch'],state_dict['best_loss'],state_dict['best_psnr']))

    def forward(self,x):
        map1 = self.IRS(x)
        map2 = self.ISMP(map1)
        y = self.input(torch.cat((x,map1,map2),1))
        for block in self.blocks:
            y = block(y)
        out = self.out(y)
        return map2,out

# 接着RS，使用RS的输出作为GRS的输入，进行最终的图像生成
class GRS(nn.Module):
    def __init__(self,block_num):
        super(GRS,self).__init__()
        self.RS = RS(16)
        self.GS = DCNN(16)
        self.input = nn.Sequential(
            nn.Conv2d(1,64,5,1,2),
            nn.ReLU(inplace=True),
        )
        self.blocks = nn.ModuleList()
        for _ in range(block_num-1):
            self.blocks.append(Block(64,64))
        self.out = nn.Sequential(
            nn.Conv2d(64,1,5,1,2),
        )
        self.apply(_weights_init)

    def initModel(self,device):
        state_dict = torch.load('./model/RS_best_psnr_ori - 副本.pkl',map_location=device)
        self.RS.load_state_dict(state_dict['model'],strict=False)
        print('RS model load done, epoch:%d, best_loss:%f, best psnr:%f'
              % (state_dict['epoch'],state_dict['best_loss'],state_dict['best_psnr']))

    def forward(self,x):
        _,out = self.RS(x)
        out = self.GS(out)
        y = self.input(out)
        for block in self.blocks:
            y = block(y)
        out = self.out(y)
        return out



In [45]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [46]:
import cv2
import numpy as np
from torchvision import transforms
import torch
from PIL import Image
# Define the transformation
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert image to PyTorch Tensor
    transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize to [-1,1]
])

In [49]:
model = RS(16).to(device)
state_dict = torch.load('./model/RS_best_psnr_ori - 副本.pkl', map_location=device)
model.load_state_dict(state_dict['model'], strict=False)
print('RS model load done, epoch:%d, best_loss:%f, best psnr:%f'
      % (state_dict['epoch'], state_dict['best_loss'], state_dict['best_psnr']))
model.eval()
PSNR = []
with torch.no_grad(): 
    for img in tqdm(os.listdir(data_dir)[2:3]):
        ht_img_path = os.path.join('./dataset/val/image_gray',img)
        image = Image.open(ht_img_path)
        image = transform(image)
        image = image.unsqueeze(0)
        map2, out = model(image.to(device))
        out = torch.tanh(out)
        out = (out + 1) * 127.5
        if out.is_cuda:
            out = out.cpu()
        out = out.numpy()
        for im in out:
            im = np.transpose(im, (1, 2, 0))
            im = np.squeeze(im)        
            im = cv2.equalizeHist(im.astype(np.uint8))
            cv2.imwrite('out.png', im)


RS model load done, epoch:395, best_loss:-1.000000, best psnr:30.509077


100%|██████████| 1/1 [00:01<00:00,  1.30s/it]
