In [1]:
# !pip install -U segmentation-models-pytorch albumentations --user
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import numpy as np
# import cv2
import matplotlib.pyplot as plt
import linzhutils as lu
from testdataset import Dataset
from torch.utils.data import DataLoader
import albumentations as albu
import segmentation_models_pytorch as smp
import torch
from tqdm import tqdm

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(DEVICE)

cuda


In [10]:
DATA_DIR = '/scratch/project_2007251/hel1996/'

IMAGE_DIR = 'small_images'
OUTPUT_DIR = 'output'

lu.checkDir(os.path.join(DATA_DIR, OUTPUT_DIR))


MASK_DIR = IMAGE_DIR

img_dir = os.path.join(DATA_DIR, IMAGE_DIR)
anno_dir = os.path.join(DATA_DIR, MASK_DIR)


def get_validation_augmentation():
    """Add paddings to make image shape divisible by 32"""
    test_transform = [albu.PadIfNeeded(512, 512)]
    return albu.Compose(test_transform)


def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')


def get_preprocessing(preprocessing_fn):
    """Construct preprocessing transform
    
    Args:
        preprocessing_fn (callbale): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    
    """

    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)

In [11]:
# load best saved checkpoint
model = torch.load('./models/kussi_FPN_efficientnet-b7_7_0.490.pth')
ENCODER = 'efficientnet-b7'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = ['kussi']
preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

In [12]:
# create test dataset

test_dataset = Dataset(
    img_dir,
    anno_dir,
    augmentation=get_validation_augmentation(),
    preprocessing=get_preprocessing(preprocessing_fn),
    classes=CLASSES,
)

test_dataloader = DataLoader(test_dataset)

In [13]:
# # helper function for data visualization
# def visualize(**images):
#     """PLot images in one row."""
#     n = len(images)
#     plt.figure(figsize=(16, 5))
#     for i, (name, image) in enumerate(images.items()):
#         plt.subplot(1, n, i + 1)
#         plt.xticks([])
#         plt.yticks([])
#         plt.title(' '.join(name.split('_')).title())
#         plt.imshow(image)
#     plt.show()

In [14]:
def save_masks(masks):
    global file_num
    name = f"result_masks_{file_num}.npy"
    filename = f"{os.path.join(DATA_DIR, OUTPUT_DIR, name)}"
    concatenated = np.concatenate(masks, axis=0)
    reshaped = concatenated.reshape((-1, 512, 512))
    np.save(filename, np.array(masks))
    print(f"Saved {filename}")
    file_num += 1

In [15]:
fileList = []
for i in test_dataset.ids:
    fileList.append(os.path.join(img_dir, i))
np.save(os.path.join(DATA_DIR, OUTPUT_DIR, 'file_list.npy'),np.array(fileList))

In [20]:
list_5120 = []
for i in tqdm(range(0,len(fileList))):
    img = test_dataset[i]
    if img[0].shape != (3, 512, 512):
        print(fileList[i])
        list_5120.append(fileList[i])
        # break

  0%|          | 97/37566 [00:01<07:55, 78.79it/s]

/scratch/project_2007251/hel1996/small_images/hel1996_00_10.png


  0%|          | 113/37566 [00:04<1:00:58, 10.24it/s]

/scratch/project_2007251/hel1996/small_images/hel1996_00_100.png


  1%|▏         | 510/37566 [00:11<46:49, 13.19it/s]  

/scratch/project_2007251/hel1996/small_images/hel1996_00_50.png


  2%|▏         | 717/37566 [00:16<37:31, 16.36it/s]

/scratch/project_2007251/hel1996/small_images/hel1996_00_80.png


  3%|▎         | 1118/37566 [00:23<31:55, 19.03it/s]

/scratch/project_2007251/hel1996/small_images/hel1996_100_110.png


  3%|▎         | 1199/37566 [00:24<08:34, 70.62it/s]

/scratch/project_2007251/hel1996/small_images/hel1996_100_30.png


  3%|▎         | 1215/37566 [00:28<1:00:24, 10.03it/s]

/scratch/project_2007251/hel1996/small_images/hel1996_100_40.png


  4%|▍         | 1421/37566 [00:33<36:50, 16.35it/s]  

/scratch/project_2007251/hel1996/small_images/hel1996_100_70.png


  5%|▍         | 1718/37566 [00:39<34:57, 17.09it/s]

/scratch/project_2007251/hel1996/small_images/hel1996_10_10.png


  5%|▌         | 1925/37566 [00:43<35:54, 16.54it/s]

/scratch/project_2007251/hel1996/small_images/hel1996_10_20.png


  6%|▌         | 2084/37566 [00:46<13:03, 45.28it/s]


KeyboardInterrupt: 

In [17]:
batch_size = 128
image_batch = []
input_batch = []
mask_batch = []
result_masks = []
file_num = 0
total_size = 0

index = 0

for i in tqdm(range(len(test_dataset))):
    # image_batch.append(image[i])
    input_batch.append(test_dataset[i][0])

    if len(input_batch) == batch_size:
        x_tensor = torch.from_numpy(np.array(input_batch)).to(DEVICE)
        # print(x_tensor.shape)
        pr_mask_batch = model.predict(x_tensor).squeeze().cpu().numpy().round()
        result_masks.append(pr_mask_batch)
        total_size += pr_mask_batch.nbytes
        
        index += 1
        
        # break
        if total_size > 1e9:
            save_masks(result_masks)
            result_masks = []
            total_size = 0

        input_batch = []
        mask_batch = []
save_masks(result_masks)

  x_tensor = torch.from_numpy(np.array(input_batch)).to(DEVICE)
  0%|          | 127/37566 [00:05<26:52, 23.21it/s]


ValueError: could not broadcast input array from shape (3,512,512) into shape (3,)

In [None]:
# # helper function for data visualization
# def to_image(x, **kwargs):
#     if len(x.shape) == 3:
#         return x.transpose(1, 2, 0)
#     else: return x

# def visualize_batch(img, msk):
#     for i in range(img.shape[0]):
#         plt.figure(figsize=(16, 5))
#         plt.subplot(1, 2, 1)
#         plt.imshow(to_image(img[i]))
#         plt.subplot(1, 2, 2)
#         plt.imshow(to_image(msk[0][i]))
#         plt.show()

In [None]:
# for i in image_batch:
#     plt.figure()
#     plt.imshow(to_image(i))
#     plt.show()

In [None]:
# img = np.array(input_batch.copy())
# msk = result_masks.copy()
# visualize_batch(img, msk)

In [None]:
# for i in range(len(result_masks[0])):
#     plt.figure()
#     plt.imshow(result_masks[0][i])
#     plt.show()

In [None]:
# import time

# for i in range(20):
#     n = np.random.choice(len(test_dataset))

#     t1 = time.time()
    
#     image_vis = test_dataset_vis[n][0].astype('uint8')
#     image, gt_mask = test_dataset[n]

#     gt_mask = gt_mask.squeeze()

#     x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
#     pr_mask = model.predict(x_tensor)
#     pr_mask = (pr_mask.squeeze().cpu().numpy().round())

#     print(time.time()-t1)
    
#     visualize(image=image_vis,
#               ground_truth_mask=gt_mask,
#               predicted_mask=pr_mask)
