In [14]:
import torch
from torchvision.transforms.functional import to_tensor
import segmentation_models_pytorch as smp
from typing import Final
import os
from torchvision import transforms as T
import rasterio
import matplotlib.pyplot as plt
import numpy as np
import albumentations as A
import cv2
from albumentations.pytorch import ToTensorV2
import rasterio
from torchvision.transforms import ToTensor, ToPILImage
from segmentation_models_pytorch.utils.metrics import IoU
from sklearn.metrics import precision_score, recall_score, f1_score

In [15]:
torch.cuda.is_available()

True

## Import model

In [16]:

architectures: Final[dict] = {'Unet': smp.Unet,
                              'Unet++': smp.UnetPlusPlus,
                              'MAnet': smp.MAnet,
                              'DeepLabV3+': smp.DeepLabV3Plus}

encoders: Final[dict] = {'mit-b2': 'mit_b2',
                         'mit-b3': 'mit_b3',
                         'efficientnet-b1': 'efficientnet-b1',
                         'efficientnet-b2': 'efficientnet-b2',
                         'efficientnet-b3': 'efficientnet-b3',
                         'efficientnet-b4': 'efficientnet-b4',
                         'efficientnet-b5': 'efficientnet-b5',
                         'timm-res2net50-26w-4s': 'timm-res2net50_26w_4s'}

models_root = 'D:\diploma\cv-corruption-research\models'
model_names = ['DeepLabV3+_efficientnet-b4' , 'MAnet_efficientnet-b4', 'Unet_mit-b2', 'Unet++_efficientnet-b5']

model_name = model_names[3]

segm_arch = model_name.split('_')[0]
encoder = model_name.split('_')[1]

model = architectures[segm_arch](in_channels = 3, classes=4, 
                                 encoder_name = encoders[encoder],
                                 encoder_weights=None,
                                 activation = None).to('cuda')

model.load_state_dict(torch.load(os.path.join(models_root, f'{model_name}.pth')))
model.eval();

## Inference model

In [17]:
normparams= {'mean': [105.  , 109.,  100.], 'std': [53.660343, 51.114082, 51.887432]}

def get_img_transform(normparams):
    return T.Compose([T.Lambda(lambda x: torch.as_tensor(x, dtype=torch.float)),
                      T.Normalize(**normparams)])


def get_mask_inverse_transform():
    def transform(x):
        x = x.detach().cpu().numpy()
        return x
    return transform


def get_predictor(model,
                  transform,
                  inv_transform,
                  device: str = 'cpu'):
    def predictor(sample):
        model.eval()
        model.to(device)
        with torch.no_grad():
            return inv_transform(model(torch.unsqueeze(transform(sample), 0).to(device))[0])
    return predictor

sample_size=(1024,1024)

def add_border(img):
        img = to_tensor(img)

        old_size = (img.shape[1], img.shape[2])

        new_size = ((int(old_size[0] / sample_size[0]) + 1)*sample_size[0], 
                    (int(old_size[1] / sample_size[1]) + 1)*sample_size[1])
            
        new_img = torch.zeros((img.shape[0], new_size[0], new_size[1]))
        
        add_x = [int((new_size[0]-old_size[0])/2), int((new_size[0]-old_size[0])/2)]
        add_y = [int((new_size[1]-old_size[1])/2), int((new_size[1]-old_size[1])/2)]

        if new_size[0]-np.sum(add_x) != img.shape[1]:
            add_x[1] = add_x[1] + 1

        if new_size[1]-np.sum(add_y) != img.shape[2]:
            add_y[1] = add_y[1] + 1

        new_img[:, add_x[0]:new_size[0]-add_x[1], 
                   add_y[0]:new_size[1]-add_y[1]] = img

        return new_img.detach().cpu().numpy()


predictor = get_predictor(model,
                          transform=get_img_transform(normparams),
                          inv_transform=get_mask_inverse_transform(),
                          device='cuda')



# Cropp images for GPU optimization

In [18]:
pic_num = 13
IoU_res, p, r, f1 = [],[],[],[]

for i in range(pic_num):
    image = rasterio.open("true_images/"+ str(i+1) +".tif")
    true_mask = rasterio.open("true_rastr_masks/"+ str(i+1) +".tif")

    # cropp images for GPU optimization
    x = image.read()
    x = torch.tensor(x)
    kernel_size, stride = 1024, 1024
    patches = x.unfold(1, kernel_size, stride).unfold(2, kernel_size, stride)
    patches = patches.contiguous().view(patches.size(0), -1, kernel_size, kernel_size)
    patches = patches.swapaxes(0,1).swapaxes(1,3)
    
     # cropp mask for GPU optimization
    true_mask = true_mask.read()
    true_mask = torch.tensor(true_mask)
    kernel_size, stride = 1024, 1024
    mask_patches = true_mask.unfold(1, kernel_size, stride).unfold(2, kernel_size, stride)
    mask_patches = mask_patches.contiguous().view(mask_patches.size(0), -1, kernel_size, kernel_size)
    mask_patches = mask_patches.swapaxes(0,1).swapaxes(1,3)
    val_idx = []

    # drop pics with no buildings
    for i in range(len(mask_patches)):
        if mask_patches[i].sum() != 0:
            val_idx.append(i)
    
    for j in val_idx:
        
        image = patches[j].swapaxes(0,2).swapaxes(1,2)
        true_mask = mask_patches[j]

        mask = predictor(image) #returns mask with raw logits, shape = (4, H, W)

        mask = torch.argmax(torch.from_numpy(mask), dim=0)
        mask = mask == 3

        true_mask = true_mask[:,:,0]

        '''
        fig, axes= plt.subplots(1,3)
        axes[0].imshow(image.swapaxes(0,1).swapaxes(1,2))
        axes[1].imshow(true_mask)
        axes[2].imshow(mask)

        axes[0].set_title('Image')
        axes[1].set_title('True mask')
        axes[2].set_title('Predict mask')
        plt.show()
        '''
        
        iou = IoU()
        IoU_ = iou(mask,true_mask).item()
        
        true_mask, mask = true_mask.numpy().flatten(), mask.flatten()
        precision = precision_score(true_mask, mask)
        recall = recall_score(true_mask, mask)
        f1_ = f1_score(true_mask, mask)
        
        p.append(precision)
        r.append(recall)
        f1.append(f1_)
        IoU_res.append(IoU_)

        #print(f1_,precision,recall,IoU_)
    
print(np.mean(f1), np.mean(p), np.mean(r), np.mean(IoU_res))

  _warn_prf(average, modifier, msg_start, len(result))


0.4465620995949126 0.6928671612478626 0.3873576126052838 0.3459532115301117
