In [None]:
from torchvision import datasets, models, transforms, utils
from cc_torch import connected_components_labeling
import numpy as np
import torch
import os 
import time 

from skimage.measure import label, regionprops
import matplotlib.patches as mpatches
from skimage.color import label2rgb
import matplotlib.pyplot as plt
from PIL import Image 

In [None]:
def write_patch(image, label, idx, area_cutoff):
    masks=[]
    bbox=[]
    for region in regionprops(label):
        if region.area > area_cutoff: 
            masks.append(region.filled_image)
            bbox.append(region.bbox) 

    j=0
    if os.path.exists('patches') is False: 
        os.mkdir('patches')
    if os.path.exists('patches/img_%i' %idx) is False: 
        os.mkdir('patches/img_%i' %idx)
    for box, mask in zip(bbox,masks):
        im = Image.fromarray(mask)
        im.save('patches/img_%i/img_%i_patch_%i.png' %(idx,idx,j))
        j+=1 

In [None]:
data_dir        = 'data_diffract'
num_workers     = 4
data_transforms = transforms.Compose([transforms.Grayscale(num_output_channels=1),
                            transforms.ToTensor()])
device          = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
if torch.cuda.device_count() > 1: 
    print("Using", torch.cuda.device_count(),"gpus!")

image_dataset   = datasets.ImageFolder(os.path.join(data_dir),data_transforms)

dataloader      = torch.utils.data.DataLoader(image_dataset,num_workers=num_workers,batch_size=len(image_dataset))
images,labels   = next(iter(dataloader))

images = images.to(device, torch.uint8)  

In [None]:
idx=0
for image in images:
    image = torch.reshape(image,(2048,2048))
    cc_out = connected_components_labeling(image)
    cc_out = cc_out.cpu().numpy()
    image  = image.cpu().numpy() 
    cc_image_overlay = label2rgb(cc_out,image=image,bg_label=0)
    write_patch(cc_image_overlay,cc_out,idx,area_cutoff=1)
    idx+=1 