In [None]:
import os
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import matplotlib.pyplot as plt
import json
import re
import torchvision.models as models
from pathlib import Path # Use pathlib for better path handling
import albumentations as A # For data augmentation
from albumentations.pytorch import ToTensorV2
import warnings
import time
import segmentation_models_pytorch as smp

# --- Constants ---
# Updated for B3 dataset with 9 keypoints
KEYPOINTS_NAMES = ["wither", "pinbone", "shoulderbone", "front_girth_bottom", "front_girth_top", 
                   "Height_bottom", "Height_top", "rear_girth_bottom", "rear_girth_top"]
NUM_KEYPOINTS = len(KEYPOINTS_NAMES)  # Now 9 instead of 6
TARGET_SIZE = (224, 224) # Standard size for many pre-trained models
BATCH_SIZE = 64 # Adjusted batch size (tune based on GPU memory)
EPOCHS = 100 # Adjusted epochs (tune based on convergence)
LEARNING_RATE = 0.004 # Adjusted learning rate (tune)
WEIGHT_DECAY = 0.01 # Weight decay for AdamW
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ImageNet mean/std for normalization with pre-trained models
SEGMENTATION_INPUT_SIZE = (512, 512)
NORM_MEAN = [0.485, 0.456, 0.406]
NORM_STD = [0.229, 0.224, 0.225]
KERAS_MODEL_PATH = r'best_keypoints_model_9pts_limited.keras' # Path to your pre-trained Keras keypoint model
BEST_MODEL_SAVE_PATH = r'best_enhanced_triple_cattle_weight_model_res18_100ep_2600p.pth'
BEST_TRAIN_MODEL = r'best_enhanced_triple_cattle_weight_model_res18_100ep_2600p.pth'
SEGMENTATION_MODEL_PATH = r'best_cattle_segmentation_model_6.keras'


# --- Helper Function to Create Segmentation Model ---
def create_model(num_classes=3):
    """Create a UNet model with a ResNet18 backbone"""
    try:
        model = smp.Unet(
            encoder_name="resnet18",
            encoder_weights="imagenet",
            in_channels=3,
            classes=num_classes,
        )
        print("Segmentation model created successfully")
        print(f"Model parameters: {sum(p.numel() for p in model.parameters())}")
        return model
    except Exception as e:
        print(f"Error creating segmentation model: {e}")
        raise

# --- Keras Model Loading (for inference keypoint prediction) ---
# Note: This introduces a dependency on TensorFlow/Keras for the inference part
# If you replace keypoint prediction with a PyTorch model or use ground truth, this can be removed.
try:
    from tensorflow.keras.models import load_model
    if Path(KERAS_MODEL_PATH).is_file():
        keypoint_detector_model = load_model(KERAS_MODEL_PATH)
        print(f"Successfully loaded Keras keypoint model from {KERAS_MODEL_PATH}")
    else:
        keypoint_detector_model = None
        warnings.warn(f"Keras keypoint model not found at {KERAS_MODEL_PATH}. "
                      f"`get_keypoints_from_image` will return None.")
except ImportError:
    warnings.warn("TensorFlow/Keras not installed. `get_keypoints_from_image` will not function.")
    keypoint_detector_model = None
except Exception as e:
    warnings.warn(f"Error loading Keras model from {KERAS_MODEL_PATH}: {e}")
    keypoint_detector_model = None


# --- PyTorch Segmentation Model Loading ---
try:
    segmentation_model = create_model(num_classes=3)
    if Path(SEGMENTATION_MODEL_PATH).is_file():
        # Correctly load the state dictionary for a PyTorch model
        segmentation_model.load_state_dict(torch.load(SEGMENTATION_MODEL_PATH, map_location=DEVICE))
        segmentation_model.to(DEVICE)
        segmentation_model.eval()  # Set to evaluation mode
        print(f"Successfully loaded PyTorch segmentation model from {SEGMENTATION_MODEL_PATH}")
    else:
        segmentation_model = None
        warnings.warn(f"Segmentation model not found at {SEGMENTATION_MODEL_PATH}. Segmentation will not be available.")
except Exception as e:
    warnings.warn(f"Error loading segmentation model: {e}")
    segmentation_model = None


# --- Helper Functions ---

def find_segmented_filename(original_filename: str, segmented_dir: Path) -> str | None:
    """Find the corresponding segmented image by matching the base part of the filename"""
    base_name = Path(original_filename).stem # Get filename without extension
    # Look for files starting with the base name (more robust)
    for f in segmented_dir.glob(f"{base_name}*"):
        return f.name # Return the filename string
    # Fallback: try matching just the prefix if separated by '.' or '_'
    parts = base_name.split('_')
    if len(parts) > 1:
       prefix = parts[0]
       for f in segmented_dir.glob(f"{prefix}*"):
           return f.name
    parts = base_name.split('.')
    if len(parts) > 1:
       prefix = parts[0]
       for f in segmented_dir.glob(f"{prefix}*"):
           return f.name
    return None

