In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import sys
import os
sys.path.append('/content/drive/MyDrive/gcolab/bacteria_segmentation')
os.chdir('/content/drive/MyDrive/gcolab/bacteria_segmentation')

In [None]:
import pandas as pd
import scipy as sc
import numpy as np
import matplotlib.pyplot as plt
import copy

In [None]:
!pip install -U albumentations

Collecting albumentations
[?25l  Downloading https://files.pythonhosted.org/packages/03/58/63fb1d742dc42d9ba2800ea741de1f2bc6bb05548d8724aa84794042eaf2/albumentations-0.5.2-py3-none-any.whl (72kB)
[K     |████▌                           | 10kB 15.0MB/s eta 0:00:01[K     |█████████                       | 20kB 21.2MB/s eta 0:00:01[K     |█████████████▋                  | 30kB 22.5MB/s eta 0:00:01[K     |██████████████████▏             | 40kB 17.4MB/s eta 0:00:01[K     |██████████████████████▊         | 51kB 9.4MB/s eta 0:00:01[K     |███████████████████████████▏    | 61kB 8.9MB/s eta 0:00:01[K     |███████████████████████████████▊| 71kB 10.0MB/s eta 0:00:01[K     |████████████████████████████████| 81kB 5.3MB/s 
Collecting opencv-python-headless>=4.1.1
[?25l  Downloading https://files.pythonhosted.org/packages/6d/6d/92f377bece9b0ec9c893081dbe073a65b38d7ac12ef572b8f70554d08760/opencv_python_headless-4.5.1.48-cp37-cp37m-manylinux2014_x86_64.whl (37.6MB)
[K     |█████████

In [None]:
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as TF
import torch

import albumentations as A
import albumentations.augmentations.functional as F
from albumentations.pytorch.transforms import ToTensor, ToTensorV2

import cv2

In [45]:
class BacteriaDataset(Dataset):

  LABELS = ['_background', '_epidermidis', '_pneumoniae', '_aureus', '_moraxella', '_kefir', '_cloacae']
  LABELS_MAPPING = dict(zip(LABELS, [0, 1, 2, 3, 4, 5, 6]))

  def __init__(self, df_data, images_path, masks_path_postfix, transform=None, transform_image=None, transform_mask=None):
    self.df_data = copy.deepcopy(df_data)
    self.images_path = images_path
    self.masks_path_postfix = masks_path_postfix

    self.transform = transform
    self.transform_image = transform_image
    self.transform_mask = transform_mask

  def __len__(self):
    return df_data.shape[0]

  def __getitem__(self, idx):
    filename = format(self.df_data.iloc[idx]['filename'], '03d')
    label = self.df_data.iloc[idx]['labels']

    image = cv2.imread(os.path.join(self.images_path, f'{filename}.png')) # load image
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # to optimize process we can load only mask with label from df_data (all others are zero - background)
    mask = cv2.imread(os.path.join(self.images_path, f'{filename}{self.masks_path_postfix}/', f'{filename}{label}.png'), cv2.IMREAD_UNCHANGED) # load labeled mask
    
    if self.transform is not None:
      transformed = self.transform(image=image, mask=mask)
      image, mask = transformed['image'], transformed['mask']

    mask[mask != 0] = BacteriaDataset.LABELS_MAPPING[label]

    if self.transform_image is not None:
      image = self.transform_image(image=image)['image']

    if self.transform_mask is not None:
      mask = self.transform_mask(mask)

    return image, mask

In [38]:
def vizualize_dataset(bacteria_dataset, df_data, num_samples=2):
  figure, ax = plt.subplots(nrows=num_samples, ncols=1+len(BacteriaDataset.LABELS), figsize=(80, 5*num_samples))
  for i in range(num_samples):
    idx = np.random.randint(0, len(bacteria_dataset)-1)
    image, mask = bacteria_dataset[idx]

    image = image.permute(1, 2, 0).numpy()

    ax[i, 0].imshow(image)
    ax[i, 0].set_title(f'image num {df_data.iloc[idx]["filename"]}')
    ax[i, 0].set_axis_off()

    for j, label in enumerate(BacteriaDataset.LABELS):
      ax[i, 1+j].imshow(mask == BacteriaDataset.LABELS_MAPPING[label])
      ax[i, 1+j].set_title(f'{label} mask')
      ax[i, 1+j].set_axis_off() 

  plt.tight_layout()
  plt.show()

In [30]:
train_transform = A.Compose(
    [
        A.Resize(320, 320),
        A.HorizontalFlip(),
        A.VerticalFlip(),
        A.RandomRotate90()
    ]
)
transform_image = A.Compose(
    [
        ToTensor(sigmoid=False)
    ]
)
transform_mask = torch.from_numpy

In [44]:
df_data = pd.read_csv('bacteria_segmentation_eda.csv')
bacteria_dataset = BacteriaDataset(
    df_data, 'data//train', '_masks', 
    transform=train_transform,
    transform_image=transform_image,
    transform_mask=transform_mask
)
vizualize_dataset(bacteria_dataset, df_data, 10)

Output hidden; open in https://colab.research.google.com to view.