In [None]:
import os
import cv2
import torch
from flask import Flask, request, jsonify, send_from_directory
from flask_cors import CORS
from sam2.build_sam import build_sam2_video_predictor
import numpy as np
from segmentation_module import SegmentationModule
from video_processor import VideoProcessor

# Initialize Flask
app = Flask(__name__)
CORS(app)

# Create folders if they dont exist
UPLOAD_FOLDER = 'uploads/'
FRAME_FOLDER = 'frames/'
MODEL_CHECKPOINT = 'models/sam2_hiera_large.pt'
MODEL_CONFIG = 'model/sam2_hiera_1.yaml'

if not os.path.exists(UPLOAD_FOLDER):
    os.makedirs(UPLOAD_FOLDER)

if not os.path.exists(FRAME_FOLDER):
    os.makedirs(FRAME_FOLDER)

# Initialize Video Processor and Segmentation Module
video_processor = VideoProcessor()
segmentation_module = SegmentationModule()

# Route to upload video
@app.route('/upload', methods=['POST'])
def upload_video():
    if 'video' not in request.files:
        return jsonify({'error': 'No video file provided'}), 400
    
    file = request.files['video']
    file_path = os.path.join(UPLOAD_FOLDER, file.filename)
    file.save(file_path)

    # Extract frames
    video_processor.video_to_frames(file_path)

    return jsonify({'message': 'Video uploaded and frames extracted successfully'})

# Route to get frames
@app.route('/predict_mask', methods=['POST'])
def predict_mask():
    data = request.json
    frame_idx = data['frame_idx']
    points = np.array(data['points'])
    labels = np.array(data['labels'])
    obj_id = data['obj_id']

    # Perform mask prediction
    out_obj_ids, out_mask_logits = segmentation_module.add_click(frame_idx, obj_id, points, labels)

    return jsonify({'mask_logits': out_mask_logits.tolist(), 'obj_ids': out_obj_ids.tolist()}), 200

# Route to propagate mask
@app.route('/propagate_mask', methods=['POST'])
def propagate_mask():
    # Perform mask propagation over all frames
    propagated_results = segmentation_module.propagate_segmentation()

    return jsonify(propagated_results), 200

# Main driver
if __name__ == '__main__':
    app.run(debug=True, port=5000)
