# INSTALL

In [None]:
# Install the core libraries we need
!pip install transformers datasets evaluate -q
!pip install albumentations -q  # A popular library for data augmentation

    

# CONFIGURATION

In [None]:
import torch
import os

class TrainingConfig:
    # --- Paths ---
    # Assumes your labeled CVUSA dataset is added to Kaggle
    dataset_path = "/kaggle/input/cvusa-subset/polarmap"
    image_dir = os.path.join(dataset_path, "normal")
    mask_dir = "/kaggle/working/corrected_masks"
    output_dir = "/kaggle/working/segformer-finetuned-cvusa/"

    # --- Model ---
    # We start from a model pre-trained on Cityscapes
    model_name = "nvidia/segformer-b5-finetuned-cityscapes-1024-1024"
    # YOU MUST CHANGE THIS: Set the number of classes in your dataset
    # (e.g., building, road, tree, grass, background = 5 classes)
    num_classes = 6 #CHANGED TO SIX

    # --- Training ---
    batch_size = 4  # Adjust based on your GPU memory (P100/V100 can handle more)
    learning_rate = 6e-5
    num_epochs = 25 # Train for more epochs for better results

    # --- System ---
    device = "cuda" if torch.cuda.is_available() else "cpu"

# Create the output directory if it doesn't exist
os.makedirs(TrainingConfig.output_dir, exist_ok=True)
print(f"Device: {TrainingConfig.device}")

# CSV NAMING FIXES

In [None]:
import csv
import os
from pathlib import Path

def check_csv_structure(csv_path):
    """
    Check the structure of the CSV file and print sample rows (no pandas)
    """
    print(f"📁 Checking CSV structure: {csv_path}")
    
    # Check if file exists
    if not os.path.exists(csv_path):
        print(f"❌ File not found: {csv_path}")
        return None, None
    
    # Read first few lines to understand structure
    with open(csv_path, 'r', newline='', encoding='utf-8') as file:
        reader = csv.reader(file)
        rows = []
        for i, row in enumerate(reader):
            rows.append(row)
            if i >= 4:  # Read first 5 rows
                break
    
    if not rows:
        print(f"❌ Empty CSV file")
        return None, None
    
    first_row = rows[0]
    second_row = rows[1] if len(rows) > 1 else None
    
    print(f"🔍 Raw file inspection:")
    print(f"  First row:  {','.join(first_row)}")
    if second_row:
        print(f"  Second row: {','.join(second_row)}")
    
    # Check if first row looks like data (contains .png) rather than headers
    has_proper_headers = not any('.png' in cell for cell in first_row)
    
    # Count total rows
    with open(csv_path, 'r', newline='', encoding='utf-8') as file:
        total_rows = sum(1 for _ in csv.reader(file))
    
    if has_proper_headers:
        print(f"✅ File has proper headers")
        data_rows = rows[1:]  # Skip header
        headers = first_row
        actual_data_count = total_rows - 1
    else:
        print(f"⚠️  File has NO proper headers - first row is data!")
        print(f"📝 Using custom column names...")
        data_rows = rows  # All rows are data
        headers = ['satellite_path', 'ground_path', 'duplicate_ground_path']
        actual_data_count = total_rows
    
    print(f"\n📊 CSV Info:")
    print(f"  - Total rows: {total_rows}")
    print(f"  - Data rows: {actual_data_count}")
    print(f"  - Columns: {len(headers)} -> {headers}")
    
    print(f"\n📋 First 5 data rows:")
    for i, row in enumerate(data_rows[:5]):
        print(f"  Row {i+1}: {row}")
    
    print(f"\n🔍 Sample paths analysis:")
    if data_rows:
        sample_row = data_rows[0]
        for i, (header, value) in enumerate(zip(headers, sample_row)):
            print(f"  Column {i+1} ({header}): {value}")
    
    return rows, has_proper_headers