def load_keypoints_data(json_annotation_path: Path) -> dict[str, list[float]]:
    """Load keypoints data (x, y) from COCO-style JSON annotation file"""
    if not json_annotation_path.is_file():
        raise FileNotFoundError(f"Annotation file not found: {json_annotation_path}")
    with open(json_annotation_path, 'r') as f:
        data = json.load(f)

    keypoints_data = {}
    image_id_to_filename = {img['id']: img['file_name'] for img in data['images']}

    # Debug info to help understand the data structure
    print(f"JSON contains {len(data.get('images', []))} images and {len(data.get('annotations', []))} annotations")
    
    if 'categories' in data:
        for cat in data['categories']:
            if 'keypoints' in cat:
                print(f"Category {cat['name']} has {len(cat['keypoints'])} keypoints: {cat['keypoints']}")
    
    # Check if NUM_KEYPOINTS matches the data
    expected_keypoints_length = NUM_KEYPOINTS * 3  # x, y, visibility for each keypoint
    
    for annotation in data['annotations']:
        image_id = annotation['image_id']
        if image_id in image_id_to_filename:
            file_name = image_id_to_filename[image_id]
            keypoints = annotation['keypoints'] # Format: [x1, y1, v1, x2, y2, v2, ...]
            
            # Debug for first few annotations to understand format
            if len(keypoints_data) < 5:
                print(f"Sample keypoints for {file_name}: {keypoints[:6]}...")
            
            # Check if keypoints length matches our expectation
            if len(keypoints) != expected_keypoints_length:
                print(f"Warning: {file_name} has {len(keypoints)} values, expected {expected_keypoints_length} (x,y,v for {NUM_KEYPOINTS} keypoints)")
            
            # Only retain x and y values for keypoints
            keypoints_2d = []
            for i in range(0, min(len(keypoints), expected_keypoints_length), 3):
                if keypoints[i+2] > 0:  # If visibility > 0 (keypoint is visible)
                    keypoints_2d.append(keypoints[i])   # x
                    keypoints_2d.append(keypoints[i+1]) # y
                else:
                    # For invisible points, still include them
                    keypoints_2d.append(keypoints[i])   # x
                    keypoints_2d.append(keypoints[i+1]) # y
            
            # Check if we have all keypoints
            if len(keypoints_2d) == NUM_KEYPOINTS * 2:
                keypoints_data[file_name] = keypoints_2d
            else:
                warnings.warn(f"Image {file_name} has {len(keypoints_2d)//2} keypoints, expected {NUM_KEYPOINTS}. Skipping.")

    print(f"Successfully loaded keypoints for {len(keypoints_data)} images.")
    return keypoints_data

def create_filename_mapping(original_dir: Path, segmented_dir: Path) -> dict[str, str]:
    """Create a dictionary mapping original filenames to their corresponding segmented filenames"""
    mapping = {}
    original_files = [f.name for f in original_dir.glob('*') if f.is_file()]
    print(f"Found {len(original_files)} files in original directory.")

    count_found = 0
    for orig_file in tqdm(original_files, desc="Mapping filenames"):
        seg_file = find_segmented_filename(orig_file, segmented_dir)
        if seg_file:
            mapping[orig_file] = seg_file
            count_found += 1

    print(f"Created mapping for {count_found} out of {len(original_files)} original files.")
    if mapping:
        sample_items = list(mapping.items())[:5]
        print("Mapping examples:")
        for orig, seg in sample_items:
            print(f"  {orig} -> {seg}")
    else:
         print("No mappings created. Check paths and filename patterns.")
    return mapping

# --- Data Augmentation and Dataset ---

def get_transforms(is_train: bool = True) -> A.Compose:
    """Get Albumentations transforms for training or validation"""
    if is_train:
        transform = A.Compose([
            A.HorizontalFlip(p=0.5),
            A.Rotate(limit=15, p=0.3, border_mode=cv2.BORDER_CONSTANT, value=0),
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
            A.ShiftScaleRotate(shift_limit=0.06, scale_limit=0.1, rotate_limit=0, p=0.3, border_mode=cv2.BORDER_CONSTANT, value=0), # No rotation here, handled above
            A.Resize(height=TARGET_SIZE[0], width=TARGET_SIZE[1]),
            A.Normalize(mean=NORM_MEAN, std=NORM_STD),
            ToTensorV2(), # Converts image to tensor C, H, W and scales to [0, 1] if not already done by Normalize
        ], keypoint_params=A.KeypointParams(format='xy', label_fields=[], remove_invisible=False)) # Keypoints are (x, y)
    else: # Validation/Test: Only resize and normalize
        transform = A.Compose([
            A.Resize(height=TARGET_SIZE[0], width=TARGET_SIZE[1]),
            A.Normalize(mean=NORM_MEAN, std=NORM_STD),
            ToTensorV2(),
        ], keypoint_params=A.KeypointParams(format='xy', label_fields=[], remove_invisible=False))
    return transform

