# Land Cover Classification using DeepLabV3+ (MobileNetV2)

## 1. Setup Environment

Install and import necessary libraries.

**Note:** If using a **TPU accelerator** on Kaggle, uncomment the TPU-specific installation lines in the next cell. Ensure the dataset is correctly placed in the input directory.

In [None]:
# Install base libraries
!pip install -q transformers datasets evaluate accelerate Pillow torch torchvision torchaudio numpy matplotlib seaborn scikit-learn

# --- TPU Setup ---
# Uncomment the following lines ONLY if using a TPU accelerator in your Kaggle session
# print("Installing TPU-specific libraries using env-setup.py...")
# # Download the setup script
# !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
# # Run the setup script (adjust version if needed, e.g., "2.1.0", "2.0.0")
# !python pytorch-xla-env-setup.py --version 2.1.0 --apt-packages libomp5 libopenblas-dev
# print("TPU libraries setup complete (if script ran successfully).")

In [None]:
import os
import numpy as np
import pandas as pd
import torch
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
from datasets import Dataset, DatasetDict, Image as HFImage
from transformers import AutoModelForSemanticSegmentation, AutoFeatureExtractor, Trainer, TrainingArguments
import evaluate
from huggingface_hub import notebook_login
import random
import torch.nn.functional as F

# --- TPU Imports (conditional) ---
_TPU_AVAILABLE = False
try:
    import torch_xla.core.xla_model as xm
    import torch_xla.distributed.parallel_loader as pl
    # Check if TPU is actually available
    if xm.xla_device(): 
        print("torch_xla found and TPU device is available.")
        _TPU_AVAILABLE = True
    else:
        print("torch_xla found, but no TPU device detected. Check accelerator settings.")
except ImportError:
    print("torch_xla not found. Running on CPU/GPU.")

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Load Dataset

Load the DeepGlobe Land Cover Classification dataset. 
You might need to download it from Kaggle first: https://www.kaggle.com/datasets/balraj98/deepglobe-land-cover-classification-dataset

**Important:** Ensure the dataset is placed in the Kaggle input directory.

In [None]:
# Define the standard Kaggle input directory
dataset_base_dir = '/kaggle/input/deepglobe-land-cover-classification-dataset'
dataset_root_dir = os.path.join(dataset_base_dir, 'deepglobe')
metadata_path = os.path.join(dataset_root_dir, 'metadata.csv')

# Check if the dataset path exists
if not os.path.exists(dataset_root_dir):
    print(f"Error: Dataset directory not found at {dataset_root_dir}")
    print("Please ensure the DeepGlobe dataset is correctly placed in the Kaggle input directory.")
    # You might want to raise an error or exit here in a real script
    # For the notebook, we'll proceed but expect errors later.
    metadata_df = pd.DataFrame(columns=['image_id', 'split', 'sat_image_path', 'mask_path']) # Dummy df
else:
    print(f"Dataset found at: {dataset_root_dir}")
    metadata_df = pd.read_csv(metadata_path)
    # Prepend the root directory to the paths in the CSV
    metadata_df['sat_image_path'] = metadata_df['sat_image_path'].apply(lambda x: os.path.join(dataset_root_dir, x))
    metadata_df['mask_path'] = metadata_df['mask_path'].apply(lambda x: os.path.join(dataset_root_dir, x))

# Define class names and their corresponding IDs
id2label = {
    0: 'urban_land',
    1: 'agriculture_land',
    2: 'rangeland',
    3: 'forest_land',
    4: 'water',
    5: 'barren_land',
    6: 'unknown'
}
label2id = {v: k for k, v in id2label.items()}
num_labels = len(id2label)
class_names = list(id2label.values())

# Function to load data paths based on split from metadata.csv
def load_data_paths(df, split):
    split_df = df[df['split'] == split]
    image_paths = split_df['sat_image_path'].tolist()
    mask_paths = split_df['mask_path'].tolist()
    # Verify files exist (optional but recommended)
    image_paths = [p for p in image_paths if os.path.exists(p)]
    mask_paths = [p for p in mask_paths if os.path.exists(p)]
    print(f"Found {len(image_paths)} images and {len(mask_paths)} masks for split '{split}'.")
    return image_paths, mask_paths

train_image_paths, train_mask_paths = load_data_paths(metadata_df, 'train')
val_image_paths, val_mask_paths = load_data_paths(metadata_df, 'valid')
test_image_paths, test_mask_paths = load_data_paths(metadata_df, 'test')

