In [1]:
import sys
import os
sys.path.append(os.path.abspath("../src"))
sys.path.append(os.path.abspath("../"))

In [10]:
import os
import numpy as np
import shutil
import torch
from IPython.display import clear_output
import os
import glob
from tqdm import tqdm
from PIL import Image
from utils import load_tiff_image,load_mask,load_and_normalize_tiff
from config import (
    RAW_DATA_PATH,
    PROCESSED_DATA_PATH,
    PROCESSED_FULLY_CLOUDED_DATA_PATH,
    PROCESSED_FREE_CLOUDED_DATA_PATH,
    PROCESSED_PARTIALLY_CLOUDED_DATA_PATH,
    PROCESSED_FULLY_CLOUDED_MASK_PATH,
    PROCESSED_FREE_CLOUDED_MASK_PATH,
    PROCESSED_PARTIALLY_CLOUDED_MASK_PATH,
    PROCESSED_MISCLASSIFIED_DATA_PATH,
    PROCESSED_MISCLASSIFIED_MASK_PATH,
    create_dirs,
    REVIEW_DIR,
    MODEL_PATH
)
from visualization import plot_image_and_mask,plot_image_and_mask_and_prediction
from models.unet import UNet

# Set matplotlib to display inline
%matplotlib inline

In [4]:
create_dirs()
RELABEL = False

In [5]:
def calculate_cloud_coverage(mask):
    total_pixels = mask.size
    cloud_pixels = np.sum(mask > 0)
    return (cloud_pixels / total_pixels) * 100

def classify_cloud_coverage(coverage):
    if coverage == 0:
        return "cloud_free"
    elif coverage > 80:  # Adjust threshold as needed
        return "fully_clouded"
    else:
        return "partially_clouded"

In [6]:
def copy_to_processed(image_path, mask_path, category):
    # Determine destination paths based on category
    if category == "cloud_free":
        data_dest = PROCESSED_FREE_CLOUDED_DATA_PATH
        mask_dest = PROCESSED_FREE_CLOUDED_MASK_PATH
    elif category == "fully_clouded":
        data_dest = PROCESSED_FULLY_CLOUDED_DATA_PATH
        mask_dest = PROCESSED_FULLY_CLOUDED_MASK_PATH
    elif category == "partially_clouded":
        data_dest = PROCESSED_PARTIALLY_CLOUDED_DATA_PATH
        mask_dest = PROCESSED_PARTIALLY_CLOUDED_MASK_PATH
    else:
        data_dest = PROCESSED_MISCLASSIFIED_DATA_PATH
        mask_dest = PROCESSED_MISCLASSIFIED_MASK_PATH
        
    # Copy files
    shutil.copy2(image_path, os.path.join(data_dest, os.path.basename(image_path)))
    shutil.copy2(mask_path, os.path.join(mask_dest, os.path.basename(mask_path)))

In [7]:
def is_already_processed(image_filename):
    # Check all processed directories for the file
    for dir_path in [PROCESSED_FREE_CLOUDED_DATA_PATH, PROCESSED_FULLY_CLOUDED_DATA_PATH, PROCESSED_PARTIALLY_CLOUDED_DATA_PATH]:
        if os.path.exists(os.path.join(dir_path, image_filename)):
            return True
    return False

