In [None]:
# 1. MOUNT GOOGLE DRIVE
from google.colab import drive
drive.mount('/content/drive')

# 2. INSTALL RASTERIO
# 1. UPGRADE LIBRARIES AND INSTALL FOLIUM
# 1. INSTALL/UPDATE LIBRARIES
!pip install -U pandas
!pip install folium
!pip install rasterio
!pip install python-docx

In [None]:
# 2. IMPORT LIBRARIES
import rasterio
import numpy as np
import ipywidgets as widgets
from ipywidgets import HBox, VBox, Layout, HTML
from IPython.display import display, clear_output
import pandas as pd
import datetime
import getpass
import folium
from sklearn.metrics import confusion_matrix, cohen_kappa_score
from google.colab import files
import docx
from docx.shared import Pt
import random

#==============================================================================
# A. CONFIGURATION - !!! EDIT THIS SECTION !!!
#==============================================================================

# --- Define the paths to your images in Google Drive ---
BASE_IMG_PATH = '/content/drive/My Drive/Accuracy_assessment/Singrauli_Merged_Image.tif'
CNN64_IMG_PATH = '/content/drive/My Drive/Accuracy_assessment/cnn64_them.img'  # <-- RENAMED for clarity
CNN10_IMG_PATH = '/content/drive/My Drive/Accuracy_assessment/cnn10_them.img'  # <-- NEW: Add path for the 3rd image
RF_IMG_PATH = '/content/drive/My Drive/Accuracy_assessment/rf_them.img'

# --- Define your class labels ---
CLASS_LABELS = {
    1: 'Water',
    2: 'Agriculture & Vegetation',
    3: 'Settlement',
    4: 'Mining',
    5: 'Barren Land',
    6: 'Forest',
}

#==============================================================================
# C. MODIFIED FUNCTION: STRATIFIED SAMPLING PLAN - (No changes needed here)
#==============================================================================
def create_stratified_sampling_plan(image_path, class_labels, points_per_class):
    """
    Generates an ORDERED, stratified sampling plan.
    Returns a list of (y, x, class_name) coordinates to be sampled, grouped by class.
    """
    print("Creating stratified sampling plan (grouped by class)...")
    sampling_plan = []
    with rasterio.open(image_path) as src:
        image = src.read(1)

        for class_code, class_name in class_labels.items():
            print(f"  - Finding pixels for class: {class_name}")
            rows, cols = np.where(image == class_code)

            if len(rows) < points_per_class:
                print(f"    Warning: Class '{class_name}' has only {len(rows)} pixels. Using all of them.")
                points_to_sample = len(rows)
            else:
                points_to_sample = points_per_class

            pixel_locations = list(zip(rows, cols))

            if points_to_sample > 0:
                chosen_points = random.sample(pixel_locations, points_to_sample)
                sampling_plan.extend([(p[0], p[1], class_name) for p in chosen_points])

    print("Sampling plan created successfully.")
    return sampling_plan

#==============================================================================
# C2. NEW FUNCTION: RANDOM SAMPLING PLAN
#==============================================================================
def create_random_sampling_plan(image_path, class_labels, total_points):
    """
    Generates a RANDOMLY sampled plan.
    Returns a list of (y, x, class_name) coordinates to be sampled.
    """
    print(f"Creating random sampling plan for {total_points} total points...")
    sampling_plan = []
    with rasterio.open(image_path) as src:
        image = src.read(1)
        height, width = src.height, src.width

        # Keep track of generated points to ensure uniqueness
        generated_points = set()

        # Loop until we have the desired number of valid points
        while len(sampling_plan) < total_points:
            # Generate a random coordinate
            y = random.randint(0, height - 1)
            x = random.randint(0, width - 1)

            # Check if this point has already been sampled
            if (y, x) in generated_points:
                continue # Skip and try again

            # Get the class code at this random location
            class_code = image[y, x]

            # Only include points that fall within one of our defined classes
            if class_code in class_labels:
                class_name = class_labels[class_code]
                sampling_plan.append((y, x, class_name))
                generated_points.add((y, x))

            # Safety break to prevent infinite loops on small/empty images
            if len(generated_points) > (height * width) / 2:
                print("Warning: Could not find enough valid points. Returning available points.")
                break

    print("Random sampling plan created successfully.")
    return sampling_plan

