In [1]:
# -*- encoding: utf-8 -*-
import IPython.display as display
import pandas as pd
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from models import ConvolutionalBlock, ResidualBlock
from torch import nn
import hickle as hkl
import matplotlib
import matplotlib.pyplot as plt
torch.cuda.set_device(3)
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
device = torch.device("cuda")

class testDataset(Dataset):
    def __init__(self):
        lo, hi, distance_chrome = hkl.load('./data/GM12878/test_data_half.hkl')
        lo = lo.squeeze()
        hi = hi.squeeze()
        lo = np.expand_dims(lo,axis=1)
        hi = np.expand_dims(hi,axis=1)
        self.sample_list = []
        for i in range(len(lo)):
            lr = lo[i]
            hr = hi[i]
            dist = abs(distance_chrome[i][0])
            label_one_hot = torch.zeros(5)
            label_one_hot[int(dist/40)]=1
            chrom = distance_chrome[i][1]
            self.sample_list.append([lr, hr, label_one_hot, dist, chrom])
        print("dataset loaded : " + str(len(lo)) + '*' + str(len(lo[0])) + '*' + str(len(lo[0][0])) + '*' + str(len(lo[0][0][0])))
    def __getitem__(self, i):
        (lr_img, hr_img, label_one_hot, distance, chromosome) = self.sample_list[i]
        return lr_img, hr_img, label_one_hot, distance, chromosome
    def __len__(self):
        return len(self.sample_list)
    
class Generator(nn.Module):
    def __init__(self, kernel_size=3, n_channels=64, n_blocks=5):
        super(Generator, self).__init__()
        self.conv_block1 = ConvolutionalBlock(in_channels=6, out_channels=n_channels, kernel_size=kernel_size,
                                              batch_norm=False, activation='relu')
        self.residual_blocks = nn.Sequential(
            *[ResidualBlock(kernel_size=kernel_size, n_channels=n_channels) for i in range(n_blocks)])
        self.conv_block2 = ConvolutionalBlock(in_channels=n_channels, out_channels=n_channels,
                                              kernel_size=kernel_size,
                                              batch_norm=True, activation=None)
        self.conv_block3 = ConvolutionalBlock(in_channels=n_channels, out_channels=128, kernel_size=kernel_size,
                                              batch_norm=False, activation=None)
        self.conv_block4 = ConvolutionalBlock(in_channels=128, out_channels=256, kernel_size=kernel_size,
                                              batch_norm=False, activation=None)
        self.conv_block5 = ConvolutionalBlock(in_channels=256, out_channels=1, kernel_size=1,
                                              batch_norm=False, activation='tanh')
    def forward(self, lr_imgs):
        output = self.conv_block1(lr_imgs)  # (batch_size, 1, 40, 40)
        residual = output
        output = self.residual_blocks(output)
        output = self.conv_block2(output)
        output = output + residual
        output = self.conv_block3(output)
        output = self.conv_block4(output)
        sr_imgs = self.conv_block5(output)
        return sr_imgs

def make_input(imgs, distances): #imgs batchsize*1*40*40     distances batchsize*5
    dis = distances.unsqueeze(2).unsqueeze(3)
    dis = dis.repeat(1,1,40,40)
    data_input = torch.cat((imgs,dis),1)
    return data_input

In [2]:
test_dataset = testDataset()
test_loader = torch.utils.data.DataLoader(test_dataset,
                                               batch_size=1,
                                               shuffle=False,
                                               num_workers=1,
                                               pin_memory=True)
generator = Generator(kernel_size=3,n_channels=64,n_blocks=5)
generator = generator.to(device)
mse_loss_criterion = nn.MSELoss()
mse_loss_criterion = mse_loss_criterion.to(device)

dataset loaded : 3173*1*40*40


In [3]:
num_to_test = 200
df = pd.read_excel("./log_cgan.xls", usecols=[0, 1])
df = df.sort_values(by='MSE损失',inplace=False,ascending=True).head(num_to_test)
mse_list_train = df["MSE损失"].values.tolist()
epoch_list = df["epoch"].values.tolist()
list_ave = []
list_0 = []
list_40 = []
list_80 = []
list_120 = []
list_160 = []
cnt = 0
for epoch in epoch_list :
    model_path = "./result_cgan/best_checkpoint_epoch" + str(epoch).zfill(4) + ".pth"
    if os.path.exists(model_path):
        cnt = cnt + 1
    else:
        break
    checkpoint = torch.load(model_path)
    generator.load_state_dict(checkpoint['generator'])
    generator = generator.eval()

    ave_mse = []
    MSE_0 = []
    MSE_40 = []
    MSE_80 = []
    MSE_120 = []
    MSE_160 = []
    for i, (lr_img, hr_img, label, distance, chrom) in enumerate(test_loader):
        display.clear_output(wait=True)
        print("***开始测试***")
        print("正在测试模型：" + str(cnt) + "/" + str(num_to_test))
        print("正在处理样本 ： " + str(i+1) + "/"+str(len(test_dataset)))
        print("MSE_AVG = " + str(np.mean(ave_mse)))
        print("MSE_0   = " + str(np.mean(MSE_0)))
        print("MSE_40  = " + str(np.mean(MSE_40)))
        print("MSE_80  = " + str(np.mean(MSE_80)))
        print("MSE_120 = " + str(np.mean(MSE_120)))
        print("MSE_160 = " + str(np.mean(MSE_160)))
        
        lr_img = lr_img.type(torch.FloatTensor).to(device)
        hr_img = hr_img.type(torch.FloatTensor).to(device)
        label = label.to(device)
        G_input = make_input(lr_img, label)
        with torch.no_grad():
            sr_img = generator(G_input.detach())
        mse = mse_loss_criterion(sr_img , hr_img).to('cpu')
        ave_mse.append(mse)
        if abs(distance) == 0:
            MSE_0.append(mse)
        elif abs(distance) == 40:
            MSE_40.append(mse) 
        elif abs(distance) == 80:
            MSE_80.append(mse) 
        elif abs(distance) == 120:
            MSE_120.append(mse) 
        elif abs(distance) == 160:
            MSE_160.append(mse)
    list_ave.append(np.mean(ave_mse))
    list_0.append(np.mean(MSE_0))
    list_40.append(np.mean(MSE_40))
    list_80.append(np.mean(MSE_80))
    list_120.append(np.mean(MSE_120))
    list_160.append(np.mean(MSE_160))
print("***测试结束***")


KeyError: 'MSE损失h'

In [None]:
data = {"epoch" : epoch_list,
        "mse_train" mse_list_train，
        "mse_average_test" : list_ave,
        "mse_0"   : list_0,
        "mse_40"  : list_40,
        "mse_80"  : list_80,
        "mse_120" : list_120,
        "mse_160" : list_160}
df=pd.DataFrame(data)
df.to_excel("./model2.xlsx")
df = df.sort_values(by='mse_average',inplace=False,ascending=True).head(num_to_test)
print(df.head(50))

## 