In [1]:
# -*- encoding: utf-8 -*-
import IPython.display as display
import pandas as pd
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from models import Generator
from torch import nn
import hickle as hkl
import matplotlib
import matplotlib.pyplot as plt
torch.cuda.set_device(3)
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
device = torch.device("cuda")

class testDataset(Dataset):
    def __init__(self):
        lo, hi, distance_chrome = hkl.load('./data/GM12878/test_data_half.hkl')
        lo = lo.squeeze()
        hi = hi.squeeze()
        lo = np.expand_dims(lo,axis=1)
        hi = np.expand_dims(hi,axis=1)
        self.sample_list = []
        for i in range(len(lo)):
            lr = lo[i]
            hr = hi[i]
            dist = abs(distance_chrome[i][0])
            label_one_hot = torch.zeros(5)
            label_one_hot[int(dist/40)]=1
            chrom = distance_chrome[i][1]
            self.sample_list.append([lr, hr, label_one_hot, dist, chrom])
        print("dataset loaded : " + str(len(lo)) + '*' + str(len(lo[0])) + '*' + str(len(lo[0][0])) + '*' + str(len(lo[0][0][0])))
    def __getitem__(self, i):
        (lr_img, hr_img, label_one_hot, distance, chromosome) = self.sample_list[i]
        return lr_img, hr_img, label_one_hot, distance, chromosome
    def __len__(self):
        return len(self.sample_list)
    

In [2]:
test_dataset = testDataset()
test_loader = torch.utils.data.DataLoader(test_dataset,
                                               batch_size=1,
                                               shuffle=False,
                                               num_workers=1,
                                               pin_memory=True)
generator = Generator(kernel_size=3,n_channels=64,n_blocks=5)
generator = generator.to(device)
mse_loss_criterion = nn.MSELoss()
mse_loss_criterion = mse_loss_criterion.to(device)

dataset loaded : 3173*1*40*40


In [3]:
num_to_test = 30
df = pd.read_excel("./log_wgan.xls", usecols=[0, 1])
df = df.sort_values(by='MSE损失',inplace=False,ascending=True).head(num_to_test)
mse_list_train = df["MSE损失"].values.tolist()
epoch_list = df["epoch"].values.tolist()
list_ave = []
list_0 = []
list_40 = []
list_80 = []
list_120 = []
list_160 = []
cnt = 0
for epoch in epoch_list :
    model_path = "./result_wgan/best_checkpoint_epoch" + str(epoch).zfill(4) + ".pth"
    if os.path.exists(model_path):
        cnt = cnt + 1
    else:
        break
    checkpoint = torch.load(model_path)
    generator.load_state_dict(checkpoint['generator'])
    generator = generator.eval()

    ave_mse = []
    MSE_0 = []
    MSE_40 = []
    MSE_80 = []
    MSE_120 = []
    MSE_160 = []
    for i, (lr_img, hr_img, label, distance, chrom) in enumerate(test_loader):
        display.clear_output(wait=True)
        print("***开始测试***")
        print("正在测试模型：" + str(cnt) + "/" + str(num_to_test))
        print("正在处理样本 ： " + str(i+1) + "/"+str(len(test_dataset)))
        print("MSE_AVG = " + str(np.mean(ave_mse)))
        print("MSE_0   = " + str(np.mean(MSE_0)))
        print("MSE_40  = " + str(np.mean(MSE_40)))
        print("MSE_80  = " + str(np.mean(MSE_80)))
        print("MSE_120 = " + str(np.mean(MSE_120)))
        print("MSE_160 = " + str(np.mean(MSE_160)))
        
        lr_img = lr_img.type(torch.FloatTensor).to(device)
        hr_img = hr_img.type(torch.FloatTensor).to(device)
        with torch.no_grad():
            sr_img = generator(lr_img.detach())
        mse = mse_loss_criterion(sr_img , hr_img).to('cpu')
        ave_mse.append(mse)
        if abs(distance) == 0:
            MSE_0.append(mse)
        elif abs(distance) == 40:
            MSE_40.append(mse) 
        elif abs(distance) == 80:
            MSE_80.append(mse) 
        elif abs(distance) == 120:
            MSE_120.append(mse) 
        elif abs(distance) == 160:
            MSE_160.append(mse)
    list_ave.append(np.mean(ave_mse))
    list_0.append(np.mean(MSE_0))
    list_40.append(np.mean(MSE_40))
    list_80.append(np.mean(MSE_80))
    list_120.append(np.mean(MSE_120))
    list_160.append(np.mean(MSE_160))
print("***测试结束***")

***开始测试***
正在测试模型：27/30
正在处理样本 ： 3173/3173
MSE_AVG = 0.032896783
MSE_0   = 0.042761058
MSE_40  = 0.024307575
MSE_80  = 0.027780522
MSE_120 = 0.03336994
MSE_160 = 0.03626112
***测试结束***


In [4]:
data = {"epoch" : epoch_list[0:cnt],
        "mse_train" :mse_list_train[0:cnt],
        "mse_test" : list_ave,
        "mse_0"   : list_0,
        "mse_40"  : list_40,
        "mse_80"  : list_80,
        "mse_120" : list_120,
        "mse_160" : list_160}
df=pd.DataFrame(data)
df.to_excel("./model2.xlsx")
df = df.sort_values(by='mse_test',inplace=False,ascending=True).head(num_to_test)
print(df.head(100).to_string(index=False))

 epoch  mse_train  mse_test     mse_0    mse_40    mse_80   mse_120   mse_160
  7557   0.024359  0.024704  0.017214  0.019051  0.024069  0.029579  0.034385
  6151   0.022956  0.028580  0.021048  0.022744  0.029591  0.033903  0.036327
  6148   0.023718  0.028592  0.021283  0.022114  0.029279  0.034292  0.036730
  6094   0.024080  0.028865  0.020235  0.023480  0.030219  0.034436  0.036700
  6095   0.023949  0.028893  0.020062  0.023603  0.030068  0.034489  0.037007
  6150   0.023491  0.029261  0.024331  0.022412  0.028824  0.034323  0.037062
  6096   0.023799  0.029265  0.019802  0.024582  0.030977  0.034667  0.037055
  5091   0.023577  0.029331  0.040429  0.020000  0.022617  0.028751  0.034879
  7612   0.024315  0.029624  0.029188  0.022676  0.027705  0.032999  0.035982
  6092   0.024105  0.029673  0.024585  0.023903  0.029433  0.034214  0.036832
  6091   0.024158  0.029800  0.023278  0.024149  0.030191  0.034812  0.037234
  6090   0.024146  0.029970  0.022427  0.024177  0.030949  0.035

## 