class TripleInputCattleDataset(Dataset):
    def __init__(self, original_dir: Path, segmented_dir: Path, filenames: list[str],
                 weights: list[float], keypoints_data: dict[str, list[float]],
                 filename_map: dict[str, str], transform: A.Compose):
        self.original_dir = original_dir
        self.segmented_dir = segmented_dir
        self.filenames = filenames
        self.weights = np.array(weights) # Use numpy array for easier indexing
        self.keypoints_data = keypoints_data
        self.filename_map = filename_map
        self.transform = transform

    def __len__(self) -> int:
        return len(self.filenames)

    def load_image(self, img_path: Path) -> np.ndarray | None:
        if not img_path.is_file():
            warnings.warn(f"Image file not found: {img_path}")
            return None
        img = cv2.imread(str(img_path))
        if img is None:
             warnings.warn(f"Failed to load image: {img_path}")
             return None
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # Convert to RGB for transforms/models
        return img

    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] | None:
        orig_filename = self.filenames[idx]
        weight = self.weights[idx]

        # --- Get image paths ---
        orig_path = self.original_dir / orig_filename
        seg_filename = self.filename_map.get(orig_filename)
        if not seg_filename:
             warnings.warn(f"No segmented mapping for {orig_filename}. Skipping item.")
             # To make DataLoader work, we need to return something of the expected structure,
             # but it will likely cause errors downstream or be filtered. Best to filter upstream.
             # For simplicity, returning zero tensors, but filtering is better.
             dummy_img = torch.zeros((3, TARGET_SIZE[0], TARGET_SIZE[1]), dtype=torch.float32)
             dummy_kpts = torch.zeros(NUM_KEYPOINTS * 2, dtype=torch.float32)
             dummy_weight = torch.tensor(0.0, dtype=torch.float32)
             return dummy_img, dummy_img, dummy_kpts, dummy_weight

        seg_path = self.segmented_dir / seg_filename

        # --- Load images ---
        orig_image = self.load_image(orig_path)
        seg_image = self.load_image(seg_path)

        if orig_image is None or seg_image is None:
            # Handle case where image loading failed
            warnings.warn(f"Failed loading images for {orig_filename}. Skipping item.")
            dummy_img = torch.zeros((3, TARGET_SIZE[0], TARGET_SIZE[1]), dtype=torch.float32)
            dummy_kpts = torch.zeros(NUM_KEYPOINTS * 2, dtype=torch.float32)
            dummy_weight = torch.tensor(0.0, dtype=torch.float32)
            return dummy_img, dummy_img, dummy_kpts, dummy_weight

        # --- Load keypoints ---
        keypoints_list_xy = [] # List of [x, y] tuples/lists
        if orig_filename in self.keypoints_data:
            kp_flat = self.keypoints_data[orig_filename]
            for i in range(0, len(kp_flat), 2):
                  keypoints_list_xy.append([kp_flat[i], kp_flat[i+1]]) # Use list format for albumentations
        else:
            # Handle missing keypoints if necessary, e.g., fill with zeros or center point
            # For training, it's better to filter out samples without keypoints beforehand
            warnings.warn(f"No keypoints found for {orig_filename} in keypoints_data. Using zeros.")
            keypoints_list_xy = [[0.0, 0.0]] * NUM_KEYPOINTS


        # --- Apply transformations ---
        # Apply transform to original image and its keypoints
        # Augmentations require keypoints relative to the image being augmented
        try:
            transformed = self.transform(image=orig_image, keypoints=keypoints_list_xy)
            orig_tensor = transformed['image']
            transformed_keypoints = transformed['keypoints'] # Keypoints are now relative to the transformed (resized) image

            # For simplicity, we apply the same resize/norm transform to the segmented image.
            # For complex *random* train transforms, more care is needed to ensure geometric consistency.
            transformed_seg = self.transform(image=seg_image, keypoints=[]) # Apply same resize/norm
            seg_tensor = transformed_seg['image']

        except Exception as e:
            warnings.warn(f"Error during augmentation for {orig_filename}: {e}. Skipping.")
            dummy_img = torch.zeros((3, TARGET_SIZE[0], TARGET_SIZE[1]), dtype=torch.float32)
            dummy_kpts = torch.zeros(NUM_KEYPOINTS * 2, dtype=torch.float32)
            dummy_weight = torch.tensor(0.0, dtype=torch.float32)
            return dummy_img, dummy_img, dummy_kpts, dummy_weight


        # --- Normalize Keypoints ---
        # Transformed keypoints are already relative to the output size (TARGET_SIZE)
        # Normalize them to [0, 1] range based on TARGET_SIZE
        keypoints_norm_flat = []
        for x, y in transformed_keypoints:
            norm_x = x / TARGET_SIZE[1] # width is target_size[1]
            norm_y = y / TARGET_SIZE[0] # height is target_size[0]
            keypoints_norm_flat.extend([norm_x, norm_y])

        # Ensure correct length and clip values [0, 1]
        if len(keypoints_norm_flat) != NUM_KEYPOINTS * 2:
             warnings.warn(f"Keypoint length mismatch after transform for {orig_filename} ({len(keypoints_norm_flat)}). Padding with zeros.")
             keypoints_norm_flat.extend([0.0] * (NUM_KEYPOINTS * 2 - len(keypoints_norm_flat)))
             keypoints_norm_flat = keypoints_norm_flat[:NUM_KEYPOINTS * 2] # Truncate if too long

        keypoints_norm_flat = [np.clip(v, 0.0, 1.0) for v in keypoints_norm_flat]
        keypoints_tensor = torch.tensor(keypoints_norm_flat, dtype=torch.float32)

        # --- Weight Tensor ---
        weight_tensor = torch.tensor(weight, dtype=torch.float32)

        return orig_tensor, seg_tensor, keypoints_tensor, weight_tensor

# --- Enhanced Model Architecture ---

