In [2]:

#! pip install albumentations==0.4.6

import torch
from torchvision.transforms.functional import to_tensor
import segmentation_models_pytorch as smp
from typing import Final
import os
from torchvision import transforms as T
import rasterio
import matplotlib.pyplot as plt
import numpy as np
import albumentations as A
import cv2
from albumentations.pytorch import ToTensorV2
import rasterio


  from .autonotebook import tqdm as notebook_tqdm


In [5]:
torch.cuda.is_available()

True

## Import model

In [6]:

architectures: Final[dict] = {'Unet': smp.Unet,
                              'Unet++': smp.UnetPlusPlus,
                              'MAnet': smp.MAnet,
                              'DeepLabV3+': smp.DeepLabV3Plus}

encoders: Final[dict] = {'mit-b2': 'mit_b2',
                         'mit-b3': 'mit_b3',
                         'efficientnet-b1': 'efficientnet-b1',
                         'efficientnet-b2': 'efficientnet-b2',
                         'efficientnet-b3': 'efficientnet-b3',
                         'efficientnet-b4': 'efficientnet-b4',
                         'efficientnet-b5': 'efficientnet-b5',
                         'timm-res2net50-26w-4s': 'timm-res2net50_26w_4s'}

models_root = 'D:\diploma\cv-corruption-research\models'
model_names = ['DeepLabV3+_efficientnet-b4' , 'MAnet_efficientnet-b4', 'Unet_mit-b2', 'Unet++_efficientnet-b5']

model_name = model_names[0]

segm_arch = model_name.split('_')[0]
encoder = model_name.split('_')[1]

model = architectures[segm_arch](in_channels = 3, classes=4, 
                                 encoder_name = encoders[encoder],
                                 encoder_weights=None,
                                 activation = None).to('cuda')

model.load_state_dict(torch.load(os.path.join(models_root, f'{model_name}.pth')))
model.eval();

## Inference model

In [8]:
normparams= {'mean': [105.  , 109.,  100.], 'std': [53.660343, 51.114082, 51.887432]}

def get_img_transform(normparams):
    return T.Compose([T.Lambda(lambda x: torch.as_tensor(x, dtype=torch.float)),
                      T.Normalize(**normparams)])


def get_mask_inverse_transform():
    def transform(x):
        x = x.detach().cpu().numpy()
        return x
    return transform


def get_predictor(model,
                  transform,
                  inv_transform,
                  device: str = 'cpu'):
    def predictor(sample):
        model.eval()
        model.to(device)
        with torch.no_grad():
            return inv_transform(model(torch.unsqueeze(transform(sample), 0).to(device))[0])
    return predictor

sample_size=(1024,1024)

def add_border(img):
        img = to_tensor(img)

        old_size = (img.shape[1], img.shape[2])

        new_size = ((int(old_size[0] / sample_size[0]) + 1)*sample_size[0], 
                    (int(old_size[1] / sample_size[1]) + 1)*sample_size[1])
            
        new_img = torch.zeros((img.shape[0], new_size[0], new_size[1]))
        
        add_x = [int((new_size[0]-old_size[0])/2), int((new_size[0]-old_size[0])/2)]
        add_y = [int((new_size[1]-old_size[1])/2), int((new_size[1]-old_size[1])/2)]

        if new_size[0]-np.sum(add_x) != img.shape[1]:
            add_x[1] = add_x[1] + 1

        if new_size[1]-np.sum(add_y) != img.shape[2]:
            add_y[1] = add_y[1] + 1

        new_img[:, add_x[0]:new_size[0]-add_x[1], 
                   add_y[0]:new_size[1]-add_y[1]] = img

        return new_img.detach().cpu().numpy()


predictor = get_predictor(model,
                          transform=get_img_transform(normparams),
                          inv_transform=get_mask_inverse_transform(),
                          device='cuda')


In [None]:
image = rasterio.open("true_images/1.tif")
sample = image.read().swapaxes(0,2).swapaxes(0,1)
#sample = sample[:992,:944,:] # можем обрезать до ближайшего кратного числа
mask = predictor(add_border(sample)) #returns mask with raw logits, shape = (4, H, W)
fig, axes= plt.subplots(1,2)
axes[0].imshow(torch.from_numpy(sample))
axes[1].imshow(torch.from_numpy(mask)[3,:,:])

In [None]:
pic_num = 14

for i in range(pic_num):
    image = rasterio.open("true_images/"+ str(i+1) +".tif")
    image = image.read().swapaxes(0,2).swapaxes(0,1)/255
    
    mask = predictor(add_border(image)) #returns mask with raw logits, shape = (4, H, W)
    fig, axes= plt.subplots(1,2)

    axes[0].imshow(torch.from_numpy(image))
    axes[1].imshow(torch.from_numpy(mask)[3,:,:])
    plt.show()