<a href="https://colab.research.google.com/github/Sanford-Lab/satellite_cnns/blob/main/benin_predictions.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Introduction

This notebook uses a dataset of Benin to train a model to predict which pixels belong to a village. The dataset takes in the feature collection from 'projects/satellite-cnns/assets/benin/voronoi_villages' which has circles around villages in Benin and uses this as a target for prediction. As input, the landsat 7 imagery of Benin is used.

The notebook has the sections:

1.   Setup & General settings
2.   Read data




**NOTES:**

*   The Benin prediction problem is a segmentation task.


> YB: Classification: Currently, the model predicts a pixel-wise classification. We wondered whether this is really necessary since our ground truth is captured in larger patches. Therefore, it might be sufficient to predict one value for a rectangle and take care of the border areas by combining the values of all relevant rectangles. If we use pixel-wise predictions, the model might simply learn how to draw borders. Could you explain why pixel-wise classification makes sense in the case of Benin? Maybe we are overlooking an important aspect also for the Kenyan case…

> DD: [... F]or Benin we need it because we need to both predict treatment status and outcomes, and use residuals from both of those models. For each even a relatively small tile size is much larger than the spatial scale of interest.





# 1. Setup & General settings

## 1.1 Setup

In [1]:
%rm -r /content/satellite_cnns
%cd /content

rm: cannot remove '/content/satellite_cnns': No such file or directory
/content


In [2]:
branch = "main"

In [None]:
# Clone from SPIRES Repo
!git clone --branch {branch} https://github.com/Sanford-Lab/satellite_cnns.git
%cd /content/satellite_cnns

**New**: The notebook now sources files from the Sanford-Lab/satellite_cnns repo. For this notebook, it will use the `benin-data` package. It's built it based on the weather-forcasting notebook patterns and to allow importing for project-specific packages to *plug and play* for data creation. The new patterns should allow the workflow to be much more modular. All a new project would need to do is define 3 main functions `get_inputs_image`, `get_labels_image`, and `sample_points`. The create_dataset.py script should then be able to synthesize the dataset through abstraction. I've kept most of the demonstration functionality the same to show how using the package works.

In [None]:
!pip install --quiet --upgrade pip

# We need `build` and `virtualenv` to build the local packages.
!pip install --quiet build virtualenv

# Install Apache Beam and the `benin-data` local package.
!pip install apache-beam[gcp] src/benin-data

In [5]:
# run to manually restart runtime by ending process
# exit()

## 1.2 General settings

At this point runtime is restarted. Navigate back to our working directory.

In [1]:
%cd /content/satellite_cnns

/content/satellite_cnns


In [2]:
#@title Project settings
from __future__ import annotations

import os
from google.colab import auth

auth.authenticate_user()

# Please fill in these values.
project = "satellite-cnns" #@param {type:"string"}
bucket = "beninbucket" #@param {type:"string"}
location = "us-central1" #@param {type:"string"}

# Quick input validations.
assert project, "⚠️ Please provide a Google Cloud project ID"
assert bucket, "⚠️ Please provide a Cloud Storage bucket name"
assert not bucket.startswith('gs://'), f"⚠️ Please remove the gs:// prefix from the bucket name: {bucket}"
assert location, "⚠️ Please provide a Google Cloud location"

# Authenticate to Colab.
auth.authenticate_user()

# Set GOOGLE_CLOUD_PROJECT for google.auth.default().
os.environ['GOOGLE_CLOUD_PROJECT'] = project

# Set the gcloud project for other gcloud commands.
!gcloud config set project {project}

Updated property [core/project].


In [3]:
import ee
import google.auth

credentials, _ = google.auth.default()
ee.Initialize(
    credentials.with_quota_project(None),
    project=project,
    opt_url="https://earthengine-highvolume.googleapis.com",
)

# 2. Read data


## 2.1 Google Cloud Storage

Let's check Google Cloud Storage to see the files in GC

In [None]:
bucket = 'beninbucket'
folder = 'yb_test'
data_path=f"gs://{bucket}/{folder}/data"
print(data_path)
!gsutil ls -lh {data_path}

Next, let's copy the files to a local directory to look at them

