In [1]:
# !pip install -U segmentation-models-pytorch albumentations --user
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import numpy as np
# import cv2
import matplotlib.pyplot as plt
import linzhutils as lu
from testdataset import Dataset
from torch.utils.data import DataLoader
import albumentations as albu
import segmentation_models_pytorch as smp
import torch
from tqdm import tqdm

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(DEVICE)

cuda


In [2]:
DATA_DIR = '/scratch/project_2007251/hel1943/'

IMAGE_DIR = 'small_images'
OUTPUT_DIR = 'output'

lu.checkDir(os.path.join(DATA_DIR, OUTPUT_DIR))


MASK_DIR = IMAGE_DIR

img_dir = os.path.join(DATA_DIR, IMAGE_DIR)
anno_dir = os.path.join(DATA_DIR, MASK_DIR)


def get_validation_augmentation():
    """Add paddings to make image shape divisible by 32"""
    test_transform = [albu.PadIfNeeded(512, 512)]
    return albu.Compose(test_transform)


def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')


def get_preprocessing(preprocessing_fn):
    """Construct preprocessing transform
    
    Args:
        preprocessing_fn (callbale): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    
    """

    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)

In [3]:
# load best saved checkpoint
model = torch.load('./models/kussi_FPN_efficientnet-b7_7_0.490.pth')
ENCODER = 'efficientnet-b7'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = ['kussi']
preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

In [4]:
# create test dataset

test_dataset = Dataset(
    img_dir,
    anno_dir,
    augmentation=get_validation_augmentation(),
    preprocessing=get_preprocessing(preprocessing_fn),
    classes=CLASSES,
)

test_dataloader = DataLoader(test_dataset)

In [5]:
# # helper function for data visualization
# def visualize(**images):
#     """PLot images in one row."""
#     n = len(images)
#     plt.figure(figsize=(16, 5))
#     for i, (name, image) in enumerate(images.items()):
#         plt.subplot(1, n, i + 1)
#         plt.xticks([])
#         plt.yticks([])
#         plt.title(' '.join(name.split('_')).title())
#         plt.imshow(image)
#     plt.show()

In [6]:
def save_masks(masks):
    global file_num
    name = f"result_masks_{file_num}.npy"
    filename = f"{os.path.join(DATA_DIR, OUTPUT_DIR, name)}"
    concatenated = np.concatenate(masks, axis=0)
    reshaped = concatenated.reshape((-1, 512, 512))
    np.save(filename, np.array(masks))
    print(f"Saved {filename}")
    file_num += 1

In [7]:
fileList = []
for i in test_dataset.ids:
    fileList.append(os.path.join(img_dir, i))
np.save(os.path.join(DATA_DIR, OUTPUT_DIR, 'file_list.npy'),np.array(fileList))

In [8]:
batch_size = 128
image_batch = []
input_batch = []
mask_batch = []
result_masks = []
file_num = 0
total_size = 0

index = 0

for i in tqdm(range(len(test_dataset))):
    # image_batch.append(image[i])
    input_batch.append(test_dataset[i][0])

    if len(input_batch) == batch_size:
        x_tensor = torch.from_numpy(np.array(input_batch)).to(DEVICE)
        # print(x_tensor.shape)
        pr_mask_batch = model.predict(x_tensor).squeeze().cpu().numpy().round()
        result_masks.append(pr_mask_batch)
        total_size += pr_mask_batch.nbytes
        
        index += 1
        
        # break
        if total_size > 1e9:
            save_masks(result_masks)
            result_masks = []
            total_size = 0

        input_batch = []
        mask_batch = []
