In [None]:
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
import cv2
from skimage.transform import resize
from sklearn.cluster import KMeans

import torch
from torchvision import models

random_state = 42
valid_size = 83

dataset_dir = 'dataset_new'

In [None]:
def preprocess_image(img, target_size=(128, 128)):
    x = resize(img, target_size, mode='constant').transpose(2, 0, 1)
    x = torch.FloatTensor(x).unsqueeze(0)
    x = torch.autograd.Variable(x, volatile=True)
    if torch.cuda.is_available():
        x = x.cuda()
    return x

In [None]:
def vgg_extractor():
    model = models.vgg16(pretrained=True)
    if torch.cuda.is_available():
        model = model.cuda()
    model.eval()
    return torch.nn.Sequential(*list(model.features.children())[:-1])

In [None]:
def cluster_features(features, n_clusters=10):
    kmeans = KMeans(n_clusters=n_clusters, random_state=random_state)
    kmeans.fit(features)
    labels = kmeans.labels_
    return labels

In [None]:
def get_vgg_clusters(dataset_dir, meta):
    image_id = meta['image_id'].values

    extractor = vgg_extractor()

    features = []
    for id in tqdm(image_id):
        filepath = os.path.join(dataset_dir, id, 'masks.png')
        img = cv2.imread(filepath)
        img = img / 255.0
        x = preprocess_image(img)
        feature = extractor(x)
        feature = np.ndarray.flatten(feature.cpu().data.numpy())
        features.append(feature)
    features = np.stack(features, axis=0)

    labels = cluster_features(features)

    return labels

In [None]:
df_metadata = pd.DataFrame(columns=['image_id', 'height', 'width', 'count', 'split'])

for image_id in tqdm(sorted(os.listdir(dataset_dir))):
    masks = cv2.imread(os.path.join(dataset_dir, image_id, 'masks.png'), cv2.IMREAD_UNCHANGED)
    height, width = masks.shape
    count = masks.max()

    df_metadata = df_metadata.append({'image_id': image_id,
                                      'height': height,
                                      'width': width,
                                      'count': count}, ignore_index=True)

vgg_features_clusters = get_vgg_clusters(dataset_dir, df_metadata)
df_metadata['vgg_features_clusters'] = vgg_features_clusters

In [None]:
categories = df_metadata['vgg_features_clusters'].unique()
np.random.seed(random_state)
valid_category_ids = np.random.choice(categories, valid_size)

valid = df_metadata[df_metadata['vgg_features_clusters'] == 0].sample(n=valid_size, random_state=random_state)
train = df_metadata.loc[~df_metadata.index.isin(valid.index)].copy()

valid['split'] = 'valid'
train['split'] = 'train'

df_metadata = pd.concat([train, valid])
df_metadata = df_metadata.sort_values('image_id')

In [None]:
meta_filepath = os.path.join(dataset_dir, 'metadata.csv')
df_metadata.to_csv(meta_filepath, index=None)