In [1]:
import os
import sys
import cv2
import tqdm
import time
import random
import scipy.io
import itertools
import numpy as np
from math import ceil
from itertools import chain
from tqdm.contrib import tzip
import matplotlib.pyplot as plt
from skimage.io import imread
from scipy.ndimage.filters import gaussian_filter
from sklearn.model_selection import train_test_split

In [2]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms

In [None]:
def equalizeHistnorm(img):
    r_image, g_image, b_image = cv2.split(img)

    r_image_eq = cv2.equalizeHist(r_image)
    g_image_eq = cv2.equalizeHist(g_image)
    b_image_eq = cv2.equalizeHist(b_image)

    image_eq = cv2.merge((r_image_eq, g_image_eq, b_image_eq))
    return image_eq

In [27]:
class ShanghaitechDataset(torch.utils.data.Dataset):
    def __init__(self, DATA_PATH, transform, size = 1000, random_crop_aug = True, mixed_aug = False):
        
        images, labels = self.data_loader(DATA_PATH)
        
        if random_crop_aug:
            images, labels = self.random_crop_aug(images, labels, 9, 260, 260)
            
        if mixed_aug:
            images, labels = self.mixed_aug(images, labels, 9, 260, 260)
            
        _, X, _, y = train_test_split(images, labels, test_size=size, random_state=42)

        self.X, self.y = X, y
        self.transform = transform
        
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, index):
        image = self.X[index]
        label = self.y[index]
        image = self.transform(image)
        return image, np.float32(label)
    
    def density_map(self, label_info):
        label = np.zeros((int(768/4), int(1024/4)))
        for [i, j] in label_info['image_info'][0][0][0][0][0]:
            I, J = int(np.round(i/4)), int(np.round(j/4))
            if J > 191: J = 191
            if I > 255: I = 255
            label[J, I] = 1
        return label
    
    def mixed_aug(self, images, labels, num, width, height):
        images_res, masks_res = [], []
        for i, j in tzip(images, labels):
            for _ in range(num):
                if np.random.rand() > 0.7:
                    img, mask = self.random_crop(i, j, width, height)
                    images_res.append(img)
                    mask = cv2.resize(mask, (66, 66), interpolation=cv2.INTER_NEAREST)
                    masks_res.append(mask)
                elif np.random.rand() > 0.5:
                    pipeline = Compose(
                                 [tr.RandomRotation(degrees = 90),
                                  tr.RandomRotation(degrees = 270)])
                    augmented_image = pipeline(img = img)
        
    def data_loader(self, PATH):
        images = []
        labels = []
        ims = ['images/' + i for i in os.listdir(PATH + 'images/')]
        lbs = ['ground-truth/' + 'GT_' + i[:-4] + '.mat' for i in os.listdir(PATH + 'images/')]
        for i, j in tzip(ims, lbs):
            image = imread(PATH + i)
            label_info = scipy.io.loadmat(PATH + j)
            images.append(equalizeHistnorm(image))
            labels.append(self.density_map(label_info))
        if len(images) != len(labels):
            print("len(images) != len(labels)")
        print("Datasets' length:", len(images))
        return images, labels

    def random_crop(self, img, mask, width, height):
        x = random.randint(0, img.shape[1] - width)
        y = random.randint(0, img.shape[0] - height)
        img = img[y:y+height, x:x+width]
        mask = mask[int(np.round((y/4))):int(np.round((y+height)/4)), int(np.round(x/4)):int(np.round((x+width)/4))]
        return img, mask

    def random_crop_aug(self, images, labels, num, width, height):
        images_res, masks_res = [], []
        for i, j in tzip(images, labels):
            for _ in range(num):
                img, mask = self.random_crop(i, j, width, height)
                images_res.append(img)
                mask = cv2.resize(mask, (66, 66), interpolation=cv2.INTER_NEAREST)
                masks_res.append(mask)
        
        if len(images_res) != len(masks_res):
            print("len(images) != len(labels)")
        print("Datasets' length afte aug:", len(images_res))

        return images_res, masks_res

In [28]:
def imshow(img, t = ''):
    img = img / 2 + 0.5
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.title(t)
    plt.show()

In [None]:
#gaussian_filter(label, 2)