#==============================================================================
# D. ACCURACY REPORT GENERATION FUNCTION - (No changes needed here)
#==============================================================================
def generate_accuracy_report(df, model_name, classified_col, ref_col, class_labels, image_file_path):
    report_lines = []
    class_names = {0: "Background", **{k: v for k, v in class_labels.items()}}
    all_class_codes = sorted(list(class_labels.keys()))

    report_lines.append("CLASSIFICATION ACCURACY ASSESSMENT REPORT")
    report_lines.append("-----------------------------------------")
    report_lines.append(f"Image File : {image_file_path}")
    report_lines.append(f"User Name  : {getpass.getuser()}")
    report_lines.append(f"Date       : {datetime.datetime.now(datetime.timezone.utc).astimezone().strftime('%a %b %d %H:%M:%S %Y')}")
    report_lines.append("\n\n")

    y_true = df[ref_col]
    y_pred = df[classified_col]
    cm = confusion_matrix(y_true, y_pred, labels=all_class_codes)

    cm_df = pd.DataFrame(cm, index=[f"   {class_names[i]}" for i in all_class_codes],
                             columns=[f"   {class_names[i]}" for i in all_class_codes])

    report_lines.append("ERROR MATRIX")
    report_lines.append("------------")
    report_lines.append("                      Reference Data")
    report_lines.append("                      --------------")

    cm_string = cm_df.to_string(header=True, index=True)
    report_lines.append("Classified Data\n---------------")
    report_lines.append(cm_string)

    row_totals = cm.sum(axis=1)
    col_totals = cm.sum(axis=0)

    report_lines.append("\n" + pd.DataFrame({
        'Row Total': row_totals,
        'Column Total': col_totals
    }, index=[f"   {class_names[i]}" for i in all_class_codes]).T.to_string())

    report_lines.append("\n\n\t\t----- End of Error Matrix -----\n\n")

    report_lines.append("ACCURACY TOTALS")
    report_lines.append("----------------")

    num_correct = np.diag(cm)
    producers_accuracy = np.divide(num_correct, col_totals, out=np.zeros_like(num_correct, dtype=float), where=col_totals!=0)
    users_accuracy = np.divide(num_correct, row_totals, out=np.zeros_like(num_correct, dtype=float), where=row_totals!=0)

    acc_df = pd.DataFrame({
        'Class Name': [class_names[i] for i in all_class_codes],
        'Reference Totals': col_totals,
        'Classified Totals': row_totals,
        'Number Correct': num_correct,
        'Producers Accuracy': [f"{acc:.2%}" for acc in producers_accuracy],
        'Users Accuracy': [f"{acc:.2%}" for acc in users_accuracy]
    })

    report_lines.append(acc_df.to_string(index=False))

    total_correct = np.sum(num_correct)
    total_points = np.sum(cm)
    overall_accuracy = total_correct / total_points if total_points > 0 else 0
    report_lines.append(f"\n\nOverall Classification Accuracy =   {overall_accuracy:.2%}")
    report_lines.append("\n\n\t\t----- End of Accuracy Totals -----\n\n")

    report_lines.append("KAPPA (K^) STATISTICS")
    report_lines.append("---------------------")

    overall_kappa = cohen_kappa_score(y_true, y_pred, labels=all_class_codes)
    report_lines.append(f"Overall Kappa Statistics = {overall_kappa:.4f}\n")

    report_lines.append("Conditional Kappa for each Category.")
    report_lines.append("------------------------------------")
    kappa_data = []
    for i, code in enumerate(all_class_codes):
        observed = num_correct[i]
        chance = (row_totals[i] * col_totals[i]) / total_points if total_points > 0 else 0
        k_cond = (observed - chance) / (col_totals[i] - chance) if (col_totals[i] - chance) != 0 else 0
        kappa_data.append({'Class Name': class_names[code], 'Kappa': f"{k_cond:.4f}"})

    kappa_df = pd.DataFrame(kappa_data)
    report_lines.append(kappa_df.to_string(index=False))
    report_lines.append("\n\n\t\t----- End of Kappa Statistics -----\n")

    return report_lines


#==============================================================================
# E. SCRIPT LOGIC - Interactive Collection
#==============================================================================
reference_data = []
current_point = None

# --- MODIFIED: Open all datasets ---
try:
    base_dataset = rasterio.open(BASE_IMG_PATH)
    cnn64_dataset = rasterio.open(CNN64_IMG_PATH)
    cnn10_dataset = rasterio.open(CNN10_IMG_PATH) # <-- NEW
    rf_dataset = rasterio.open(RF_IMG_PATH)
