In [1]:
import PIL
from tqdm import tqdm
import numpy as np
import argparse
import cv2
import os
import torch
from interact_tools import SamControler
from tracker.base_tracker import BaseTracker

In [2]:
class TrackingAnything():
    def __init__(self):
        self.samcontroler = SamControler("./sam_checkpoint/sam_vit_h_4b8939.pth", "vit_h", "cpu")
        self.xmem = BaseTracker("./xmem_checkpoint/XMem.pth", device="cpu")
    
    def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
        mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
        return mask, logit, painted_image

    def generator(self, images: list, template_mask:np.ndarray):
        masks = []
        logits = []
        painted_images = []
        for i in tqdm(range(len(images)), desc="Tracking image"):
            if i == 0:           
                mask, logit, painted_image = self.xmem.track(images[i], template_mask)
                masks.append(mask)
                logits.append(logit)
                painted_images.append(painted_image)
                
            else:
                mask, logit, painted_image = self.xmem.track(images[i])
                masks.append(mask)
                logits.append(logit)
                painted_images.append(painted_image)
        return masks, logits, painted_images

In [4]:
video_path = '.././dataset/train01.mp4'
output_folder = './train01_frame'
os.makedirs(output_folder, exist_ok=True)
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
    print("Error: Could not open video.")
    exit()

frame_count = 1

while True:
    ret, frame = cap.read()
    if not ret:
        break
    frame_filename = os.path.join(output_folder, f'{frame_count}.jpg')
    cv2.imwrite(frame_filename, frame)
    #print(f"Saved frame {frame_count} to {frame_filename}")
    frame_count += 1

cap.release()
print("Video processing complete.")

Video processing complete.


In [7]:
import pandas as pd

df = pd.read_csv('.././ground_truth/CATARACTS_2017/train_gt/train01.csv')
column_names = df.columns
for i in column_names:
    sum = df[i].sum()
    print(f"Sum of {i} column: {sum}")

Sum of Frame column: 103399390
Sum of biomarker column: 0
Sum of Charleux canula column: 0
Sum of hydrodissection canula column: 190.0
Sum of Rycroft canula column: 802.5
Sum of viscoelastic cannula column: 223.5
Sum of cotton column: 0
Sum of capsulorhexis cystotome column: 618.0
Sum of Bonn forceps column: 180.0
Sum of capsulorhexis forceps column: 59.5
Sum of Troutman forceps column: 0
Sum of needle holder column: 0
Sum of irrigation/aspiration handpiece column: 2970.0
Sum of phacoemulsifier handpiece column: 2036.0
Sum of vitrectomy handpiece column: 0
Sum of implant injector column: 219.0
Sum of primary incision knife column: 77.5
Sum of secondary incision knife column: 109.5
Sum of micromanipulator column: 3252.5
Sum of suture needle column: 0
Sum of Mendez ring column: 0
Sum of Vannas scissors column: 0


In [8]:
df_hc = df[df['hydrodissection canula'] != 0]
df_hc

Unnamed: 0,Frame,biomarker,Charleux canula,hydrodissection canula,Rycroft canula,viscoelastic cannula,cotton,capsulorhexis cystotome,Bonn forceps,capsulorhexis forceps,...,irrigation/aspiration handpiece,phacoemulsifier handpiece,vitrectomy handpiece,implant injector,primary incision knife,secondary incision knife,micromanipulator,suture needle,Mendez ring,Vannas scissors
2259,2260,0,0,0.5,0.0,0.0,0,0.0,0.0,0.0,...,0.0,0.0,0,0.0,0.0,0.0,0.0,0,0,0
2260,2261,0,0,0.5,0.0,0.0,0,0.0,0.0,0.0,...,0.0,0.0,0,0.0,0.0,0.0,0.0,0,0,0
2261,2262,0,0,0.5,0.0,0.0,0,0.0,0.0,0.0,...,0.0,0.0,0,0.0,0.0,0.0,0.0,0,0,0
2262,2263,0,0,0.5,0.0,0.0,0,0.0,0.0,0.0,...,0.0,0.0,0,0.0,0.0,0.0,0.0,0,0,0
2263,2264,0,0,1.0,0.0,0.0,0,0.0,0.0,0.0,...,0.0,0.0,0,0.0,0.0,0.0,0.0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2447,2448,0,0,1.0,0.0,0.0,0,0.0,0.0,0.0,...,0.0,0.0,0,0.0,0.0,0.0,0.0,0,0,0
2448,2449,0,0,1.0,0.0,0.0,0,0.0,0.0,0.0,...,0.0,0.0,0,0.0,0.0,0.0,0.0,0,0,0
2449,2450,0,0,1.0,0.0,0.0,0,0.0,0.0,0.0,...,0.0,0.0,0,0.0,0.0,0.0,0.0,0,0,0
2450,2451,0,0,0.5,0.0,0.0,0,0.0,0.0,0.0,...,0.0,0.0,0,0.0,0.0,0.0,0.0,0,0,0


In [None]:
file_list = os.listdir(output_folder)

allowed_extensions = ['.jpg']
filtered_files = [f for f in file_list if os.path.splitext(f)[1].lower() in allowed_extensions]
sorted_files = sorted(filtered_files, key=lambda x: int(os.path.splitext(x)[0]))

image_list = []

for file_name in sorted_files:
    file_path = os.path.join(image_folder, file_name)
    img = cv2.imread(file_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    if img is not None:
        image_list.append(img)
        #print(f"Successfully loaded {file_name}")
    else:
        print(f"Error loading {file_name}")

print(f"Total images loaded: {len(image_list)}")

In [None]:
import cv2
import matplotlib.pyplot as plt

image = image_list[0]

points = [(580, 600), (450, 500), (1340, 690), (1500, 1000)]

for point in points:
    cv2.circle(image, point, radius=10, color=(0, 0, 255), thickness=-1)

image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

plt.imshow(image_rgb)
plt.title("Marked Points")
plt.axis('on')
plt.show()