In [1]:
import torch
import ffmpeg
from transnet import TransNetV2
from inference import predict_video

model = TransNetV2()
state_dict = torch.load("transnetv2-pytorch-weights.pth")
model.load_state_dict(state_dict)


  state_dict = torch.load("transnetv2-pytorch-weights.pth")


<All keys matched successfully>

In [2]:
with torch.no_grad():
    # shape: batch dim x video frames x frame height x frame width x RGB (not BGR) channels
    scenes = predict_video('../../Dataset/00100.mp4', model)
    print("Detected scenes:", scenes)

[TransNetV2] Processing video frames 6480/6480Detected scenes: [[   0  433]
 [ 434  495]
 [ 496  553]
 [ 554  646]
 [ 647  809]
 [ 810  972]
 [ 973 1041]
 [1042 1110]
 [1111 1217]
 [1218 1497]
 [1498 1564]
 [1565 1617]
 [1618 1682]
 [1683 1725]
 [1726 2170]
 [2171 2212]
 [2213 2279]
 [2280 2338]
 [2339 2424]
 [2425 2572]
 [2573 2669]
 [2670 2754]
 [2755 2826]
 [2827 3236]
 [3237 3341]
 [3342 4011]
 [4012 4083]
 [4084 4144]
 [4145 4204]
 [4205 4299]
 [4300 4387]
 [4388 4413]
 [4414 4471]
 [4472 4579]
 [4580 4639]
 [4640 4832]
 [4833 4896]
 [4897 5115]
 [5116 5252]
 [5253 5444]
 [5445 5557]
 [5558 5658]
 [5659 5677]
 [5678 5770]
 [5771 5871]
 [5872 6012]
 [6013 6070]
 [6071 6256]
 [6257 6448]
 [6449 6479]]


In [3]:
import cv2
print(cv2.__version__)

4.11.0


In [4]:
import cv2
import numpy as np
import torch
import os
from PIL import Image, ImageTk
import tkinter as tk
from tkinter import ttk

def extract_frame(video_path, frame_idx):
    cap = cv2.VideoCapture(video_path)
    cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
    ret, frame = cap.read()
    cap.release()
    if not ret:
        return None
    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    return Image.fromarray(frame_rgb)

def visualize_shots_scrollable(video_path, scenes_data):
    if isinstance(scenes_data, torch.Tensor):
        scenes_data = scenes_data.cpu().numpy()
    elif isinstance(scenes_data, list):
        scenes_data = np.array(scenes_data)

    root = tk.Tk()
    root.title("Shot Boundary Viewer")
    def on_closing():
        root.quit()
        root.destroy()

    root.protocol("WM_DELETE_WINDOW", on_closing)

    main_frame = ttk.Frame(root)
    main_frame.pack(fill=tk.BOTH, expand=True)

    canvas = tk.Canvas(main_frame)
    scrollbar = ttk.Scrollbar(main_frame, orient=tk.VERTICAL, command=canvas.yview)
    scrollable_frame = ttk.Frame(canvas)

    scrollable_frame.bind(
        "<Configure>",
        lambda e: canvas.configure(
            scrollregion=canvas.bbox("all")
        )
    )

    canvas.create_window((0, 0), window=scrollable_frame, anchor="nw")
    canvas.configure(yscrollcommand=scrollbar.set)

    canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
    scrollbar.pack(side=tk.RIGHT, fill=tk.Y)

    for idx, (start, end) in enumerate(scenes_data):
        start_img = extract_frame(video_path, int(start))
        end_img = extract_frame(video_path, int(end))

        if start_img is None or end_img is None:
            continue

        start_img_tk = ImageTk.PhotoImage(start_img.resize((160, 90)))
        end_img_tk = ImageTk.PhotoImage(end_img.resize((160, 90)))

        pair_frame = ttk.Frame(scrollable_frame, padding=10)
        pair_frame.pack(fill=tk.X)

        ttk.Label(pair_frame, text=f"Shot {idx+1}: [{start}-{end}]").pack(anchor="w")

        images_frame = ttk.Frame(pair_frame)
        images_frame.pack()

        ttk.Label(images_frame, image=start_img_tk).pack(side=tk.LEFT, padx=5)
        ttk.Label(images_frame, image=end_img_tk).pack(side=tk.LEFT, padx=5)

        # Store reference to avoid garbage collection
        pair_frame.start_img_tk = start_img_tk
        pair_frame.end_img_tk = end_img_tk

    root.mainloop()

if __name__ == '__main__':
    video_file = '../../Dataset/00001.mp4'
    #example_scenes = [[0, 433], [434, 495], [496, 553], [554, 646]]
    with torch.no_grad():
    # shape: batch dim x video frames x frame height x frame width x RGB (not BGR) channels
        scenes = predict_video(video_file, model)
    scenes_np = np.array(scenes)

    if os.path.exists(video_file):
        visualize_shots_scrollable(video_file, scenes_np)
    else:
        print(f"Video file not found: {video_file}")


[TransNetV2] Processing video frames 6824/6824

In [64]:
import torch
import clip
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-L/14", device=device)

image = preprocess(Image.open("cat.jpg")).unsqueeze(0).to(device)
#image1 = preprocess(Image.open("car.png")).unsqueeze(0).to(device)
text = clip.tokenize(["a cat"]).to(device)

with torch.no_grad():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)

    print("Image features:", image_features.shape)
    print("Text features:", text_features.shape)
    im = torch.nn.functional.normalize(image_features[0], dim=0)
    print(im.min(), im.max(), im.mean())
    print("Diff: ", torch.nn.functional.cosine_similarity(im , torch.nn.functional.normalize(text_features, dim=0), dim=-1).item())
    
    logits_per_image, logits_per_text = model(torch.cat([image, image1], dim=0), text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()
    probs1 = logits_per_text.softmax(dim=-1).cpu().numpy()
    print("Image features:", probs1)

print("Top predictions:", probs)


Image features: torch.Size([1, 768])
Text features: torch.Size([1, 768])
tensor(-0.4855) tensor(0.5251) tensor(-0.0007)
Diff:  0.17194582521915436
Image features: [[9.9999309e-01 6.9734533e-06]]
Top predictions: [[1.]
 [1.]]


In [81]:
import faiss
import numpy as np
import scipy
import faiss.contrib.torch_utils

x_torch = torch.nn.functional.normalize(image_features, p=2, dim=1)
q_torch = torch.nn.functional.normalize(text_features, p=2, dim=1)

# If you need them back as numpy arrays for faiss
x = x_torch.numpy(force=True)
q = q_torch.numpy(force=True)

index = faiss.IndexFlatIP(x.shape[1])
index.add(x)
distance, idx = index.search(q, 5)
print('Distance by FAISS:{}'.format(distance))

# Cosine similarity for comparison
result = 1 - scipy.spatial.distance.cosine(x[0], q[0])
print('Cosine similarity:', result)

ValueError: input not a numpy array