In [1]:
import openslide as opsl

import os
import torch
from torch import nn
from torch.autograd import Variable
import torchvision
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import utils
from arch import define_Gen, define_Dis
import kornia
import pandas as pd
import warnings

import torch.nn.functional as F
import numpy as np
import json
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import skimage.transform
import argparse
from scipy.misc import imread, imresize
from PIL import Image

from sklearn.feature_extraction.image import reconstruct_from_patches_2d as reconstruct

In [2]:
warnings.filterwarnings('ignore')

In [3]:
class Arguments(object):
    def __init__(self, dictionary):
        """Constructor"""
        for key in dictionary:
            setattr(self, key, dictionary[key])

In [4]:
args = {
    'epochs': 30,
    'decay_epoch': 25,
    'batch_size': 16,
    'lr': 0.0002,
    'load_height': 128,
    'load_width': 128,
    'gpu_ids': '0',
    'crop_height': 128,
    'crop_width': 128,
    'alpha': 5, # Cyc loss
    'beta': 5, # Scyc loss
    'gamma': 2, # Dssim loss 
    'delta': 0.1, # Identity
    'training': True,
    'testing': True,
    'results_dir': '/project/DSone/as3ek/data/ganstain/500/results/',
    'dataset_dir': '/project/DSone/as3ek/data/ganstain/500/',
    'checkpoint_dir': '/project/DSone/as3ek/data/ganstain/500/checkpoint/',
    'norm': 'batch',
    'use_dropout': False,
    'ngf': 64,
    'ndf': 64,
    'gen_net': 'unet_128',
    'dis_net': 'n_layers',
    'self_attn': True,
    'spectral': True,
    'log_freq': 50,
    'custom_tag': 'p100',
    'gen_samples': True,
    'specific_samples': False
}

args = Arguments(args)

tag1 = 'noattn'
if args.self_attn:
    tag1 = 'attn'

tag2 = 'nospec'
if args.spectral:
    tag2 = 'spectral'

# Generate paths for checkpoint and results
args.identifier = str(args.gen_net) + '_' + str(args.dis_net) + '_' \
+ str(args.lr) + '_' + args.norm + '_' + tag1 + '_' + tag2 + '_' + str(args.batch_size) + '_' \
+ str(args.load_height) + '_coefs_' + str(args.alpha) + '_' + str(args.beta) + '_' + str(args.gamma) + '_'\
+ str(args.delta) + '_' + args.custom_tag

args.checkpoint_path = args.checkpoint_dir + args.identifier
args.results_path = args.results_dir + args.identifier

args.gpu_ids = []
for i in range(torch.cuda.device_count()):
    args.gpu_ids.append(i)
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
Gab = define_Gen(input_nc=3, output_nc=3, ngf=args.ngf, netG=args.gen_net, norm=args.norm, 
                                                    use_dropout= args.use_dropout, gpu_ids=args.gpu_ids, self_attn=args.self_attn, spectral = args.spectral)
Gba = define_Gen(input_nc=3, output_nc=3, ngf=args.ngf, netG=args.gen_net, norm=args.norm, 
                                                    use_dropout= args.use_dropout, gpu_ids=args.gpu_ids, self_attn=args.self_attn, spectral = args.spectral)

Network initialized with weights sampled from N(0,[0.02]).
Network initialized with weights sampled from N(0,[0.02]).


In [6]:
ckpt = utils.load_checkpoint('%s/latest.ckpt' % (args.checkpoint_path))
Gab.load_state_dict(ckpt['Gab'])
Gba.load_state_dict(ckpt['Gba'])

 [*] Loading checkpoint from /project/DSone/as3ek/data/ganstain/500/checkpoint/unet_128_n_layers_0.0002_batch_attn_spectral_16_128_coefs_5_5_2_0.1_p100/latest.ckpt succeed!


In [7]:
Gab.eval()
Gba.eval()
print('Eval mode')

Eval mode


In [8]:
transform = transforms.Compose([
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

to_pink = True
PATH = '/project/DSone/biopsy_images/chrc_data_case_preserved/train/EE/'
size = 256
target = '/scratch/as3ek/normalized_svs/'
target_path = '/scratch/as3ek/normalized_svs/patches/'
for file in os.listdir(PATH):
    if file == '130370_6722_001.svs':
        continue
    image = opsl.OpenSlide(PATH + file)
    new_dims = (image.dimensions[0] // 256) * 256 , (image.dimensions[1] // 256) * 256 
    joined_image = Image.new('RGB', (new_dims))
    for vert_count in range(image.dimensions[1] // size):
        y_cord = size * vert_count
        for hor_count in range(image.dimensions[0] // size):
            x_cord = size * hor_count
            patch = image.read_region((x_cord, y_cord), 0, (size, size))
            patch = patch.convert('RGB')
            patch = imresize(patch, (256, 256))
            patch = patch.transpose(2, 0, 1)
            patch = patch / 255.
            patch = torch.FloatTensor(patch).to(device)
            patch = transform(patch)
            patch = patch.unsqueeze(0)
            
            if to_pink:
                out = Gba(patch)
            else:
                out = Gab(patch)
            target_folder = target_path + file.split('.')[0]
            if not os.path.exists(target_folder):
                os.mkdir(target_folder)
            filename = target_folder + '/' + file.split('.')[0] + '_' + str(x_cord) + '_' + str(y_cord) + '.png'
            
            torchvision.utils.save_image((out + 1)/2, filename)
            
            out = (out + 1) / 2
            # this converts it from GPU to CPU and selects first image
            img = out.detach().cpu().numpy()[0]
            #convert image back to Height,Width,Channels
            img = np.transpose(img, (1,2,0))
            patch_join = Image.fromarray(np.uint8(img*255))
            
            joined_image.paste(patch_join, (x_cord, y_cord))
            
    joined_image.save(target + '/' + file.split('.')[0] + '.png')