class EnhancedTripleInputCattleWeightCNN(nn.Module):
    def __init__(self, num_keypoints: int = NUM_KEYPOINTS, pretrained: bool = True):
        super().__init__()

        # --- Original Image Branch (ResNet18) ---
        self.orig_backbone = models.resnet18(weights=models.ResNet18_Weights.DEFAULT if pretrained else None)
        num_ftrs_orig = self.orig_backbone.fc.in_features
        self.orig_backbone.fc = nn.Identity() # Remove final classification layer
        self.orig_bn = nn.BatchNorm1d(num_ftrs_orig)

        # --- Segmented Image Branch (ResNet18) ---
        self.seg_backbone = models.resnet18(weights=models.ResNet18_Weights.DEFAULT if pretrained else None)
        num_ftrs_seg = self.seg_backbone.fc.in_features
        self.seg_backbone.fc = nn.Identity()
        self.seg_bn = nn.BatchNorm1d(num_ftrs_seg)

        # --- Keypoints Branch ---
        keypoints_input_dim = num_keypoints * 2 # x, y for each keypoint
        self.keypoints_fc = nn.Sequential(
            nn.Linear(keypoints_input_dim, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True)
        )
        keypoints_feature_dim = 256

        # --- Combined Features Processing ---
        combined_input_dim = num_ftrs_orig + num_ftrs_seg + keypoints_feature_dim
        self.combined_fc = nn.Sequential(
            nn.BatchNorm1d(combined_input_dim), # BN before first linear layer
            nn.Linear(combined_input_dim, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Linear(512, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Linear(128, 1) # Output layer for weight regression
        )


    def forward(self, orig_img: torch.Tensor, seg_img: torch.Tensor, keypoints: torch.Tensor) -> torch.Tensor:
        # Process images
        orig_features = self.orig_backbone(orig_img)
        orig_features = self.orig_bn(orig_features)

        seg_features = self.seg_backbone(seg_img)
        seg_features = self.seg_bn(seg_features)

        # Process keypoints
        keypoints_features = self.keypoints_fc(keypoints)

        # Concatenate features
        combined_features = torch.cat((orig_features, seg_features, keypoints_features), dim=1)

        # Final prediction
        weight = self.combined_fc(combined_features)
        return weight

# --- Training Function ---

def train_model(model: nn.Module, train_loader: DataLoader, val_loader: DataLoader,
                epochs: int, lr: float, weight_decay: float, device: torch.device,
                model_save_path: Path):
    """Trains the model and saves the best version."""
    criterion = nn.HuberLoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=lr * 0.4)

    train_losses, val_losses, val_mae_list, val_rmse_list = [], [], [], []
    best_val_rmse = float('inf')

    print(f"Starting training for {epochs} epochs on {device}...")

    for epoch in range(epochs):
        # --- Training Phase ---
        model.train()
        running_loss = 0.0
        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]", ncols=100)

        for batch in train_pbar:
            if batch is None: continue
            orig_imgs, seg_imgs, keypoints, weights = batch
            if keypoints.nelement() == 0: continue

            orig_imgs, seg_imgs, keypoints, weights = (
                orig_imgs.to(device),
                seg_imgs.to(device),
                keypoints.to(device),
                weights.to(device).view(-1, 1),
            )

            optimizer.zero_grad()
            outputs = model(orig_imgs, seg_imgs, keypoints)
            loss = criterion(outputs, weights)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            running_loss += loss.item()
            train_pbar.set_postfix(loss=f"{loss.item():.4f}")

        epoch_train_loss = running_loss / len(train_loader)
        train_losses.append(epoch_train_loss)

        # --- Validation Phase ---
        model.eval()
        val_loss = 0.0
        all_preds, all_targets = [], []
        val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} [Valid]", ncols=100)

        with torch.no_grad():
            for batch in val_pbar:
                if batch is None: continue
                orig_imgs, seg_imgs, keypoints, weights = batch
                if keypoints.nelement() == 0: continue

                orig_imgs, seg_imgs, keypoints, weights = (
                    orig_imgs.to(device),
                    seg_imgs.to(device),
                    keypoints.to(device),
                    weights.to(device).view(-1, 1),
                )

                outputs = model(orig_imgs, seg_imgs, keypoints)
                loss = criterion(outputs, weights)
                val_loss += loss.item()
                val_pbar.set_postfix(loss=f"{loss.item():.4f}")

                all_preds.extend(outputs.cpu().numpy().flatten())
                all_targets.extend(weights.cpu().numpy().flatten())

        epoch_val_loss = val_loss / len(val_loader)
        val_losses.append(epoch_val_loss)

        # Calculate validation metrics
        mae = mean_absolute_error(all_targets, all_preds)
        rmse = np.sqrt(mean_squared_error(all_targets, all_preds))
        r2 = r2_score(all_targets, all_preds)
        val_mae_list.append(mae)
        val_rmse_list.append(rmse)

        scheduler.step()

        print(f"Epoch [{epoch+1}/{epochs}], Train Loss: {epoch_train_loss:.4f}, Val Loss: {epoch_val_loss:.4f}, "
              f"Val MAE: {mae:.2f} kg, Val RMSE: {rmse:.2f} kg, Val R2: {r2:.3f}, LR: {optimizer.param_groups[0]['lr']:.6f}")

        # Save best model based on validation RMSE
        if rmse < best_val_rmse:
            best_val_rmse = rmse
            torch.save(model.state_dict(), model_save_path)
            print(f"  => Saved best model to {model_save_path} (RMSE: {best_val_rmse:.2f} kg)")

    print(f"Training completed. Best validation RMSE: {best_val_rmse:.2f} kg")

    # --- Plot training history ---
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Training Loss (Huber)')
    plt.plot(val_losses, label='Validation Loss (Huber)')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training and Validation Loss')

    plt.subplot(1, 2, 2)
    plt.plot(val_mae_list, label='Validation MAE')
    plt.plot(val_rmse_list, label='Validation RMSE')
    plt.xlabel('Epochs')
    plt.ylabel('Error (kg)')
    plt.legend()
    plt.title('Validation Metrics')

    plt.tight_layout()
    plt.savefig('enhanced_training_history_B3.png')
    print("Saved training history plot to enhanced_training_history_B3.png")
    plt.close()

    return train_losses, val_losses, val_mae_list, val_rmse_list


