In [None]:
import os, platform

distr = platform.dist()[0]

PATH = os.path.expanduser('~/datasets/letsdance') if distr == 'Ubuntu' else '/run/media/nast/DATA/letsdance'
TRAIN_PATH = "letsdance_split/train"

print("dataset path:", PATH)

In [None]:
from scipy import misc
import matplotlib.pyplot as plt
%matplotlib inline

path = "letsdance_split/train/ballet/Et31LySAxf0_020_0266.jpg"
image = misc.imread(os.path.join(PATH, path))
plt.imshow(image)
plt.show()

In [None]:
import numpy as np
print(image.shape, np.amin(image), np.amax(image))

In [None]:
from collections import Counter


# A vector of filenames.


def get_file_names_in_dataset(dataset_path):
    classes = os.listdir(os.path.join(PATH, dataset_path))
    file_names_in_dataset = {}
    for cl in classes:
        file_names_in_dataset[cl] = sorted(os.listdir(os.path.join(PATH, dataset_path, cl)))
    return file_names_in_dataset


def video_name_from_file_name(file_name):
    return '_'.join(file_name.split('_')[:-1])


def get_num_of_frames_in_videos(list_of_file_names):
    videos_names = map(lambda x: video_name_from_file_name(x), list_of_file_names)
    return Counter(videos_names)
 
    
def select_videos_with_N_frames(list_of_file_names, N):
    nfr = get_num_of_frames_in_videos(list_of_file_names)
    video_names, _ = zip(*filter(lambda x: x[1] == N, nfr.items()))
    return video_names


def select_video_names_for_dances(file_names_in_dataset, N):
    """Selects videos with N frames for each dance so all dances
    have equal number of videos. Number of videos for a dance is
    the smallest number of videos having N frames among all dances."""
    selected = {}
    for dance_name, list_of_file_names in file_names_in_dataset.items():
        videos_with_N_frames = select_videos_with_N_frames(list_of_file_names, N)
        selected[dance_name] = videos_with_N_frames
    min_num_of_videos_with_N_frames = min(map(len, selected.values()))
    for k, v in selected.items():
        selected[k] = sorted(v)[:min_num_of_videos_with_N_frames]
    return selected


def select_file_names_for_work(file_names_in_dataset, N):
    video_names = select_video_names_for_dances(file_names_in_dataset, N)
    selected_file_names = {}
    for dance, list_of_file_names in file_names_in_dataset.items():
        selected_file_names[dance] = [fn for fn in list_of_file_names
                                      if video_name_from_file_name(fn) in video_names[dance]]
    return selected_file_names
        
    
file_names_in_dataset = get_file_names_in_dataset(TRAIN_PATH)

print("beforer filtering")
for dance, loffn in file_names_in_dataset.items():
    print(dance,
          'total number of frames: {}'.format(len(loffn)),
          'number of videos: {}'.format(len(get_num_of_frames_in_videos(loffn))),
          end='\n\n', sep='\n')
print('*********\n\nAfter filtering')
file_names_for_train = select_file_names_for_work(file_names_in_dataset, 300)
for dance, loffn in file_names_for_train.items():
    print(dance,
          'total number of frames: {}'.format(len(loffn)),
          'number of videos: {}'.format(len(get_num_of_frames_in_videos(loffn))),
          end='\n\n', sep='\n')




videos_with_300_frames = {}




In [None]:
import tensorflow as tf

BATCH_SIZE = 30
NUM_DANCES = len(file_names_for_train)
print(NUM_DANCES)
NUM_FRAMES_PER_DANCE = len(list(file_names_for_train.values())[0])

def _parse_function(filename, label):
    image_string = tf.read_file(filename)
    image_decoded = tf.image.decode_jpeg(image_string)
    # image_resized = tf.image.resize_images(image_decoded, [28, 28])
    return image_decoded, label

datasets_by_dance = {}

for idx, (dance, loffn) in enumerate(sorted(file_names_for_train.items())):
    labels = tf.constant([idx] * len(loffn))
    filenames = tf.constant(list(map(lambda x: os.path.join(PATH, TRAIN_PATH, dance, x), loffn)))
    datasets_by_dance[dance] = tf.data.Dataset.from_tensor_slices(
        (filenames, labels)
    ).shuffle(NUM_FRAMES_PER_DANCE).map(_parse_function)
# print()
dance_zip = tf.data.Dataset.zip(tuple(datasets_by_dance.values()))
# print(dance_zip)
train_dataset = dance_zip.batch(BATCH_SIZE // NUM_DANCES)


In [None]:
iterator = dance_zip.make_initializable_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for _ in range(10):
        sess.run(iterator.initializer)
        i = 0
        while True:
            try:
                res = sess.run(next_element)
                if i < 5:
                    print(i)
                    array = res[0][0]
                    plt.imshow(array)
                    plt.show()
                i += 1
            except tf.errors.OutOfRangeError:
                break
        print('*' * 10)
