<a href="https://colab.research.google.com/github/gee-community/geemap/blob/master/docs/notebooks/00_ee_auth_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open in Colab"/></a>


## Earth Engine Automatic Authentication on Google Colab

### Step 1: Locating the Earth Engine Token

1. Locate the Earth Engine token on your computer by navigating to the following file path based on your operating system:

    - Windows: C:\\Users\\USERNAME\\.config\\earthengine\\credentials
    - Linux: /home/USERNAME/.config/earthengine/credentials
    - MacOS: /Users/USERNAME/.config/earthengine/credentials

2. Open the credentials file and copy the entire content to the clipboard.

    **Note:** Ensure that you do not share the content of the credentials file with others to prevent unauthorized access to your Earth Engine account.

### Step 2: Creating the Secret in Google Colab

1. Open your Google Colab notebook and click on the `secrets` tab.
2. Create a new secret with the name `EARTHENGINE_TOKEN`.
3. Paste the content from the clipboard into the `Value` input box of the created secret.
4. Toggle the button on the left to allow notebook access to the secret.

![](https://i.imgur.com/Z9R08uU.png)

### Step 3: Installing the Required Version of geemap

Ensure that you have installed geemap version 0.29.3 or later, as only these versions support the automatic authentication feature.

In [None]:
%pip install -U geemap

### Step 4: Automatic Authentication with geemap

To automatically authenticate Earth Engine using the EARTHENGINE_TOKEN in your Google Colab notebook, run the following code:

In [None]:
%%time
# Import required libraries
import tensorflow as tf
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import timm
import matplotlib.pyplot as plt
import seaborn as sns
import shap
from google.cloud import storage
from google.colab import userdata # Import userdata to access secrets
import ee # Import Earth Engine
import os
import logging
from sklearn.metrics import accuracy_score, f1_score

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Authenticate Earth Engine using the secret
try:
    ee.Initialize(project='your-earthengine-project') # Replace with your Earth Engine project ID
    logging.info('Earth Engine authentication successful.')
except Exception as e:
    logging.error(f'Earth Engine authentication failed: {e}')

# Configuration
DATA_DIR = 'gs://your-bucket/esri_patches'  # Replace with your GCS bucket
STATS_PATH = '/content/drive/MyDrive/esri_normalization_stats.npy'  # Path to normalization stats
OUTPUT_DIR = '/content/drive/MyDrive/lulc_experiment'  # Save models and results
BATCH_SIZE = 16
LEARNING_RATE = 0.001
EPOCHS = 20
PATIENCE = 5  # Early stopping threshold
NUM_CLASSES = 9
IMG_SIZE = 224
BANDS = 6  # Blue, Green, Red, NIR, SWIR1, SWIR2
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device: {DEVICE}')

# LULC class names
CLASS_NAMES = ['Water', 'Trees', 'Flooded Vegetation', 'Crops', 'Built Area',
               'Bare Ground', 'Snow/Ice', 'Clouds', 'Rangeland']

# 1. Data Loading and Preprocessing
class ESRIDataset(Dataset):
    def __init__(self, tfrecord_files, transform=None):
        self.tfrecord_files = tfrecord_files
        self.transform = transform
        self.dataset = self._load_tfrecords()

    def _load_tfrecords(self):
        dataset = tf.data.TFRecordDataset(self.tfrecord_files)
        def parse_fn(example):
            feature_desc = {
                'patch': tf.io.FixedLenFeature([IMG_SIZE * IMG_SIZE * BANDS], tf.float32),
                'label': tf.io.FixedLenFeature([], tf.int64)
            }
            example = tf.io.parse_single_example(example, feature_desc)
            patch = tf.reshape(example['patch'], [IMG_SIZE, IMG_SIZE, BANDS])
            label = example['label']
            return patch, label
        return dataset.map(parse_fn).shuffle(1000)

    def __len__(self):
        return sum(1 for _ in self.dataset)

    def __getitem__(self, idx):
        patch, label = next(iter(self.dataset.skip(idx).take(1)))
        patch = patch.numpy()
        label = label.numpy()
        if self.transform:
            patch = self.transform(patch)
        return torch.tensor(patch, dtype=torch.float32).permute(2, 0, 1), label

# Load normalization stats
stats = np.load(STATS_PATH, allow_pickle=True).item()
mean, std = stats['mean'], stats['std']
logging.info(f'Normalization stats - Mean: {mean}, Std: {std}')

# Preprocessing transforms
transform = transforms.Compose([
    transforms.Normalize(mean=mean, std=std)
])

# Load dataset splits
def get_tfrecord_files(data_dir, split='train'):
    client = storage.Client()
    bucket = client.bucket(data_dir.split('gs://')[1])
    blobs = bucket.list_blobs(prefix=f'esri_patches/{split}')
    return [f'gs://{data_dir.split("gs://")[1]}/{blob.name}' for blob in blobs]

train_files = get_tfrecord_files(DATA_DIR, 'train')  # ~18,900 patches
val_files = get_tfrecord_files(DATA_DIR, 'val')      # ~4,050 patches
test_files = get_tfrecord_files(DATA_DIR, 'test')    # ~4,050 patches

train_dataset = ESRIDataset(train_files, transform=transform)
val_dataset = ESRIDataset(val_files, transform=transform)
test_dataset = ESRIDataset(test_files, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, num_workers=2)

# 2. Model Setup
def create_model(model_name, pretrained=True, num_classes=NUM_CLASSES):
    model = timm.create_model(model_name, pretrained=pretrained, num_classes=num_classes)
    return model.to(DEVICE)

# Models to evaluate
models = {
    'resnet50': create_model('resnet50', pretrained=True),
    'vit_base_patch16_224': create_model('vit_base_patch16_224', pretrained=True),
    'vit_large_patch16_224': create_model('vit_large_patch16_224', pretrained=True),
    'swin_small_patch4_window7_224': create_model('swin_small_patch4_window7_224', pretrained=True),
    'swin_large_patch4_window7_224': create_model('swin_large_patch4_window7_224', pretrained=True)
}

# 3. Training Function
def train_model(model, model_name, train_loader, val_loader, fine_tune='full', epochs=EPOCHS):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9)

    # Partial fine-tuning: Freeze lower layers
    if fine_tune == 'partial':
        if 'vit' in model_name:
            for param in model.blocks[:-2].parameters():  # Freeze all but last 2 blocks
                param.requires_grad = False
        elif 'swin' in model_name:
            for param in model.layers[:-1].parameters():  # Freeze all but last layer
                param.requires_grad = False
        elif 'resnet' in model_name:
            for param in model.layer1.parameters():  # Freeze first layer
                param.requires_grad = False
            for param in model.layer2.parameters():
                param.requires_grad = False

    best_val_loss = float('inf')
    patience_counter = 0
    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

        val_loss /= len(val_loader)
        train_loss /= len(train_loader)
        logging.info(f'{model_name} ({fine_tune}) Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            torch.save(model.state_dict(), f'{OUTPUT_DIR}/{model_name}_{fine_tune}.pth')
        else:
            patience_counter += 1
            if patience_counter >= PATIENCE:
                logging.info(f'Early stopping at epoch {epoch+1} for {model_name}_{fine_tune}')
                break

    return model

# 4. Evaluation Function
def evaluate_model(model, test_loader):
    model.eval()
    preds, true_labels = [], []
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            preds.extend(predicted.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())

    accuracy = accuracy_score(true_labels, preds)
    f1 = f1_score(true_labels, preds, average='weighted')
    # mIoU: Simplified for single-label classification
    iou_scores = []
    for cls in range(NUM_CLASSES):
        intersection = sum((np.array(true_labels) == cls) & (np.array(preds) == cls))
        union = sum((np.array(true_labels) == cls) | (np.array(preds) == cls))
        iou_scores.append(intersection / union if union > 0 else 0)
    miou = np.mean(iou_scores)

    return {'accuracy': accuracy, 'f1_score': f1, 'miou': miou}

# 5. Explainability Analysis
def get_attention_maps(model, model_name, images):
    model.eval()
    with torch.no_grad():
        if 'vit' in model_name:
            # ViT: Access attention weights from the last block
            outputs = model.forward_features(images)
            attn_weights = model.blocks[-1].attn.attn_weights  # Shape: (batch, heads, patches, patches)
            return attn_weights.mean(dim=1)  # Average over heads
        elif 'swin' in model_name:
            # Swin: Access attention weights from the last layer
            outputs = model.forward_features(images)
            attn_weights = model.layers[-1].blocks[-1].attn.attn_weights
            return attn_weights.mean(dim=1)
    return None

def plot_attention_map(attn_weights, image, label, class_name):
    attn_map = attn_weights[0].cpu().numpy().reshape(int(IMG_SIZE/16), int(IMG_SIZE/16))
    plt.figure(figsize=(8, 8))
    plt.imshow(image[:3].permute(1, 2, 0).cpu().numpy() * std[:3] + mean[:3], interpolation='nearest')
    plt.imshow(attn_map, cmap='jet', alpha=0.5, extent=(0, IMG_SIZE, IMG_SIZE, 0))
    plt.title(f'Attention Map: {class_name}')
    plt.axis('off')
    plt.savefig(f'{OUTPUT_DIR}/attention_map_{class_name}.png')
    plt.close()

def compute_shap_values(model, images, background_samples=50):
    model.eval()
    explainer = shap.DeepExplainer(model, images[:background_samples])
    shap_values = explainer.shap_values(images)
    return shap_values

# 6. Main Experiment
os.makedirs(OUTPUT_DIR, exist_ok=True)
results = {}
for model_name in models.keys():
    for fine_tune in ['full', 'partial']:
        logging.info(f'Training {model_name} with {fine_tune} fine-tuning...')
        model = create_model(model_name, pretrained=True)
        model = train_model(model, model_name, train_loader, val_loader, fine_tune)
        metrics = evaluate_model(model, test_loader)
        results[f'{model_name}_{fine_tune}'] = metrics
        logging.info(f'{model_name}_{fine_tune}: Accuracy={metrics["accuracy"]:.3f}, F1={metrics["f1_score"]:.3f}, mIoU={metrics["miou"]:.3f}')

        # Explainability: Analyze one test sample
        images, labels = next(iter(test_loader))
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        attn_weights = get_attention_maps(model, model_name, images)
        if attn_weights is not None:
            plot_attention_map(attn_weights, images[0], labels[0], CLASS_NAMES[labels[0]])
        shap_values = compute_shap_values(model, images)
        shap.image_plot(shap_values, images.cpu().numpy() * std + mean, labels=[CLASS_NAMES[l] for l in labels], show=False)
        plt.savefig(f'{OUTPUT_DIR}/shap_{model_name}_{fine_tune}.png')
        plt.close()

# 7. Save and Visualize Results
results_df = pd.DataFrame(results).T
results_df.to_csv(f'{OUTPUT_DIR}/experiment_results.csv')
logging.info(f'Results saved to {OUTPUT_DIR}/experiment_results.csv')

plt.figure(figsize=(12, 6))
sns.barplot(x=results_df.index, y=results_df['accuracy'])
plt.xticks(rotation=45)
plt.title('Model Accuracy Comparison on ESRI 10m LULC Dataset')
plt.ylabel('Accuracy')
plt.tight_layout()
plt.savefig(f'{OUTPUT_DIR}/accuracy_comparison.png')
plt.close()
logging.info(f'Accuracy plot saved to {OUTPUT_DIR}/accuracy_comparison.png')