# Create Hugging Face Datasets
def create_hf_dataset(image_paths, mask_paths):
    if not image_paths or not mask_paths or len(image_paths) != len(mask_paths):
        print(f"Warning: Mismatch or empty paths. Creating empty dataset.")
        return Dataset.from_dict({'image': [], 'label': []}).cast_column('image', HFImage()).cast_column('label', HFImage())
    dataset = Dataset.from_dict({'image': image_paths, 'label': mask_paths})
    # Casting ensures the columns are treated as images
    dataset = dataset.cast_column('image', HFImage())
    dataset = dataset.cast_column('label', HFImage())
    return dataset

train_dataset = create_hf_dataset(train_image_paths, train_mask_paths)
val_dataset = create_hf_dataset(val_image_paths, val_mask_paths)
test_dataset = create_hf_dataset(test_image_paths, test_mask_paths)

ds = DatasetDict({
    'train': train_dataset,
    'validation': val_dataset,
    'test': test_dataset
})

print("\nDataset structure:")
print(ds)

## 3. Preprocessing

Define feature extractor and transformations. The masks in DeepGlobe are RGB images. We need a function to convert these RGB masks to class ID masks (0-6).

In [None]:
# Use the feature extractor corresponding to the new model
model_checkpoint = "google/deeplabv3_mobilenet_v2_1.0_513"
feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)

# Define the RGB to Class ID mapping (Verify these with dataset documentation/inspection)
rgb_to_id = {
    (0, 255, 255): 0,  # Urban land (Cyan)
    (255, 255, 0): 1,  # Agriculture land (Yellow)
    (255, 0, 255): 2,  # Rangeland (Magenta)
    (0, 255, 0): 3,    # Forest land (Green)
    (0, 0, 255): 4,    # Water (Blue)
    (255, 255, 255): 5,# Barren land (White)
    (0, 0, 0): 6       # Unknown (Black)
}
id_to_rgb = {v: k for k, v in rgb_to_id.items()} # Invert mapping for visualization

def rgb_mask_to_class_id_mask(mask_img):
    """Converts an RGB mask image (PIL Image) to a 2D array of class IDs."""
    mask_arr = np.array(mask_img.convert('RGB')) # Ensure it's RGB
    class_mask = np.full(mask_arr.shape[:2], 6, dtype=np.uint8) # Default to 'unknown'
    for rgb, class_id in rgb_to_id.items():
        matches = np.all(mask_arr == np.array(rgb).reshape(1, 1, 3), axis=2)
        class_mask[matches] = class_id
    return Image.fromarray(class_mask) # Return as PIL Image

def preprocess_data(examples):
    images = [img.convert("RGB") for img in examples['image']]
    labels = [rgb_mask_to_class_id_mask(mask) for mask in examples['label']]

    inputs = feature_extractor(images, labels, return_tensors="pt")

    return inputs

print("\nApplying preprocessing...")
if len(ds['train']) > 0:
   processed_ds = ds.map(preprocess_data, batched=True, batch_size=4)
else:
   print("Skipping preprocessing as datasets are empty.")
   processed_ds = ds

print("\nProcessed dataset structure (first element example):")
if len(processed_ds['train']) > 0:
    print(processed_ds['train'][0])
else:
    print("Train dataset is empty.")

## 4. Model Definition

Load a pre-trained DeepLabV3+ model and configure it for our specific number of classes.

In [None]:
# Load the new model using AutoModelForSemanticSegmentation
model = AutoModelForSemanticSegmentation.from_pretrained(
    model_checkpoint,
    num_labels=num_labels,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True,
)

if _TPU_AVAILABLE:
    device = xm.xla_device()
    print(f"Using TPU device: {device}")
else:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using non-TPU device: {device}")

model.to(device)
print(f"Model loaded on {device}")

**Note:** The model loaded here (`google/deeplabv3_mobilenet_v2_1.0_513`) is **DeepLabV3+** with a **MobileNetV2** backbone. This architecture is known for providing a good balance between segmentation accuracy and computational efficiency, making it suitable for faster inference compared to larger transformer models.

## 5. Training Configuration

Set up `TrainingArguments` and define the evaluation metric (Mean Intersection over Union - mIoU).

In [None]:
metric = evaluate.load("mean_iou")

def compute_metrics(eval_pred):
    logits, labels = eval_pred

    if isinstance(labels, torch.Tensor):
        labels = labels.detach().cpu().numpy()

    if isinstance(logits, torch.Tensor):
        logits = logits.detach().cpu()
    if isinstance(logits, np.ndarray):
        logits = torch.from_numpy(logits)

    logits_shape = logits.shape[-2:]
    labels_shape = labels.shape[-2:]

    if logits_shape != labels_shape:
        upsampled_logits = F.interpolate(
            logits,
            size=labels_shape,
            mode='bilinear',
            align_corners=False
        )
    else:
        upsampled_logits = logits

    pred_labels = upsampled_logits.argmax(dim=1).numpy()

    metrics = metric.compute(
        predictions=pred_labels,
        references=labels,
        num_labels=num_labels,
        ignore_index=6,
        reduce_labels=False,
    )

    per_category_iou = metrics.pop('per_category_iou', [0.0] * num_labels)
    per_category_accuracy = metrics.pop('per_category_accuracy', [0.0] * num_labels)
    for i, label in id2label.items():
        metrics[f"iou_{label}"] = per_category_iou[i]
        metrics[f"accuracy_{label}"] = per_category_accuracy[i]

    return {
        "mean_iou": metrics.get("mean_iou", 0.0),
        "mean_accuracy": metrics.get("mean_accuracy", 0.0),
        "overall_accuracy": metrics.get("overall_accuracy", 0.0),
        **metrics
    }