def get_keypoints_from_image(image_path: Path, model=keypoint_detector_model, target_size: tuple[int, int]=TARGET_SIZE) -> np.ndarray | None:
    """
    Get keypoints from an image using the pre-trained Keras keypoint detection model.
    """
    print(f"\nPredicting keypoints for: {image_path}")
    
    if model is None:
        print("Error: Keras keypoint model not loaded. Cannot predict keypoints.")
        return None
        
    if not Path(image_path).is_file():
        print(f"Error: Image file not found: {image_path}")
        return None

    # Load the image
    try:
        image = cv2.imread(str(image_path))
        if image is None:
            print(f"Error: Failed to load image: {image_path}")
            return None
            
        orig_height, orig_width = image.shape[:2]
        print(f"Original image size: {orig_width}x{orig_height}")
    except Exception as e:
        print(f"Error loading image: {e}")
        return None

    # Prepare image for model input
    try:
        image_resized = cv2.resize(image, (target_size[1], target_size[0]))  # CV2 takes (width, height)
        image_input = image_resized / 255.0
        image_input = np.expand_dims(image_input, axis=0)
        print(f"Model input shape: {image_input.shape}")
    except Exception as e:
        print(f"Error preparing image for prediction: {e}")
        return None

    # Predict keypoints using the Keras model
    try:
        start_time = time.time()
        keypoints_pred = model.predict(image_input, verbose=0)
        elapsed = time.time() - start_time
        print(f"Prediction completed in {elapsed:.4f} sec")
        print(f"Model output shape: {keypoints_pred.shape}")
    except Exception as e:
        print(f"Error during Keras model prediction: {e}")
        return None

    # Process the keypoint predictions
    try:
        keypoints_relative_pixels = keypoints_pred[0].reshape(-1, 2)

        if len(keypoints_relative_pixels) != NUM_KEYPOINTS:
            print(f"Error: Expected {NUM_KEYPOINTS} kps, got {len(keypoints_relative_pixels)}")
            keypoints_relative_pixels = keypoints_relative_pixels[:NUM_KEYPOINTS]
            if len(keypoints_relative_pixels) < NUM_KEYPOINTS: return None

        # --- Correctly scale keypoints to original resolution ---
        target_h, target_w = target_size
        scale_x = orig_width / target_w if target_w > 0 else 1.0
        scale_y = orig_height / target_h if target_h > 0 else 1.0

        keypoints_original = keypoints_relative_pixels * np.array([scale_x, scale_y])

        # Clip coordinates to original image boundaries
        keypoints_original[:, 0] = np.clip(keypoints_original[:, 0], 0, orig_width - 1)
        keypoints_original[:, 1] = np.clip(keypoints_original[:, 1], 0, orig_height - 1)

        print(f"Predicted keypoints shape (original coords): {keypoints_original.shape}")
        print(f"First few keypoints (original coords): {keypoints_original[:3]}")

        return keypoints_original
        
    except Exception as e:
        print(f"Error processing keypoint predictions: {e}")
        return None

