In [1]:
### -*- encoding: utf-8 -*-
import IPython.display as display
import pandas as pd
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from models import ConvolutionalBlock, ResidualBlock
from torch import nn
import hickle as hkl
import matplotlib
import matplotlib.pyplot as plt
from skimage.metrics import peak_signal_noise_ratio, structural_similarity

torch.cuda.set_device(3)
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
device = torch.device("cuda")

class testDataset(Dataset):
    def __init__(self):
        lo, hi, distance_chrome = hkl.load('./data/GM12878/test_data_half.hkl')
        lo = lo.squeeze()
        hi = hi.squeeze()
        lo = np.expand_dims(lo,axis=1)
        hi = np.expand_dims(hi,axis=1)
        self.sample_list = []
        for i in range(len(lo)):
            lr = lo[i]
            hr = hi[i]
            dist = abs(distance_chrome[i][0])
            label_one_hot = torch.zeros(5)
            label_one_hot[int(dist/40)]=1
            chrom = distance_chrome[i][1]
            self.sample_list.append([lr, hr, label_one_hot, dist, chrom])
        print("dataset loaded : " + str(len(lo)) + '*' + str(len(lo[0])) + '*' + str(len(lo[0][0])) + '*' + str(len(lo[0][0][0])))
    def __getitem__(self, i):
        (lr_img, hr_img, label_one_hot, distance, chromosome) = self.sample_list[i]
        return lr_img, hr_img, label_one_hot, distance, chromosome
    def __len__(self):
        return len(self.sample_list)
    
class Generator(nn.Module):
    def __init__(self, kernel_size=3, n_channels=64, n_blocks=5):
        super(Generator, self).__init__()
        self.conv_block1 = ConvolutionalBlock(in_channels=6, out_channels=n_channels, kernel_size=kernel_size,
                                              batch_norm=False, activation='relu')
        self.residual_blocks = nn.Sequential(
            *[ResidualBlock(kernel_size=kernel_size, n_channels=n_channels) for i in range(n_blocks)])
        self.conv_block2 = ConvolutionalBlock(in_channels=n_channels, out_channels=n_channels,
                                              kernel_size=kernel_size,
                                              batch_norm=True, activation=None)
        self.conv_block3 = ConvolutionalBlock(in_channels=n_channels, out_channels=128, kernel_size=kernel_size,
                                              batch_norm=False, activation=None)
        self.conv_block4 = ConvolutionalBlock(in_channels=128, out_channels=256, kernel_size=kernel_size,
                                              batch_norm=False, activation=None)
        self.conv_block5 = ConvolutionalBlock(in_channels=256, out_channels=1, kernel_size=1,
                                              batch_norm=False, activation='tanh')
    def forward(self, lr_imgs):
        output = self.conv_block1(lr_imgs)  # (batch_size, 1, 40, 40)
        residual = output
        output = self.residual_blocks(output)
        output = self.conv_block2(output)
        output = output + residual
        output = self.conv_block3(output)
        output = self.conv_block4(output)
        sr_imgs = self.conv_block5(output)
        return sr_imgs

def make_input(imgs, distances): #imgs batchsize*1*40*40     distances batchsize*5
    dis = distances.unsqueeze(2).unsqueeze(3)
    dis = dis.repeat(1,1,40,40)
    data_input = torch.cat((imgs,dis),1)
    return data_input

In [2]:
test_dataset = testDataset()
test_loader = torch.utils.data.DataLoader(test_dataset,
                                               batch_size=1,
                                               shuffle=False,
                                               num_workers=1,
                                               pin_memory=True)
generator = Generator(kernel_size=3,n_channels=64,n_blocks=5)
generator = generator.to(device)
mse_loss_criterion = nn.MSELoss()
mse_loss_criterion = mse_loss_criterion.to(device)

dataset loaded : 3173*1*40*40


