# Extracting the CARLA simulator images from videos and showing the crossing prediction

In [None]:
import numpy as np

from torch import from_numpy
from torch import cuda
from torch import no_grad
from torch import optim

from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from matplotlib.backends.backend_agg import FigureCanvasAgg

%matplotlib inline

from Code.GNN import *
from Code.SkeletonsDataset import *
from Code.ModelTrainEvaluate import *

import cv2
from PIL import Image

## Loading the skeletons dataset

In [None]:
subset = 'test'
numberOfClasses = 2
info = 2

dataset = SkeletonsDataset('Data/CARLA/' + subset + '_preprocessed.csv', numberOfJoints=26,
                           normalization='minmax', target='crossing', info=info, remove_undetected=False)

In [None]:
videos_list = dataset.loadedData['video'].unique().tolist()

videos_list

In [None]:
dataset.loadedData[['video','frame','skeleton','crossing']]

In [None]:
# First element of the dataset:
t0 = dataset[0]

# Node features:
t1 = t0.x_temporal[0]

# Number of nodes:
numberOfNodes = t1.shape[0]

# Number of dimensions of each node features:
embed_dim = t1.shape[1]

print('Number of nodes per skeleton:', numberOfNodes)
print('Number of features per node:', embed_dim)

## Loading the trained model

In [None]:
model = SpatialTemporalGNN(embed_dim, numberOfClasses, numberOfNodes)

#model_path = 'exportedModels/CARLA/Approach_2-3'
model_path = 'exportedModels/CARLA/Full train dataset/SpatialTemporal - 5 frames/Epoch_199'
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

## Making the crossing/not-crossing prediction

In [None]:
loader = DataLoader(dataset, batch_size=500)

device = torch.device('cpu')

predictions, groundtruth = predict(model, loader, device)

## Loading the video and exporting the result as a GIF

In [None]:
import os

quantity = 20

gen_clips = os.listdir('Data/CARLA-videos/')

videos_list_frames = dataset.loadedData['video'].tolist()

exported = 0

for video_id in videos_list:
    
    video_file_id = video_id.replace('clips/', '').replace('.mp4', '')
    
    if (video_file_id + '.mp4') not in gen_clips:
        continue
    
    print('Starting processing of video', video_file_id)
    
    video = cv2.VideoCapture("Data/CARLA-videos/" + video_file_id + ".mp4")
    

    # First column in the dataset where the video starts:
    video_first_dataset_row = videos_list_frames.index(video_id)


    video_outputs = []

    frame_i = 0
    ret = True
    while ret:

        ret, frame = video.read()

        if ret:

            frame_prediction = int(predictions[video_first_dataset_row + frame_i])
            frame_groundtruth = int(groundtruth[video_first_dataset_row + frame_i])

            frame_prediction = "Crossing" if frame_prediction else "Not-crossing"
            frame_groundtruth = "Crossing" if frame_groundtruth else "Not-crossing"

            im_title = "Prediction: " + frame_prediction + "\nGroundtruth: " + frame_groundtruth

            try:
                
                fig = dataset.showSkeleton(videoNum=video_id, frameNum=frame_i, showLegend=False, frameImage=frame, 
                                           normalizedSkeletons=False, title=im_title, show=False, prediction=frame_prediction, groundtruth=frame_groundtruth)
            
                
                canvas = FigureCanvasAgg(fig)
                canvas.draw()
                frame_result = np.asarray(canvas.buffer_rgba()).astype(np.uint8)

                frame_result = Image.fromarray(frame_result)

                video_outputs.append(frame_result)

                canvas.get_renderer().clear()
                plt.close(fig)

                frame_i = frame_i + 1
                                
            except:
                
                ret = False

    
    # Export the prediction result as a GIF:
    video_outputs[0].save("Videos_results/CARLA/" + subset + "/" + video_file_id + ".gif", save_all=True,
                          append_images=video_outputs[1:], duration=30, loop=0)
    
    exported = exported + 1
    
    print('Exported video:', video_file_id, ' - Clip', str(exported) + '/' + str(quantity), '\n')
    
    video.release()
    
    if exported == quantity:
        break