except Exception as e:
    print(f"Error opening files: {e}")

output_container = widgets.Output()
progress_label = widgets.Label()

def on_button_clicked(b):
    global current_point
    reference_class = int(b.description.split(':')[0])

    # Original (y, x) from the sampling plan (based on the CNN-10 image)
    y_orig, x_orig = current_point

    # 1. Convert the original pixel coordinates to geospatial coordinates (lon, lat)
    #    We use the base_dataset for this conversion, assuming all images are co-registered.
    lon, lat = base_dataset.xy(y_orig, x_orig)

    # 2. Convert the (lon, lat) back to the specific pixel indices for EACH dataset
    y_cnn64, x_cnn64 = cnn64_dataset.index(lon, lat)
    y_cnn10, x_cnn10 = cnn10_dataset.index(lon, lat)
    y_rf, x_rf = rf_dataset.index(lon, lat)

    # 3. Read the class value from each model using its OWN correct coordinates
    cnn64_class = cnn64_dataset.read(1)[y_cnn64, x_cnn64]
    cnn10_class = cnn10_dataset.read(1)[y_cnn10, x_cnn10]
    rf_class = rf_dataset.read(1)[y_rf, x_rf]

    # --- Append all three model classes to the results ---
    reference_data.append({
        'ref_x': x_orig, 'ref_y': y_orig, 'reference_class': reference_class,
        'cnn64_class': cnn64_class,
        'cnn10_class': cnn10_class,
        'rf_class': rf_class
    })
    next_point()

def next_point():
    with output_container:
        clear_output(wait=True)

        num_collected = len(reference_data)

        if num_collected >= num_points_total:
            progress_label.value = "Data collection complete! Generating reports..."
            # --- MODIFIED: Close all file handlers ---
            base_dataset.close(); cnn64_dataset.close(); cnn10_dataset.close(); rf_dataset.close()
            final_df = pd.DataFrame(reference_data)

            doc_filename = "3-model_accuracy_report.docx" # <-- MODIFIED
            doc = docx.Document()

            # --- MODIFIED: Generate and print report for all three models ---

            # Report 1: CNN-64
            cnn64_report_text = generate_accuracy_report(final_df, "CNN-64", 'cnn64_class', 'reference_class', CLASS_LABELS, CNN64_IMG_PATH)
            print("="*58+"\n                   CNN-64 MODEL REPORT                         \n"+"="*58)
            for line in cnn64_report_text: print(line)
            doc.add_heading('CNN-64 MODEL REPORT', level=1)
            p = doc.add_paragraph(); p.add_run('\n'.join(cnn64_report_text)).font.name = 'Courier New'
            doc.add_page_break()

            # Report 2: CNN-10 (NEW)
            cnn10_report_text = generate_accuracy_report(final_df, "CNN-10", 'cnn10_class', 'reference_class', CLASS_LABELS, CNN10_IMG_PATH)
            print("\n\n"+"="*58+"\n                   CNN-10 MODEL REPORT                         \n"+"="*58)
            for line in cnn10_report_text: print(line)
            doc.add_heading('CNN-10 MODEL REPORT', level=1)
            p = doc.add_paragraph(); p.add_run('\n'.join(cnn10_report_text)).font.name = 'Courier New'
            doc.add_page_break()

            # Report 3: Random Forest
            rf_report_text = generate_accuracy_report(final_df, "Random Forest", 'rf_class', 'reference_class', CLASS_LABELS, RF_IMG_PATH)
            print("\n\n"+"="*58+"\n                 RANDOM FOREST MODEL REPORT                      \n"+"="*58)
            for line in rf_report_text: print(line)
            doc.add_heading('RANDOM FOREST MODEL REPORT', level=1)
            p = doc.add_paragraph(); p.add_run('\n'.join(rf_report_text)).font.name = 'Courier New'
            doc.add_page_break()

            doc.add_heading('Raw Reference Data', level=1)
            doc.add_paragraph("The raw data collected during the interactive session.")
            t = doc.add_table(rows=1, cols=len(final_df.columns)); t.style = 'Table Grid'
            for j, col_name in enumerate(final_df.columns): t.cell(0, j).text = col_name
            for i, row in final_df.iterrows():
                row_cells = t.add_row().cells
                for j, val in enumerate(row): row_cells[j].text = str(val)

            doc.save(doc_filename)
            print(f"\n\nReports successfully generated and saved to '{doc_filename}'.")
            print("Downloading the Word document...")
            files.download(doc_filename)
            return

        progress_label.value = f"Classifying point {num_collected + 1} of {num_points_total}..."

        global current_point
        y, x, expected_class_name = sampling_plan[num_collected]
        current_point = (y, x)
        lon, lat = base_dataset.xy(y, x)

        # Use Google's satellite tiles, which are often more reliable in Colab
        m = folium.Map(
            location=[lat, lon],
            zoom_start=18,  # A zoom of 22 is very high, 18 is usually better to start
            tiles='https://mt1.google.com/vt/lyrs=s&x={x}&y={y}&z={z}',
            attr='Google',
            height=400,
            width=800
        )
        folium.Marker([lat, lon], popup="Classify this exact point").add_to(m)

        map_widget = HTML(m._repr_html_())
        buttons = []
        for class_val, class_name in CLASS_LABELS.items():
            button = widgets.Button(description=f"{class_val}: {class_name}", button_style='primary', layout=Layout(width='auto'))
            button.on_click(on_button_clicked)
            buttons.append(button)
        button_box = HBox(buttons, layout=Layout(justify_content='center', flex_flow='wrap'))

        title_html = HTML(f"<h3><b>Point {num_collected + 1}/{num_points_total}</b>: Please verify this point for class: <b>{expected_class_name}</b></h3>")

        display(VBox([title_html, map_widget, button_box], layout=Layout(width='810px')))


