In [13]:
from Prediction import dataset
from Prediction import train_eval
from Prediction import visualize
import imp
import numpy as np
import cv2 as cv
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
from matplotlib.widgets import Slider

In [3]:
imp.reload(train_eval)

<module 'Prediction.train_eval' from '/app/Pogona_realtime/Arena/Prediction/train_eval.py'>

In [4]:
all_df = dataset.collect_data(data_sources={'detections': True, 'timestamps': True, 'dlc': False, 'touches': True})

46 trials loaded


In [5]:
trials = all_df.index.unique()

In [6]:
trial = trials[0]
trial_df = all_df.loc[trial]
vid_path = dataset.get_trial_video_path(trial)
homography = dataset.homography_for_trial(trial)
correction_fn = visualize.get_correction_fn(homography, 1920)

In [7]:
input_labels = ['x1', 'y1', 'x2', 'y2']
output_labels = ['x1', 'y1', 'x2', 'y2']
input_dim = len(input_labels)
output_dim = len(output_labels)

inp_seq_len = 20
out_seq_len = 20

In [8]:
def get_vid_frames(vid_path, correction_fn, start, num):
    vcap = cv.VideoCapture(vid_path)
    vcap.set(cv.CAP_PROP_POS_FRAMES, start)
    frames = []
    for i in range(num):
        ret, frame = vcap.read()
        frames.append(correction_fn(frame))
        
    return frames

In [9]:
def draw_sequence(arr_X,
                  arr_Y,
                  arr_pred,
                  ax,
                  to_scatter=True,
                  l_alpha=0.35,
                  sctr_s=0.5,
                  sctr_alpha=1,
                  past_c='b', ftr_c='r', pred_c='g', diff_c='k',
                  draw_diffs=True,
                  zoom_x = None,
                  zoom_y = None):
    
    if len(arr_X.shape)==2:
        arr_X = arr_X.reshape(1,arr_X.shape[0],arr_X.shape[1])
        arr_Y = arr_Y.reshape(1,arr_Y.shape[0],arr_Y.shape[1])
        if draw_diffs:
            arr_pred = arr_pred.reshape(1,arr_pred.shape[0],arr_pred.shape[1])
    
    ax.add_collection(LineCollection(segments=[seq for seq in arr_X[:,:,:2]], colors=[past_c], label='past',alpha=l_alpha))
    ax.add_collection(LineCollection(segments=[seq for seq in arr_Y[:,:,:2]], colors=[ftr_c], label='future',alpha=l_alpha))
    if draw_diffs:
        ax.add_collection(LineCollection(segments=[seq for seq in arr_pred[:,:,:2]], colors=[pred_c],label='pred',alpha=l_alpha))
        diffs = [np.array([arr_pred[j,i,:2],arr_Y[j,i,:2]]) for i in range(out_seq_len) for j in range(arr_pred.shape[0])]
        ax.add_collection(LineCollection(segments=diffs, colors=[diff_c],label='diff',alpha=l_alpha))
    if to_scatter:
        ax.scatter(arr_X[:,:,0], arr_X[:,:,1], s=sctr_s, color=past_c, alpha=sctr_alpha)
        ax.scatter(arr_Y[:,:,0], arr_Y[:,:,1], s=sctr_s, color=ftr_c, alpha=sctr_alpha)
        if draw_diffs:
            ax.scatter(arr_pred[:,:,0], arr_pred[:,:,1], s=sctr_s, color=pred_c, alpha=sctr_alpha)
            
    if zoom_x:
        ax.set_xlim(zoom_x)
    if zoom_y:
        ax.set_ylim(zoom_y)


In [10]:
X, Y = train_eval.trial_to_samples(trial_df, input_labels, output_labels, inp_seq_len, out_seq_len, keep_nans=True)

In [65]:
seq_num = np.random.randint(0, 4000)
seq_num = 319
seq = X[seq_num], Y[seq_num]
alpha = 1/(inp_seq_len+out_seq_len)
cat_seq = np.concatenate(seq)
frames = get_vid_frames(vid_path, correction_fn, seq_num, inp_seq_len+out_seq_len)

%matplotlib widget
fig, ax = plt.subplots(1, 1, figsize=(7,8))
plt.subplots_adjust(left=0.25, bottom=0.1)
fig.suptitle(f"Trial: {trial} sequence num: {seq_num}")

ax_im = ax.imshow(frames[0])
draw_sequence(X[seq_num], Y[seq_num], None, ax, draw_diffs=False)
ax_cur_point = ax.scatter(seq[0][0, 0], seq[0][0, 1], c='g')

"""
fig, ax = plt.subplots(1, 1, figsize=(20,20))
for i in range(inp_seq_len+out_seq_len):
    ax.imshow(frames[i], alpha=alpha)
    #axs[i].set_xlim([cat_seq[:, 0].min(), cat_seq[:, 0].max()])
    #axs[i].set_ylim([cat_seq[:, 1].min(), cat_seq[:, 0].max()])

draw_sequence(X[seq_num], Y[seq_num], None, ax, draw_diffs=False, )


axs = axs.flatten()
for i in range(inp_seq_len+out_seq_len):
    ax = axs[i]
    ax.imshow(frames[i])
    #axs[i].set_xlim([cat_seq[:, 0].min(), cat_seq[:, 0].max()])
    #axs[i].set_ylim([cat_seq[:, 1].min(), cat_seq[:, 0].max()])
    draw_sequence(seq[0], seq[1], None, ax, draw_diffs=False)
    if i >= inp_seq_len:
        idx = i - inp_seq_len
        ax.scatter(seq[1][idx, 0], seq[1][idx, 1])
    else:
        ax.scatter(seq[0][i, 0], seq[0][i, 1])
"""

slider_ax = plt.axes([0.25, 0.1, 0.65, 0.03], facecolor='lightgoldenrodyellow')
slider = Slider(slider_ax, "timestep", 0, inp_seq_len+out_seq_len, valinit=0, valstep=1)

def update_timestep(t):
    t = int(t)
    ax_im.set_data(frames[t])
    if t >= inp_seq_len:
        idx = t - inp_seq_len
        ax_cur_point.set_offsets([seq[1][idx, 0], seq[1][idx, 1]])
    else:
        ax_cur_point.set_offsets([seq[0][t, 0], seq[0][t, 1]])
    
    fig.canvas.draw_idle()

slider.on_changed(update_timestep)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

0