In [2]:
import os
import cv2
import torch
from model import VideoSaliencyModel
import argparse
from utils import *
from os.path import join
from torchvision import transforms
import yaml
from PIL import Image

import concurrent.futures
import concurrent.futures
import os
from os.path import join
import torch
from torch.utils.data import DataLoader , Dataset
from moviepy import VideoFileClip

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

VIDEO_DIR = '/ssd_scratch/cvit/sarthak395/outputs/IyMdcXl4vag/raw_videos'
OUTPUT_DIR = '/ssd_scratch/cvit/sarthak395/outputs/IyMdcXl4vag/saliency_results'
BATCH_SIZE = 16
FILE_WEIGHT = '/home2/sarthak395/Sony_Shorts_Creator/ViNet_Saliency/saved_models/ViNet_DHF1K.pt'

In [None]:
class SaliencyPredictionDataset(Dataset):
    def __init__(self , video_dir , save_dir , len_temporal = 32):
        self.video_dir = video_dir
        self.len_temporal = len_temporal
        self.save_dir = save_dir
        self.video_list = os.listdir(video_dir) # contains the video names : ['IyMdcXl4vag_1.mp4' , 'IyMdcXl4vag_2.mp4' , ...]
        self.video_list.sort()
        self._create_samples()
    
    def _create_samples(self):
        self.samples = [] # contains the samples of the form (video_path , start_frame , end_frame , save_path , flip)
        for video_name in self.video_list:
            video_path = os.path.join(self.video_dir , video_name)
            # clip = VideoFileClip(video_path)# remove logs
            clip  = VideoFileClip(video_path)
            num_frames = int(clip.fps * clip.duration)
            
            # processing the initial (len_temporal-1) frames
            for i in range(self.len_temporal - 1):
                if i < num_frames:
                    start_frame = i
                    end_frame = i + self.len_temporal - 1 # total frames = len_temporal
                    save_path = os.path.join(self.save_dir , video_name.split('.')[0] , f'frame_{i:04d}.png')
                    self.samples.append((video_path , start_frame , end_frame , save_path , True)) # need to flip the frames
                else:
                    break

            # processing the rest of the frames
            for i in range(self.len_temporal - 1 , num_frames):
                start_frame = i - self.len_temporal + 1
                end_frame = i
                save_path = os.path.join(self.save_dir , video_name.split('.')[0] , f'frame_{i:04d}.png')
                self.samples.append((video_path , start_frame , end_frame , save_path , False))
            
        print(f'Total samples created : {len(self.samples)}')
    
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        video_path , start_frame , end_frame , save_path , flip = self.samples[idx]
        clip = VideoFileClip(video_path)
        
        # can you use moviepy to read the frames from start_frame to end_frame
        assert end_frame - start_frame + 1 == self.len_temporal , f'len(frames) = {len(frames)} != {self.len_temporal}'
        num_frames = int(clip.fps * clip.duration)
        subclip = clip.subclipped(start_frame / clip.fps , (end_frame+1) / clip.fps)

        frames = list(subclip.iter_frames(fps=clip.fps, dtype='uint8'))
        frames = [Image.fromarray(frame).convert('RGB') for frame in frames]
        frames = [self.torch_transform(frame)[0] for frame in frames]

        # make frames to be of length len_temporal , first by selecting only first len_temporal frames
        # and then by padding the rest with zeros
        frames = frames[:self.len_temporal]
        if len(frames) < self.len_temporal:
            for i in range(self.len_temporal - len(frames)):
                if not flip:
                    frames.insert(0 , torch.zeros_like(frames[0]))
                else:
                    frames.append(torch.zeros_like(frames[0]))

        clip = torch.FloatTensor(torch.stack(frames, dim=0)) # (len_temporal , 3 , H , W)
        # clip = clip.permute((0,2,1,3,4)) # 
        clip = clip.permute((1,0,2,3)) # (3 , len_temporal , H , W)
        if flip:
            clip = torch.flip(clip , [1])

        return clip , idx
    
    def torch_transform(self , img):
        img_transform = transforms.Compose([
                transforms.Resize((224, 384)),
                transforms.ToTensor(),
                transforms.Normalize(
                    [0.485, 0.456, 0.406],
                    [0.229, 0.224, 0.225]
                )
        ])
        sz = img.size
        img = img_transform(img)
        return img, sz

In [None]:
saliency_prediction_dataset = SaliencyPredictionDataset(VIDEO_DIR , OUTPUT_DIR)

In [None]:
sample = saliency_prediction_dataset[2030]
print("sample shape : " , sample[0].shape)
print("sample save path : " , saliency_prediction_dataset.samples[sample[1]][3])

In [None]:
saliency_prediction_dataloader = DataLoader(saliency_prediction_dataset , batch_size = BATCH_SIZE , shuffle = False)

In [None]:
# check out a sample from the dataloader
sample = next(iter(saliency_prediction_dataloader))
print(sample[0].shape) # (BATCH_SIZE , 3 , len_temporal , H , W)

In [None]:
model = VideoSaliencyModel(
    transformer_in_channel=32, 
    nhead=4,
    use_upsample=True,
    num_hier=3,
    num_clips=32   
)
model.load_state_dict(torch.load(FILE_WEIGHT))
model.to(device)
model.eval()

In [None]:
# GET A PREDICTION FOR A SINGLE SAMPLE
test_sample = next(iter(saliency_prediction_dataloader))
test_output = model(test_sample[0].to(device))
print(test_output.shape) # (BATCH_SIZE , 1 , H , W)

In [3]:
clip = VideoFileClip('/ssd_scratch/cvit/sarthak395/outputs/IyMdcXl4vag/raw_videos/IyMdcXl4vag_1.mp4')
        
# get image size of the first frame
first_frame = clip.get_frame(0)
img_size = (first_frame.shape[1] , first_frame.shape[0]) # (W , H)
print(img_size)

{'video_found': True, 'audio_found': True, 'metadata': {'major_brand': 'isom', 'minor_version': '512', 'compatible_brands': 'isomiso2avc1mp41', 'encoder': 'Lavf61.1.100'}, 'inputs': [{'streams': [{'input_number': 0, 'stream_number': 0, 'stream_type': 'video', 'language': None, 'default': True, 'size': [854, 480], 'bitrate': 817, 'fps': 25.0, 'codec_name': 'h264', 'profile': '(High)', 'metadata': {'Metadata': '', 'handler_name': 'VideoHandler', 'vendor_id': '[0][0][0][0]', 'encoder': 'Lavc61.3.100 libx264'}}, {'input_number': 0, 'stream_number': 1, 'stream_type': 'audio', 'language': None, 'default': True, 'fps': 44100, 'bitrate': 128, 'metadata': {'Metadata': '', 'handler_name': 'SoundHandler', 'vendor_id': '[0][0][0][0]'}}], 'input_number': 0}], 'duration': 32.0, 'bitrate': 952, 'start': 0.0, 'default_video_input_number': 0, 'default_video_stream_number': 0, 'video_codec_name': 'h264', 'video_profile': '(High)', 'video_size': [854, 480], 'video_bitrate': 817, 'video_fps': 25.0, 'defau