In [None]:
def validate_images(image_paths, mask_paths, model=None, device=None, samples_per_class=300, 
                   model_threshold=0.5, similarity_threshold=0.85):
    
    # Initialize counters
    counts = {
        "cloud_free": 0,
        "fully_clouded": 0,
        "partially_clouded": 0,
        "misclassified": 0  # For failed validations
    }
    
    # Check if model is available
    model_available = model is not None and device is not None
    
    # Initialize lists for each category
    cloud_free_images = [img for img, mask in zip(image_paths, mask_paths) 
                        if classify_cloud_coverage(calculate_cloud_coverage(load_mask(mask))) == "cloud_free"]
    partially_clouded_images = [img for img, mask in zip(image_paths, mask_paths) 
                              if classify_cloud_coverage(calculate_cloud_coverage(load_mask(mask))) == "partially_clouded"]
    fully_clouded_images = [img for img, mask in zip(image_paths, mask_paths) 
                          if classify_cloud_coverage(calculate_cloud_coverage(load_mask(mask))) == "fully_clouded"]

    # Process images in order
    categories = ["cloud_free", "partially_clouded", "fully_clouded"]
    category_lists = [cloud_free_images, partially_clouded_images, fully_clouded_images]
    
    while any(len(lst) > 0 for lst in category_lists):
        for i, category in enumerate(categories):
            if counts[category] >= samples_per_class:
                continue

            if not category_lists[i]:
                continue
            
            img_path = category_lists[i].pop(0)
            mask_path = mask_paths[image_paths.index(img_path)]
            
            if is_already_processed(os.path.basename(img_path)):
                continue

            # Load data
            image = load_tiff_image(img_path)
            true_mask = load_mask(mask_path)
            true_binary = (true_mask > 0).astype(np.uint8)
            coverage = calculate_cloud_coverage(true_mask)
            original_category = classify_cloud_coverage(coverage)
            
            # Automated validation attempt
            validated = False
            if model_available:
                try:
                    image_np = load_and_normalize_tiff(img_path)
                    image_tensor = torch.from_numpy(image_np).unsqueeze(0).to(device)
                    
                    with torch.no_grad():
                        output = model(image_tensor)
                    
                    pred_mask = torch.sigmoid(output)
                    pred_mask_np = pred_mask.squeeze(0).squeeze(0).cpu().numpy()
                    binary_pred_mask = (pred_mask_np > model_threshold).astype(np.uint8)
                    
                    # Calculate accuracy for full similarity
                    correct_pixels = np.sum(true_binary == binary_pred_mask)
                    accuracy = correct_pixels / true_mask.size
                    
                    if accuracy >= similarity_threshold:
                        copy_to_processed(img_path, mask_path, original_category)
                        counts[original_category] += 1
                        print(f"✅ Auto-validated {os.path.basename(img_path)} as {original_category} (Accuracy: {accuracy:.2f})")   
                    else: 
                        copy_to_processed(img_path, mask_path, "misclassified")
                        counts["misclassified"] += 1
                        print(f"❌ Moved to misclassified")
                    validated = True
                        
                except Exception as e:
                    print(f"⚠ Model failed on {img_path}: {str(e)}")
            
            # Manual validation if automated failed
            if not validated:
                plot_image_and_mask(image, true_mask, f"True: {original_category} ({coverage:.2f}%)")
                
                response = input(f"Validate as {original_category}? (y/n/skip/quit): ").lower()
                
                if response == 'quit':
                    print("Early termination. Current counts:", counts)
                    return counts
                elif response == 'skip':
                    continue
                elif response == 'y':
                    copy_to_processed(img_path, mask_path, original_category)
                    counts[original_category] += 1
                else:  # 'n' or any other input
                    copy_to_processed(img_path, mask_path, "misclassified")
                    counts["misclassified"] += 1
                    print(f"❌ Moved to misclassified")
            
            print(f"Current counts: {counts}")
            clear_output(wait=True)
            
            # Early exit if all categories are filled
            if all(counts[c] >= samples_per_class for c in categories):
                break

    print("\n=== Final Validation Counts ===")
    for k, v in counts.items():
        print(f"{k:>20}: {v}")
    
    return counts

In [11]:
# Load model if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = None

if os.path.exists(MODEL_PATH):
    try:
        model = UNet(n_channels=4, n_classes=1)
        model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
        model.to(device)
        model.eval()
        print("Model loaded successfully for automated validation")
    except Exception as e:
        print(f"Failed to load model, will use manual validation only: {str(e)}")
        model = None
else:
    print("No model found at specified path, using manual validation")

# Get all image and mask paths
image_files = sorted([f for f in os.listdir(os.path.join(RAW_DATA_PATH, "data")) if f.endswith(".tif")])
mask_files = sorted([f for f in os.listdir(os.path.join(RAW_DATA_PATH, "masks")) if f.endswith(".tif")])

# Verify we have matching pairs
assert len(image_files) == len(mask_files), "Mismatch between number of images and masks"

# Create full paths
image_paths = [os.path.join(RAW_DATA_PATH, "data", f) for f in image_files]
mask_paths = [os.path.join(RAW_DATA_PATH, "masks", f) for f in mask_files]

# Start validation process with model if available
validate_images(
    image_paths=image_paths,
    mask_paths=mask_paths,
    model=model,
    device=device,
    samples_per_class=10000,
    model_threshold=0.35,
    similarity_threshold=0.85   
)