def fix_csv_paths(input_csv_path, output_csv_path=None):
    """
    Fix the CSV file paths according to the requirements (no pandas):
    - Column 2: Remove 'input' from filename (streetview/input0026840.png → streetview/0026840.png)
    - Column 3: Change to polarmap/normal/input{ID}.png
    """
    
    # Check structure first (this only reads first 5 rows for inspection)
    result = check_csv_structure(input_csv_path)
    if result[0] is None:
        return None
    
    sample_rows, has_proper_headers = result
    
    # Now read ALL rows from the file
    print(f"🔄 Reading complete file: {input_csv_path}")
    with open(input_csv_path, 'r', newline='', encoding='utf-8') as file:
        reader = csv.reader(file)
        all_rows = list(reader)
    
    print(f"📊 Complete file loaded: {len(all_rows)} total rows")
    
    print(f"\n🔧 Applying transformations...")
    
    # Determine data start and headers using the complete dataset
    if has_proper_headers:
        headers = all_rows[0]
        data_rows = all_rows[1:]
        print(f"  - Using existing headers: {headers}")
    else:
        headers = ['satellite_path', 'ground_path', 'polarmap_path']
        data_rows = all_rows  # All rows are data
        print(f"  - Using custom headers: {headers}")
    
    # Transform the data
    fixed_rows = []
    
    print(f"  - Processing {len(data_rows)} data rows...")
    print(f"  - Fixing column 2: removing 'input' from filenames AND changing .png to .jpg")
    print(f"  - Fixing column 3: changing to polarmap/normal/input{{ID}}.png")
    
    for row_idx, row in enumerate(data_rows):
        if len(row) < 3:
            print(f"⚠️  Row {row_idx + 1} has fewer than 3 columns: {row}")
            fixed_rows.append(row)  # Keep as-is
            continue
        
        # Extract original values
        satellite_path = row[0]  # Keep as-is
        ground_path = row[1]     # Remove 'input'
        third_path = row[2]      # Convert to polarmap
        
        # Transform column 2: Remove 'input' from filename AND change .png to .jpg
        fixed_ground_path = ground_path.replace('input', '').replace('.png', '.jpg')
        
        # Transform column 3: Change to polarmap/normal/input{ID}.png
        try:
            # Extract filename from original path
            original_filename = Path(third_path).name
            fixed_third_path = f"polarmap/normal/{original_filename}"
        except:
            fixed_third_path = third_path  # Keep original if parsing fails
        
        # Create fixed row
        fixed_row = [satellite_path, fixed_ground_path, fixed_third_path]
        
        # Add any additional columns if they exist
        if len(row) > 3:
            fixed_row.extend(row[3:])
        
        fixed_rows.append(fixed_row)
    
    # Show before/after comparison
    if data_rows:
        print(f"\n📊 Transformation Results:")
        print(f"Original sample row:")
        sample_orig = data_rows[0]
        for i, val in enumerate(sample_orig[:3]):
            print(f"  Col {i+1}: {val}")
        
        print(f"\nFixed sample row:")
        sample_fixed = fixed_rows[0]
        for i, val in enumerate(sample_fixed[:3]):
            print(f"  Col {i+1}: {val}")
    
    # Determine output path
    if output_csv_path is None:
        input_path = Path(input_csv_path)
        output_csv_path = Path("/kaggle/working") / f"{input_path.stem}_fixed{input_path.suffix}"
    
    # Write the fixed CSV
    with open(output_csv_path, 'w', newline='', encoding='utf-8') as file:
        writer = csv.writer(file)
        
        # Write headers
        writer.writerow(headers)
        
        # Write data
        writer.writerows(fixed_rows)
    
    print(f"\n✅ Fixed CSV saved to: {output_csv_path}")
    print(f"📁 Location: Kaggle working directory (writable)")
    print(f"📊 Output summary: {len(headers)} columns, {len(fixed_rows)} data rows")
    
    return fixed_rows, headers

def process_train_test_csvs(train_csv_path, test_csv_path):
    """
    Process both training and test CSV files (no pandas)
    """
    print("="*60)
    print("🚀 Processing Training and Test CSV Files")
    print("="*60)
    
    # Process training CSV
    print("\n📚 TRAINING CSV:")
    train_result = fix_csv_paths(train_csv_path)
    
    print("\n" + "="*60)
    
    # Process test CSV
    print("\n🧪 TEST CSV:")
    test_result = fix_csv_paths(test_csv_path)
    
    print("\n" + "="*60)
    print("✅ Both CSV files processed successfully!")
    
    return train_result, test_result

