In [None]:
#import dependencies

import torchvision
import model.c2d as c2d
import os
import numpy as np
from torchvision.utils import save_image, make_grid
from torchvision.models import vgg19

import torch.nn as nn
import torch
import math
import cv2
from datetime import datetime
import time
import pandas as pd
from torch.utils.data import DataLoader

In [None]:
#Dataset class for generating video snippet samples from raw videos.
class trafficvidset(torch.utils.data.Dataset):
    def __init__(self,vid_folder,audio_file, vid_fps, duration,cols=None,clip_delta=1,normalize_examples=320):
        """
        Args:
            vid_folder: Folder containing videos with naming pattern as - YYYY-MM-DD_HH-MM-SS.mp4.
                        example - 2022-04-13_12-14-33.mp4
                        this is required for synchromizing audio levels from the csv to correct video frames
            audio_file: A csv file where every row is a frequency domain audio level recorded for a particular 
                        unix timestamp
            vid_fps: Frame-rate of the source videos
            duration: duration of the snippets to be sampled
            cols:Frequency channels to be included from the file. "None" means all frequency channels will be included
            clip_delta: To strike a balance between having manageable sized data set and to span across almost all the recorded data,
                        clip delta determines the time duration between two subsequent samples.
            normalize_examples=320
        Returns:
            The decorated function will return the unbatched computation output Tensors.
          """
        self.vid_fps = vid_fps
        self.duration = duration
        self.df = pd.read_csv(audio_file,header=None)
        self.cols = cols
        self.norm_eg = 320
        self.curr_eg = 0
        self.running_norm = None
        self.running_std = None
        if self.cols is None:
            self.cols = self.df.columns[1:] 
        vid_files = []
        self.clips = []
        for vid_file in os.listdir(vid_folder):
            if vid_file[-4:] != '.mp4':
                continue
            vid_time_stamp = int(time.mktime(datetime.strptime(vid_file[0:-4], '%Y-%m-%d_%H-%M-%S').timetuple()))
            vid_obj = cv2.VideoCapture(os.path.join(vid_folder,vid_file)) 
            last_second = int(vid_obj.get(cv2.CAP_PROP_FRAME_COUNT)/vid_fps) - duration 
            vid_tuple = [vid_obj,vid_obj.get(cv2.CAP_PROP_FRAME_COUNT),vid_time_stamp,last_second]
            vid_files.append(vid_tuple)
            for i in range(0,last_second+1,clip_delta):
                self.clips.append([i,vid_tuple])
        print("no. of clips - " + str(len(self.clips)))
                    
    def __len__(self):
        return(len(self.clips))
    
    def __getitem__(self,id):
        
        found = False        
        while(not found):

            vid_tuple = self.clips[id][1]
            vid_start_second =  self.clips[id][0]
            audio_start_tstamp = vid_tuple[2] + vid_start_second + math.ceil(duration/2)
            video_snippet_start_index = vid_start_second*self.vid_fps
            video_snippet_end_index = video_snippet_start_index + (self.duration*self.vid_fps)
            vid_tuple[0].set(cv2.CAP_PROP_POS_FRAMES,video_snippet_start_index)
            y = self.df[self.df[0] == audio_start_tstamp][self.cols].to_numpy()
            
            if y.shape[0] < 1:
                if id < (len(self.clips)-1):
                    id += 1
                else:
                    id = 0
                continue

            frames = []
            for j in range(video_snippet_start_index,video_snippet_end_index):

                retval,frame = vid_tuple[0].read()
                if not retval:                    
                    print(retval,video_snippet_start_index,j,vid_tuple[1],vid_start_second,last_second)
                frames.append(frame)
            x_arr = np.concatenate(frames,axis=2).astype(np.float32)
            
            break 
        width = x_arr.shape[1]
        x_arr = np.concatenate([x_arr[...,:int(width/2),:],x_arr[...,int(width/2):,:]],axis=2)
        return np.transpose(x_arr,(2,0,1)),y.astype(np.float32)[0] 
    