=== Final Validation Counts ===
          cloud_free: 514
       fully_clouded: 4160
   partially_clouded: 4106
       misclassified: 1792


{'cloud_free': 514,
 'fully_clouded': 4160,
 'partially_clouded': 4106,
 'misclassified': 1792}

In [None]:
def relabel_misclassified(data_folder, mask_folder, model, device, model_threshold=0.5, 
                         confidence_threshold=0.3, agreement_threshold=0.2):
    # Get all misclassified files
    image_paths = sorted(glob.glob(os.path.join(data_folder, '*.tif')))
    mask_paths = sorted(glob.glob(os.path.join(mask_folder, '*.tif')))
    
    # Initialize counters
    results = {
        'auto_relabeled': 0,
        'needs_review': 0,
        'categories': {
            'cloud_free': 0,
            'partially_clouded': 0,
            'fully_clouded': 0
        }
    }
    
    for img_path, mask_path in tqdm(zip(image_paths, mask_paths), total=len(image_paths), desc="Processing"):
        # Load data
        image = load_and_normalize_tiff(img_path)
        true_mask = load_mask(mask_path)
        true_binary = (true_mask > 0).astype(np.uint8)
        
        # Get model prediction
        with torch.no_grad():
            pred_prob = model(torch.from_numpy(image).unsqueeze(0).to(device)).sigmoid().squeeze().cpu().numpy()
        
        pred_binary = (pred_prob > model_threshold).astype(np.uint8)
        
        # Calculate confidence metrics
        confidence = np.abs(pred_prob - 0.5).mean()
        pixel_agreement = np.mean(true_binary == pred_binary)
        
        # Determine if extreme case
        is_extreme = (confidence < confidence_threshold) or (pixel_agreement < agreement_threshold)
        
        if not is_extreme and RELABEL:
            # Auto-relabel confident cases
            new_category = classify_cloud_coverage(calculate_cloud_coverage(pred_binary))
            success = replace_and_move(img_path, mask_path, pred_binary, new_category)
            if success:
                results['auto_relabeled'] += 1
                results['categories'][new_category] += 1
        else:
            # Save case for manual review (with both masks)
            success = save_for_review(img_path, mask_path, true_mask, pred_binary)
            if success:
                results['needs_review'] += 1
    
    print("\n=== Results ===")
    print(f"Auto-relabeled: {results['auto_relabeled']}")
    print(f"Needs review: {results['needs_review']}")
    print("Category distribution:")
    for cat, count in results['categories'].items():
        print(f"  {cat}: {count}")
    
    return results

def replace_and_move(img_path, mask_path, new_mask, new_category):
    try:
        # Create target directories
        base_dir = PROCESSED_DATA_PATH
        new_img_dir = os.path.join(base_dir, 'data', new_category)
        new_mask_dir = os.path.join(base_dir, 'masks', new_category)
        
        # Get filenames
        img_name = os.path.basename(img_path)
        mask_name = os.path.basename(mask_path)
        
        # New paths
        new_img_path = os.path.join(new_img_dir, img_name)
        new_mask_path = os.path.join(new_mask_dir, mask_name)
        
        # Remove old mask if it exists
        if os.path.exists(mask_path):
            try:
                os.remove(mask_path)
            except Exception as e:
                print(f"Warning: Could not remove old mask {mask_path}: {str(e)}")
        
        # Move image and save new mask
        shutil.move(img_path, new_img_path)
        save_mask(new_mask, new_mask_path)
        
        return True
        
    except Exception as e:
        print(f"Error processing {img_path}: {str(e)}")
        # Attempt to undo partial operations
        if os.path.exists(new_img_path) and not os.path.exists(img_path):
            shutil.move(new_img_path, img_path)
        return False

def save_mask(mask, path):
    try:
        # Convert to 8-bit (0-255) if needed
        if mask.dtype != np.uint8:
            mask = (mask * 255).astype(np.uint8)
        
        # Ensure 2D array
        if len(mask.shape) == 3:
            mask = mask.squeeze()
        
        # Create and save image
        img = Image.fromarray(mask)
        img.save(path, format='TIFF')
        return True
    except Exception as e:
        print(f"Error saving mask to {path}: {str(e)}")
        return False

