In [None]:
# 16  CELL 1: Setup and Image Loading
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from sam2.sam2_image_predictor import SAM2ImagePredictor

# Initialize predictor with downloaded checkpoint
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Use the locally downloaded checkpoint
checkpoint_path = "./checkpoints/sam2.1_hiera_tiny.pt"

try:
    # Load from local checkpoint instead of from_pretrained
    from sam2.build_sam import build_sam2
    from sam2.sam2_image_predictor import SAM2ImagePredictor
    
    # Build SAM2 model from checkpoint
    sam2_model = build_sam2("sam2.1_hiera_t.yaml", checkpoint_path, device=device)
    predictor = SAM2ImagePredictor(sam2_model)
    print("✓ SAM2 predictor loaded successfully from local checkpoint")
except Exception as e:
    print(f"✗ Error loading predictor from checkpoint: {e}")
    # Fallback to from_pretrained if available
    try:
        predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2.1-hiera-tiny")
        print("✓ SAM2 predictor loaded successfully (fallback)")
    except Exception as e2:
        print(f"✗ Fallback also failed: {e2}")
        predictor = None

# Load and display image
image_path = "/Users/adamaslan/code/ai-fin-opt2/ai-fin3/AAPL_band_anomalies_1-hour.png"

try:
    input_image_pil = Image.open(image_path).convert("RGB")
    input_image_np = np.array(input_image_pil)
    print(f"✓ Image loaded: {image_path}")
    print(f"Image dimensions: {input_image_np.shape}")
    
    # Display the image
    plt.figure(figsize=(12, 8))
    plt.imshow(input_image_pil)
    plt.title("Original Financial Chart - Click coordinates for segmentation")
    plt.axis('on')  # Keep axis to see coordinates
    plt.grid(True, alpha=0.3)
    plt.show()
    
except Exception as e:
    print(f"✗ Error loading image: {e}")
    input_image_pil = None
    input_image_np = None


In [None]:
# 15 from huggingface_hub import hf_hub_download

checkpoint_path = hf_hub_download(
    repo_id="facebook/sam2.1-hiera-tiny",
    filename="sam2.1_hiera_tiny.pt",
    local_dir="./checkpoints"
)


In [None]:
# 17
f predictor:
    # 1. Load your image
    # Replace with the actual path to one of your PNG files
    image_path = "/Users/adamaslan/code/ai-fin-opt2/ai-fin3/AAPL_band_anomalies_1-hour.png" 
    try:
        input_image_pil = Image.open(image_path).convert("RGB")
        print(f"Successfully loaded image: {image_path}")
        
        # Convert PIL Image to a NumPy array if needed by set_image, or pass PIL image directly
        # predictor.set_image usually expects a PyTorch tensor in CHW format (Channels, Height, Width)
        # or sometimes a BGR numpy array. Check the specific requirements of SAM2ImagePredictor.set_image.
        # For now, let's assume it can handle a PIL image or we'll adapt it.

    except FileNotFoundError:
        print(f"Error: Image file not found at {image_path}")
        input_image_pil = None
    except Exception as e:
        print(f"Error loading image: {e}")
        input_image_pil = None

    if input_image_pil:
        with torch.inference_mode(), torch.autocast(device, dtype=torch.bfloat16 if device != 'cpu' else torch.float32):
            # 2. Set the image in the predictor
            # This might involve preprocessing the PIL image to the format expected by the model
            # (e.g., converting to tensor, normalizing, etc.)
            # The exact preprocessing depends on the SAM2ImagePredictor implementation details.
            # For simplicity, we'll assume set_image can handle a PIL image or a basic numpy array.
            # You may need to consult the SAM2 documentation for precise input format.
            try:
                # Example: Convert PIL to NumPy array (H, W, C)
                # input_image_np = np.array(input_image_pil)
                # predictor.set_image(input_image_np) 
                
                # The `set_image` method in the official SAM/SAM2 examples often takes a BGR numpy array.
                # Let's try to provide it in a way that's commonly used, but you might need to adjust.
                bgr_image = np.array(input_image_pil)[:, :, ::-1] # RGB to BGR
                predictor.set_image(bgr_image)
                print("Image set in predictor.")

                # 3. Define your input prompts (THIS IS THE CRITICAL PART YOU NEED TO CUSTOMIZE)
                # Prompts can be points, boxes, or masks.
                # For Bollinger Band squeezes, you might use bounding boxes.
                # Example: A bounding box [x_min, y_min, x_max, y_max]
                # You'll need to determine these coordinates by inspecting your image.
                # input_prompts = {
                #     "point_coords": None, # e.g., [[[x1, y1], [x2, y2]]]
                #     "point_labels": None, # e.g., [[1, 0]] (1 for foreground, 0 for background)
                #     "box": np.array([[100, 200, 300, 250]]), # Example: [[x_min, y_min, x_max, y_max]]
                #     "mask_input": None
                # }
                
                # Placeholder: You MUST define appropriate prompts for your image and task.
                # For example, to segment a squeeze visually identified at coordinates (x1,y1) to (x2,y2)
                # on the 'AAPL_band_anomalies_1-hour.png' image:
                # You would need to find these pixel coordinates by looking at the image.
                # Let's assume a hypothetical squeeze box for demonstration:
                # This is an EXAMPLE, replace with actual coordinates from your image inspection.
                squeeze_box_example = np.array([[50, 100, 200, 150]]) # [xmin, ymin, xmax, ymax]
                
                input_prompts = {"box": squeeze_box_example}
                print(f"Using example prompt box: {squeeze_box_example}")

                # 4. Make the prediction
                masks, scores, logits = predictor.predict(
                    point_coords=input_prompts.get("point_coords"),
                    point_labels=input_prompts.get("point_labels"),
                    box=input_prompts.get("box"),
                    mask_input=input_prompts.get("mask_input"),
                    multimask_output=True # Get multiple masks if available
                )
                
                print(f"Prediction complete. Number of masks generated: {len(masks)}")
                # masks is a NumPy array of shape (num_masks, height, width)
                # scores is a NumPy array of shape (num_masks,)

                # You can then visualize these masks, e.g., using matplotlib
                # import matplotlib.pyplot as plt
                # for i, mask in enumerate(masks):
                #     plt.figure()
                #     plt.imshow(input_image_pil)
                #     plt.imshow(mask, alpha=0.5, cmap='jet')
                #     plt.title(f"Mask {i+1}, Score: {scores[i]:.2f}")
                #     plt.axis('off')
                # plt.show()

            except Exception as e:
                print(f"Error during prediction: {e}")
else:
    print("Predictor not loaded. Cannot proceed with image processing and prediction.")
