In [None]:
%pip install wilds

Collecting wilds
  Downloading wilds-2.0.0-py3-none-any.whl.metadata (22 kB)
Collecting ogb>=1.2.6 (from wilds)
  Downloading ogb-1.3.6-py3-none-any.whl.metadata (6.2 kB)
Collecting outdated>=0.2.0 (from wilds)
  Downloading outdated-0.2.2-py2.py3-none-any.whl.metadata (4.7 kB)
Collecting littleutils (from outdated>=0.2.0->wilds)
  Downloading littleutils-0.2.4-py3-none-any.whl.metadata (679 bytes)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.7.0->wilds)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=1.7.0->wilds)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=1.7.0->wilds)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch>=1.7.0->wilds)
  Using cached nvidia_cudnn

In [None]:
from wilds.datasets.wilds_dataset import WILDSDataset
from wilds import get_dataset
from wilds.common.data_loaders import get_train_loader

In [None]:
import tensorflow as tf

In [None]:
from tensorflow.keras.metrics import RootMeanSquaredError

In [None]:
from functools import partial
import numpy as np

In [None]:
dataset = get_dataset(dataset='poverty', download=True)

Downloading dataset to data/poverty_v1.1...
You can also download the dataset manually at https://wilds.stanford.edu/downloads.
Downloading https://worksheets.codalab.org/rest/bundles/0xfc0aa86ad9af4eb08c42dfc40eacf094/contents/blob/ to data/poverty_v1.1/archive.tar.gz


13091954688Byte [12:06, 18030250.55Byte/s]                               


Extracting data/poverty_v1.1/archive.tar.gz to data/poverty_v1.1

It took 14.56 minutes to download and uncompress the dataset.



In [None]:
train_loader = get_train_loader('standard', dataset, batch_size=32)

In [None]:
for batch in train_loader:
  x, y, metadata = batch
  print(x.shape)
  print(y.shape)
  break

torch.Size([32, 8, 224, 224])
torch.Size([32, 1])


In [None]:
DefaultConv2D = partial(tf.keras.layers.Conv2D, kernel_size=3, padding="same", activation="relu", kernel_initializer="he_normal")

model = tf.keras.Sequential([
    DefaultConv2D(filters=64, kernel_size=7, input_shape=(224, 224, 8)),
    tf.keras.layers.MaxPool2D(),

    DefaultConv2D(filters=128),
    DefaultConv2D(filters=128),
    tf.keras.layers.MaxPool2D(),

    DefaultConv2D(filters=256),
    DefaultConv2D(filters=256),
    tf.keras.layers.MaxPool2D(),

    tf.keras.layers.Flatten(),

    tf.keras.layers.Dense(units=100, kernel_initializer="he_normal", use_bias=False),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Activation("relu"),
    tf.keras.layers.Dropout(0.5),

    tf.keras.layers.Dense(units=100, kernel_initializer="he_normal", use_bias=False),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Activation("relu"),
    tf.keras.layers.Dropout(0.5),

    tf.keras.layers.Dense(units=1)
])

In [None]:
def print_status_bar(step, total, loss, metrics=None):
    metrics = " - ".join([f"{m.name}: {m.result():.4f}"
                          for m in [loss] + (metrics or [])])
    end = "" if step < total else "\n"
    print(f"\r{step}/{total} - " + metrics, end=end)

In [None]:
num_epochs = 10
batch_size = 32
num_steps = len(train_loader)
optimizer = tf.keras.optimizers.Nadam()
loss_function = tf.keras.losses.MeanSquaredError()
mean_loss = tf.keras.metrics.Mean()
metrics = [tf.keras.metrics.RootMeanSquaredError()]

In [None]:
mean_loss.reset_state()

In [None]:
for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")

    for step, batch in enumerate(train_loader):
        x, y, metadata = batch
        x = x.numpy()
        x = np.moveaxis(x, 1, -1)
        y = y.numpy()

        with tf.GradientTape() as tape:
            predictions = model(x)
            loss = loss_function(y, predictions)

        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        mean_loss.update_state(loss)
        for metric in metrics:
            metric(y, predictions)

        print_status_bar(step + 1, num_steps, mean_loss, metrics)

    for metric in metrics:
        metric.reset_state()

    mean_loss.reset_state()

Epoch 1/10
615/615 - mean: 487.3397 - root_mean_squared_error: 22.0819
Epoch 2/10
615/615 - mean: 0.2849 - root_mean_squared_error: 0.5338
Epoch 3/10
615/615 - mean: 0.2417 - root_mean_squared_error: 0.4916
Epoch 4/10
615/615 - mean: 0.2086 - root_mean_squared_error: 0.4567
Epoch 5/10
615/615 - mean: 0.1878 - root_mean_squared_error: 0.4334
Epoch 6/10
615/615 - mean: 0.1804 - root_mean_squared_error: 0.4246
Epoch 7/10
615/615 - mean: 0.1723 - root_mean_squared_error: 0.4151
Epoch 8/10
615/615 - mean: 0.1392 - root_mean_squared_error: 0.3732
Epoch 9/10
615/615 - mean: 0.1293 - root_mean_squared_error: 0.3596
Epoch 10/10
615/615 - mean: 0.1167 - root_mean_squared_error: 0.3416


In [None]:
tf.saved_model.save(model, 'satellite_imagery_model_tf')

In [None]:
%cd /content/
!zip -r satellite_imagery_model_tf.zip satellite_imagery_model_tf

/content
  adding: satellite_imagery_model_tf/ (stored 0%)
  adding: satellite_imagery_model_tf/saved_model.pb (deflated 88%)
  adding: satellite_imagery_model_tf/variables/ (stored 0%)
  adding: satellite_imagery_model_tf/variables/variables.index (deflated 62%)
  adding: satellite_imagery_model_tf/variables/variables.data-00000-of-00001 (deflated 7%)
  adding: satellite_imagery_model_tf/assets/ (stored 0%)
  adding: satellite_imagery_model_tf/fingerprint.pb (stored 0%)
