# Extracting the JAAD images from videos and showing the crossing prediction

In [None]:
import numpy as np

from torch import from_numpy
from torch import cuda
from torch import no_grad
from torch import optim

from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from matplotlib.backends.backend_agg import FigureCanvasAgg

%matplotlib inline

from Code.GNN import *
from Code.SkeletonsDataset import *
from Code.ModelTrainEvaluate import *

import cv2
from PIL import Image

## Loading the skeletons dataset

In [2]:
subset = 'test'
numberOfClasses = 2
info = 2

dataset = SkeletonsDataset('Data/' + subset + '_annotations_with_skeletons.csv',
                                 normalization='minmax', target='cross', info=info, remove_undetected=False)

In [3]:
videos_list = dataset.loadedData['video'].tolist()

sorted(list(set(videos_list)))

['video_0001',
 'video_0003',
 'video_0004',
 'video_0007',
 'video_0008',
 'video_0009',
 'video_0010',
 'video_0011',
 'video_0012',
 'video_0014',
 'video_0017',
 'video_0020',
 'video_0021',
 'video_0023',
 'video_0024',
 'video_0027',
 'video_0028',
 'video_0030',
 'video_0032',
 'video_0035',
 'video_0037',
 'video_0038',
 'video_0039']

In [4]:
dataset.loadedData[['video','frame','skeleton','cross']]

Unnamed: 0,video,frame,skeleton,cross
0,video_0001,0,"[[1448.57, 674.503, 0.924437], [1442.74, 698.0...",not-crossing
1,video_0001,1,"[[1451.47, 674.493, 0.906491], [1445.66, 698.0...",not-crossing
2,video_0001,2,"[[1451.58, 674.604, 0.856983], [1448.49, 698.0...",not-crossing
3,video_0001,3,"[[1457.41, 677.444, 0.82119], [1448.66, 698.11...",not-crossing
4,video_0001,4,"[[1466.12, 677.516, 0.846119], [1451.56, 700.9...",not-crossing
...,...,...,...,...
5559,video_0039,355,"[[0.0, 0.0, 0.0], [609.777, 795.251, 0.774795]...",crossing
5560,video_0039,356,"[[595.029, 771.685, 0.15024], [600.952, 795.13...",crossing
5561,video_0039,357,"[[0.0, 0.0, 0.0], [597.887, 795.25, 0.869133],...",crossing
5562,video_0039,358,"[[0.0, 0.0, 0.0], [592.049, 792.265, 0.860434]...",crossing


In [5]:
# First element of the dataset:
t0 = dataset[0]

# Node features:
t1 = t0.x_temporal[0]

# Number of nodes:
numberOfNodes = t1.shape[0]

# Number of dimensions of each node features:
embed_dim = t1.shape[1]

print('Number of nodes per skeleton:', numberOfNodes)
print('Number of features per node:', embed_dim)

Number of nodes per skeleton: 25
Number of features per node: 3


## Loading the trained model

In [6]:
model = SpatialTemporalGNN(embed_dim, numberOfClasses, numberOfNodes)
model.load_state_dict(torch.load('exportedModels/Approach_2-3'))

<All keys matched successfully>

## Making the crossing/not-crossing prediction

In [7]:
loader = DataLoader(dataset, batch_size=500)

device = torch.device('cpu')

predictions, groundtruth = predict(model, loader, device)

## Loading the video and exporting the result as a GIF

In [8]:
videos_list = dataset.loadedData['video'].tolist()

for video_id in sorted(list(set(videos_list)))[0:10]:

    video = cv2.VideoCapture("Data/JAAD-videos/" + subset + "/" + video_id + ".mp4")


    # First column in the dataset where the video starts:
    video_first_dataset_row = videos_list.index(video_id)


    video_outputs = []

    frame_i = 0
    ret = True
    while ret:

        ret, frame = video.read()

        if ret:

            frame_prediction = int(predictions[video_first_dataset_row + frame_i])
            frame_groundtruth = int(groundtruth[video_first_dataset_row + frame_i])

            frame_prediction = "Crossing" if frame_prediction else "Not-crossing"
            frame_groundtruth = "Crossing" if frame_groundtruth else "Not-crossing"

            im_title = "Prediction: " + frame_prediction + "\nGroundtruth: " + frame_groundtruth

            try:
                
                fig = dataset.showSkeleton(videoNum=video_id, frameNum=frame_i, showLegend=False,
                                                 frameImage=frame, normalizedSkeletons=False, title=im_title, show=False)
            
                
                canvas = FigureCanvasAgg(fig)
                canvas.draw()
                frame_result = np.asarray(canvas.buffer_rgba()).astype(np.uint8)

                frame_result = Image.fromarray(frame_result)

                video_outputs.append(frame_result)

                canvas.get_renderer().clear()
                plt.close(fig)

                frame_i = frame_i + 1
                
            except:
                
                ret = False

    
    # Export the prediction result as a GIF:
    video_outputs[0].save("Videos_results/" + subset + "/" + video_id + ".gif", save_all=True,
                          append_images=video_outputs[1:], duration=30, loop=0)
    
    print('Exported video:', video_id)
    
    video.release()

Exported video: video_0001
Exported video: video_0003
Exported video: video_0004
Exported video: video_0007
Exported video: video_0008
Exported video: video_0009
Exported video: video_0010
Exported video: video_0011
Exported video: video_0012
Exported video: video_0014