train_batch_size = 16 if _TPU_AVAILABLE else 8
eval_batch_size = 16 if _TPU_AVAILABLE else 8
print(f"Using Train Batch Size: {train_batch_size}, Eval Batch Size: {eval_batch_size}")

output_dir_name = "./deeplabv3-mobilenetv2-finetuned-deepglobe"

training_args = TrainingArguments(
    output_dir=output_dir_name,
    learning_rate=5e-5,
    num_train_epochs=20,
    per_device_train_batch_size=train_batch_size, 
    per_device_eval_batch_size=eval_batch_size,  
    save_total_limit=2, 
    evaluation_strategy="epoch", 
    save_strategy="epoch", 
    logging_strategy="steps",
    logging_steps=100, 
    load_best_model_at_end=True,
    metric_for_best_model="mean_iou",
    push_to_hub=False, 
    remove_unused_columns=False, 
    fp16=torch.cuda.is_available() and not _TPU_AVAILABLE, 
    tpu_num_cores=8 if _TPU_AVAILABLE else None, 
    dataloader_num_workers=2, 
    report_to="none" 
)

## 6. Fine-tuning

Instantiate the `Trainer` and start the fine-tuning process. You might want to log in to Hugging Face if you plan to push the model.

In [None]:
# Optional: Login to Hugging Face Hub if push_to_hub=True
# notebook_login()

# Check if datasets are valid before creating Trainer
train_data_available = 'train' in processed_ds and len(processed_ds['train']) > 0
eval_data_available = 'validation' in processed_ds and len(processed_ds['validation']) > 0

trainer = None
if train_data_available and eval_data_available:
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=processed_ds["train"],
        eval_dataset=processed_ds["validation"],
        compute_metrics=compute_metrics,
    )
    
    print("Starting training...")
    train_results = trainer.train()
    
    if not _TPU_AVAILABLE or xm.is_master_ordinal():
        print("Saving model and state...")
        trainer.save_model() 
        trainer.save_state()
        print("Model and state saved.")
    else:
        print(f"Skipping save on TPU replica {xm.get_ordinal()}")
        
    if _TPU_AVAILABLE:
        xm.rendezvous('save_model_done')
        
    print("Training finished.")
    print("Training Results:", train_results)
    
    if not _TPU_AVAILABLE or xm.is_master_ordinal():
        metrics = train_results.metrics
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
    
    print("\nEvaluating final model on validation set...")
    eval_metrics = trainer.evaluate()
    if not _TPU_AVAILABLE or xm.is_master_ordinal():
        trainer.log_metrics("eval", eval_metrics)
        trainer.save_metrics("eval", eval_metrics)
    
else:
    print("Skipping training as train or validation dataset is empty or invalid.")

## 7. Evaluation & Visualization

Evaluate the fine-tuned model on the test set and visualize some predictions.

In [None]:
test_data_available = 'test' in processed_ds and len(processed_ds['test']) > 0
all_preds = []
all_labels = []
test_metrics = None

if trainer is not None and test_data_available:
    print("\nEvaluating on the test set...")
    test_results = trainer.predict(processed_ds['test'])
    test_metrics = test_results.metrics
    
    if not _TPU_AVAILABLE or xm.is_master_ordinal():
        print("\nTest Set Evaluation Results:")
        print(test_metrics)
        trainer.log_metrics("test", test_metrics)
        trainer.save_metrics("test", test_metrics)
    
    logits = test_results.predictions
    labels = test_results.label_ids
    
    if isinstance(logits, torch.Tensor) and logits.device.type == 'xla':
        logits = logits.cpu()
    if isinstance(logits, np.ndarray):
        logits = torch.from_numpy(logits)
        
    if isinstance(labels, torch.Tensor) and labels.device.type == 'xla':
        labels = labels.cpu()
    if isinstance(labels, torch.Tensor):
        labels = labels.numpy()
        
    logits_shape = logits.shape[-2:]
    labels_shape = labels.shape[-2:]
    if logits_shape != labels_shape:
        upsampled_logits = F.interpolate(
            logits,
            size=labels_shape,
            mode='bilinear',
            align_corners=False
        )
    else:
        upsampled_logits = logits
        
    all_preds = upsampled_logits.argmax(dim=1).detach().cpu().numpy().flatten()
    all_labels = labels.flatten()
    