def load_csv_data(csv_path):
    """
    Simple utility to load CSV data back into memory (no pandas)
    Returns: (data_rows, headers)
    """
    with open(csv_path, 'r', newline='', encoding='utf-8') as file:
        reader = csv.reader(file)
        headers = next(reader)  # First row is headers
        data_rows = list(reader)  # Rest are data
    
    return data_rows, headers

# Example usage:
if __name__ == "__main__":
    # Your actual file paths
    train_csv = "/kaggle/input/cvusa-subset/train-19zl.csv"
    test_csv = "/kaggle/input/cvusa-subset/val-19zl.csv"  # Note: this is val, not test
    
    # Check structure of both files
    print("Checking CSV structures...")
    check_csv_structure(train_csv)
    print("\n" + "="*40)
    check_csv_structure(test_csv)
    
    print("\n" + "="*60)
    
    # Process both files - saves to /kaggle/working/
    train_result, test_result = process_train_test_csvs(train_csv, test_csv)
    
    # Print the final file locations
    print(f"\n📁 Fixed files are now available at:")
    print(f"   Training: /kaggle/working/train-19zl_fixed.csv")
    print(f"   Test:     /kaggle/working/val-19zl_fixed.csv")
    
    # Example of loading the data back
    print(f"\n🔄 Example: Loading fixed training data...")
    train_data, train_headers = load_csv_data("/kaggle/working/train-19zl_fixed.csv")
    print(f"   Headers: {train_headers}")
    print(f"   Data rows: {len(train_data)}")
    print(f"   Sample: {train_data[0] if train_data else 'No data'}")


# CONVERSION

In [None]:
import numpy as np
from PIL import Image
import os
from tqdm import tqdm

# --- 1. FINALIZED CONFIGURATION ---

# The source directory for your original RGB masks
SOURCE_MASK_DIR = "/kaggle/input/cvusa-subset/polarmap/segmap/"
# The target directory where the corrected, single-channel masks will be saved
TARGET_MASK_DIR = "/kaggle/working/corrected_masks/"

# This is your corrected and finalized color map.
# The script will use these values to perform the conversion.
FINAL_COLOR_MAP = {
    # Class_ID : (R, G, B)
    0: (255, 0, 0),        # Red
    1: (0, 0, 255),        # Dark Blue
    2: (255, 255, 255),    # White
    3: (0, 255, 255),      # Light Blue (Cyan)
    4: (0, 255, 0),        # Green
    5: (255, 255, 0)       # Yellow
}

# --- 2. CONVERSION LOGIC (No need to edit below) ---

def convert_rgb_to_class_id(rgb_array, color_map):
    """Converts an RGB numpy array to a class ID array using a color map."""
    primary_colors = np.array(list(color_map.values()), dtype=np.uint8)
    class_ids = np.array(list(color_map.keys()), dtype=np.uint8)

    # Calculate Euclidean distance and find the closest primary color
    distances = np.sum((rgb_array[:, :, np.newaxis, :] - primary_colors[np.newaxis, np.newaxis, :, :])**2, axis=3)
    closest_class_indices = np.argmin(distances, axis=2)
    
    return class_ids[closest_class_indices].astype(np.uint8)

# --- 3. AUTOMATIC CHECK AND EXECUTION ---

print("🚀 Starting Mask Preparation Process...")
os.makedirs(TARGET_MASK_DIR, exist_ok=True)

source_files = sorted(os.listdir(SOURCE_MASK_DIR))
needs_conversion = False

# Check if conversion is required
if not os.listdir(TARGET_MASK_DIR) or len(os.listdir(TARGET_MASK_DIR)) != len(source_files):
    print("⚠️ Target directory is empty or file count mismatches. Full conversion required.")
    needs_conversion = True
else:
    # If files exist, verify the first one to see if the color map matches.
    print("🔍 Target directory found. Verifying first mask to check for correctness...")
    
    first_source_path = os.path.join(SOURCE_MASK_DIR, source_files[0])
    first_target_path = os.path.join(TARGET_MASK_DIR, source_files[0])

    # Open the source image and convert it in memory
    source_img_rgb = np.array(Image.open(first_source_path).convert("RGB"))
    expected_class_ids = convert_rgb_to_class_id(source_img_rgb, FINAL_COLOR_MAP)
    
    # Open the existing target image
    existing_class_ids = np.array(Image.open(first_target_path))

    # Compare if the existing mask matches what we expect from the current color map
    if not np.array_equal(expected_class_ids, existing_class_ids):
        print("❌ Verification failed! Existing masks were made with a different color map.")
        print("Full conversion is required.")
        needs_conversion = True
    else:
        print("✅ Verification successful! Masks are already correct and up-to-date.")

