In [None]:
%cd /kaggle/input/refvos-yt/refvos-ytvos

In [None]:
!pip install transformers==4.38.2
!pip install torch==2.2.1+cu121 --extra-index-url https://download.pytorch.org/whl/cu121
!pip install torchvision==0.17.1+cu121 --extra-index-url https://download.pytorch.org/whl/cu121

In [None]:
!pip install yt-dlp -q -U

In [None]:
import transformers
print(transformers.__version__)

In [None]:
import torch
from transformers import BertModel, BertTokenizer
from lib import segmentation
from PIL import Image
import numpy as np
import os
import cv2
from torchvision import transforms as T
import time
import datetime
from scipy.signal import argrelextrema

In [None]:
import subprocess
# def download_video(video_link, video_path):
#   # Construct the command with the video link and specify the output format as mp4
#   command = f"yt-dlp -f 22 -o {video_path} {video_link} "    # Execute the command to download the video
#   subprocess.run(command, shell=True)

def download_video(video_link):
    # Construct the command with the video link and specify the output format as mp4
    command = fr"yt-dlp {video_link} -f 'bestvideo[ext=mp4]+bestaudio[ext=m4a]/bestvideo+bestaudio' -o '/kaggle/working/input_video.%(ext)s'"

    # Execute the command to download the video
    subprocess.run(command, shell=True)

In [None]:
def smooth(x, window_len=13, window='hanning'):
    """smooth the data using a window with requested size.

    This method is based on the convolution of a scaled window with the signal.
    The signal is prepared by introducing reflected copies of the signal
    (with the window size) in both ends so that transient parts are minimized
    in the begining and end part of the output signal.

    input:
        x: the input signal
        window_len: the dimension of the smoothing window
        window: the type of window from 'flat', 'hanning', 'hamming', 'bartlett', 'blackman'
            flat window will produce a moving average smoothing.

    output:
        the smoothed signal

    example:

    import numpy as np
    t = np.linspace(-2,2,0.1)
    x = np.sin(t)+np.random.randn(len(t))*0.1
    y = smooth(x)

    see also:

    numpy.hanning, numpy.hamming, numpy.bartlett, numpy.blackman, numpy.convolve
    scipy.signal.lfilter

    TODO: the window parameter could be the window itself if an array instead of a string
    """
    if x.ndim != 1:
        raise ValueError ("smooth only accepts 1 dimension arrays.")

    if x.size < window_len:
        return x

    if window_len < 3:
        return x

    # if not window in ['flat', 'hanning', 'hamming', 'bartlett', 'blackman']:
    #     raise ValueError ("Window is on of 'flat', 'hanning', 'hamming', 'bartlett', 'blackman'")

    s = np.r_[2 * x[0] - x[window_len:1:-1],
              x, 2 * x[-1] - x[-1:-window_len:-1]]
    #print(len(s))

    if window == 'flat':  # moving average
        w = np.ones(window_len, 'd')
    else:
        w = getattr(np, window)(window_len)
    y = np.convolve(w / w.sum(), s, mode='same')
    return y[window_len - 1:-window_len + 1]

In [None]:
class Frame:
    def __init__(self, id, frame):
        self.id = id
        self.frame = frame
def extract_key_frames(video_path):
  cap = cv2.VideoCapture(video_path)
  fps = int(cap.get(cv2.CAP_PROP_FPS))
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
  curr_frame = None
  prev_frame = None

  frame_diffs = []
  frames = []
  ret, frame = cap.read()
  i = 1

  while(ret):
      luv = cv2.cvtColor(frame, cv2.COLOR_BGR2LUV)
      curr_frame = luv
      if curr_frame is not None and prev_frame is not None:
          #logic here
          diff = cv2.absdiff(curr_frame, prev_frame)
          count = np.sum(diff)
          frame_diffs.append(count)
          frame = Frame(i-1, frame)
          frames.append(frame)
      prev_frame = curr_frame
      i = i + 1
      ret, frame = cap.read()
  cap.release()

  diff_array = np.array(frame_diffs)
  sm_diff_array = smooth(diff_array, 5)
  frame_indexes = np.asarray(argrelextrema(sm_diff_array, np.greater))[0]
  keyframes = []
  for i in frame_indexes:
    keyframes.append(frames[i - 1])
  print("# of frames : ", total_frames)
  print("# of keyframes : ", len(keyframes))
  del diff_array, sm_diff_array, frame_indexes
  return keyframes, fps, width, height, total_frames

