In [29]:
import os
from tqdm import tqdm
import cv2
from glob import glob
import torch
from torch.utils.data import Dataset as BaseDataset
import albumentations as A
import segmentation_models_pytorch as smp

In [30]:
img_path_list = glob(os.path.normpath('./saved_images/*'))

# Set variables for model
DATA_DIR = img_path_list
ENCODER = 'resnet34'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = ['background', 'soybean', 'stake']
ACTIVATION = 'softmax2d'
DEVICE = 'cpu' # cuda or cpu

model_parameters_path = './best_model.pth'

In [31]:
# Define dataset for inference  
class Dataset(BaseDataset):
    def __init__(self, data_dir, augmentation=None, preprocessing=None):
        self.image_paths = data_dir
        self.augmentation = augmentation
        self.preprocessing = preprocessing
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image = cv2.imread(self.image_paths[idx])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.augmentation is not None:
            sample = self.augmentation(image=image)
            image = sample['image']
        
        transform = A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], p=1.0)
        sample = transform(image=image)
        image = sample['image']
        
        if self.preprocessing is not None:
            sample = self.preprocessing(image=image)
            image = sample['image']
            
        image = torch.from_numpy(image) 
        
        return image

In [32]:
# Prepare model
def prepare_model():
        # Define model
        model = smp.Unet(
        encoder_name=ENCODER,
        encoder_weights=ENCODER_WEIGHTS,
        in_channels=3,
        classes=len(CLASSES),
        activation=ACTIVATION
        )
        
        model.to(DEVICE)
        
        # Install trained parameters in model
        model.load_state_dict(torch.load(model_parameters_path))
        model.eval()
        
        return model

# Create dataset
def create_dataset():
        # Define augmentation for inference
        def get_inference_augmentation():
            transform = A.Compose([
                A.Resize(width=512, height=512, p=1.0)
            ], p=1.0)
            return transform
    
        def to_tensor(x, **kwargs):
            return x.transpose(2, 0, 1).astype('float32') 

        def get_preprocessing(preprocessing_fn):
            _transform = A.Compose([
                A.Lambda(image=preprocessing_fn, p=1.0),
                A.Lambda(image=to_tensor, mask=to_tensor, p=1.0)
            ], p=1.0)
            return _transform
                
        # Set pretrained parameters
        preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
        
        # Create dataset
        dataset = Dataset(
            data_dir=DATA_DIR, 
            augmentation=get_inference_augmentation(),
            preprocessing=get_preprocessing(preprocessing_fn)
        )
        return dataset

In [33]:
model = prepare_model()
dataset = create_dataset()

mask_folder_path = './test_output_masks'

In [34]:
def inference(model, dataset):
    def one_hot_function(pr_mask):
        max_index_number = pr_mask.argmax(axis=0) 
        for i in range(len(pr_mask)):
            one_hot_vector = (max_index_number == i) 
            pr_mask[i] = one_hot_vector
        return pr_mask
    
    for i in tqdm(range(len(dataset)), ncols=80):
        image = dataset[i]
        x_tensor = image.to(DEVICE).unsqueeze(0)
        pr_mask = model.predict(x_tensor)
        pr_mask = pr_mask.squeeze().cpu().numpy()
        pr_mask = one_hot_function(pr_mask)
        pr_mask = pr_mask[CLASSES.index('soybean')] # Extract only soybean
        image_path = DATA_DIR[i]
        assert dataset.image_paths[i] == image_path, "Not match pathes"
        raw_image = cv2.imread(image_path)
        height, width = raw_image.shape[:2]
        pr_mask = A.resize(pr_mask, height, width)
        pr_mask = pr_mask * 255
        pr_mask_path = os.path.normpath(os.path.join(mask_folder_path, os.path.basename(DATA_DIR[i]).split('.')[0]+'.png')) # Save masks as PNG format
        cv2.imwrite(pr_mask_path, pr_mask)

inference(model, dataset)

  0%|                                                     | 0/6 [00:00<?, ?it/s]