def segment_image_for_prediction(image_path: Path,
                                 model=segmentation_model, 
                                 device=DEVICE, 
                                 seg_target_size: tuple[int, int] = SEGMENTATION_INPUT_SIZE,
                                 debug_save_dir: Path = None) -> np.ndarray | None:
    """
    Segments an image using the loaded PyTorch segmentation model.
    """
    print(f"Segmenting image for prediction: {image_path}")
    debug_filename_base = image_path.stem if debug_save_dir else None

    if model is None:
        print("Segmentation model not loaded. Returning None.")
        return None
    
    if not Path(image_path).is_file():
        print(f"Error: Image file not found: {image_path}")
        return None

    # 1. Load Image
    try:
        image = cv2.imread(str(image_path))
        if image is None:
            print(f"Error: Failed to load image: {image_path}")
            return None
        orig_height, orig_width = image.shape[:2]
        orig_image_bgr = image.copy()
        image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        print(f"DEBUG: Original image dimensions: {orig_width}x{orig_height}")
    except Exception as e:
        print(f"Error loading image: {e}")
        return None

    # 2. Preprocess using Albumentations
    try:
        seg_transform = A.Compose([
            A.Resize(height=seg_target_size[0], width=seg_target_size[1]),
            A.Normalize(mean=NORM_MEAN, std=NORM_STD),
            ToTensorV2(),
        ])
        transformed = seg_transform(image=image_rgb)
        image_tensor = transformed['image'].unsqueeze(0).to(device)
        print(f"DEBUG: Tensor shape for segmentation model: {image_tensor.shape}")
    except Exception as e:
        print(f"Error during preprocessing for segmentation: {e}")
        return None

    # 3. Predict Segmentation Mask
    try:
        start_time = time.time()
        with torch.no_grad():
            output_logits = model(image_tensor) # (N, C, H, W) logits
        elapsed = time.time() - start_time
        print(f"Segmentation completed in {elapsed:.4f} sec")
        print(f"DEBUG: Segmentation model output shape (logits): {output_logits.shape}")

        # 4. Post-process using argmax
        pred_mask = torch.argmax(output_logits, dim=1).squeeze().cpu().numpy().astype(np.uint8)
        print(f"DEBUG: Mask shape after argmax: {pred_mask.shape}")
        print(f"DEBUG: Unique classes in predicted mask (at {seg_target_size}): {np.unique(pred_mask)}")

        if debug_save_dir and debug_filename_base:
            debug_save_dir.mkdir(parents=True, exist_ok=True)
            mask_viz = (pred_mask * (255 // max(1, np.max(pred_mask)) )).astype(np.uint8) if np.max(pred_mask) > 0 else pred_mask
            cv2.imwrite(str(debug_save_dir / f"{debug_filename_base}_class_mask_{seg_target_size[0]}x{seg_target_size[1]}.png"), mask_viz)
            print(f"DEBUG: Saved class mask to {debug_save_dir / f'{debug_filename_base}_class_mask_{seg_target_size[0]}x{seg_target_size[1]}.png'}")

        # 5. Create binary mask for the COW (class 1) and resize
        cow_class_index = 1 # Make sure this is the correct index for 'cow'
        cow_mask_binary = (pred_mask == cow_class_index).astype(np.uint8) * 255
        cow_mask_resized = cv2.resize(cow_mask_binary, (orig_width, orig_height), interpolation=cv2.INTER_NEAREST)
        print(f"DEBUG: Cow mask dimensions after resize: {cow_mask_resized.shape}")
        print(f"DEBUG: Unique values in final cow mask (0 or 255): {np.unique(cow_mask_resized)}")

        if debug_save_dir and debug_filename_base:
             cv2.imwrite(str(debug_save_dir / f"{debug_filename_base}_cow_binary_mask_{orig_width}x{orig_height}.png"), cow_mask_resized)
             print(f"DEBUG: Saved final binary cow mask")

        # 6. Apply mask to original image
        mask_3channel = cv2.cvtColor(cow_mask_resized, cv2.COLOR_GRAY2BGR)
        segmented_image = cv2.bitwise_and(orig_image_bgr, mask_3channel)

        if debug_save_dir and debug_filename_base:
            cv2.imwrite(str(debug_save_dir / f"{debug_filename_base}_segmented_cow.png"), segmented_image)
            print(f"DEBUG: Saved segmented cow image")

        return segmented_image # Return BGR image with mask applied

    except Exception as e:
        print(f"Error during segmentation or post-processing: {e}")
        import traceback
        traceback.print_exc()
        print("Returning None as a fallback.")
        return None

# --- Prediction Function ---

def predict_weight(model: nn.Module, orig_img_path: Path, seg_img_input: Path | np.ndarray | None,
                   keypoints_xy: np.ndarray, device: torch.device,
                   target_size: tuple[int, int] = TARGET_SIZE) -> float | None:
    model.eval()
    transform = get_transforms(is_train=False)

    try:
        orig_image = cv2.imread(str(orig_img_path))
        if orig_image is None:
            print(f"Prediction Error: Could not load original image {orig_img_path}")
            return None
        orig_image_rgb = cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB)
        orig_height, orig_width = orig_image.shape[:2]

        # --- Handle segmented input ---
        seg_image = None
        if isinstance(seg_img_input, Path): # If a path is provided
            seg_image = cv2.imread(str(seg_img_input))
            if seg_image is None:
                print(f"Warning: Could not load segmented image {seg_img_input}, attempting to segment on-the-fly...")
                seg_image = segment_image_for_prediction(orig_img_path)
        
        elif isinstance(seg_img_input, np.ndarray): # If an array is provided
             seg_image = seg_img_input
        
        elif seg_img_input is None: # If None, perform segmentation
            print("Segmented image not provided, performing automatic segmentation...")
            seg_image = segment_image_for_prediction(orig_img_path)
        
        if seg_image is None:
            print("Error: Could not obtain a segmented image. Weight prediction is not possible.")
            return None

        seg_image_rgb = cv2.cvtColor(seg_image, cv2.COLOR_BGR2RGB)

        # Apply transforms to images
        transformed_orig = transform(image=orig_image_rgb)
        orig_tensor = transformed_orig['image'].unsqueeze(0).to(device)

        transformed_seg = transform(image=seg_image_rgb)
        seg_tensor = transformed_seg['image'].unsqueeze(0).to(device)

    except Exception as e:
        print(f"Error loading/preprocessing images for weight prediction: {e}")
        return None

    try:
        if keypoints_xy is None:
             print("Prediction Error: Keypoints are None")
             return None

        if keypoints_xy.shape[0] != NUM_KEYPOINTS or keypoints_xy.shape[1] != 2:
             print(f"Prediction Error: Expected {NUM_KEYPOINTS} keypoints in (N, 2) format, got {keypoints_xy.shape}")
             # Attempt to fix shape, might not be robust
             keypoints_xy = keypoints_xy[:NUM_KEYPOINTS, :2]
             while len(keypoints_xy) < NUM_KEYPOINTS:
                   keypoints_xy = np.vstack([keypoints_xy, [0, 0]]) # Pad

        normalized_keypoints_flat = []
        for x, y in keypoints_xy:
            norm_x = x / orig_width
            norm_y = y / orig_height
            normalized_keypoints_flat.extend([norm_x, norm_y])

        normalized_keypoints_flat = [np.clip(v, 0.0, 1.0) for v in normalized_keypoints_flat]
        keypoints_tensor = torch.tensor(normalized_keypoints_flat, dtype=torch.float32).unsqueeze(0).to(device)
    except Exception as e:
        print(f"Error processing keypoints for weight prediction: {e}")
        return None

    # --- Predict Weight ---
    try:
        with torch.no_grad():
            predicted_weight = model(orig_tensor, seg_tensor, keypoints_tensor).item()
        return predicted_weight
    except Exception as e:
        print(f"Error during weight prediction: {e}")
        return None

def visualize_prediction(orig_img_path: Path, keypoints: np.ndarray, predicted_weight: float):
    """Visualize the image with predicted keypoints and predicted weight."""
    try:
        orig_image = cv2.imread(str(orig_img_path))
        orig_image_rgb = cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB)
        
        plt.figure(figsize=(8, 8))
        plt.imshow(orig_image_rgb)
        
        for kp in keypoints:
            plt.scatter(kp[0], kp[1], c='red', s=20, marker='o')
        
        plt.title(f"Predicted Weight: {predicted_weight:.2f} kg", fontsize=14)
        plt.axis('off')
        plt.show()
    except Exception as e:
        print(f"Error during visualization: {e}")