In [None]:
def text_to_emb(tokenizer, exp):
  max_tokens = 10
  exp = " ".join(exp.lower().split())
  input_ids = tokenizer.encode(text=exp, add_special_tokens=True)
  input_ids = input_ids[:max_tokens]
  attention_mask = [0] * max_tokens
  padded_input_ids = [0] * max_tokens
  padded_input_ids[:len(input_ids)] = input_ids
  attention_mask[:len(input_ids)] = [1] * len(input_ids)
  emb = torch.tensor(padded_input_ids).unsqueeze(0)
  atten = torch.tensor(attention_mask).unsqueeze(0)
  return emb, atten

In [None]:
def evaluate(keyframes, fps, width, height, total_frames, model, bert_model, emb, atten, transformer, device, output_path):
  model.eval()
  with torch.no_grad():
    emb = emb.squeeze(1)
    atten = atten.squeeze(1)
    emb = emb.unsqueeze(-1)
    atten = atten.unsqueeze(-1)
    emb, atten = emb.to(device), atten.to(device)
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
    i = 0
    ff = None
    for f in keyframes:
      ff = f.frame
      fid = f.id
      transformed_frame = transformer(ff)
      last_hidden_states = bert_model(emb[:, :, 0], attention_mask=atten[:, :, 0])[0]
      embedding = last_hidden_states[:, 0, :]
      transformed_frame = torch.unsqueeze(transformed_frame, dim=0)
      transformed_frame = transformed_frame.to(device)
      output,_, _ = model(transformed_frame, embedding)
      output = output['out'].cpu()
      m = output.argmax(1).data.numpy()
      m = m.squeeze(0)
      color = np.array((0, 255, 0))
      mask = m.reshape(m.shape[0], m.shape[1]).astype('uint8')
      m = m > 0.5
      ff[m] = ff[m] * 0.5 + color * 0.5
      while(i <= fid):
        out.write(ff)
        i = i + 1
    while(i< total_frames):
      out.write(ff)
      i = i + 1
    out.release()

In [None]:
def get_transform():
  transforms = []
  # transforms.append(T.Resize(480, interpolation=Image.BILINEAR))
  transforms.append(T.ToTensor())
  transforms.append(T.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225]))
  return T.Compose(transforms)

In [None]:
def main(video_path, text_input, output_path):
#   start_time = time.time()
#   download_video(video_url, video_path)
#   total_time = time.time() - start_time
#   total_time_str = str(datetime.timedelta(seconds=int(total_time)))
#   print('{} Total time: {}'.format('downloading video', total_time_str))

  start_time = time.time()
  keyframes, fps, width, height, total_frames = extract_key_frames(video_path)
  total_time = time.time() - start_time
  total_time_str = str(datetime.timedelta(seconds=int(total_time)))
  print('{} Total time: {}'.format('keyframing', total_time_str))

  start_time = time.time()
  from args import get_parser
  parser = get_parser()
  s = '--resume ./checkpoints/model_davis.pth'
  args = parser.parse_args(s.split())
  device = torch.device(args.device)
  model = segmentation.__dict__[args.model](num_classes=2,
      aux_loss=False,
      pretrained=False,
      args=args)

  model.to(device)
  model_class = BertModel
  bert_model = model_class.from_pretrained(args.ck_bert)
  bert_model.to(device)
  checkpoint = torch.load(args.resume, map_location='cpu')
  bert_model.load_state_dict(checkpoint['bert_model'])
  model.load_state_dict(checkpoint['model'])
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
  total_time = time.time() - start_time
  total_time_str = str(datetime.timedelta(seconds=int(total_time)))
  print('{} Total time: {}'.format('loading models', total_time_str))

  start_time = time.time()
  transformer = get_transform()
  emb, atten = text_to_emb(tokenizer, text_input)
  evaluate(keyframes, fps, width, height, total_frames, model, bert_model, emb, atten, transformer, args.device, output_path)
  total_time = time.time() - start_time
  total_time_str = str(datetime.timedelta(seconds=int(total_time)))
  print('{} Total time: {}'.format('searching', total_time_str))

In [None]:
import json

f = open('/kaggle/input/youtube-links/video_link.json')
data = json.load(f)
video_link = data['link']
query = data['query']
f.close()

download_video(video_link)


In [None]:
main("/kaggle/working/input_video.mp4", query, "/kaggle/working/output_video.mp4")