# Pseudo Labeling
<p>
    Pseudo labeling is a semi-supervised learning technique. It means you train a model on little annotated data and apply it to a large set of not annotated data. The results you get from this are considered as new annotations, so you can train a new model with much more data, to the cost that there might be erroneous labels.
    </p>
    <p>
    <b> This notebook is used as a playground to test things and implementations for the LargePseudoLabeling script </b>
    </p>

In [None]:
from forward_pass import forward_pass, flatten_out_dict, seconds_to_str
import numpy as np
import sys
import cv2
import matplotlib.pyplot as plt
from fnmatch import fnmatch
import os
import imageio
import shutil
from difflib import SequenceMatcher
import time
%matplotlib inline
import tensorflow as tf
from object_detection.utils import dataset_util
import shelve

category_names = ['Bier', 'Bier Mass', 'Weissbier', 'Cola', 'Wasser', 'Curry-Wurst', 'Weisswein',
                   'A-Schorle', 'Jaegermeister', 'Pommes', 'Burger', 'Williamsbirne', 'Alm-Breze', 'Brotzeitkorb',
                   'Kaesespaetzle']

In [None]:
def empty_directory(folder):
    for the_file in os.listdir(folder):
        file_path = os.path.join(folder, the_file)
        try:
            if os.path.isfile(file_path):
                os.unlink(file_path)
            #elif os.path.isdir(file_path): shutil.rmtree(file_path)
        except Exception as e:
            print(e)
            
def get_files(path, pattern, not_pattern = None, printout=True):
    found = []
    for path, subdirs, files in os.walk(path):
        for name in files:
            if fnmatch(name, pattern) and (not_pattern is None or not fnmatch(name, not_pattern)):
                found.append(os.path.join(path, name))
    if printout:
        print("Found %d files in path %s"%(len(found), path))
    return found

"""
Source: https://gist.github.com/soply/f3eec2e79c165e39c9d540e916142ae1
"""
def show_images(images, cols = 1, titles = None):
    """Display a list of images in a single figure with matplotlib.
    
    Parameters
    ---------
    images: List of np.arrays compatible with plt.imshow.
    
    cols (Default = 1): Number of columns in figure (number of rows is 
                        set to np.ceil(n_images/float(cols))).
    
    titles: List of titles corresponding to each image. Must have
            the same length as titles.
    """
    assert((titles is None)or (len(images) == len(titles)))
    n_images = len(images)
    if titles is None: titles = ['Image (%d)' % i for i in range(1,n_images + 1)]
    fig = plt.figure()
    for n, (image, title) in enumerate(zip(images, titles)):
        a = fig.add_subplot(cols, np.ceil(n_images/float(cols)), n + 1)
        if image.ndim == 2:
            plt.gray()
        plt.imshow(image)
        a.set_title(title)
    fig.set_size_inches(np.array(fig.get_size_inches()) * n_images)
    plt.show()
    
def similar(a, b):
    return SequenceMatcher(None, a, b).ratio()

### Find given video files and extract all frames as images

In [None]:
cams = [ 'Cam1']#, 'CamStereoL','Cam2', 'CamStereoR'] #StereoR ignored because not in sync
days = range(18,19)#28)#'2018-05-26'
videos_list = [[] for i in range(len(cams))]
for d in days:
    d_s = '2018-05-'+str(d)
    for i, cam in enumerate(cams):
        video_folder_id = cam+'/'+d_s
        video_path = '/nfs/students/winter-term-2018/project_2/video_data/videos/'+ video_folder_id
        pattern = "*.mp4"
        videos_list[i].append(get_files(video_path, pattern, not_pattern='*._*', printout=False))# for video_path in video_paths]
for i in range(len(cams)):
    print("Found %d days with %d total files for cam %d"%(len(videos_list[i]), sum([len(videos_list[i][j]) for j in range(len(videos_list[i]))]), i))
#videos_list[0].sort()
for i in range(len(videos_list)):
    for j in range(len(videos_list[i])):
        videos_list[i][j].sort() #[video.sort() for video in videos_list]
