In [1]:
# !pip install -U segmentation-models-pytorch albumentations --user
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import numpy as np
# import cv2
import matplotlib.pyplot as plt
import linzhutils as lu
from testdataset import Dataset
from torch.utils.data import DataLoader
import albumentations as albu
import segmentation_models_pytorch as smp
import torch
from tqdm import tqdm

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(DEVICE)

cuda


In [2]:
DATA_DIR = '/scratch/project_2007251/hel1964/'

IMAGE_DIR = 'small_images'
OUTPUT_DIR = 'output'

lu.checkDir(os.path.join(DATA_DIR, OUTPUT_DIR))


MASK_DIR = IMAGE_DIR

img_dir = os.path.join(DATA_DIR, IMAGE_DIR)
anno_dir = os.path.join(DATA_DIR, MASK_DIR)

print(f'There are {len(os.listdir(img_dir))} files.')


def get_validation_augmentation():
    """Add paddings to make image shape divisible by 32"""
    test_transform = [albu.PadIfNeeded(512, 512)]
    return albu.Compose(test_transform)


def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')


def get_preprocessing(preprocessing_fn):
    """Construct preprocessing transform
    
    Args:
        preprocessing_fn (callbale): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    
    """

    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)

There are 50400 files.


In [3]:
# load best saved checkpoint
model = torch.load('./models/kussi_FPN_efficientnet-b7_7_0.490.pth')
ENCODER = 'efficientnet-b7'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = ['kussi']
preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

In [4]:
# create test dataset

test_dataset = Dataset(
    img_dir,
    anno_dir,
    augmentation=get_validation_augmentation(),
    preprocessing=get_preprocessing(preprocessing_fn),
    classes=CLASSES,
)

test_dataloader = DataLoader(test_dataset)

In [5]:
# # helper function for data visualization
# def visualize(**images):
#     """PLot images in one row."""
#     n = len(images)
#     plt.figure(figsize=(16, 5))
#     for i, (name, image) in enumerate(images.items()):
#         plt.subplot(1, n, i + 1)
#         plt.xticks([])
#         plt.yticks([])
#         plt.title(' '.join(name.split('_')).title())
#         plt.imshow(image)
#     plt.show()

In [6]:
def save_masks(masks):
    global file_num
    name = f"result_masks_{file_num}.npy"
    filename = f"{os.path.join(DATA_DIR, OUTPUT_DIR, name)}"
    concatenated = np.concatenate(masks, axis=0)
    reshaped = concatenated.reshape((-1, 512, 512))
    np.save(filename, np.array(masks))
    print(f"Saved {filename}")
    file_num += 1

In [7]:
fileList = []
for i in test_dataset.ids:
    fileList.append(os.path.join(img_dir, i))
np.save(os.path.join(DATA_DIR, OUTPUT_DIR, 'file_list.npy'),np.array(fileList))

In [8]:
# wrong size checker, uncomment this if some images are not 512x512
# import cv2
# list_5120 = []
# for i in tqdm(os.listdir(img_dir)):
#     # print(i)
#     img = cv2.imread(os.path.join(img_dir,i))
#     if img.shape != (512, 512, 3):
#         print(o, img.shape)
#         list_5120.append(fileList[i])
#         # break

In [9]:
batch_size = 128
image_batch = []
input_batch = []
mask_batch = []
result_masks = []
file_num = 0
total_size = 0

index = 0

for i in tqdm(range(len(test_dataset))):
    # image_batch.append(image[i])
    input_batch.append(test_dataset[i][0])

    if len(input_batch) == batch_size:
        x_tensor = torch.from_numpy(np.array(input_batch)).to(DEVICE)
        # print(x_tensor.shape)
        pr_mask_batch = model.predict(x_tensor).squeeze().cpu().numpy().round()
        result_masks.append(pr_mask_batch)
        total_size += pr_mask_batch.nbytes
        
        index += 1
        
        # break
        if total_size > 1e9:
            save_masks(result_masks)
            result_masks = []
            total_size = 0

        input_batch = []
        mask_batch = []