def save_for_review(img_path, mask_path, true_mask, pred_mask):
    try:
        base_name = os.path.splitext(os.path.basename(img_path))[0]
        
        # Move original files (not copy)
        new_img_path = os.path.join(REVIEW_DIR, f"{base_name}_image.tif")
        new_true_mask_path = os.path.join(REVIEW_DIR, f"{base_name}_true_mask.tif")
        
        shutil.move(img_path, new_img_path)
        shutil.move(mask_path, new_true_mask_path)
        
        # Save predicted mask (new file)
        save_mask(pred_mask, os.path.join(REVIEW_DIR, f"{base_name}_pred_mask.tif"))
        
        return True
        
    except Exception as e:
        print(f"Error moving to review queue {img_path}: {str(e)}")
        
        # Attempt to undo partial moves
        if os.path.exists(new_img_path) and not os.path.exists(img_path):
            shutil.move(new_img_path, img_path)
        if os.path.exists(new_true_mask_path) and not os.path.exists(mask_path):
            shutil.move(new_true_mask_path, mask_path)
            
        return False

def manual_review(review_dir):
    image_paths = sorted(glob.glob(os.path.join(review_dir, '*_image.tif')))
    
    for img_path in tqdm(image_paths, desc="Manual Review"):
        base_name = os.path.splitext(os.path.basename(img_path))[0].replace('_image', '')
        
        # Define all related files
        true_mask_path = os.path.join(review_dir, f"{base_name}_true_mask.tif")
        pred_mask_path = os.path.join(review_dir, f"{base_name}_pred_mask.tif")
        
        # Load all components
        image = load_tiff_image(img_path)
        true_mask = load_mask(true_mask_path)
        pred_mask = load_mask(pred_mask_path)
        
        # Calculate coverages
        true_coverage = calculate_cloud_coverage(true_mask)
        pred_coverage = calculate_cloud_coverage(pred_mask)
        
        plot_image_and_mask_and_prediction(image, true_mask, pred_mask, f"Reviewing {base_name}")
        
        # Get user decision
        while True:
            decision = input("Use which mask? (o=original, p=predicted, s=skip, q=quit): ").lower()
            
            if decision == 'q':
                return
            elif decision == 's':
                break
            elif decision in ['o', 'p']:
                # Determine which files to keep
                selected_mask = true_mask if decision == 'o' else pred_mask
                mask_to_delete = pred_mask_path if decision == 'o' else true_mask_path
                coverage = true_coverage if decision == 'o' else pred_coverage
                new_category = classify_cloud_coverage(coverage)
                
                # Set up destination
                base_dir = os.path.dirname(os.path.dirname(review_dir))
                new_img_dir = os.path.join(base_dir, 'data', new_category)
                new_mask_dir = os.path.join(base_dir, 'masks', new_category)
                os.makedirs(new_img_dir, exist_ok=True)
                os.makedirs(new_mask_dir, exist_ok=True)
                
                # Original filename without suffixes
                original_name = f"{base_name.split('_')[0]}.tif"
                new_img_path = os.path.join(new_img_dir, original_name)
                new_mask_path = os.path.join(new_mask_dir, original_name)
                
                try:
                    # Move selected files (not copy)
                    shutil.move(img_path, new_img_path)
                    save_mask(selected_mask, new_mask_path)
                    
                    # Delete unused mask file
                    if os.path.exists(mask_to_delete):
                        os.remove(mask_to_delete)
                    
                    # Delete the other mask file we didn't use
                    remaining_files = glob.glob(os.path.join(review_dir, f"{base_name}*"))
                    for f in remaining_files:
                        try:
                            os.remove(f)
                        except:
                            pass
                    
                    print(f"Moved to {new_category} using {'original' if decision == 'o' else 'predicted'} mask")
                    break
                    
                except Exception as e:
                    print(f"Error during file operations: {str(e)}")
                    # Try to undo partial operations
                    if os.path.exists(new_img_path) and not os.path.exists(img_path):
                        shutil.move(new_img_path, img_path)
                    break
            else:
                print("Invalid input. Use o/p/s/q")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(n_channels=4, n_classes=1).to(device)
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.eval()
# Auto-process misclassified images
results = relabel_misclassified(
    data_folder='../data/processed/data/misclassified/',
    mask_folder='../data/processed/masks/misclassified/',
    model=model,
    device=device
)

In [None]:
# Manual review if needed
if results['needs_review'] > 0:
    manual_review(REVIEW_DIR)