In [1]:
# %load test.py
#!/usr/bin/env python
import torch
from torch.utils.data import Dataset
import numpy as np
from models import Generator
import time
from PIL import Image
import cv2 as cv
import os
import hickle as hkl
import shutil

torch.cuda.set_device(0)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [2]:
def save_img(img, i, label):
    img = img.squeeze().cpu()
    img = img.numpy()
    img = 127.5 * img + 127.5
    img = cv.cvtColor(img, cv.COLOR_GRAY2RGB)
    img[:, :, 1] = 0
    img[:, :, 2] = 0
    img = Image.fromarray(np.uint8(img))
    if img.mode == "F":
        img = img.convert('RGB')
    img.save('./test/' + str(i).zfill(3) + '_' + label + '.jpg')
    return

class test_Dataset(Dataset):
    def __init__(self):
        lo, hi, _ = hkl.load('./data/GM12878/test_data_half.hkl')
        lo = lo.squeeze()
        hi = hi.squeeze()
        lo = np.expand_dims(lo,axis=1)
        hi = np.expand_dims(hi,axis=1)
        len_min = int(min(len(hi), len(lo)))
        self.sample_list = []
        for i in range(len_min):
            lr = lo[i]
            hr = hi[i]
            self.sample_list.append([lr, hr])
        print("dataset loaded : " + str(len(lo)) + '*' + str(len(lo[0])) + '*' + str(len(lo[0][0])) + '*' + str(len(lo[0][0][0])))

    def __getitem__(self, i):
        (lr_img, hr_img) = self.sample_list[i]
        return lr_img, hr_img

    def __len__(self):
        return len(self.sample_list)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
max_num = 1000
cnt = 0

if __name__ == '__main__':
    shutil.rmtree('./test')
    os.mkdir('./test')
    
    # checkpoint = torch.load("./result/checkpoint_epoch280.pth")
    # checkpoint = torch.load("./result_wgan/checkpoint_epoch600.pth")
    # checkpoint = torch.load("./result_wgan/best_checkpoint_epoch031.pth")
    checkpoint = torch.load("./result_gan/best_checkpoint_epoch20109.pth")
    # checkpoint = torch.load("./result/best_checkpoint_epoch0381.pth")
    test_dataset = test_Dataset()
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                               batch_size=1,
                                               shuffle=False,
                                               num_workers=1,
                                               pin_memory=True)
    generator = Generator(kernel_size=3,n_channels=64,n_blocks=5)
    generator = generator.to(device)
    generator.load_state_dict(checkpoint['generator'])
    generator.eval()
    
    print("***开始生成***")
    for i, (lr_img, hr_img) in enumerate(test_loader):
        img = lr_img.type(torch.FloatTensor).squeeze()
        tag = 1
        for x in range(40):
            for y in range(40):
                if not img[x,y]==img[y,x]:
                    tag = 0
                    break
        if tag == 0 :
            continue
        cnt = cnt + 1
        lr_img = lr_img.type(torch.FloatTensor)
        hr_img = hr_img.type(torch.FloatTensor)
        lr_img = lr_img.to(device)
        hr_img = hr_img.to(device)
        with torch.no_grad():
            sr_img = generator(lr_img.detach())
        
        save_img(lr_img, cnt, 'lr')
        save_img(hr_img, cnt, 'hr')
        save_img(sr_img, cnt, 'sr')
        if i == max_num:
            break
    print("***生成结束***")

dataset loaded : 3173*1*40*40
***开始生成***
***生成结束***