# --- Run the full conversion only if necessary ---
if needs_conversion:
    print(f"\n🔧 Converting {len(source_files)} masks from '{SOURCE_MASK_DIR}'...")
    
    for filename in tqdm(source_files, desc="Converting All Masks"):
        source_path = os.path.join(SOURCE_MASK_DIR, filename)
        
        # Open the source RGB mask
        rgb_mask_image = Image.open(source_path).convert("RGB")
        rgb_array = np.array(rgb_mask_image)
        
        # Convert to class IDs using the finalized map
        class_id_array = convert_rgb_to_class_id(rgb_array, FINAL_COLOR_MAP)
        
        # Save the new single-channel (grayscale) mask as a PNG
        new_mask_image = Image.fromarray(class_id_array)
        target_path = os.path.join(TARGET_MASK_DIR, filename)
        new_mask_image.save(target_path)
        
    print("\n✅ Conversion complete!")
    print(f"All corrected masks are saved in: '{TARGET_MASK_DIR}'")
else:
    print("\n👍 No action needed. Your masks are ready for training.")

# THE DATASET 

In [None]:
# Cell 3: The Dataset Class (Corrected for Segmentation)

from torch.utils.data import Dataset
from PIL import Image
import numpy as np
import os
import csv
import albumentations as A
from albumentations.pytorch import ToTensorV2

