In [1]:
"""
Title           :create_lmdb.py
Description     :This script divides the training images into 2 sets
    and stores them in lmdb databases for training and validation.
Author          :Xikun Zhang
usage           :python create_lmdb.py
python_version  :3.5
"""

import os
from os.path import *
import glob
import random
import numpy as np

import cv2
import sys

In [10]:
data_dir = '../CS446-project'
train_x = join(data_dir, 'train_X.npy')
train_y = join(data_dir, 'train_binary_Y.npy')
train_x = np.load(train_x)
train_y = np.load(train_y)
train_x.shape

(4602, 26, 31, 23)

In [11]:
train_y.shape

(4602, 19)

In [16]:
test_x = join(data_dir, 'valid_test_X.npy')
test_x = np.load(test_x)
test_x.shape

(1971, 26, 31, 23)

In [17]:
# print(sys.path)
# try:
#     user_paths = os.environ['PYTHONPATH'].split(os.pathsep)
# except KeyError:
#     user_paths = []
#
# print(user_paths)


import caffe
from caffe.proto import caffe_pb2
import lmdb

# Size of images
IMAGE_WIDTH = 227
IMAGE_HEIGHT = 227


# def transform_img(img, img_width=IMAGE_WIDTH, img_height=IMAGE_HEIGHT):
#     # Histogram Equalization
#     img[:, :, 0] = cv2.equalizeHist(img[:, :, 0])
#     img[:, :, 1] = cv2.equalizeHist(img[:, :, 1])
#     img[:, :, 2] = cv2.equalizeHist(img[:, :, 2])

#     # Image Resizing
#     img = cv2.resize(img, (img_width, img_height), interpolation=cv2.INTER_CUBIC)

#     return img


def make_datum(img, label):
    # image is numpy.ndarray format. BGR instead of RGB
    return caffe_pb2.Datum(
        channels=3,
        width=IMAGE_WIDTH,
        height=IMAGE_HEIGHT,
        label=label,
        data=np.rollaxis(img, 2).tostring())

In [None]:
train_lmdb = join(data_dir, 'train_lmdb')
validation_lmdb = join(data_dir, 'validation_lmdb')
test_lmdb = join(data_dir, 'test_lmdb')

os.system('rm -rf  ' + train_lmdb)
os.system('rm -rf  ' + validation_lmdb)
os.system('rm -rf  ' + test_lmdb)


# train_data = [img for img in glob.glob("../input/train/*jpg")]
# test_data = [img for img in glob.glob("../input/test1/*jpg")]

# # Shuffle train_data
# random.shuffle(train_data)

print('Creating train_lmdb')

in_db = lmdb.open(train_lmdb, map_size=int(1e12))
with in_db.begin(write=True) as in_txn:
    for in_idx in range(len(train_x)):
        if in_idx % 6 == 0:
            continue
#         img = cv2.imread(img_path, cv2.IMREAD_COLOR)
#         img = transform_img(img, img_width=IMAGE_WIDTH, img_height=IMAGE_HEIGHT)
#         if 'cat' in img_path:
#             label = 0
#         else:
#             label = 1
        datum = make_datum(img, label)
        in_txn.put('{:0>5d}'.format(in_idx), datum.SerializeToString())
        print('{:0>5d}'.format(in_idx) + ':' + img_path)
in_db.close()

print('\nCreating validation_lmdb')

in_db = lmdb.open(validation_lmdb, map_size=int(1e12))
with in_db.begin(write=True) as in_txn:
    for in_idx, img_path in enumerate(train_data):
        if in_idx % 6 != 0:
            continue
        img = cv2.imread(img_path, cv2.IMREAD_COLOR)
        img = transform_img(img, img_width=IMAGE_WIDTH, img_height=IMAGE_HEIGHT)
        if 'cat' in img_path:
            label = 0
        else:
            label = 1
        datum = make_datum(img, label)
        in_txn.put('{:0>5d}'.format(in_idx), datum.SerializeToString())
        print('{:0>5d}'.format(in_idx) + ':' + img_path)
in_db.close()

print('\nFinished processing all images')