class trafficvidset_optflow(torch.utils.data.Dataset):
    def __init__(self,vid_folder,audio_file, vid_fps, duration,cols=None,clip_delta=1):
        self.vid_fps = vid_fps
        self.duration = duration
        self.df = pd.read_csv(audio_file,header=None)
        self.cols = cols
        if self.cols is None:
            self.cols = self.df.columns[1:] 
        vid_files = []
        self.clips = []
        for vid_file in os.listdir(vid_folder):
            if vid_file[-4:] != '.mp4':
                continue
            vid_time_stamp = int(time.mktime(datetime.strptime(vid_file[0:-4], '%Y-%m-%d_%H-%M-%S').timetuple()))
            vid_obj = cv2.VideoCapture(os.path.join(vid_folder,vid_file)) 
            last_second = int(vid_obj.get(cv2.CAP_PROP_FRAME_COUNT)/vid_fps) - duration 
            vid_tuple = [vid_obj,vid_obj.get(cv2.CAP_PROP_FRAME_COUNT),vid_time_stamp,last_second]
            vid_files.append(vid_tuple)
            for i in range(0,last_second+1,clip_delta):
                self.clips.append([i,vid_tuple])
        print("no. of clips - " + str(len(self.clips)))
                    
    def __len__(self):
        return(len(self.clips))
    
    def __getitem__(self,id):
        
        found = False        
        while(not found):

            vid_tuple = self.clips[id][1]
            vid_start_second =  self.clips[id][0]
            audio_start_tstamp = vid_tuple[2] + vid_start_second + math.ceil(duration/2)
            video_snippet_start_index = vid_start_second*self.vid_fps
            video_snippet_end_index = video_snippet_start_index + (self.duration*self.vid_fps)
            vid_tuple[0].set(cv2.CAP_PROP_POS_FRAMES,video_snippet_start_index)
            y = self.df[self.df[0] == audio_start_tstamp][self.cols].to_numpy()
            
            if y.shape[0] < 1:
                if id < (len(self.clips)-1):
                    id += 1
                else:
                    id = 0
                continue

            frames = []
            for j in range(video_snippet_start_index,video_snippet_end_index):

                retval,frame = vid_tuple[0].read()
                if not retval:                    
                    print(retval,video_snippet_start_index,j,vid_tuple[1],vid_start_second,last_second)
                frames.append(frame[...,[0,2]])
            x_arr = [np.concatenate(frames,axis=2)]
            y_arr = [(y[0])]
            
            break            
        
        y_arr = np.array(y_arr).astype(np.float32)
        x_arr = np.array(x_arr).astype(np.float32)
        width = x_arr.shape[2]
        x_arr = np.concatenate([x_arr[...,:int(width/2),:],x_arr[...,int(width/2):,:]],axis=3)
        return np.transpose(x_arr[0],(2,0,1)),y_arr[0] 
        

        
class trafficvidset_normalization(torch.utils.data.Dataset):
    def __init__(self,vid_folder,audio_file, vid_fps, duration,cols=None,clip_delta=1,normalize_clips=150,frames_per_clip=2):
        self.vid_fps = vid_fps
        self.duration = duration
        self.df = pd.read_csv(audio_file,header=None)
        self.cols = cols
        if self.cols is None:
            self.cols = self.df.columns[1:] 
        vid_files = []
        self.clips = []
        for vid_file in os.listdir(vid_folder):
            if vid_file[-4:] != '.mp4':
                continue
            vid_time_stamp = int(time.mktime(datetime.strptime(vid_file[0:-4], '%Y-%m-%d_%H-%M-%S').timetuple()))
            vid_obj = cv2.VideoCapture(os.path.join(vid_folder,vid_file)) 
            last_second = int(vid_obj.get(cv2.CAP_PROP_FRAME_COUNT)/vid_fps) - duration 
            vid_tuple = [vid_obj,vid_obj.get(cv2.CAP_PROP_FRAME_COUNT),vid_time_stamp,last_second]
            vid_files.append(vid_tuple)
            for i in range(0,last_second+1,clip_delta):
                self.clips.append([i,vid_tuple])
        print("no. of clips - " + str(len(self.clips)))
        metric_frames = []
        for id in np.random.randint(0,len(self.clips),normalize_clips):
            
            vid_tuple = self.clips[id][1]
            for i,frame_id in enumerate(np.random.randint(0,vid_tuple[0].get(cv2.CAP_PROP_FRAME_COUNT),frames_per_clip)):
                vid_tuple[0].set(cv2.CAP_PROP_POS_FRAMES,frame_id)
                retval,frame = vid_tuple[0].read()
                width = frame.shape[1]
                if retval is True:                    
                    metric_frames.append(frame) 
        metric_frames = np.array(metric_frames)
        self.norm1 = np.mean(metric_frames[...,:int(width/2),:],(0,1,2),keepdims=True)
        self.std1 = np.std(metric_frames[...,:int(width/2),:],(0,1,2),keepdims=True)
        self.norm2 = np.mean(metric_frames[...,int(width/2):,:],(0,1,2),keepdims=True)
        self.std2 = np.std(metric_frames[...,int(width/2):,:],(0,1,2),keepdims=True)
 
            
                    
    def __len__(self):
        return(len(self.clips))
    
    def __getitem__(self,id):
        
        found = False        
        while(not found):

            vid_tuple = self.clips[id][1]
            vid_start_second =  self.clips[id][0]
            audio_start_tstamp = vid_tuple[2] + vid_start_second + math.ceil(duration/2)
            video_snippet_start_index = vid_start_second*self.vid_fps
            video_snippet_end_index = video_snippet_start_index + (self.duration*self.vid_fps)
            vid_tuple[0].set(cv2.CAP_PROP_POS_FRAMES,video_snippet_start_index)
            y = self.df[self.df[0] == audio_start_tstamp][self.cols].to_numpy()
            
            if y.shape[0] < 1:
                if id < (len(self.clips)-1):
                    id += 1
                else:
                    id = 0
                continue

            frames = []
            for j in range(video_snippet_start_index,video_snippet_end_index):

                retval,frame = vid_tuple[0].read()
                if not retval:                    
                    print(retval,video_snippet_start_index,j,vid_tuple[1],vid_start_second,last_second)
                    
                frames.append(frame)
            x_arr = np.concatenate(frames,axis=2).astype(np.float32)
            
            break 
        width = x_arr.shape[1]
        x_arr = np.concatenate([x_arr[...,:int(width/2),:],x_arr[...,int(width/2):,:]],axis=2)
        return np.transpose(x_arr,(2,0,1)),y.astype(np.float32)[0]  