class SegmentationCvusaDataset(Dataset):
    """
    A dataset class to load satellite images and their corresponding segmentation masks,
    using a CSV file for the train/test split and deriving mask paths from image paths.
    """
    def __init__(self, csv_path, data_root, transform=None):
        self.data_root = data_root
        self.transform = transform
        self.image_mask_pairs = []

        print(f"📂 Loading dataset from: {csv_path}")

        with open(csv_path, 'r', newline='', encoding='utf-8') as file:
            reader = csv.reader(file)
            next(reader)  # Skip header row

            for row in reader:
                if len(row) < 3: continue
                
                # --- DERIVATION LOGIC ---
                # The CSV contains the relative path to the satellite image (e.g., 'polarmap/input0000008.png')
                satellite_relative_path = row[2]  # Assuming polarmap path is the 3rd column
                
                # 1. Derive the full satellite image path
                image_full_path = os.path.join(self.data_root, satellite_relative_path)
                
                # 2. Derive the corresponding mask path
                base_filename = os.path.basename(satellite_relative_path) # -> "input0000008.png"
                mask_filename = base_filename.replace('input', 'output')  # -> "output0000008.png"
                mask_full_path = os.path.join('/kaggle/working/corrected_masks', mask_filename)

                # 3. Add the pair to our list if both files exist
                if os.path.exists(image_full_path) and os.path.exists(mask_full_path):
                    self.image_mask_pairs.append((image_full_path, mask_full_path))

        print(f"✅ Found {len(self.image_mask_pairs)} valid image-mask pairs.")

    def __len__(self):
        return len(self.image_mask_pairs)

    def __getitem__(self, idx):
        image_path, mask_path = self.image_mask_pairs[idx]

        image = np.array(Image.open(image_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L")) # Masks are single-channel

        if self.transform:
            transformed = self.transform(image=image, mask=mask)
            image = transformed['image']
            mask = transformed['mask']
            # Masks must be LongTensor for segmentation loss
            mask = mask.long()

        return image, mask

# --- Define Transformations ---
# We use Albumentations because it correctly handles transforms for both images and masks.
# As requested, we will start with no data augmentation, just resizing and normalization.
# You can easily add augmentations here later (e.g., A.HorizontalFlip).
transform = A.Compose([
    A.Resize(height=512, width=512),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

# model import

In [None]:
# Cell 4: Model and Optimizer Definition

from transformers import SegformerForSemanticSegmentation
import torch

# --- 1. Import Segformer and Load the Pre-trained Model ---
print("🧠 Loading pre-trained SegFormer model...")

# We load the model from Hugging Face, telling it to replace the final
# classification layer with a new one matching our number of classes.
model = SegformerForSemanticSegmentation.from_pretrained(
    TrainingConfig.model_name,  # Using your lowercase config object
    num_labels=TrainingConfig.num_classes,
    ignore_mismatched_sizes=True # This is key to replacing the final layer
).to(TrainingConfig.device)

# --- 2. Create the Optimizer ---
# AdamW is a standard, robust optimizer for transformer models.
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=TrainingConfig.learning_rate
)

print("✅ Model and optimizer created successfully.")

# TRAINING

In [None]:
# At the top of your cell
from torch.cuda.amp import GradScaler, autocast

# Cell 5: The Full Training and Validation 

from torch.utils.data import DataLoader
from tqdm import tqdm

# --- Define paths to your CSV files and data ---
TRAIN_CSV_PATH = "/kaggle/working/train-19zl_fixed.csv"
VALID_CSV_PATH = "/kaggle/working/val-19zl_fixed.csv"
DATA_ROOT = "/kaggle/input/cvusa-subset"
CORRECTED_MASK_DIR = "/kaggle/working/corrected_masks/" # Use this if you converted your masks

# --- Create Datasets and DataLoaders ---
print("\n🚀 Creating train and validation datasets for segmentation...")
train_dataset = SegmentationCvusaDataset(
    csv_path=TRAIN_CSV_PATH,
    data_root=DATA_ROOT,
    transform=transform,
)

valid_dataset = SegmentationCvusaDataset(
    csv_path=VALID_CSV_PATH,
    data_root=DATA_ROOT,
    transform=transform,
)

if len(train_dataset) == 0 or len(valid_dataset) == 0:
    print("\n❌ ERROR: One or both datasets are empty. Cannot proceed with training.")
else:
    train_loader = DataLoader(
        train_dataset,
        batch_size=TrainingConfig.batch_size,
        shuffle=True
    )
    valid_loader = DataLoader(
        valid_dataset,
        batch_size=TrainingConfig.batch_size,
        shuffle=False
    )
    print(f"\n✅ Dataloaders created successfully.")

    scaler = GradScaler() # ADDED: Create the scaler ONCE, outside the loop

    # --- Start the Training ---
    print("\n--- Starting Fine-Tuning ---")
    for epoch in range(TrainingConfig.num_epochs):
        # This will now work because 'model' is defined in Cell 4
        model.train()
        total_loss = 0
        
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{TrainingConfig.num_epochs}")
        
        for images, masks in progress_bar:
            images = images.to(TrainingConfig.device)
            masks = masks.to(TrainingConfig.device)
            
            # ADDED: autocast context manager
            with autocast():
                outputs = model(pixel_values=images, labels=masks)
                loss = outputs.loss
            
            optimizer.zero_grad()
            
            # CHANGED: Use the scaler to scale the loss and call backward()
            scaler.scale(loss).backward()
            
            # CHANGED: Use the scaler to unscale gradients and call optimizer.step()
            scaler.step(optimizer)
            
            # CHANGED: Update the scaler for the next iteration
            scaler.update()
            
            total_loss += loss.item()
            progress_bar.set_postfix({'loss': loss.item()})

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1} finished. Average Training Loss: {avg_loss:.4f}")

        # --- Validation Loop ---
        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for images, masks in valid_loader:
                images = images.to(TrainingConfig.device)
                masks = masks.to(TrainingConfig.device)
                outputs = model(pixel_values=images, labels=masks)
                total_val_loss += outputs.loss.item()
        avg_val_loss = total_val_loss / len(valid_loader)
        print(f"Average Validation Loss: {avg_val_loss:.4f}")

        # --- Save a Checkpoint ---
        model.save_pretrained(os.path.join(TrainingConfig.output_dir, f"checkpoint-epoch-{epoch+1}"))

    print("\n--- Fine-Tuning Finished ---")
    model.save_pretrained(os.path.join(TrainingConfig.output_dir, "final_model"))
    print(f"Final model saved to {os.path.join(TrainingConfig.output_dir, 'final_model')}")