In [3]:
num_to_test = 20
df = pd.read_excel("./log_cgan4.xls", usecols=[0, 1])
df = df.sort_values(by='MSE损失',inplace=False,ascending=True).head(num_to_test)
mse_list_train = df["MSE损失"].values.tolist()
epoch_list = df["epoch"].values.tolist()
list_ave = []
list_0 = []
list_40 = []
list_80 = []
list_120 = []
list_160 = []
cnt = 0
for epoch in epoch_list :
    model_path = "./result_cgan4/best_checkpoint_epoch" + str(epoch).zfill(4) + ".pth"
    if os.path.exists(model_path):
        cnt = cnt + 1
    else:
        break
    checkpoint = torch.load(model_path)
    generator.load_state_dict(checkpoint['generator'])
    generator = generator.eval()

    ave_mse = []
    MSE_0 = []
    MSE_40 = []
    MSE_80 = []
    MSE_120 = []
    MSE_160 = []
    ave_psnr = []
    PSNR_0 = []
    PSNR_40 = []
    PSNR_80 = []
    PSNR_120 = []
    PSNR_160 = []
    ave_ssim = []
    SSIM_0 = []
    SSIM_40 = []
    SSIM_80 = []
    SSIM_120 = []
    SSIM_160 = []
    for i, (lr_img, hr_img, label, distance, chrom) in enumerate(test_loader):
        display.clear_output(wait=True)
        print("***开始测试***")
        print("正在测试模型：" + str(cnt) + "/" + str(num_to_test))
        print("正在处理样本 ： " + str(i+1) + "/"+str(len(test_dataset)))
        print("MSE_AVG = " + str(np.mean(ave_mse))  + "  PSNR_AVG = " + str(np.mean(ave_psnr))+ "  SSIM_AVG = " + str(np.mean(ave_ssim)))
        print("MSE_0   = " + str(np.mean(MSE_0))  + "  PSNR_0   = " + str(np.mean(PSNR_0))+ "  SSIM_0   = " + str(np.mean(SSIM_0)))
        print("MSE_40  = " + str(np.mean(MSE_40)) + "  PSNR_40  = " + str(np.mean(PSNR_40))+ "  SSIM_40  = " + str(np.mean(SSIM_40)))
        print("MSE_80  = " + str(np.mean(MSE_80)) + "  PSNR_80  = " + str(np.mean(PSNR_80))+ "  SSIM_80  = " + str(np.mean(SSIM_80)))
        print("MSE_120 = " + str(np.mean(MSE_120))+ "  PSNR_120 = " + str(np.mean(PSNR_120))+ "  SSIM_120 = " + str(np.mean(SSIM_120)))
        print("MSE_160 = " + str(np.mean(MSE_160))+ "  PSNR_160 = " + str(np.mean(PSNR_160))+ "  SSIM_160 = " + str(np.mean(SSIM_160)))
        lr_img = lr_img.type(torch.FloatTensor).to(device)
        hr_img = hr_img.type(torch.FloatTensor).to(device)
        label = label.to(device)
        G_input = make_input(lr_img, label)
        with torch.no_grad():
            sr_img = generator(G_input.detach())
        mse = mse_loss_criterion(sr_img , hr_img).to('cpu')
        sr_img = sr_img.squeeze().to("cpu").numpy()
        hr_img = hr_img.squeeze().to("cpu").numpy()
        
        psnr = peak_signal_noise_ratio(hr_img, sr_img)
        ssim = structural_similarity(hr_img, sr_img)
    
        ave_mse.append(mse)
        ave_psnr.append(psnr)
        ave_ssim.append(ssim)
        if abs(distance) == 0:
            MSE_0.append(mse)
            PSNR_0.append(psnr)
            SSIM_0.append(ssim)
        elif abs(distance) == 40:
            MSE_40.append(mse)
            PSNR_40.append(psnr)
            SSIM_40.append(ssim)
        elif abs(distance) == 80:
            MSE_80.append(mse)
            PSNR_80.append(psnr)
            SSIM_80.append(ssim)
        elif abs(distance) == 120:
            MSE_120.append(mse)
            PSNR_120.append(psnr)
            SSIM_120.append(ssim)
        elif abs(distance) == 160:
            MSE_160.append(mse)
            PSNR_160.append(psnr)
            SSIM_160.append(ssim)
    list_ave.append(np.mean(ave_mse))
    list_0.append(np.mean(MSE_0))
    list_40.append(np.mean(MSE_40))
    list_80.append(np.mean(MSE_80))
    list_120.append(np.mean(MSE_120))
    list_160.append(np.mean(MSE_160))
print("***测试结束***")


***开始测试***
正在测试模型：20/20
正在处理样本 ： 3173/3173
MSE_AVG = 0.060361207  PSNR_AVG = 20.567508585190602  SSIM_AVG = 0.2835752793499226
MSE_0   = 0.017268976  PSNR_0   = 24.221341487281993  SSIM_0   = 0.5493183538855767
MSE_40  = 0.020786356  PSNR_40  = 23.716214049159284  SSIM_40  = 0.2574529514092742
MSE_80  = 0.031026667  PSNR_80  = 21.480572533032053  SSIM_80  = 0.23161828928022926
MSE_120 = 0.16813369  PSNR_120 = 14.099685757831402  SSIM_120 = 0.21531540741854288
MSE_160 = 0.06862084  PSNR_160 = 18.993057109872762  SSIM_160 = 0.1506556700942406
***测试结束***


In [4]:
data = {"epoch" : epoch_list[0:cnt],
        "mse_train" :mse_list_train[0:cnt],
        "mse_test" : list_ave,
        "mse_0"   : list_0,
        "mse_40"  : list_40,
        "mse_80"  : list_80,
        "mse_120" : list_120,
        "mse_160" : list_160}
df=pd.DataFrame(data)
df.to_excel("./model.xlsx")
df = df.sort_values(by='mse_test',inplace=False,ascending=True).head(num_to_test)
print(df.head(100).to_string(index=False))

 epoch  mse_train  mse_test     mse_0    mse_40    mse_80   mse_120   mse_160
   350   0.023942  0.030828  0.024474  0.025532  0.024071  0.039380  0.041526
   107   0.022577  0.032742  0.023131  0.024741  0.034988  0.023786  0.058327
   262   0.023227  0.034745  0.020599  0.026460  0.021168  0.051144  0.056029
   101   0.022371  0.035308  0.030408  0.020239  0.025251  0.044163  0.057905
   158   0.024053  0.040380  0.039550  0.039290  0.033946  0.049828  0.039456
   256   0.022671  0.040932  0.030568  0.030481  0.030490  0.077777  0.036270
   303   0.022210  0.042930  0.026186  0.058267  0.025590  0.050601  0.054866
   424   0.022793  0.043140  0.033827  0.021147  0.033731  0.078778  0.049717
   221   0.023099  0.044235  0.059389  0.070482  0.023049  0.029547  0.037320
   425   0.022146  0.045731  0.022478  0.043221  0.071793  0.045295  0.046627
    85   0.024414  0.048719  0.032531  0.024722  0.054262  0.058875  0.075285
   264   0.019699  0.053782  0.025173  0.057369  0.052566  0.058

## 