In [1]:
# %load test.py
import IPython.display as display
import torch
from torch.utils.data import Dataset
from models import ConvolutionalBlock, ResidualBlock
from torch import nn
import numpy as np
import time
from PIL import Image
import cv2 as cv
import os
import hickle as hkl
import shutil

torch.cuda.set_device(0)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
test_data_path = './data/GM12878/test_data_half.hkl'

In [2]:
class Generator(nn.Module):
    def __init__(self, kernel_size=3, n_channels=64, n_blocks=5):
        super(Generator, self).__init__()

        # 第一个卷积块
        self.conv_block1 = ConvolutionalBlock(in_channels=6, out_channels=n_channels, kernel_size=kernel_size,
                                              batch_norm=False, activation='relu')

        # 一系列残差模块, 每个残差模块包含一个跳连
        self.residual_blocks = nn.Sequential(
            *[ResidualBlock(kernel_size=kernel_size, n_channels=n_channels) for i in range(n_blocks)])

        # 第二个卷积块
        self.conv_block2 = ConvolutionalBlock(in_channels=n_channels, out_channels=n_channels,
                                              kernel_size=kernel_size,
                                              batch_norm=True, activation=None)

        # 最后3个卷积模块
        self.conv_block3 = ConvolutionalBlock(in_channels=n_channels, out_channels=128, kernel_size=kernel_size,
                                              batch_norm=False, activation=None)
        self.conv_block4 = ConvolutionalBlock(in_channels=128, out_channels=256, kernel_size=kernel_size,
                                              batch_norm=False, activation=None)
        self.conv_block5 = ConvolutionalBlock(in_channels=256, out_channels=1, kernel_size=1,
                                              batch_norm=False, activation='tanh')

    def forward(self, lr_imgs):
        output = self.conv_block1(lr_imgs)  # (batch_size, 1, 40, 40)
        residual = output
        output = self.residual_blocks(output)
        output = self.conv_block2(output)
        output = output + residual
        output = self.conv_block3(output)
        output = self.conv_block4(output)
        sr_imgs = self.conv_block5(output)
        return sr_imgs

def make_input(imgs, distances): #imgs batchsize*1*40*40     distances batchsize*5
    dis = distances.unsqueeze(2).unsqueeze(3)
    dis = dis.repeat(1,1,40,40)
    data_input = torch.cat((imgs,dis),1)
    return data_input



class testDataset(Dataset):
    def __init__(self):
        lo, hi, distance_chrome = hkl.load(test_data_path)
        lo = lo.squeeze()
        hi = hi.squeeze()
        lo = np.expand_dims(lo,axis=1)
        hi = np.expand_dims(hi,axis=1)
        self.sample_list = []
        for i in range(len(lo)):
            lr = lo[i]
            hr = hi[i]
            dist = distance_chrome[i][0]
            label_one_hot = torch.zeros(5)
            label_one_hot[int(-abs((distance_chrome[i][0]))/40)]=1
            chrom = distance_chrome[i][1]
            self.sample_list.append([lr, hr, label_one_hot, dist, chrom])
        print("dataset loaded : " + str(len(lo)) + '*' + str(len(lo[0])) + '*' + str(len(lo[0][0])) + '*' + str(len(lo[0][0][0])))

    def __getitem__(self, i):
        (lr_img, hr_img, label_one_hot, distance, chromosome) = self.sample_list[i]
        return lr_img, hr_img, label_one_hot, distance, chromosome

    def __len__(self):
        return len(self.sample_list)

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load("./model/best_checkpoint_epoch3840.pth")
test_dataset = testDataset()
test_loader = torch.utils.data.DataLoader(test_dataset,
                                           batch_size=1,
                                           shuffle=False,
                                           num_workers=1,
                                           pin_memory=True)
generator = Generator(kernel_size=3,n_channels=64,n_blocks=5)
generator = generator.to(device)
generator.load_state_dict(checkpoint['generator'])
generator.eval()

dataset loaded : 3173*1*40*40


Generator(
  (conv_block1): ConvolutionalBlock(
    (conv_block): Sequential(
      (0): Conv2d(6, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): ReLU()
    )
  )
  (residual_blocks): Sequential(
    (0): ResidualBlock(
      (conv_block1): ConvolutionalBlock(
        (conv_block): Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (conv_block2): ConvolutionalBlock(
        (conv_block): Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
    )
    (1): ResidualBlock(
      (conv_block1): ConvolutionalBlock(
        (conv_block): Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), b

In [4]:
def save_img(img, i, label):
    img = img.squeeze().cpu()
    img = img.numpy()
    img = 127.5 * img + 127.5
    
    img = cv.cvtColor(img, cv.COLOR_GRAY2RGB)
    red = img[:, :, 0]
    white = 255-red
    img[:, :, 1] = 255
    img[:, :, 2] = 255
    img = Image.fromarray(np.uint8(img))
    img.save('./test/' + str(i).zfill(3) + '_' + label + '.jpg')
    
shutil.rmtree('./test')
os.mkdir('./test')
cnt = 0
for i, (lr_img, hr_img, label, distance, __) in enumerate(test_loader):
    display.clear_output(wait=True)
    if not distance == 0 :
        continue
    cnt = cnt + 1
    print("***开始生成***")
    print("正在处理：" + str(i) + "/" + str(len(test_dataset)))
    lr_img = lr_img.type(torch.FloatTensor).to(device)
    hr_img = hr_img.type(torch.FloatTensor).to(device)
    label = label.to(device)
    G_input = make_input(lr_img, label)
    with torch.no_grad():
        sr_img = generator(G_input.detach())

    save_img(lr_img, cnt, 'lr')
    save_img(hr_img, cnt, 'hr')
    save_img(sr_img, cnt, 'sr')

print("***生成结束***")

***生成结束***