tensor([[[[-2.1257, -2.1272, -2.1272,  ..., -2.1278, -2.1266, -2.1269],
          [-2.1266, -2.1248, -2.1260,  ..., -2.1257, -2.1266, -2.1263],
          [-2.1266, -2.1260, -2.1251,  ..., -2.1269, -2.1269, -2.1275],
          ...,
          [-2.1307, -2.1316, -2.1304,  ..., -2.1287, -2.1284, -2.1290],
          [-2.1292, -2.1292, -2.1292,  ..., -2.1284, -2.1295, -2.1281],
          [-2.1292, -2.1292, -2.1295,  ..., -2.1295, -2.1295, -2.1287]],

         [[-2.0174, -2.0165, -2.0174,  ..., -2.0196, -2.0202, -2.0205],
          [-2.0171, -2.0165, -2.0177,  ..., -2.0196, -2.0211, -2.0208],
          [-2.0168, -2.0177, -2.0168,  ..., -2.0205, -2.0202, -2.0205],
          ...,
          [-2.0214, -2.0217, -2.0214,  ..., -2.0419, -2.0410, -2.0428],
          [-2.0217, -2.0214, -2.0217,  ..., -2.0410, -2.0425, -2.0422],
          [-2.0223, -2.0223, -2.0226,  ..., -2.0428, -2.0425, -2.0422]],

         [[-1.7724, -1.7724, -1.7721,  ..., -1.7751, -1.7754, -1.7757],
          [-1.7724, -1.7721, -

 17%|███████▌                                     | 1/6 [00:00<00:03,  1.31it/s]

tensor([[[[-2.1284, -2.1269, -2.1284,  ..., -2.1204, -2.1222, -2.1207],
          [-2.1295, -2.1290, -2.1298,  ..., -2.1213, -2.1222, -2.1219],
          [-2.1281, -2.1284, -2.1284,  ..., -2.1222, -2.1222, -2.1234],
          ...,
          [-2.1284, -2.1290, -2.1295,  ..., -2.1372, -2.1383, -2.1386],
          [-2.1295, -2.1298, -2.1298,  ..., -2.1366, -2.1369, -2.1375],
          [-2.1295, -2.1290, -2.1295,  ..., -2.1366, -2.1372, -2.1378]],

         [[-2.0199, -2.0196, -2.0214,  ..., -2.0196, -2.0211, -2.0199],
          [-2.0202, -2.0205, -2.0202,  ..., -2.0205, -2.0208, -2.0202],
          [-2.0205, -2.0208, -2.0205,  ..., -2.0208, -2.0208, -2.0208],
          ...,
          [-2.0199, -2.0202, -2.0208,  ..., -2.0524, -2.0520, -2.0520],
          [-2.0196, -2.0199, -2.0214,  ..., -2.0520, -2.0517, -2.0511],
          [-2.0196, -2.0196, -2.0208,  ..., -2.0517, -2.0527, -2.0527]],

         [[-1.7739, -1.7733, -1.7745,  ..., -1.7742, -1.7751, -1.7739],
          [-1.7736, -1.7745, -

 33%|███████████████                              | 2/6 [00:01<00:02,  1.33it/s]

tensor([[[[-2.1369, -2.1372, -2.1369,  ..., -2.1507, -2.1489, -2.1471],
          [-2.1372, -2.1366, -2.1380,  ..., -2.1492, -2.1489, -2.1471],
          [-2.1372, -2.1369, -2.1375,  ..., -2.1501, -2.1477, -2.1477],
          ...,
          [-2.1319, -2.1313, -2.1331,  ..., -2.1339, -2.1325, -2.1328],
          [-2.1319, -2.1316, -2.1325,  ..., -2.1336, -2.1328, -2.1336],
          [-2.1301, -2.1307, -2.1328,  ..., -2.1363, -2.1339, -2.1351]],

         [[-2.0530, -2.0533, -2.0527,  ..., -2.0545, -2.0542, -2.0524],
          [-2.0533, -2.0533, -2.0530,  ..., -2.0545, -2.0545, -2.0511],
          [-2.0530, -2.0520, -2.0517,  ..., -2.0554, -2.0527, -2.0514],
          ...,
          [-2.0294, -2.0309, -2.0315,  ..., -2.0245, -2.0251, -2.0257],
          [-2.0294, -2.0284, -2.0288,  ..., -2.0263, -2.0260, -2.0260],
          [-2.0300, -2.0297, -2.0284,  ..., -2.0266, -2.0254, -2.0245]],

         [[-1.8165, -1.8168, -1.8161,  ..., -1.8171, -1.8155, -1.8143],
          [-1.8168, -1.8165, -

 50%|██████████████████████▌                      | 3/6 [00:02<00:02,  1.35it/s]

tensor([[[[-2.1290, -2.1284, -2.1284,  ..., -2.1196, -2.1178, -2.1196],
          [-2.1272, -2.1290, -2.1290,  ..., -2.1196, -2.1187, -2.1196],
          [-2.1287, -2.1287, -2.1304,  ..., -2.1190, -2.1199, -2.1187],
          ...,
          [-2.1290, -2.1275, -2.1281,  ..., -2.1269, -2.1272, -2.1272],
          [-2.1269, -2.1263, -2.1281,  ..., -2.1287, -2.1281, -2.1278],
          [-2.1275, -2.1278, -2.1275,  ..., -2.1269, -2.1246, -2.1266]],

         [[-2.0217, -2.0214, -2.0211,  ..., -2.0159, -2.0150, -2.0150],
          [-2.0196, -2.0205, -2.0217,  ..., -2.0156, -2.0143, -2.0159],
          [-2.0208, -2.0214, -2.0196,  ..., -2.0147, -2.0147, -2.0153],
          ...,
          [-2.0177, -2.0183, -2.0180,  ..., -2.0404, -2.0416, -2.0413],
          [-2.0183, -2.0177, -2.0183,  ..., -2.0413, -2.0407, -2.0407],
          [-2.0186, -2.0189, -2.0183,  ..., -2.0404, -2.0392, -2.0413]],

         [[-1.7748, -1.7754, -1.7757,  ..., -1.7703, -1.7691, -1.7694],
          [-1.7745, -1.7748, -

 67%|██████████████████████████████               | 4/6 [00:02<00:01,  1.37it/s]

tensor([[[[-2.1301, -2.1301, -2.1301,  ..., -2.1375, -2.1383, -2.1366],
          [-2.1325, -2.1307, -2.1322,  ..., -2.1378, -2.1380, -2.1369],
          [-2.1301, -2.1304, -2.1304,  ..., -2.1386, -2.1380, -2.1383],
          ...,
          [-2.1298, -2.1284, -2.1290,  ..., -2.0861, -2.0858, -2.0856],
          [-2.1292, -2.1281, -2.1307,  ..., -2.0853, -2.0850, -2.0864],
          [-2.1287, -2.1298, -2.1298,  ..., -2.0856, -2.0861, -2.0858]],

         [[-2.0254, -2.0254, -2.0235,  ..., -2.0309, -2.0306, -2.0288],
          [-2.0242, -2.0235, -2.0254,  ..., -2.0309, -2.0294, -2.0300],
          [-2.0245, -2.0242, -2.0223,  ..., -2.0309, -2.0297, -2.0288],
          ...,
          [-2.0232, -2.0223, -2.0214,  ..., -2.0002, -2.0002, -2.0002],
          [-2.0226, -2.0208, -2.0211,  ..., -1.9999, -2.0006, -2.0009],
          [-2.0232, -2.0229, -2.0223,  ..., -2.0002, -2.0009, -2.0006]],

         [[-1.7785, -1.7785, -1.7767,  ..., -1.7839, -1.7839, -1.7836],
          [-1.7779, -1.7773, -

 83%|█████████████████████████████████████▌       | 5/6 [00:03<00:00,  1.39it/s]

tensor([[[[-2.1292, -2.1284, -2.1284,  ..., -2.1345, -2.1339, -2.1334],
          [-2.1284, -2.1284, -2.1272,  ..., -2.1351, -2.1366, -2.1336],
          [-2.1278, -2.1269, -2.1281,  ..., -2.1351, -2.1348, -2.1360],
          ...,
          [-2.1336, -2.1336, -2.1336,  ..., -2.1322, -2.1307, -2.1325],
          [-2.1345, -2.1348, -2.1339,  ..., -2.1304, -2.1304, -2.1331],
          [-2.1336, -2.1342, -2.1336,  ..., -2.1319, -2.1319, -2.1313]],

         [[-2.0235, -2.0242, -2.0238,  ..., -2.0254, -2.0257, -2.0245],
          [-2.0245, -2.0242, -2.0232,  ..., -2.0251, -2.0254, -2.0254],
          [-2.0245, -2.0242, -2.0238,  ..., -2.0263, -2.0248, -2.0260],
          ...,
          [-2.0269, -2.0266, -2.0263,  ..., -2.0211, -2.0205, -2.0211],
          [-2.0272, -2.0269, -2.0269,  ..., -2.0205, -2.0205, -2.0211],
          [-2.0275, -2.0272, -2.0272,  ..., -2.0220, -2.0220, -2.0217]],

         [[-1.7788, -1.7788, -1.7794,  ..., -1.7800, -1.7800, -1.7797],
          [-1.7791, -1.7788, -

100%|█████████████████████████████████████████████| 6/6 [00:04<00:00,  1.38it/s]