else:
    print("Skipping test set evaluation as trainer was not initialized or test dataset is empty.")

if (not _TPU_AVAILABLE or xm.is_master_ordinal()) and len(all_preds) > 0 and len(all_labels) > 0:
    ignore_idx = 6 
    valid_indices = all_labels != ignore_idx
    filtered_labels = all_labels[valid_indices]
    filtered_preds = all_preds[valid_indices]
    
    print(f"\nGenerating Confusion Matrix (ignoring class {ignore_idx}: '{id2label.get(ignore_idx, 'N/A')}')")
    if len(filtered_labels) > 0:
        cm = confusion_matrix(filtered_labels, filtered_preds, labels=list(range(num_labels-1)))
        plt.figure(figsize=(10, 8))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names[:-1], yticklabels=class_names[:-1])
        plt.xlabel('Predicted Label')
        plt.ylabel('True Label')
        plt.title('Confusion Matrix (Test Set)')
        plt.show()
    else:
        print("Skipping confusion matrix: No valid labels after filtering.")
elif not _TPU_AVAILABLE or xm.is_master_ordinal():
    print("Skipping confusion matrix generation (no predictions/labels available).")

def visualize_predictions(num_samples=5):
    if trainer is None or not test_data_available or len(test_image_paths) == 0:
        print("Skipping visualization: Trainer not available, test data missing, or no test image paths.")
        return
        
    print(f"\nVisualizing predictions for {num_samples} random test samples...")
    
    print(f"Loading best model from {training_args.output_dir} for visualization...")
    try:
        viz_model = AutoModelForSemanticSegmentation.from_pretrained(training_args.output_dir).cpu()
        viz_model.eval()
    except Exception as e:
        print(f"Error loading saved model for visualization: {e}. Using current model state on CPU.")
        viz_model = model.cpu()
        viz_model.eval()
    
    try: 
        _ = feature_extractor
    except NameError:
        print("Error: feature_extractor not found. Reloading...")
        feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
        
    num_available = len(test_image_paths)
    indices = random.sample(range(num_available), min(num_samples, num_available))
    
    for i in indices:
        image_path = test_image_paths[i]
        mask_path = test_mask_paths[i]
        
        try:
            image = Image.open(image_path).convert("RGB")
            true_mask_rgb = Image.open(mask_path).convert("RGB")
        except Exception as e:
            print(f"Error loading image/mask {i}: {e}")
            continue

        encoding = feature_extractor(image, return_tensors="pt")
        pixel_values = encoding.pixel_values

        with torch.no_grad():
            outputs = viz_model(pixel_values=pixel_values)
            logits = outputs.logits

        logits_shape_viz = logits.shape[-2:]
        image_shape_viz = image.size[::-1]
        if logits_shape_viz != image_shape_viz:
            upsampled_logits = F.interpolate(
                logits,
                size=image_shape_viz,
                mode="bilinear",
                align_corners=False,
            )
        else:
            upsampled_logits = logits
            
        pred_mask_id = upsampled_logits.argmax(dim=1).squeeze().numpy()

        pred_mask_rgb = np.zeros((*pred_mask_id.shape, 3), dtype=np.uint8)
        for class_id, color in id_to_rgb.items():
            pred_mask_rgb[pred_mask_id == class_id] = color
            
        legend_patches = [plt.Rectangle((0,0),1,1, fc=np.array(color)/255.0) for color in id_to_rgb.values()]
        legend_labels = [f"{idx}: {name}" for idx, name in id2label.items()]

        fig, axes = plt.subplots(1, 3, figsize=(18, 6))
        fig.suptitle(f"Sample {i}: {os.path.basename(image_path)}")
        axes[0].imshow(image)
        axes[0].set_title("Input Image")
        axes[0].axis('off')
        axes[1].imshow(true_mask_rgb)
        axes[1].set_title("True Mask (RGB)")
        axes[1].axis('off')
        axes[2].imshow(pred_mask_rgb)
        axes[2].set_title("Predicted Mask (RGB)")
        axes[2].axis('off')
        
        fig.legend(legend_patches, legend_labels, loc='lower center', ncol=len(id2label), bbox_to_anchor=(0.5, -0.05))
        
        plt.tight_layout(rect=[0, 0.05, 1, 0.95])
        plt.show()

if not _TPU_AVAILABLE or xm.is_master_ordinal():
    visualize_predictions(num_samples=5)
else:
    print(f"Skipping visualization on TPU replica {xm.get_ordinal()}")