save_masks(result_masks)

  2%|▏         | 1031/50400 [01:08<3:17:21,  4.17it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_0.npy


  4%|▍         | 2050/50400 [02:20<6:49:41,  1.97it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_1.npy


  6%|▌         | 3075/50400 [03:32<4:40:45,  2.81it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_2.npy


  8%|▊         | 4099/50400 [04:44<4:41:12,  2.74it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_3.npy


 10%|█         | 5123/50400 [05:57<4:11:41,  3.00it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_4.npy


 12%|█▏        | 6146/50400 [07:10<5:02:44,  2.44it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_5.npy


 14%|█▍        | 7169/50400 [08:24<6:21:41,  1.89it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_6.npy


 16%|█▋        | 8192/50400 [09:39<6:14:45,  1.88it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_7.npy


 18%|█▊        | 9219/50400 [10:52<4:42:51,  2.43it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_8.npy


 20%|██        | 10243/50400 [12:06<4:02:27,  2.76it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_9.npy


 22%|██▏       | 11266/50400 [13:18<4:30:05,  2.41it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_10.npy


 24%|██▍       | 12291/50400 [14:31<4:47:06,  2.21it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_11.npy


 26%|██▋       | 13313/50400 [15:49<6:33:13,  1.57it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_12.npy


 28%|██▊       | 14340/50400 [17:02<2:54:08,  3.45it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_13.npy


 30%|███       | 15364/50400 [18:15<3:50:19,  2.54it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_14.npy


 33%|███▎      | 16386/50400 [19:26<4:55:36,  1.92it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_15.npy


 35%|███▍      | 17411/50400 [20:40<4:20:43,  2.11it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_16.npy


 37%|███▋      | 18434/50400 [21:55<4:25:07,  2.01it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_17.npy


 39%|███▊      | 19460/50400 [23:13<3:47:16,  2.27it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_18.npy


 41%|████      | 20481/50400 [24:28<5:16:01,  1.58it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_19.npy


 43%|████▎     | 21508/50400 [25:43<3:20:40,  2.40it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_20.npy


 45%|████▍     | 22532/50400 [26:57<3:11:27,  2.43it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_21.npy


 47%|████▋     | 23553/50400 [28:14<2:36:38,  2.86it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_22.npy


 49%|████▉     | 24579/50400 [29:28<3:44:54,  1.91it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_23.npy


 51%|█████     | 25601/50400 [30:45<3:36:02,  1.91it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_24.npy


 53%|█████▎    | 26626/50400 [32:03<3:09:11,  2.09it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_25.npy


 55%|█████▍    | 27652/50400 [33:18<2:21:08,  2.69it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_26.npy


 57%|█████▋    | 28674/50400 [34:34<2:58:23,  2.03it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_27.npy


 59%|█████▉    | 29700/50400 [35:50<2:25:36,  2.37it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_28.npy


 61%|██████    | 30722/50400 [37:06<2:29:51,  2.19it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_29.npy


 63%|██████▎   | 31748/50400 [38:19<1:54:04,  2.72it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_30.npy


 65%|██████▌   | 32771/50400 [39:36<1:53:25,  2.59it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_31.npy


 67%|██████▋   | 33795/50400 [40:52<1:48:53,  2.54it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_32.npy


 69%|██████▉   | 34816/50400 [42:07<2:35:26,  1.67it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_33.npy


 71%|███████   | 35844/50400 [43:21<1:26:02,  2.82it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_34.npy


 73%|███████▎  | 36866/50400 [44:35<1:40:38,  2.24it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_35.npy


 75%|███████▌  | 37892/50400 [45:51<1:14:13,  2.81it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_36.npy


 77%|███████▋  | 38916/50400 [47:05<1:07:53,  2.82it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_37.npy


 79%|███████▉  | 39936/50400 [48:20<1:26:59,  2.00it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_38.npy


 81%|████████▏ | 40961/50400 [49:36<1:19:13,  1.99it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_39.npy


 83%|████████▎ | 41987/50400 [50:53<52:50,  2.65it/s]  

Saved /scratch/project_2007251/hel1943/output/result_masks_40.npy


 85%|████████▌ | 43011/50400 [52:05<49:55,  2.47it/s]  

Saved /scratch/project_2007251/hel1943/output/result_masks_41.npy


 87%|████████▋ | 44031/50400 [53:13<05:43, 18.54it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_42.npy


 89%|████████▉ | 45059/50400 [54:26<38:20,  2.32it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_43.npy


 91%|█████████▏| 46085/50400 [55:38<20:34,  3.49it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_44.npy


 93%|█████████▎| 47107/50400 [56:47<18:48,  2.92it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_45.npy


 95%|█████████▌| 48128/50400 [57:58<17:42,  2.14it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_46.npy


 98%|█████████▊| 49153/50400 [59:09<11:20,  1.83it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_47.npy


100%|█████████▉| 50181/50400 [1:00:22<01:06,  3.31it/s]

Saved /scratch/project_2007251/hel1943/output/result_masks_48.npy


100%|██████████| 50400/50400 [1:00:35<00:00, 13.86it/s]


Saved /scratch/project_2007251/hel1943/output/result_masks_49.npy


In [9]:
# helper function for data visualization
def to_image(x, **kwargs):
    if len(x.shape) == 3:
        return x.transpose(1, 2, 0)
    else: return x

def visualize_batch(img, msk):
    for i in range(img.shape[0]):
        plt.figure(figsize=(16, 5))
        plt.subplot(1, 2, 1)
        plt.imshow(to_image(img[i]))
        plt.subplot(1, 2, 2)
        plt.imshow(to_image(msk[0][i]))
        plt.show()

In [16]:
# for i in image_batch:
#     plt.figure()
#     plt.imshow(to_image(i))
#     plt.show()

In [15]:
# img = np.array(input_batch.copy())
# msk = result_masks.copy()
# visualize_batch(img, msk)

In [1]:
for i in range(len(result_masks[0])):
    plt.figure()
    plt.imshow(result_masks[0][i])
    plt.show()

In [1]:
# import time

# for i in range(20):
#     n = np.random.choice(len(test_dataset))

#     t1 = time.time()
    
#     image_vis = test_dataset_vis[n][0].astype('uint8')
#     image, gt_mask = test_dataset[n]

#     gt_mask = gt_mask.squeeze()

#     x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
#     pr_mask = model.predict(x_tensor)
#     pr_mask = (pr_mask.squeeze().cpu().numpy().round())

#     print(time.time()-t1)
    
#     visualize(image=image_vis,
#               ground_truth_mask=gt_mask,
#               predicted_mask=pr_mask)