#for i in range(4):
#    for j in range(len(videos_list)):
#        print(videos_list[j][i][-60:])
#    print("----------------")

In [None]:
img_data = [[] for c in cams]
plot_stuff = True
show_time_in_image = False
print_every = 2
picked_videos = range(0,10)#500,50) #[0, -1]
all_images = False

#data = shelve.open('tmp_data')
data_count = 0


id_str = '_'
for i in picked_videos:
    id_str += '%d_'%i

plots = []
titles = []
print("Starting")
for day in range(len(days)):
    if all_images:
        picked_videos = range(min([len(videos_list[i][day]) for i in range(len(cams))]))
    t_d = time.time()
    for c, i in enumerate(picked_videos):
        for j, cam in enumerate(videos_list):
            t_v = time.time()
            video = cam[day][i]
            images_video = []
            cap = cv2.VideoCapture(video)
            success,image = cap.read()
            count = 0
            while success and count < 90:
                #cv2.imwrite("frame%d.jpg" % count, image)     # save frame as JPEG file  
                #plt.imshow(image)
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                images_video.append(image)
                success,image = cap.read()
                count += 1
            if count > 90:
                print("WARNING: More than 90 images extracted")
            elif count < 90:
                print("WARNING: Less than 90 images extracted")
            #print("Extracted %d images"%len(images_video))
            try:
                img_data[j].append(np.array(images_video))
            except Exception as e:
                print("Exception: %s"%str(e))
                print("Data size: %f"%sys.getsizeof(img_data[j]))
            #data[str(data_count)] = np.array(images_video)
            data_count += 1
            if plot_stuff:
                #plt.title("Video %d - Cam %d"%(i, j))
                titles.append('Video %d - Day %d - Cam %d'%(i,day,j))
                if show_time_in_image:
                    plots.append(images_video[int(len(images_video)/2)][:35,:340])
                else:
                    plots.append(images_video[int(len(images_video)/2)])    
                #plt.show()
            else:
                if c % print_every == 0:
                    t_v = time.time() - t_v
                    print("%d videos loaded for cam %d (%.1fs per video)"%(c,j, t_v))
    t_d = time.time() - t_d
    print("Completed loading videos of day %d/%d\t Exp. time left: %s"%(day+1, len(days), seconds_to_str(t_d*(len(days)-day-1))))
if plot_stuff:
    show_images(plots, cols=len(days)*len(picked_videos), titles = titles)
img_data = [np.vstack(img_data[i]) for i in range(len(img_data))]
#for i in range(len(img_data)):
#    print("Img data shape for cam %i: %s"%(i, str(img_data[i].shape)))
#data['max_count'] = data_count
#data.close()

### Classify loaded data with model trained on manually annotated data

In [None]:
model = 'train/ssd_julius_mobilefpn_large/frozen_inference_graph.pb'
out_dict = forward_pass(model, img_data[0], gpu='0', BATCH_SIZE=1, return_images=False)

In [None]:
print(out_dict.keys())
print(len(out_dict['detection_classes']))

### Filter labels
<p>
    To suppress noise in annotations we only consider frames that are labeled consistently in a certain number of frames. Note that we do not consider frames without any annotations at all (empty frames).
    </p>

In [None]:
thresh=0.5
filter_value = 2
previous_detection = None
last_change = 0
begin = 0
already_printed = False
filtered_image_assignments = [] #list of tuples with ([begin img, end img], dictionary)
print_stuff = False
for i in range(len(img_data[0])):
    bxs = out_dict['detection_boxes'][i]
    clss = out_dict['detection_classes'][i]
    scr = out_dict['detection_scores'][i]
    detection = None
    for j in range(len(clss)):
        if scr[j] < thresh:
            continue
        if detection is None:
            detection = {i: 0 for i in range(len(category_names))}
        detection[int(clss[j])] += 1
    #if detection is not None:
    if detection == previous_detection:
        if detection is not None and i - last_change > filter_value and not already_printed:
            out_str = "Detected items for image %i"%(i-filter_value)
            for key in detection:
                if detection[key] > 0:
                    out_str += "\n  - %d %s"%(detection[key], category_names[key])
            if print_stuff:
                print(out_str)
            already_printed = True
    else:
        if previous_detection is not None and i - last_change - 1 > filter_value:
            if print_stuff:
                print("until image %d\n"%i)
            filtered_image_assignments.append(([begin, i], previous_detection))
        last_change = i
        already_printed = False
        begin = i+1
    previous_detection = detection
