In [None]:
from utils import *
from utils.bisector_utils import *
from utils.other_utils import *
from utils.spectral import *
from utils.sam_utils import *
from utils.visualization_utils import *
import numpy as np

In [None]:
import torch
import numpy as np
from moge.model.v2 import MoGeModel
import os

device = 'cuda'
model = MoGeModel.from_pretrained("Ruicheng/moge-2-vitl-normal").to(device)   

def get_3d_coordinates_for_frame(frames, model, device, frame_idx):
    """Get 3D coordinates and intrinsics for a specific frame using MoGe"""
    input_image = frames[frame_idx]
    input_image = torch.tensor(input_image / 255, dtype=torch.float32, device=device).permute(2, 0, 1)
    
    output = model.infer(input_image)
    return output["points"], output["mask"], output["intrinsics"]

In [None]:
from pnp_RANSAC_first_frame import main

## Get frames

In [None]:
import pickle

with open(f"Data/frames_0.pkl", "rb") as f:
    frames = pickle.load(f)

In [None]:
example_is_needed = False

if example_is_needed: # only load if you need it, it is large
    with open(f"Data/example_0.pkl", "rb") as f:
        example = pickle.load(f)

## Experiment

In [None]:
import copy

grid_size=20
num_initializations=5
result_frequency = 1
max_iterations = 10

first_tresh = 10.0
other_tresh = 5.0
third_iter = False

vid = 'video_0'
path = vid

resss = []

all_results, frames, pred_tracks, pred_visibility = main(frames=frames, grid_size=grid_size, num_initializations=num_initializations, 
        ransac_threshold=first_tresh, result_frequency=result_frequency
    )

resss.append(all_results)

outlier_tracks = [get_outlier_tracks(all_results)]

ppred_vis = copy.deepcopy(pred_visibility)
alll_res = copy.deepcopy(all_results)

new_pred_visibility = copy.deepcopy(create_modified_visibility(ppred_vis, alll_res))

updated_all_results, frames, pred_tracks, updated_pred_visibility = main(
        frames=frames, grid_size=grid_size, num_initializations=num_initializations, 
        ransac_threshold=other_tresh, result_frequency=result_frequency, 
        pred_tracks=pred_tracks, pred_visibility=new_pred_visibility, verbose=True, use_sam=True
    )

resss.append(updated_all_results)

if updated_all_results[('track_based', 'frame_zero_to_all')][2]['frame_pair_details']:
    outlier_tracks.append(get_outlier_tracks(updated_all_results))

###################################################################################################
if third_iter:
    updated_pred_visibility_to_use = copy.deepcopy(updated_pred_visibility)
    updated_all_results_to_use = copy.deepcopy(updated_all_results)
    new_pred_visibility = copy.deepcopy(create_modified_visibility(updated_pred_visibility_to_use, updated_all_results_to_use))

    outl_tracks = (all_results[('track_based', 'frame_zero_to_all')][2]['frame_pair_details']['global_outlier_tracks'])
    errors = all_results[('track_based', 'frame_zero_to_all')][2]['frame_pair_details']['global_track_errors']
    error_old_to_compare = np.mean(list({tid: errors[tid] for tid in outl_tracks if tid in errors}.values()))

    updated_all_results, frames, pred_tracks, updated_pred_visibility = main(
            frames=frames, grid_size=grid_size, num_initializations=num_initializations, 
            ransac_threshold=other_tresh, result_frequency=result_frequency, 
            pred_tracks=pred_tracks, pred_visibility=new_pred_visibility, verbose=True, use_sam=True
        )
    
    resss.append(updated_all_results)

    inlier_tracks = (all_results[('track_based', 'frame_zero_to_all')][2]['frame_pair_details']['global_inlier_tracks'])
    errors = all_results[('track_based', 'frame_zero_to_all')][2]['frame_pair_details']['global_track_errors']
    error_new_to_compare = np.mean(list({tid: errors[tid] for tid in inlier_tracks if tid in errors}.values()))

    if updated_all_results[('track_based', 'frame_zero_to_all')][2]['frame_pair_details']:
        outlier_tracks.append(get_outlier_tracks(updated_all_results))
    else:
        print('EMPTY. STOPPING')

###################################################################################################

## Obtain GIF

In [None]:
create_hierarchical_outliers_gif(
frames=frames,
pred_tracks=pred_tracks,
pred_visibility=pred_visibility,
outlier_tracks=outlier_tracks,
output_path=path + f'/{vid}.gif',
duration=600,
# figsize=(12, 8)
)

## Obtain json file for 3d visualization

In [None]:
frame_indices = list(range(0, 10))
labels = create_labels_from_outlier_tracks(outlier_tracks, pred_tracks.shape)

save_multi_frame_points_for_web_with_colored_tracks_and_error_lines(
    frames, model, 'cuda', 
    path + f"/visualization_3d.json",
    frame_indices=frame_indices,
    max_points_per_frame=15000,
    verbose=True,
    pred_tracks=pred_tracks, 
    pred_visibility=pred_visibility,
    track_labels=labels,
    all_results=all_results  # Pass the results from first iteration
)

## Obtain Reprojection Error Results

In [None]:
create_reprojection_error_visualization(
    frames=frames[:5],
    pred_tracks=pred_tracks[:, :5, :, :],
    pred_visibility=pred_visibility[:, :5, :],
    outlier_tracks=outlier_tracks,
    all_results=all_results,
    model=model,
    device=device,
    output_path=path + f'/reprojection_errors.gif',
    duration=600,
    error_threshold=10.0,
    show_error_lines=True,
    show_inliers=True,
    show_outliers=True
)

In [None]:
create_multi_iteration_reprojection_visualization(
    frames=frames,
    pred_tracks=pred_tracks,
    pred_visibility=pred_visibility,
    outlier_tracks=outlier_tracks,  # Your list: [outlier_tracks[0], outlier_tracks[1]]
    all_results_list=resss,  # Results from each iteration
    model=model,
    device=device,
    output_path=path + f'/multi_iteration_reprojection_errors.gif',
    duration=600,
    error_threshold=10.0,
    show_error_lines=True
)

## Obtain Spectral Clustering Results

In [None]:
labels = run_spectral(frames, '0debug', model, 'cuda', path)

## Obtain SAM-based post-processing

In [None]:
results = run_sam_outlier_analysis(
    frames=frames[:5],
    tracks=pred_tracks[:, :5, :, :], 
    visibility=pred_visibility[:, :5, :],
    outlier_tracks_level0=get_outlier_tracks(all_results),
    sam_checkpoint_path="sam_vit_b_01ec64.pth",
    min_outliers=10,
    outlier_ratio=0.5,
    output_dir=path
)