In [None]:
!mkdir -p data-training
!gsutil -m cp {data_path}/* data-training

In [6]:
# Use this to wipe the folder if needed
#%rm -r data-training

## 2.2 Look at dataset

In [None]:
import torch
from read_data import DatasetFromPath, test_train_split

dataset = DatasetFromPath('data-training')

**Note**: to avoid using [Hugging Face 🤗 Datasets](https://huggingface.co/docs/datasets/main/en/index) (like in the weather forcasting sample), we're going to use a custom subclass of PyTorch's `torch.utils.data.Dataset` (`DatsetFromPath`). Hugging Face is nice to use a high-level interface for using datasets and should maybe be implemented in the future, but as of this writing (7/4/2023), VertexAI (what we're using for cloud training) has an issue with it's Hugging Face Trainer API (see [weather sample issue](https://github.com/GoogleCloudPlatform/python-docs-samples/issues/9272)).

### Visualize

Let's grab the dataset (`data`) from our path (`/content/data_training`) and pull the top element from the dataset as `example`. In `DatasetFromPath`, the custom getter utilizes dictionary keys of `inputs` and `labels`, so to grab the inputs of example, we use `example['inputs']`.

Check what was run through the pipeline. You should expect to see:
- Dataset size of 2 * `POINTS_PER_CLASS`
- inputs size of (`PATCH_SIZE`, `PATCH_SIZE`, number of input bands)
- labels size of (`PATCH_SIZE`, `PATCH_SIZE`, number of label bands)

In [None]:
print(f"Dataset size: {len(dataset)}")
example = dataset[0]  # random access the first element

print(f"inputs: {example['inputs'].shape}")
print(f"labels: {example['labels'].shape}")

The `DatsetFromPath` class also allows you to retrieve all of the inputs/labels in a dataset in their raw NumPy array form by indexing with "inputs" or "labels":

In [None]:
inputs = dataset['inputs']
labels = dataset['labels']

print(f"All inputs: {inputs.shape}")
print(f"All labels: {labels.shape}")

Let's view our example using our visualization functionality. For Benin, blue = inside village, red = outside village.

In [None]:
from visualize import show_patch


inputs = example['inputs']
labels = example['labels']

show_patch(inputs, labels)

How let's split the dataset into a train and test subset using the test_train_split function. Test differt ratios to see how the dataset splits and view the first of each.

In [None]:
TEST_TRAIN_RATIO = 0.8
train, test = test_train_split(dataset, ratio=TEST_TRAIN_RATIO)

print(f"Train size: {len(train)}")
train_example = train[0]  # random access the first element
print(f"inputs: {train_example['inputs'].shape}")
print(f"labels: {train_example['labels'].shape}")
print(f"Test size: {len(test)}")
test_example = test[0]  # random access the first element
print(f"inputs: {test_example['inputs'].shape}")
print(f"labels: {test_example['labels'].shape}\n")

train_inputs = train_example['inputs']
train_labels = train_example['labels']
test_inputs = test_example['inputs']
test_labels = test_example['labels']
print(f'Train[0]:')
show_patch(train_inputs, train_labels)
print(f'Test[0]:')
show_patch(test_inputs, test_labels)

# 3. Helper functions

## 3.1 Transformers

In [None]:
def prep_normalize(dataset):
  # Access inputs in dataset
  inputs = dataset['inputs']

  # Reshape to (num_inputs * num_pixels, num_bands)
  reshaped_inputs = inputs.reshape(-1, inputs.shape[-1])

  # Calculate mean, std and max for each band across all inputs
  means = reshaped_inputs.mean(axis=0)
  stds = reshaped_inputs.std(axis=0)

  return {'mean': means.tolist(), 'std': stds.tolist()}

NORM_TRAIN = prep_normalize(train)
print("Train:", NORM_TRAIN)
NORM_TEST = prep_normalize(test)
print("Test:", NORM_TEST)

In [None]:
from torchvision.transforms import v2

# Apply image augmention and adjust labels (e.g. after flipping image)

# Specify transforms for training
transform_train = v2.Compose([
  v2.RandomHorizontalFlip(),  # default value is p=0.5
  v2.RandomVerticalFlip(),    # default value is p=0.5
  v2.Normalize(
      mean=NORM_TRAIN['mean'],
      std=NORM_TRAIN['std']),
  v2.ToTensor()
  ])

# Specify transforms for testing
transform_test = v2.Compose([
  v2.ToTensor(),
  v2.Normalize(
      mean=NORM_TEST['mean'],
      std=NORM_TEST['std'])
  ])

## Data Loader

In [22]:
from torch.utils.data import DataLoader
from torchvision.transforms import v2

def get_loaders(
    path,
    batch_size,
    ratio = TEST_TRAIN_RATIO,
    num_workers = 2,
    pin_memory = True
):

  # Load dataset
  dataset = DatasetFromPath(path)

  # Split into training and testing dataset
  # TO DO: Include Transformers
  train, test = test_train_split(dataset, ratio=TEST_TRAIN_RATIO)
  print(f"Train inputs: {train['inputs'].shape}")
  print(f"Train labels: {train['labels'].shape}")
  print(f"Test inputs: {test['inputs'].shape}")
  print(f"Test labels: {test['labels'].shape}")

  # Initialize data loaders
  train_loader = DataLoader(
      train,
      batch_size = batch_size,
      num_workers = num_workers,
      pin_memory = pin_memory,
      shuffle=True)

  test_loader = DataLoader(
      test,
      batch_size = batch_size,
      num_workers = num_workers,
      pin_memory = pin_memory,
      shuffle=True)

  return train_loader, test_loader