In [1]:
import xml.etree.ElementTree as ET
import matplotlib.pylab as plt
import numpy as np
import random
import csv
import os
from math import floor

WIN_LENGTH = 16
WIN_STRIDE = 8
VID_FRAME_LIMIT = 600

SPACE_DIM = 112
TIME_DIM = 16

ROOT_DIR = '/notebooks/Thesis/annotations/'
# DIR = ROOT_DIR + 'len' + str(WIN_LENGTH) + 'strd' + str(WIN_STRIDE) + 'lim' + str(VID_FRAME_LIMIT) + '/'
DIR = ROOT_DIR + 'sampleregions/'

tree = ET.parse('annotations_reformatted.xml')
videos = tree.getroot()

# List of suitable videos to use
id_list = []
for video in videos.iter('video'):
    status = video.find('status').text
    # width = video.find('width').text
    # length = int(video.find('length').text)
    if status == 'accepted':
        id_list.append(video.get('taskid'))

random.Random(40).shuffle(id_list)

i = 0
num_vids = len(id_list)

def write_anno_interval(writer, video, prefix, ignore_file):
    name = video.find('name').text
    tl = video.find('timeline')

    nintervals = len(tl)

    for i in range(nintervals):
        start = int(tl[i][0].text)
        end = int(tl[i][1].text)
        category = tl[i][2].text

        if end < WIN_LENGTH or (end - start) < 8: # exclude brief intervals
            log = "{}: excluding {} {} interval: {}-{}".format(prefix, name,category,start,end)
            print(log)
            ignore_file.write(log)
            ignore_file.write('\n')
            continue

        start = max(WIN_LENGTH-1, start) # ensure interval has enough prior frames
        writer.writerow([name, str(start), str(end), str(label_dict[category])])

def write_anno(writer, video):
    name = video.find('name').text
    tl = video.find('timeline')
    duration = int(video.find('length').text)
    frame_num = WIN_LENGTH - 1
    label_counts = [0, 0]
    interval = 0
    interval_last_frame = int(tl[interval][1].text)
    while frame_num < duration:
        while frame_num > interval_last_frame:
            interval += 1
            interval_last_frame = int(tl[interval][1].text)
        category = tl[interval].find('category').text
        writer.writerow([name, str(frame_num), str(label_dict[category])])
        frame_num += WIN_STRIDE

MAX_BEFORE_CRASH = 64
def add_to_train(writer, taskid):
    video = videos.find(f".//video[@taskid='{taskid}']")
    name = video.find('name').text
    width = int(video.find('width').text)
    height = int(video.find('height').text)
    length  = int(video.find('length').text)
    nW = floor(width/112/2)
    nH = floor(height/112/2)
    crashstart = int(video.find('crashstart').text)
    Tstart = max(crashstart-128, 0) # start sampling within 128 frames of impact frame
    crashsettled = int(video.find('crashsettled').text)
    Tend = min(crashsettled+32, length-32-16)
    for t in range(Tstart, Tend, 32):
        for h in range(0, nH):
            hstart = int(round((height - 112)/nH*h))
            hend = int(round((height - 112)/nH*(h+1)))-1
            for w in range(0, nW):
                wstart = int(round((width - 112)/nW*w))
                wend = int(round((width - 112)/nW*(w+1)))-1
                writer.writerow([taskid, name, t, t+31, hstart, hend, wstart, wend])

def add_to_valid(writer, taskid):
    video = videos.find(f".//video[@taskid='{taskid}']")
    name = video.find('name').text
    width = int(video.find('width').text)
    height = int(video.find('height').text)
    length  = int(video.find('length').text)
    crashstart = int(video.find('crashstart').text)
    Tstart = max(crashstart-32, 0) # start sampling within 64 frames of impact frame
    crashsettled = int(video.find('crashsettled').text)
    Tend = crashsettled
    for t in range(Tstart, Tend, 32):
        for h in range(0, height-112, 112):
            for w in range(0, width-112, 112):
                writer.writerow([taskid, name, t, t, h, h, w, w])

def add_to_test(writer, taskid):
    video = videos.find(f".//video[@taskid='{taskid}']")
    name = video.find('name').text
    writer.writerow([taskid, name])
                
try:
    os.mkdir(DIR)
except FileExistsError:
    pass

train_file = open(DIR + 'anno_train.csv', 'w', newline='')
valid_file = open(DIR + 'anno_valid.csv', 'w', newline='')
test_file = open(DIR + 'anno_test.csv', 'w', newline='')

train_writer = csv.writer(train_file)
valid_writer = csv.writer(valid_file)
test_writer = csv.writer(test_file)

# ignore_file = open(DIR + 'ignored.txt', 'w')

for i in range(0, num_vids):
    if (i % 10 == 0):
        add_to_valid(valid_writer, id_list[i])
    elif (i % 10 == 1):
        add_to_test(test_writer, id_list[i])
    else:
        add_to_train(train_writer, id_list[i])

train_file.close()
valid_file.close()
test_file.close()
# ignore_file.close()