save_masks(result_masks)

  2%|▏         | 1026/50400 [01:20<7:32:49,  1.82it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_0.npy


  4%|▍         | 2051/50400 [02:41<5:30:14,  2.44it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_1.npy


  6%|▌         | 3074/50400 [04:01<6:40:23,  1.97it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_2.npy


  8%|▊         | 4098/50400 [05:20<5:21:44,  2.40it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_3.npy


 10%|█         | 5124/50400 [06:41<5:45:57,  2.18it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_4.npy


 12%|█▏        | 6147/50400 [07:59<5:26:23,  2.26it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_5.npy


 14%|█▍        | 7172/50400 [09:20<4:40:01,  2.57it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_6.npy


 16%|█▋        | 8194/50400 [10:41<6:57:05,  1.69it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_7.npy


 18%|█▊        | 9217/50400 [12:01<5:43:05,  2.00it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_8.npy


 20%|██        | 10243/50400 [13:21<4:18:10,  2.59it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_9.npy


 22%|██▏       | 11265/50400 [14:40<7:17:25,  1.49it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_10.npy


 24%|██▍       | 12290/50400 [15:58<5:31:54,  1.91it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_11.npy


 26%|██▋       | 13315/50400 [17:17<3:54:22,  2.64it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_12.npy


 28%|██▊       | 14338/50400 [18:36<4:39:24,  2.15it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_13.npy


 30%|███       | 15363/50400 [19:55<3:54:25,  2.49it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_14.npy


 33%|███▎      | 16386/50400 [21:13<3:56:43,  2.39it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_15.npy


 35%|███▍      | 17410/50400 [22:33<4:27:08,  2.06it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_16.npy


 37%|███▋      | 18435/50400 [23:50<3:18:59,  2.68it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_17.npy


 39%|███▊      | 19457/50400 [25:07<5:12:55,  1.65it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_18.npy


 41%|████      | 20483/50400 [26:27<3:58:04,  2.09it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_19.npy


 43%|████▎     | 21506/50400 [27:48<5:06:02,  1.57it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_20.npy


 45%|████▍     | 22529/50400 [29:09<4:52:58,  1.59it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_21.npy


 47%|████▋     | 23554/50400 [30:31<4:55:43,  1.51it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_22.npy


 49%|████▉     | 24580/50400 [31:50<2:41:16,  2.67it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_23.npy


 51%|█████     | 25603/50400 [33:12<3:33:53,  1.93it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_24.npy


 53%|█████▎    | 26628/50400 [34:32<2:22:26,  2.78it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_25.npy


 55%|█████▍    | 27651/50400 [35:55<4:30:05,  1.40it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_26.npy


 57%|█████▋    | 28675/50400 [37:16<2:44:26,  2.20it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_27.npy


 59%|█████▉    | 29697/50400 [38:38<5:02:29,  1.14it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_28.npy


 61%|██████    | 30722/50400 [39:59<2:41:02,  2.04it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_29.npy


 63%|██████▎   | 31742/50400 [41:15<14:26, 21.52it/s]  

Saved /scratch/project_2007251/hel1964/output/result_masks_30.npy


 65%|██████▌   | 32771/50400 [42:41<1:56:27,  2.52it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_31.npy


 67%|██████▋   | 33794/50400 [44:01<2:01:44,  2.27it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_32.npy


 69%|██████▉   | 34819/50400 [45:22<1:54:12,  2.27it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_33.npy


 71%|███████   | 35841/50400 [46:42<2:57:22,  1.37it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_34.npy


 73%|███████▎  | 36869/50400 [48:06<1:28:10,  2.56it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_35.npy


 75%|███████▌  | 37891/50400 [49:21<1:34:51,  2.20it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_36.npy


 77%|███████▋  | 38914/50400 [50:37<1:06:38,  2.87it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_37.npy


 79%|███████▉  | 39937/50400 [51:48<1:26:52,  2.01it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_38.npy


 81%|████████▏ | 40962/50400 [52:58<1:13:24,  2.14it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_39.npy


 83%|████████▎ | 41987/50400 [54:09<1:15:13,  1.86it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_40.npy


 85%|████████▌ | 43008/50400 [55:20<56:36,  2.18it/s]  

Saved /scratch/project_2007251/hel1964/output/result_masks_41.npy


 87%|████████▋ | 44030/50400 [56:29<04:27, 23.83it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_42.npy


 89%|████████▉ | 45058/50400 [57:50<41:00,  2.17it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_43.npy


 91%|█████████▏| 46083/50400 [59:04<26:29,  2.72it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_44.npy


 93%|█████████▎| 47108/50400 [1:00:19<26:22,  2.08it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_45.npy


 96%|█████████▌| 48132/50400 [1:01:33<13:45,  2.75it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_46.npy


 98%|█████████▊| 49154/50400 [1:02:46<08:34,  2.42it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_47.npy


100%|█████████▉| 50179/50400 [1:04:05<01:39,  2.21it/s]

Saved /scratch/project_2007251/hel1964/output/result_masks_48.npy


100%|██████████| 50400/50400 [1:04:19<00:00, 13.06it/s]


Saved /scratch/project_2007251/hel1964/output/result_masks_49.npy


In [10]:
# # helper function for data visualization
# def to_image(x, **kwargs):
#     if len(x.shape) == 3:
#         return x.transpose(1, 2, 0)
#     else: return x

# def visualize_batch(img, msk):
#     for i in range(img.shape[0]):
#         plt.figure(figsize=(16, 5))
#         plt.subplot(1, 2, 1)
#         plt.imshow(to_image(img[i]))
#         plt.subplot(1, 2, 2)
#         plt.imshow(to_image(msk[0][i]))
#         plt.show()

 60%|██████    | 60/100 [00:00<00:00, 758006.75it/s]


In [None]:
# for i in image_batch:
#     plt.figure()
#     plt.imshow(to_image(i))
#     plt.show()

In [None]:
# img = np.array(input_batch.copy())
# msk = result_masks.copy()
# visualize_batch(img, msk)

In [None]:
# for i in range(len(result_masks[0])):
#     plt.figure()
#     plt.imshow(result_masks[0][i])
#     plt.show()

In [None]:
# import time

# for i in range(20):
#     n = np.random.choice(len(test_dataset))

#     t1 = time.time()
    
#     image_vis = test_dataset_vis[n][0].astype('uint8')
#     image, gt_mask = test_dataset[n]

#     gt_mask = gt_mask.squeeze()

#     x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
#     pr_mask = model.predict(x_tensor)
#     pr_mask = (pr_mask.squeeze().cpu().numpy().round())

#     print(time.time()-t1)
    
#     visualize(image=image_vis,
#               ground_truth_mask=gt_mask,
#               predicted_mask=pr_mask)
