In [1]:
import os
import numpy as np
from cotracker.utils.visualizer import Visualizer, read_video_from_path
from cotracker.models.core.model_utils import get_points_on_a_grid
import cv2
import torch
import torch.nn.functional as F
from scipy.ndimage import zoom
from scipy.interpolate import griddata
from scipy.ndimage import map_coordinates
import math

#tried to clear memory because of OOM errors
import gc
torch.cuda.empty_cache()
gc.collect()

DEFAULT_DEVICE = (
    "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
)

In [2]:
def load_video(file, size=(256,256)):
    """
    Load a video from the given file path, resize it to the specified size,
    and return the video tensor and frames. 
    args:
        file (str): Path to the video file.
        size (tuple): Desired size for the video frames (width, height).
    returns:
        video (torch.Tensor): A tensor containing the video frames, resized to the specified size.
        video_frames (list): A list of resized video frames as numpy arrays.
    """
    video_path = file
    video = read_video_from_path(video_path)
    video_frames = [cv2.resize(frame, size) for frame in video]
    video = torch.from_numpy(video).permute(0, 3, 1, 2).float()
    
    video = F.interpolate(video, size=size, mode='bilinear', align_corners=False)[None]

    return video, video_frames

In [3]:
def cotrack_model(video, grid_size):
    """"Run the CoTracker model on the provided video and grid points.
    Model defines which cotracker model is used for tracking.
    args:
        video (torch.Tensor): A tensor containing the video frames.
        grid_size (int): The size of the grid for tracking.
        points_x (list): List of x-coordinates for reference points.
        points_y (list): List of y-coordinates for reference points.
    returns:
        pred_tracks (torch.Tensor): Predicted tracks from the model.
        pred_visibility (torch.Tensor): Predicted visibility from the model.
        grid_pts (torch.Tensor): Points on the grid used for tracking.
    """
    model = torch.hub.load("facebookresearch/co-tracker", "cotracker3_offline").to("cuda")
    model = model.to(DEFAULT_DEVICE)
    video = video.to(DEFAULT_DEVICE)
    model.model.model_resolution = video.shape[3:]
    grid_pts = get_points_on_a_grid(
                grid_size, model.model.model_resolution
            )

    pred_tracks, pred_visibility = model(
        video,
        grid_size=grid_size,
        grid_query_frame=0,
        backward_tracking=True,
    )
    return pred_tracks, pred_visibility, grid_pts

In [4]:
def warping(predicted_tracks, frames):      
    """
    Warps the frames of a video based on predicted tracks.
    args:       
        predicted_tracks (torch.Tensor): Predicted tracks of shape (B, T, G, D).
        frames (torch.Tensor): Video frames of shape (T_orig, H, W, C).
    returns:
        warpeds (list): List of warped frames.
    """
    T_orig, H, W, C = frames.shape
    B, T, G, D = predicted_tracks.shape
    grid_size = int(math.sqrt(G))

    velocity = predicted_tracks[0].reshape(T, grid_size, grid_size, 2) #(24, 32, 32, 2)
    real_velocity = velocity-velocity[0] # (24, 32, 32, 2)
    v = real_velocity.transpose(0, 3, 1, 2) # (24, 2, 32, 32)
    vp = zoom(v, (1, 1, W/grid_size, H/grid_size))  #(24, 2, 256, 256) 
        
    warpeds = [frames[0][...,0]]

    for i in range(1,T_orig):
        grid_x, grid_y = np.meshgrid(np.arange(W), np.arange(H))
        grid_x = grid_x.astype(np.float32)
        grid_y = grid_y.astype(np.float32)

        phi = np.diff(vp,axis=0)[0:i].sum(0)
        grid_x += phi[0]
        grid_y += phi[1]
  
        warped = map_coordinates(frames[i][...,0].astype(np.float32), [grid_y, grid_x], order=3, mode='nearest')
        warpeds.append(warped)
       
    return np.array(warpeds)

In [8]:
torch.cuda.empty_cache()
gc.collect()

#define the video file path
video_file = "../../data/input/low_movement/1czi.tif"
if not os.path.exists(video_file):
    print(f"Video file {video_file} does not exist.")
else:
    print(f"Video file {video_file} exists.")

# Apply motion correction algorithm by
# loading the video, running the CoTracker model, and warping the frames

vid = load_video(video_file, size=(256, 256))
pred_tracks = cotrack_model(vid[0], grid_size=16)
result = warping(pred_tracks[0].cpu().numpy(), np.array(vid[1]))

Video file ../../data/input/low_movement/1czi.tif exists.


Using cache found in C:\Users\morit/.cache\torch\hub\facebookresearch_co-tracker_main


RuntimeError: shape '[1, 1, 3, 384, 512]' is invalid for input of size 100663296

In [None]:
# Print out warped results
for i in range(result.shape[0]):
    cv2.imwrite(f"output/frame_{i:04d}.png", result[i])  # Save each frame as an image
print("Warped frames saved to output directory.")

In [2]:
import numpy as np

nx, ny = (3, 2)

x = np.linspace(0, 1, nx)
print(x.shape)

y = np.linspace(0, 1, ny)
print(y.shape)

xv, yv = np.meshgrid(x, y)

xv.shape

(3,)
(2,)


(2, 3)