In [1]:
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader
import torch,numpy as np
import torch.utils.data as data
from glob import glob
import albumentations as A
from matplotlib import pyplot as plt


transform = A.Compose([  # Color normalizations used while training
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225),max_pixel_value=255.0, p=1.0),
    ToTensorV2(),
    
])



In [12]:
from patchify import unpatchify, patchify
import os, cv2, torch
from glob import glob
import csv
from skimage.measure import label, regionprops, regionprops_table
import segmentation_models_pytorch as smp

'''
# Model definition - trained on part of PV dataset
model = smp.UnetPlusPlus(
    encoder_name="efficientnet-b4",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,                      # model output channels (number of classes in your dataset)
)
'''
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"The current device is {device}")

model=torch.load('upp_trained_model.pth') #load trained model

model.to(device)
model.eval()
if not os.path.exists('out'):  #make output dirs
    os.makedirs('out')


for file in glob(os.path.join('./test','*.jpg')):  #for each image in test
    basename = os.path.basename(file)
    
    orig_image = cv2.imread(file)  # read image
    h = orig_image.shape[0]
    w = orig_image.shape[1]
    
    
    jgw_name = file[:-2]+'gw'     # read  jgw file
    with open(jgw_name,'r') as f:
        entries = [float(line.split()[0]) for line in f]
    eastings_origin = entries[4]
    northings_origin = entries[5]
    eastings_step = entries[0]
    northings_step = entries[3]
    
        
    for ratio in [0.485, 1]:   # run inference at different resolutions
        if not os.path.exists('out/'+str(ratio)):
            os.makedirs('out/'+str(ratio))
        
        full_image = cv2.resize(orig_image, None, fx= ratio, fy= ratio, interpolation= cv2.INTER_LINEAR)


        extra_h = 256 - (full_image.shape[0]%256)
        extra_w = 256 - (full_image.shape[1]%256)
        padded_full = cv2.copyMakeBorder(full_image.copy(),0,extra_h,0,extra_w,cv2.BORDER_CONSTANT,value=(255,0,0))
        padded_full_rgb = cv2.cvtColor(padded_full, cv2.COLOR_BGR2RGB)
        
        patches = patchify(padded_full_rgb, (256,256,3), step=256) #split high res image to patches for inference
        #patches = patches.reshape((:,:,1,256,256,3))
        patch_results = []
        for x_patch in patches: #run inference on each patch
            row_patches = []
            for xy_patch in x_patch:
                curr_patch = np.array(xy_patch[0])
                #print(curr_patch.shape)
                transformed =  transform(image=curr_patch)['image'][None,:]
                transformed = transformed.to(device)
                #print(transformed.shape)
                out = model(transformed)
                outer = torch.max(out,1).indices.cpu().detach().numpy()[0]
                row_patches.append(outer*255)


            patch_results.append(row_patches.copy())


        patch_results = np.array(patch_results,dtype='uint8')
        #print(patch_results.shape,patches.shape,padded_full.shape[:-1])
        reconstructed_image = unpatchify(patch_results, padded_full.shape[:-1])  #stitch masks of each patch
        #print(reconstructed_image.shape)
        

        label_img = label(reconstructed_image)
        regions = regionprops(label_img)  #segment groups as regions
        
        #bring to original size
        p_out = padded_full[:full_image.shape[0],:full_image.shape[1],:]
        p_out = cv2.resize(p_out, None, fx=1/ratio, fy= 1/ratio, interpolation= cv2.INTER_LINEAR)
        reconstructed_image = reconstructed_image[:full_image.shape[0],:full_image.shape[1]]
        reconstructed_image = cv2.resize(reconstructed_image, None, fx=1/ratio, fy= 1/ratio, interpolation= cv2.INTER_LINEAR)
        
 
        
        bbox_bng = []
        for i,r in enumerate(regions):  #treat each region as a detection- draw bbox in image and write to csv
            #print(i,r.centroid,r.area)
            if r.area<400:   #skip small detections
                continue
            minr, minc, maxr, maxc = r.bbox[0]/ratio, r.bbox[1]/ratio, r.bbox[2]/ratio, r.bbox[3]/ratio
            
            cv2.rectangle(p_out, (int(minc), int(minr)), (int(maxc),int(maxr)), (0,69, 255),2)
            
            
            bbox_bng.append([i, eastings_origin+(minc*eastings_step), northings_origin+(minr*northings_step),
                            eastings_origin+(maxc*eastings_step), northings_origin+(maxr*northings_step)])
            # format of csv file : [id of detection,easting start, northing start, easting end, northing end ]

        cv2.imwrite('out/'+str(ratio)+'/'+basename,p_out) # post-processed image
        with open('out/'+str(ratio)+'/'+basename[:-4]+'.csv', 'w', encoding='UTF8') as f: # write bng coordinates of detections in csv
            writer = csv.writer(f)
            for row in bbox_bng:
            # write a row to the csv file
                writer.writerow(row)
        
        
            
    
    
    
    

The current device is cuda:0