print("Found %d filtered assignments"%len(filtered_image_assignments))

In [None]:
for x in filtered_image_assignments[:10]:
    show_images([i[int((x[0][1]+x[0][0])/2)] for i in img_data], titles=['Cam %d'%i for i in range(len(cams))])#[:35,:340]
    print("%d frames"%(x[0][1]-x[0][0]))
    for key in x[1]:
        if x[1][key] > 0:
            print(' - %d x %s'%(x[1][key], category_names[key]))

In [None]:
print(img_data[0][0].shape)

### Create new dataset with pseudo labeled data

In [None]:

def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def create_tf_example(index, empty=False, thresh=0.5):
    # TODO(user): Populate the following variables from your example.
    example = img_data[0][index]
    height = example.shape[0] # Image height
    width = example.shape[1] # Image width
    filename = str.encode('') # Filename of the image. Empty if image is not from file
    encoded_image_data = cv2.imencode('.png', example)
    encoded_image_data = encoded_image_data[1].tostring() # Encoded image bytes
    image_format = str.encode('png') # b'jpeg' or b'png'

    xmins = [] # List of normalized left x coordinates in bounding box (1 per box)
    xmaxs = [] # List of normalized right x coordinates in bounding box
             # (1 per box)
    ymins = [] # List of normalized top y coordinates in bounding box (1 per box)
    ymaxs = [] # List of normalized bottom y coordinates in bounding box
             # (1 per box)
    classes_text = [] # List of string class name of bounding box (1 per box)
    classes = [] # List of integer class id of bounding box (1 per box)
    
    if not empty:
        bxs = out_dict['detection_boxes'][i]
        clss = out_dict['detection_classes'][i]
        scr = out_dict['detection_scores'][i]
        #clss = [0, 1, 2]
        #scr = [0.9, 0.9, 0.9]
        #print(max(scr))
        #print("Len classes: %d"%sum(i > thresh for i in scr))
        for j in range(len(clss)):
            if scr[j] < thresh:
                continue
            classes.append(clss[j])
            classes_text.append(str.encode(category_names[clss[j]]))
            ymins.append(bxs[j][0])
            xmins.append(bxs[j][1])
            ymaxs.append(bxs[j][2])
            xmaxs.append(bxs[j][3])

    tf_example = tf.train.Example(features=tf.train.Features(feature={
      'image/height': dataset_util.int64_feature(height),
      'image/width': dataset_util.int64_feature(width),
      'image/filename': dataset_util.bytes_feature(filename),
      'image/source_id': dataset_util.bytes_feature(filename),
      'image/encoded': dataset_util.bytes_feature(encoded_image_data),
      'image/format': dataset_util.bytes_feature(image_format),
      'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
      'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
      'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
      'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
      'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
      'image/object/class/label': dataset_util.int64_list_feature(classes),
    }))
    return tf_example

def del_all_flags(FLAGS):
    flags_dict = FLAGS._flags()    
    keys_list = [keys for keys in flags_dict]    
    for keys in keys_list:
        FLAGS.__delattr__(keys)


In [None]:
type(_bytes_feature(img_data[0][0].tobytes()))

In [None]:
include_empty = False

