In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import glob
import os

import numpy as np
import scipy.ndimage as nd
import skimage
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from matplotlib import pyplot as plt
from skimage import measure, segmentation
from skimage.draw import polygon, rectangle
from skimage.transform import resize
import os
import pickle


import config
import Extraction.mask.maskHelper as mh
from Extraction.clustering import run_clustering
from Training.SNN.model import DataLoaderClustering, Network, NetworkResnet
from Training.SNN.pageHandler import PageHandler, Pixel
from Training.FCN.unet import UNet
from Extraction.page import Page

In [None]:
# Load Pages
config.IMAGE_PATH = "F:\\Code\\Final\\data\\"
config.CLUSTERING_MODEL = "F:\\Code\\Final\\models\\resnet_time_4passes_complete.pth"
config.MASK_MODEL = "F:\\Code\\Final\\models\\mask_time_real.pth"
config.DEVICE = 'cuda'



In [None]:
import xml.etree.ElementTree as ET
xmlTree = ET.parse('F:\\Code\\Final\\models\\val.xml')
root = xmlTree.getroot()

images = root.findall('image')

pages = []
for e in images:
  pages.append(Page(e.attrib['name']))

In [None]:
len(pages)

In [None]:
# Installation of CUDA:
# https://pytorch.org/get-started/locally/
config.DEVICE = "cuda"
if config.DEVICE == "cuda":
    print('est')
    cluster_network = NetworkResnet(img_size=config.PATCH_SIZE).cuda()
    cluster_model = torch.load(config.CLUSTERING_MODEL,map_location=torch.device('cuda'))
    
    mask_network = UNet(n_channels=3, n_classes=1, bilinear=False).cuda()
    mask_model = torch.load(config.MASK_MODEL,map_location=torch.device('cuda'))
else:
    cluster_network = NetworkResnet(img_size=config.PATCH_SIZE)
    cluster_model = torch.load(config.CLUSTERING_MODEL,map_location=torch.device('cpu'))
    
    mask_network = UNet(n_channels=3, n_classes=1, bilinear=False)
    mask_model = torch.load(config.MASK_MODEL,map_location=torch.device('cpu'))



In [None]:
torch.cuda.empty_cache()

cluster_network.eval()
mask_network.eval()

In [None]:
cluster_network.load_state_dict(cluster_model)
mask_network.load_state_dict(mask_model)

convert_cluster = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

convert_mask = transforms.Compose([transforms.Resize([256, 256]), transforms.ToTensor()])
outfolder = 'results/'

for page in pages:
        patchHelper = page.patchHelper
        patchHelper.generateAllPatches()
        
        # Calculate Foreground/Background Mask
        mask, _ = mh.run_mask_network(page, mask_network, convert_mask)

        # Calculate Clusters
        clusters = run_clustering(page, patchHelper.patches, cluster_network, convert_cluster)
        
        # store results of clustering
        file = open(outfolder + page.file + ".cluster", 'wb')
        pickle.dump(clusters, file)
        file.close()
  
        #Show Clustering result
        clustering = resize(clusters, (page.image.size[1], page.image.size[0]), order = 0)
        #plt.imshow(clustering)
        #plt.show()

        #free memory
        del patchHelper.patches
        
        page.pageHandler.add_clustering(clusters)
        
        #mh.test(patchHelper, page, np.array(clusters).flatten(), mask=None)
        
        # Show Mask Result
        #plt.imshow(mask)
        #plt.show()
        
        # Erode Mask and Fill Holes
        mask = mh.fix_mask(mask, page.image_size, erode = 3)
        page.pageHandler.add_mask(mask)
        
        
        #mh.test(patchHelper, page, np.array(clusters).flatten(), mask)
        # Mask after Processing
        #plt.imshow(mask)
        #plt.show()
        
        # Assign clusters to mask and expand mask
        # Clustering First
        final_mask = mh.combine_cluster_mask(page, mask)
        
        page.pageHandler.add_final_segments(final_mask)
        
        #plt.rcParams["figure.figsize"] = (20,20)
        #plt.imshow(page.image)
        #plt.imshow(final_mask, cmap='jet', alpha=0.7)
        #plt.axis("off")
        #plt.show()
        #plt.savefig(outfolder +'/' + page.file)
        #plt.clf()
        #np.save(outfolder +'/' + page.file + '', final_mask)
        
        # Store result of combination
        file = open(outfolder + page.file + ".full", 'wb')
        pickle.dump(final_mask, file)
        file.close()