# --- Main Execution ---
def main():
    print(f"Using device: {DEVICE}")

    # --- Configuration ---
    original_dir = Path(r"C:\Users\andrey\.cache\kagglehub\datasets\sadhliroomyprime\cattle-weight-detection-model-dataset-12k\versions\3\www.acmeai.tech Dataset - BMGF-LivestockWeight-CV\Vector\B3\Side\data\images")
    segmented_dir = Path(r"C:\Users\andrey\.cache\kagglehub\datasets\sadhliroomyprime\cattle-weight-detection-model-dataset-12k\versions\3\www.acmeai.tech Dataset - BMGF-LivestockWeight-CV\Pixel\B3\annotations")
    csv_file = Path(r"cow_weight_data.csv")
    json_annotation_file = Path(r"coco_side_filtered.json")

    # --- Pre-checks ---
    if not original_dir.is_dir():
        print(f"Error: Original image directory not found: {original_dir}")
        return
    if not segmented_dir.is_dir():
        print(f"Error: Segmented image directory not found: {segmented_dir}")
        return
    if not csv_file.is_file():
        print(f"Error: CSV file not found: {csv_file}")
        return
    if not json_annotation_file.is_file():
        print(f"Error: JSON annotation file not found: {json_annotation_file}")
        return

    # --- Load Data ---
    print("Loading keypoints data from JSON file...")
    try:
        keypoints_data = load_keypoints_data(json_annotation_file)
        if not keypoints_data:
            print("Error: No keypoints loaded. Check JSON file format and content.")
            return
    except Exception as e:
        print(f"Error loading keypoints JSON: {e}")
        return

    print("Creating filename mapping between original and segmented images...")
    filename_map = create_filename_mapping(original_dir, segmented_dir)
    if not filename_map:
        print("Error: Could not map original to segmented filenames. Check directories and naming patterns.")
        return

    print(f"Loading weights from CSV: {csv_file}")
    try:
        df = pd.read_csv(csv_file)
        df = df.dropna(subset=['weight', 'filename'])
        df = df[df['weight'] > 0]
        print(f"Loaded {len(df)} rows from CSV after initial cleaning.")
    except Exception as e:
        print(f"Error loading or processing CSV file: {e}")
        return

    # --- Prepare Dataset ---
    filenames_all, weights_all = [], []
    print("Filtering dataset for available data (images, segmentation, keypoints)...")
    filtered_out_count = 0
    for _, row in tqdm(df.iterrows(), total=len(df), desc="Filtering Data"):
        filename = str(row['filename'])
        weight = float(row['weight'])

        if (filename not in filename_map or
            filename not in keypoints_data or
            not (original_dir / filename).is_file() or
            (filename_map.get(filename) and not (segmented_dir / filename_map[filename]).is_file())):
            filtered_out_count += 1
            continue
        
        filenames_all.append(filename)
        weights_all.append(weight)

    max_samples = 2700 
    filenames_all = filenames_all[:max_samples]
    weights_all = weights_all[:max_samples]

    print(f"Filtered dataset: {len(filenames_all)} samples remaining. ({filtered_out_count} removed).")
    if not filenames_all:
        print("Error: No valid samples found after filtering. Cannot proceed.")
        return

    # --- Split Data ---
    X_train, X_val, y_train, y_val = train_test_split(
        filenames_all, weights_all, test_size=0.2, random_state=42
    )
    print(f"Train samples: {len(X_train)}, Validation samples: {len(X_val)}")

    # --- Create Datasets and DataLoaders ---
    train_transform = get_transforms(is_train=True)
    val_transform = get_transforms(is_train=False)

    train_dataset = TripleInputCattleDataset(
        original_dir, segmented_dir, X_train, y_train, keypoints_data, filename_map, transform=train_transform
    )
    val_dataset = TripleInputCattleDataset(
        original_dir, segmented_dir, X_val, y_val, keypoints_data, filename_map, transform=val_transform
    )

    num_workers = 0 if os.name == 'nt' else 2
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers, pin_memory=True)

    # --- Initialize Model ---
    model = EnhancedTripleInputCattleWeightCNN(num_keypoints=NUM_KEYPOINTS, pretrained=True).to(DEVICE)

    # --- Train Model ---
    # Set to True to run training, False to skip and go to inference
    RUN_TRAINING = False
    if RUN_TRAINING:
       print("Starting model training...")
       train_model(
           model, train_loader, val_loader,
           epochs=EPOCHS, lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY,
           device=DEVICE, model_save_path=Path(BEST_MODEL_SAVE_PATH)
       )
    else:
       print("Skipping model training as per configuration.")

    # --- Inference Example ---
    print("\n--- Running Inference Example ---")

    if not Path(BEST_TRAIN_MODEL).is_file():
        print(f"Warning: Trained model file {BEST_TRAIN_MODEL} not found. Inference might not be meaningful.")
    else:
        print(f"Loading best model weights from {BEST_TRAIN_MODEL}")
        model.load_state_dict(torch.load(BEST_TRAIN_MODEL, map_location=DEVICE))
    
    model.eval()
    
    mae_list_inference, all_actual_weights_inf, all_predicted_weights_inf = [], [], []

    num_inference_samples = min(350, len(X_val))
    if num_inference_samples == 0:
        print("No samples available in validation set for inference.")
        return

    print("\n" + "*"*30)
    print("WARNING: Inference uses keypoints predicted by the separate Keras model.")
    print("This differs from training, which used ground truth keypoints from JSON.")
    print("Performance might differ.")
    print("*"*30 + "\n")

    for i in range(num_inference_samples):
        idx = np.random.randint(0, len(X_val))
        test_filename = X_val[idx]
        actual_weight = y_val[idx]
        orig_path = original_dir / test_filename
        
        print(f"\nPredicting for: {test_filename} (Sample {i+1}/{num_inference_samples})")
        print(f"Actual weight: {actual_weight:.2f} kg")

        # 1. Get keypoints
        predicted_keypoints_xy = get_keypoints_from_image(orig_path, model=keypoint_detector_model)
        if predicted_keypoints_xy is None:
            print("   -> Failed to predict keypoints for this image.")
            continue

        # 2. Predict weight, using automatic segmentation (seg_img_input=None)
        predicted_weight = predict_weight(
            model=model,
            orig_img_path=orig_path,
            seg_img_input=None, # Let predict_weight handle segmentation
            keypoints_xy=predicted_keypoints_xy,
            device=DEVICE
        )
        
        if predicted_weight is None:
            print("   -> Failed to predict weight for this image.")
            continue

        all_actual_weights_inf.append(actual_weight)
        all_predicted_weights_inf.append(predicted_weight)

        # Visualize the prediction
        vis_seg_image = segment_image_for_prediction(orig_path, model=segmentation_model, device=DEVICE, debug_save_dir=Path("inference_visualizations/segmentation_debug"))

        if vis_seg_image is not None:
            plt.figure(figsize=(12, 6))
            
            # Original image with keypoints
            plt.subplot(1, 2, 1)
            orig_image_vis = cv2.cvtColor(cv2.imread(str(orig_path)), cv2.COLOR_BGR2RGB)
            plt.imshow(orig_image_vis)
            for kp_idx, kp in enumerate(predicted_keypoints_xy):
                plt.scatter(kp[0], kp[1], c='red', s=30, marker='o')
                plt.text(kp[0] + 5, kp[1] + 5, KEYPOINTS_NAMES[kp_idx % NUM_KEYPOINTS], color='white', backgroundcolor='red', fontsize=7)
            plt.title(f"Original Image\nPredicted Weight: {predicted_weight:.2f} kg")
            plt.axis('off')
            
            # Segmented image
            plt.subplot(1, 2, 2)
            vis_seg_image_rgb = cv2.cvtColor(vis_seg_image, cv2.COLOR_BGR2RGB)
            plt.imshow(vis_seg_image_rgb)
            plt.title(f"Segmented Image\nActual Weight: {actual_weight:.2f} kg")
            plt.axis('off')
            
            plt.tight_layout()
            vis_save_dir = Path("inference_visualizations")
            vis_save_dir.mkdir(parents=True, exist_ok=True)
            plt.savefig(vis_save_dir / f"{test_filename}_prediction_visualization.png")
            print(f"   -> Saved individual visualization to {vis_save_dir / f'{test_filename}_prediction_visualization.png'}")
            plt.close()
        else:
            visualize_prediction(orig_path, predicted_keypoints_xy, predicted_weight)
        
        error = abs(predicted_weight - actual_weight)
        mae_list_inference.append(error)
        print(f"   Predicted weight: {predicted_weight:.2f} kg")
        print(f"   Error: {error:.2f} kg")

    if mae_list_inference:
        avg_mae_inf = sum(mae_list_inference) / len(mae_list_inference)
        print(f"\nAverage MAE on {len(mae_list_inference)} inference samples: {avg_mae_inf:.2f} kg")
    else:
        print("\nNo successful inference predictions were made.")

    # --- Actual vs. Predicted Scatter Plot ---
    if all_actual_weights_inf and all_predicted_weights_inf:
        plt.figure(figsize=(10, 8))
        plt.scatter(all_actual_weights_inf, all_predicted_weights_inf, alpha=0.7, edgecolor='k', label=f'Predictions (MAE: {avg_mae_inf:.2f} kg)' if mae_list_inference else 'Predictions')
        
        min_val = min(min(all_actual_weights_inf, default=0), min(all_predicted_weights_inf, default=0))
        data_max_val = max(max(all_actual_weights_inf, default=1), max(all_predicted_weights_inf, default=1))
        graph_max_val = min(data_max_val, 300) 
        
        plt.plot([min_val, graph_max_val], [min_val, graph_max_val], 'r--', lw=2, label='Ideal Prediction (y=x)')
        plt.xlim(min_val, graph_max_val)
        plt.ylim(min_val, graph_max_val)
        
        plt.xlabel("Actual Weight (kg)", fontsize=12)
        plt.ylabel("Predicted Weight (kg)", fontsize=12)
        plt.title(f"Actual vs. Predicted Weights on Inference Set ({len(all_actual_weights_inf)} samples)", fontsize=14)
        plt.legend(fontsize=10)
        plt.grid(True, linestyle='--', alpha=0.7)
        plt.tight_layout()
        
        scatter_plot_path = Path("inference_visualizations") / 'actual_vs_predicted_weights_scatter.png'
        scatter_plot_path.parent.mkdir(parents=True, exist_ok=True)
        plt.savefig(scatter_plot_path)
        print(f"\nSaved actual vs. predicted weights scatter plot to {scatter_plot_path}")
        plt.show()
        plt.close()
    else:
        print("\nNot enough data to generate actual vs. predicted weights scatter plot.")

if __name__ == "__main__":
    main()
