In [None]:
!pip install --quiet --upgrade pip
!pip install fiona
!pip install rasterio --quiet
!pip install "apache-beam[gcp]>=2.50.0"
!pip install "earthengine-api>=1.5.9"
!pip install "folium>=0.19.5"
!pip install "google-cloud-aiplatform>=1.47.0"
!pip install "imageio>=2.36.1"
!pip install "plotly>=5.15.0"
!pip install "tensorflow>=2.16.1"

In [2]:
from google.colab import auth
import os
import ee
import google.auth
import pandas as pd
import numpy as np
import rasterio
from google.cloud import storage

# AUTHENTICATE
auth.authenticate_user()
project = "saadwetlands"
bucket = "saadwetlands-bucket"
location = "northamerica-northeast1"
os.environ["GOOGLE_CLOUD_PROJECT"] = project
!gcloud config set project {project}
credentials, _ = google.auth.default()

ee.Initialize(
    credentials.with_quota_project(None),
    project=project,
    opt_url="https://earthengine-highvolume.googleapis.com",
)

# DIRECTORIES
!mkdir -p /content/data/
source_path = f"gs://{bucket}/pixel_data/"
local_path = "/content/data/"
print(f"Downloading all .tif files from: {source_path}")

# DOWNLOAD FILES
!gsutil -m cp "{source_path}*.tif" {local_path}
print("\nDownloaded files:")
!ls -lh /content/data/*.tif


INFORMATION: Project 'saadwetlands' has no 'environment' tag set. Use either 'Production', 'Development', 'Test', or 'Staging'. Add an 'environment' tag using `gcloud resource-manager tags bindings create`.
Updated property [core/project].
Bucket: gs://saadwetlands-bucket/
Downloading all .tif files from: gs://saadwetlands-bucket/pixel_data/
This may take a while depending on file sizes...
Copying gs://saadwetlands-bucket/pixel_data/ON.tif...
Copying gs://saadwetlands-bucket/pixel_data/MB.tif...
Copying gs://saadwetlands-bucket/pixel_data/NSBOX_LABELS.tif...
Copying gs://saadwetlands-bucket/pixel_data/NSBOX_64.tif...
Copying gs://saadwetlands-bucket/pixel_data/BC.tif...
Copying gs://saadwetlands-bucket/pixel_data/ONBOX_64.tif...
Copying gs://saadwetlands-bucket/pixel_data/ONBOX_LABELS.tif...
| [7/7 files][  1.2 GiB/  1.2 GiB] 100% Done  76.7 MiB/s ETA 00:00:00           
Operation completed over 7 objects/1.2 GiB.                                      

Download complete!

Downloaded fi

# GET RANDOM POINTS


In [41]:
import rasterio
import numpy as np

region = "ON"
tiff = f'/content/data/{region}.tif'

# GET ALL LABELS WITHIN a .TIF FILE
print(f"{region} data:")
with rasterio.open(tiff) as src:
    width = src.width
    height = src.height
    bands = src.count
    dtype = src.profile['dtype']
    bounds = src.bounds

    print(f"  Width: {width} pixels")
    print(f"  Height: {height} pixels")
    print(f"  Bands: {bands}")
    print(f"  Data Type: {dtype}")
    print(f"  Geographical Bounding Box: {bounds}")
    print(f"  CRS: {src.crs}")

    labels_array = src.read(1)
    unique_labels, counts = np.unique(labels_array[np.isfinite(labels_array)], return_counts=True)

    print("\n Labels and counts in ON.tif:")
    for label, count in zip(unique_labels, counts):
        print(f"  Label {int(label)}: {count:,} pixels")
    total_pixels = np.sum(counts)
    print(f"Total labeled pixels: {total_pixels:,}")


ON data:
  Width: 6034 pixels
  Height: 16300 pixels
  Bands: 1
  Data Type: int32
  Geographical Bounding Box: BoundingBox(left=-84.10009723621279, bottom=47.210330936202574, right=-82.47396690889963, top=51.60309267554703)
  CRS: EPSG:4326

 Labels and counts in ON.tif:
  Label 0: 16,099,851 pixels
  Label 1: 12,364,188 pixels
  Label 2: 1,149,204 pixels
  Label 3: 21,506,132 pixels
  Label 4: 2,669 pixels
  Label 5: 3,569,203 pixels
  Label 6: 42,325,005 pixels
  Label 7: 195,949 pixels
  Label 8: 11,600 pixels
  Label 9: 17,329 pixels
  Label 10: 581,764 pixels
  Label 11: 31,916 pixels
  Label 12: 4,680 pixels
  Label 13: 494,710 pixels
Total labeled pixels: 98,354,200


In [49]:
import rasterio
import random
import math
from shapely.geometry import Point, Polygon
from collections import defaultdict
import numpy as np
import pandas as pd
import os


def get_points(polygon, tiff_path, cdict, hsize=64, seed=42, attempts=10000):
    random.seed(seed)
    sampled_points = []

    with rasterio.open(tiff_path) as src:
        arr = src.read(1)
        height, width = arr.shape
        transform = src.transform
        crs = src.crs
        minx, miny, maxx, maxy = polygon.bounds

        # Safe pixel bounds for patch extraction
        col_min, row_max = ~transform * (minx, maxy)
        col_max, row_min = ~transform * (maxx, miny)
        px_min, px_max = int(math.ceil(col_min + hsize)), int(math.floor(col_max - hsize))
        py_min, py_max = int(math.ceil(row_max + hsize)), int(math.floor(row_min - hsize))
        px_min, py_min = max(px_min, hsize), max(py_min, hsize)
        px_max, py_max = min(px_max, width - 1 - hsize), min(py_max, height - 1 - hsize)

        if px_min > px_max or py_min > py_max:
            print(f"[WARN] Polygon too small for {2*hsize}px patches in {tiff_path}")
            return []

        # Filter valid classes
        valid_labels = set(np.unique(arr[np.isfinite(arr)]))
        quotas = {k: int(v) for k, v in cdict.items() if v > 0 and k in valid_labels}

        if not quotas:
            print(f"[WARN] No valid class quotas for {tiff_path}")
            return []

        # Random sampling
        for _ in range(attempts):
            if all(v <= 0 for v in quotas.values()):
                break

            px, py = random.randint(px_min, px_max), random.randint(py_min, py_max)
            x, y = transform * (px, py)

            if not polygon.contains(Point(x, y)):
                continue

            label = int(arr[py, px])
            if label in quotas and quotas[label] > 0:
                sampled_points.append((x, y, label, crs, tiff_path))
                quotas[label] -= 1

        # Report unmet quotas
        unmet = {k: v for k, v in quotas.items() if v > 0}
        if unmet:
            print(f"[WARN] Unmet quotas in {tiff_path}: {unmet}")

    return sampled_points


tiff_info = [
    ("/content/data/BC.tif", {0: 0, 1: 0, 2: 0, 3: 0, 4: 250}),
    ("/content/data/MB.tif", {0: 0, 1: 0, 2: 500, 3: 250, 4: 250}),
    ("/content/data/ON.tif", {0: 500, 1: 500, 2: 0, 3: 250, 4: 0}),
]

hsize, attempts = 64, 200000
output_dir = "/content/"

all_points = []

for path, cdict in tiff_info:
    with rasterio.open(path) as src:
        poly = Polygon.from_bounds(*src.bounds)
    points = get_points(poly, path, cdict, hsize, seed=42, attempts=attempts)
    all_points.extend(points)

df = pd.DataFrame(all_points, columns=["world_x", "world_y", "label", "src_crs", "source_file"])
print(f"\nTotal points sampled: {len(df)}")

for file, group in df.groupby("source_file"):
    csv_path = os.path.join(output_dir, f"{os.path.basename(file)}_points.csv")
    group[["world_x", "world_y", "label"]].to_csv(csv_path, index=False)
    print(f"{csv_path} : {len(group)} points")




Total points sampled: 2250
/content/BC.tif_points.csv : 250 points
/content/MB.tif_points.csv : 750 points
/content/ON.tif_points.csv : 1250 points


In [6]:
!gsutil cp /content/ON.tif_points.csv gs://saadwetlands-bucket/pixel_data
!gsutil cp /content/MB.tif_points.csv gs://saadwetlands-bucket/pixel_data
!gsutil cp /content/BC.tif_points.csv gs://saadwetlands-bucket/pixel_data

Copying file:///content/ON.tif_points.csv [Content-Type=text/csv]...
/ [1 files][ 47.7 KiB/ 47.7 KiB]                                                
Operation completed over 1 objects/47.7 KiB.                                     
Copying file:///content/MB.tif_points.csv [Content-Type=text/csv]...
/ [1 files][ 28.8 KiB/ 28.8 KiB]                                                
Operation completed over 1 objects/28.8 KiB.                                     
Copying file:///content/BC.tif_points.csv [Content-Type=text/csv]...
/ [1 files][  5.2 KiB/  5.2 KiB]                                                
Operation completed over 1 objects/5.2 KiB.                                      


# SERIALIZE TENSERFLOW


In [38]:
import apache_beam as beam
from apache_beam.options.pipeline_options import (
    PipelineOptions, StandardOptions, SetupOptions, WorkerOptions
)
import tensorflow as tf
import ee
import numpy as np
import rasterio
from rasterio.windows import Window
import tempfile
import urllib.request
import pandas as pd
import os
import subprocess
import logging


project = "saadwetlands"
bucket = "saadwetlands-bucket"
YEAR = 2024
PATCH_SIZE = 128
PIXEL_SIZE = 30

REGION = "BC"
label_tif = f"gs://{bucket}/pixel_data/{REGION}.tif"
TFRECORD_OUTPUT_PREFIX = f"gs://{bucket}/output/patches/{REGION}/"
csv_path = f"gs://{bucket}/pixel_data/{REGION}.tif_points.csv"

KEY_GCS_PATH = f"gs://{bucket}/keys/ee-dataflow-worker-key.json"
LOCAL_KEY_PATH = "/tmp/ee-dataflow-worker-key.json"

logging.basicConfig(level=logging.INFO)
_LOGGER = logging.getLogger(__name__)

def get_label_patch(lon, lat):
    """Read 128x128 patch from label GeoTIFF"""
    try:
        with rasterio.Env():
            with rasterio.open(label_tif) as src:
                row, col = src.index(lon, lat)
                half = PATCH_SIZE // 2
                window = Window(col - half, row - half, PATCH_SIZE, PATCH_SIZE)
                patch = src.read(1, window=window)
        return patch
    except Exception as e:
        _LOGGER.error(f"Failed to get label patch for ({lon},{lat}): {e}")
        return None

def fetch_embedding_patch(embedding_image, lon, lat):
    """Download embedding patch from Earth Engine"""
    try:
        bands = embedding_image.bandNames()
        image = embedding_image.select(bands)
        point = ee.Geometry.Point([lon, lat])
        region = point.buffer(PIXEL_SIZE * PATCH_SIZE / 2).bounds()
        url = image.getDownloadURL({
            "region": region.getInfo()['coordinates'],
            "dimensions": [PATCH_SIZE, PATCH_SIZE],
            "format": "GEO_TIFF"
        })
        with tempfile.NamedTemporaryFile(suffix=".tif", delete=False) as tmpfile:
            urllib.request.urlretrieve(url, tmpfile.name)
            with rasterio.open(tmpfile.name) as src:
                patch = np.transpose(src.read(), (1, 2, 0))
        os.remove(tmpfile.name)
        return patch
    except Exception as e:
        _LOGGER.error(f"Failed to fetch embedding for ({lon},{lat}): {e}")
        return None

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def serialize_example(label_patch, embedding_patch):
    feature = {
        'embedding': _bytes_feature(embedding_patch.astype(np.float32).tobytes()),
        'label': _bytes_feature(label_patch.astype(np.int32).tobytes()),
        'height': _int64_feature(label_patch.shape[0]),
        'width': _int64_feature(label_patch.shape[1]),
        'bands': _int64_feature(embedding_patch.shape[2])
    }
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

class GeneratePatchDoFn(beam.DoFn):
    def setup(self):
        import ee
        import os
        import tempfile
        from google.cloud import storage

        # Worker local path for the key
        key_path = '/tmp/ee-dataflow-worker-key.json'

        # Download key from GCS
        client = storage.Client()
        bucket = client.bucket('saadwetlands-bucket')
        blob = bucket.blob('keys/ee-dataflow-worker-key.json')
        blob.download_to_filename(key_path)

        # Initialize Earth Engine with the service account credentials
        credentials = ee.ServiceAccountCredentials(
            'ee-dataflow-worker@saadwetlands.iam.gserviceaccount.com',
            key_path
        )
        ee.Initialize(credentials)

        # Load embedding image
        self.embedding_image = ee.ImageCollection("GOOGLE/SATELLITE_EMBEDDING/V1/ANNUAL") \
                                .filterDate(f"{YEAR}-01-01", f"{YEAR}-12-31") \
                                .mosaic()
    def process(self, row):
        lon, lat = float(row['world_x']), float(row['world_y'])
        _LOGGER.info(f"Processing point: ({lon}, {lat})")

        label_patch = get_label_patch(lon, lat)
        embedding_patch = fetch_embedding_patch(self.embedding_image, lon, lat)

        if label_patch is None or embedding_patch is None:
            _LOGGER.warning(f"Skipping point ({lon},{lat}) due to missing data.")
            return

        _LOGGER.info(f"Generated patches for ({lon}, {lat})")
        yield serialize_example(label_patch, embedding_patch)

options = PipelineOptions(
    runner='DataflowRunner',
    project=project,
    job_name='wetland-patch-pipeline',
    staging_location=f'gs://{bucket}/staging',
    temp_location=f'gs://{bucket}/temp',
    region='us-central1',
    requirements_file='requirements.txt'
)

# Worker configuration
worker_options = options.view_as(WorkerOptions)
worker_options.machine_type = 'n1-standard-4'
worker_options.num_workers = 5
worker_options.max_num_workers = 20

options.view_as(StandardOptions).runner = 'DataflowRunner'
options.view_as(SetupOptions).save_main_session = True

# Read CSV from GCS
df = pd.read_csv(csv_path)
rows_list = df.to_dict(orient='records')

# Run Beam Pipeline
with beam.Pipeline(options=options) as pipeline:
    (
        pipeline
        | "Create Rows" >> beam.Create(rows_list)
        | "Generate Patches" >> beam.ParDo(GeneratePatchDoFn())
        | "Write TFRecords" >> beam.io.WriteToTFRecord(
            file_path_prefix=TFRECORD_OUTPUT_PREFIX,
            file_name_suffix=".tfrecord",
            coder=beam.coders.BytesCoder()
        )
    )

print("Pipeline submitted! Check the Dataflow console for progress.")






Pipeline submitted! Check the Dataflow console for progress.