add_name = '_notebook_test'
del_all_flags(tf.flags.FLAGS)
flags = tf.app.flags
flags.DEFINE_string('f', '', 'kernel')
flags.DEFINE_string('output_path_train', 'data/test'+add_name, 'Path to output TFRecord train')
flags.DEFINE_string('output_path_eval', 'data/test'+add_name, 'Path to output TFRecord eval')
flags.DEFINE_string('output_path', 'data/test'+add_name, 'Path to output TFRecord')
FLAGS = tf.app.flags.FLAGS
print(FLAGS.output_path_train)

out_file = FLAGS.output_path
if os.path.isfile(out_file):
    os.remove(out_file)
writer = tf.python_io.TFRecordWriter(FLAGS.output_path)
    
last = 0
#"""
for filtered_label in filtered_image_assignments:
    if include_empty:
        for i in range(last, filtered_label[0][0]):
            tf_example = create_tf_example(i, empty=True)
            writer.write(tf_example.SerializeToString())
    for i in range(filtered_label[0][0], filtered_label[0][1]):
        tf_example = create_tf_example(i)
        writer.write(tf_example.SerializeToString())
#"""
#for i in range(10):
#    ex = create_tf_example(i, empty = False, thresh=0.5)
#    writer.write(ex.SerializeToString())
writer.close()

### As a sanity check we read data again, display it and compare to manually labeled dataset

In [None]:
if 'sess' in locals() and sess is not None:
    sess.close()
new_data = '/nfs/students/winter-term-2018/project_2/models/research/object_detection/training_folder_lsml/data/pseudo_labels_days-23-28'
old_data = '/nfs/students/winter-term-2018/project_2/models/research/object_detection/training_folder_lsml/data/pseudo_labels_days-18-old_encoding'
verified_data = '/nfs/students/winter-term-2018/project_2/models/research/object_detection/training_folder_lsml/data/train_data_beer_sn_reviewed'
test_data = '/nfs/students/winter-term-2018/project_2/models/research/object_detection/training_folder_lsml/data/test_notebook_test'
#print(img_data[0][0].shape)
os.environ["CUDA_VISIBLE_DEVICES"] = ''
sess = tf.InteractiveSession()

In [None]:
record_iterator = tf.python_io.tf_record_iterator(path=verified_data)

num = 5000
every_n = 10
images = []
titles = []
count = 0
print("Start")
time_start = time.time()
stats = {c: 0 for c in category_names}
for string_record in record_iterator:
    #if count % every_n != 0:
    #    count += 1
    #    if count >= num:
    #        break
    #    continue
    example = tf.train.Example()
    example.ParseFromString(string_record)
    d = dict(example.features.feature)
    #for key in d:
    #    if 'encoded' not in key:
    #        print("Key: %s"%key)
    #        print(d[key])
    #print(d.keys())
    #print("bbox_min: %s"%d['image/object/bbox/xmin'])
    #print("class label%s"%d['image/object/class/label'])
    #print("img height: %s"%d['image/height'].int64_list.value[0])
    text = [x.decode() for x in d['image/object/class/text'].bytes_list.value]
    for t in text:
        stats[t] += 1
        if t not in stats:
            print(t)
    if count % every_n == 0:
        try:
            img = tf.image.decode_jpeg(d['image/encoded'].bytes_list.value[0])
            img = img.eval()
            #print("Image shape %s"%str(img.shape))
            #plt.imshow(img)
            #img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            images.append(img)
            print('.', end='')
            titles.append(str(text))
        except Exception as e:
            print("EXCEPTION")
            print(e)
    
    # Exit after 1 iteration as this is purely demonstrative.
    count += 1
    if count >= num:
        break
    #print('\n--------------\n')
print("\nDone")
print(stats)
print("%d datapoints took %s"%(count, seconds_to_str(time.time()-time_start)))


In [None]:
print("Total of %d items in record (%d annotations)"%(count, sum(stats.values())))
for key in stats:
    print('%d \tx %s'%(stats[key], key))
plt.bar(stats.keys(), stats.values())
plt.xticks(rotation='vertical')
plt.show()

In [None]:
n, m = 0, count
show_images(images[n:m], titles=titles[n:m], cols = len(images[n:m]))# max(1, count/(1*every_n))