# import matplotlib.pyplot as plt

# #==============================================================================
# # B. (NEW) VIEW CLASSIFIED IMAGES
# #==============================================================================

# print("Loading and displaying images for review...")

# # --- List of images and their titles for plotting ---
# images_to_view = {
#     'Base Image': BASE_IMG_PATH,
#     'CNN-64 Classified': CNN64_IMG_PATH,
#     'CNN-10 Classified': CNN10_IMG_PATH,
#     'Random Forest Classified': RF_IMG_PATH
# }

# # --- Create a figure with 2x2 subplots ---
# fig, axes = plt.subplots(2, 2, figsize=(12, 12))
# axes = axes.flatten() # Flatten the 2x2 array for easy iteration

# # --- Loop through the images, open, and display them ---
# for i, (title, path) in enumerate(images_to_view.items()):
#     try:
#         with rasterio.open(path) as src:
#             # Read the first band of the image
#             image_data = src.read(1)
#             # Display the image on the corresponding subplot
#             ax = axes[i]
#             im = ax.imshow(image_data, cmap='viridis') # You can change the colormap if you like
#             ax.set_title(title)
#             ax.set_axis_off() # Hide the x and y axes
#     except Exception as e:
#         print(f"Could not open or display image {path}: {e}")
#         axes[i].set_title(f"Error loading: {title}")
#         axes[i].set_axis_off()

# plt.tight_layout()
# plt.show()

# print("\nImage review complete. Proceeding with the script...")

# --- Main execution ---
try:
    # Ask user for the sampling method
    method = ''
    while method not in ['random', 'stratified']:
        method = input("Select sampling method (random / stratified): ").lower().strip()

    sampling_plan = []

    # --- Logic for STRATIFIED sampling ---
    if method == 'stratified':
        points_per_class_input = input("How many reference points do you want to collect PER CLASS? ")
        points_per_class = int(points_per_class_input)
        print("\nINFO: Generating stratified sampling plan from CNN-10 image.")
        sampling_plan = create_stratified_sampling_plan(CNN10_IMG_PATH, CLASS_LABELS, points_per_class)

    # --- Logic for RANDOM sampling ---
    elif method == 'random':
        total_points_input = input("How many TOTAL random reference points do you want to collect? ")
        total_points = int(total_points_input)
        print("\nINFO: Generating random sampling plan from CNN-10 image.")
        sampling_plan = create_random_sampling_plan(CNN10_IMG_PATH, CLASS_LABELS, total_points)

    num_points_total = len(sampling_plan)

    if num_points_total == 0:
        print("Warning: No sampling points were generated. Please check your input. Exiting.")
    else:
        # Shuffle the plan to mix classes for a better user experience
        random.shuffle(sampling_plan)
        print(f"\nStarting data collection for a total of {num_points_total} points...")
        display(progress_label, output_container)
        next_point()

except Exception as e:
    print(f"An error occurred: {e}. Please check your setup, file paths, and input values.")


