# Autoencoders for Climate Embeddings to Find Future Climate Analogs
In the second part of this week we will be looking at the outputs of an autoencoder model for climate zones around the world.

![climatezones](https://climate-box.com/wp-content/uploads/2021/01/Earths-Climate.png)

## Autoencoders
Autoencoders are a deep learning method to reduce the dimensionality of data that can then be used to reconstruct it.  Autoencoders use artificial neural networks that include 2 parts, an encoder and decoder.  The encoder compresses high-dimensional data into a lower dimensional form, with this encoding learning and describing the latent attributes of the input data (the embeddings).  The decoder in the encoder in reverse, which reconstructs the latent attributes into the original data, with reconstruction loss being used to minimize error.

![autoencoder](https://www.compthree.com/images/blog/ae/ae.png)

## Climate Analogs
For this analysis, we are interested in understanding current and future climate analogs - that is, understanding what locations around the globe have the most similar climate currently, and may under different future scenarios. The work was inspired by [this 2024 paper](https://journals.ametsoc.org/view/journals/aies/3/3/AIES-D-23-0035.1.xml), though we have expanded the analysis to cover the entire globe. For our historical data, we used the historical climate data from [WorldClim v2.1](https://www.worldclim.org/data/worldclim21.html) for the years 1970-2000. We use just 3 variables to describe the climate of each region -- mean annual temperature, isothermality index, and mean annual precipitation. These were chosen from [19 Bioclimatic Variables](https://luizfesser.wordpress.com/2021/03/08/the-19-bioclimatic-variables/).  The annual mean temperature is useful for understanding the approximate energy inputs into a system. Isothermality index is annual mean diurnal range/annual temperature range. The isothermality index quantifies how day-to-night temperatures (diurnal temperature range, mean of monthly maxes/monthly mins) oscillate relative to summer-to-winter (max temp warmest month - min temp coldest month). Along with annual mean precipitation, these three variables are mapped to the three visible light color channels, red for mean temp, green for isothermality, and blue for precipitation. The 3 layers of global climate map were overlaid and the projected globe (8640px x 4320px) was patched into 30,000+ images of dimension 32x32 with 3 channels, one for each climate variable. Each chip is saved as a png. If you want to nerd out about bioclimatic variables, check out [this resource from USGS](https://pubs.usgs.gov/ds/691/ds691.pdf).

## Pretraining
To train the autoencoder, these images were split into training, testing and validation datasets and fed through the autoencoder. Since the data is not "labeled", the task for the autoencoder to to featurize the input images into the latent embedding space of dimension d, and then reconstruct the original image in the decoder. The loss function compares the reconstructed image with the original image. After pretraining the model for 100+ epochs, we saved the autoencoder weights. This was done ahead of time.

## In This Notebook
In the notebook, we will do a few things:
1. Import the dataset and define a custom dataloader
2. Reconstruct the encoder, the decoder, and put them together into an autoencoder.
3. Run the image patches through the encoder to get the embeddings
4. Cluster the embeddings to find similar areas
5. Reduce the dimensionality of the embeddings to visualize them in 2 dimensions

Then, we introduce image patches from a future climate scenario generated by the IPCC, specifically [CMIP6](https://www.carbonbrief.org/cmip6-the-next-generation-of-climate-models-explained/). Future climate projections account for both uncertainty in atmospheric conditions as well as uncertainty in future emissions scenarios called [Shared Socioeconomic Pathways (SSPs)](https://en.wikipedia.org/wiki/Shared_Socioeconomic_Pathways). Here, we use a future snapshot from 2080-2100 of a middle-of-the-road scenario, SSP3-7.0. We take these image patches and run them through the encoder as well to extract the embeddings and compare them with the historical data.

## Learning Objectives
1. Understand how an autoencoder works
2. See how to create a custom dataset in Pytorch
3. Understand how images can be represented as embedding vectors
4. Understand how to reduce the dimensionality of embeddings
5. Visualize and compare embeddings


## Setup Environment

First, let's switch to using the GPU for this notebook. I used the P100, but you can decide for yourself. Next, ensure that you have the [corresponding dataset](https://www.kaggle.com/datasets/isaiahlg/climate-autoencoder) pulled into this notebook. 

In [None]:
# Reading Data
from PIL import Image
import glob

## Data Manip
import os
import numpy as np

## Plotting
import matplotlib
import matplotlib_inline.backend_inline
import matplotlib.pyplot as plt
from matplotlib.colors import to_rgb
import seaborn as sns
from tqdm.notebook import tqdm # progress bar

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torchvision
from torchvision import transforms
import pytorch_lightning as pl

print("imported")

In [None]:
# Configure plotting
sns.reset_orig()
sns.set()
matplotlib.rcParams['lines.linewidth'] = 2.0
matplotlib_inline.backend_inline.set_matplotlib_formats('svg', 'pdf') # for exporting

# Set the random seed
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Set which device to use
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("Using device:", device)

## Load Data

Add the dataset to your inputs.

In [None]:
# see what's in the input directory
print(os.listdir("/kaggle/input/climate-autoencoder/"))

In [None]:
# point to the data
data_path = '/kaggle/input/climate-autoencoder/bc_chips25/bc_chips25'
image_chip_paths = glob.glob(data_path + '/*.png') # all files that match this string

# visualize an image chip
index = 12001
image = Image.open(image_chip_paths[index])
display(image)

Go through a few different images above. What do you think the black chips mean?

In [None]:
# define a pytorch dataset class for our image chips
# these are the three required methods for a pytorch dataset
class MyDataset(Dataset):
    def __init__(self, image_chip_paths, transform):
        self.image_paths = image_chip_paths
        self.transform = transform

    def __getitem__(self, index):
        image = Image.open(self.image_paths[index]) # get an image
        to_tensor = transforms.ToTensor()  # convert to tensor
        x, y = to_tensor(image), to_tensor(image)
        return x, y

    def __len__(self):
        return len(self.image_paths)

# define a transform that normalizes the images
transform = transforms.Compose(
    [transforms.Normalize((0.5,),(0.5,))]
)

# instantiate the dataset
dataset = MyDataset(image_chip_paths, transform)

# split train, val, test
train_dataset, test_dataset, val_dataset = torch.utils.data.random_split(dataset, [0.7, 0.15, 0.15])

# load dataset splits into iterable dataloaders
batch_size = 256 # number of image chips loaded at once
num_workers = 4 # number of CPU threads used to load data
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers)

print('data loaded')

## Define Autoencoder Model

In [None]:
# set training hyperparameter
lr = 1e-3
print(lr)

# define encoder
class Encoder(nn.Module):

    def __init__(self,
                 num_input_channels : int,
                 base_channel_size : int,
                 latent_dim : int,
                 act_fn : object = nn.GELU):
        """
        Inputs:
            - num_input_channels : Number of input channels of the image. For us, this parameter is 3
            - base_channel_size : Number of channels we use in the first convolutional layers. Deeper layers might use a duplicate of it.
            - latent_dim : Dimensionality of latent representation z
            - act_fn : Activation function used throughout the encoder network
        """
        super().__init__()
        c_hid = base_channel_size
        self.net = nn.Sequential(
            nn.Conv2d(num_input_channels, c_hid, kernel_size=3, padding=1, stride=2), # 32x32 => 16x16
            act_fn(),
            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
            act_fn(),
            nn.Conv2d(c_hid, 2*c_hid, kernel_size=3, padding=1, stride=2), # 16x16 => 8x8
            act_fn(),
            nn.Conv2d(2*c_hid, 2*c_hid, kernel_size=3, padding=1),
            act_fn(),
            nn.Conv2d(2*c_hid, 2*c_hid, kernel_size=3, padding=1, stride=2), # 8x8 => 4x4
            act_fn(),
            nn.Flatten(), # Image grid to single feature vector
            nn.Linear(2*16*c_hid, latent_dim)
        )

    def forward(self, x):
        return self.net(x)

print("encoder defined")

In [None]:
class Decoder(nn.Module):

    def __init__(self,
                 num_input_channels : int,
                 base_channel_size : int,
                 latent_dim : int,
                 act_fn : object = nn.GELU):
        """
        Inputs:
            - num_input_channels : Number of channels of the image to reconstruct. For CIFAR, this parameter is 3
            - base_channel_size : Number of channels we use in the last convolutional layers. Early layers might use a duplicate of it.
            - latent_dim : Dimensionality of latent representation z
            - act_fn : Activation function used throughout the decoder network
        """
        super().__init__()
        c_hid = base_channel_size
        self.linear = nn.Sequential(
            nn.Linear(latent_dim, 2*16*c_hid),
            act_fn()
        )
        self.net = nn.Sequential(
            nn.ConvTranspose2d(2*c_hid, 2*c_hid, kernel_size=3, output_padding=1, padding=1, stride=2), # 4x4 => 8x8
            act_fn(),
            nn.Conv2d(2*c_hid, 2*c_hid, kernel_size=3, padding=1),
            act_fn(),
            nn.ConvTranspose2d(2*c_hid, c_hid, kernel_size=3, output_padding=1, padding=1, stride=2), # 8x8 => 16x16
            act_fn(),
            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
            act_fn(),
            nn.ConvTranspose2d(c_hid, num_input_channels, kernel_size=3, output_padding=1, padding=1, stride=2), # 16x16 => 32x32
            nn.Tanh() # The input images is scaled between -1 and 1, hence the output has to be bounded as well
        )

    def forward(self, x):
        x = self.linear(x)
        x = x.reshape(x.shape[0], -1, 4, 4)
        x = self.net(x)
        return x

print("decoder defined")

In [None]:
class Autoencoder(pl.LightningModule):

    def __init__(
            self,
            base_channel_size: int = 32, # add default param
            latent_dim: int = 384, # add default param
            encoder_class : object = Encoder,
            decoder_class : object = Decoder,
            num_input_channels: int = 3,
            width: int = 32,
            height: int = 32
        ):
        super().__init__()
        # Saving hyperparameters of autoencoder
        self.save_hyperparameters()
        # Creating encoder and decoder
        self.encoder = encoder_class(num_input_channels, base_channel_size, latent_dim)
        self.decoder = decoder_class(num_input_channels, base_channel_size, latent_dim)
        # Example input array needed for visualizing the graph of the network
        self.example_input_array = torch.zeros(2, num_input_channels, width, height)
 

    def forward(self, x):
        """
        The forward function takes in an image and returns the reconstructed image
        """
        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat

    def _get_reconstruction_loss(self, batch):
        """
        Given a batch of images, this function returns the reconstruction loss (MSE in our case)
        """
        x, _ = batch # We do not need the labels
        x_hat = self.forward(x)
        loss = F.mse_loss(x, x_hat, reduction="none")
        loss = loss.sum(dim=[1,2,3]).mean(dim=[0])
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=lr)
        # Using a scheduler is optional but can be helpful.
        # The scheduler reduces the LR if the validation performance hasn't improved for the last N epochs
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=0.2,
            patience=20,
            min_lr=5e-5
        )
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}

    def training_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)
        self.log('val_loss', loss)

    def test_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)
        self.log('test_loss', loss)

print("autoencoder defined")

## Load a Pretrained Autoencoder & Extract Embeddings

In [None]:
model_path = "/kaggle/input/climate-autoencoder/model_dict.pt"
# Load the model architecture
model = Autoencoder(base_channel_size=32, latent_dim=384)
# Load the model’s state dictionary
state_dict = torch.load(model_path, weights_only=False)
# Extract the actual state_dict from your saved dictionary
state_dict = state_dict[384]['model'].state_dict()
# Load the state dictionary into the model instance
model.load_state_dict(state_dict)

In [None]:
# Extract the embeddings from the model
def embed_imgs(model, data_loader):
    # Encode all images in the data_laoder using model, and return both images and encodings
    img_list, embed_list = [], []
    model.eval()
    for imgs, _ in tqdm(data_loader, desc="Encoding images", leave=False):
        with torch.no_grad():
            z = model.encoder(imgs.to(model.device))
        img_list.append(imgs)
        embed_list.append(z)
    return (torch.cat(img_list, dim=0), torch.cat(embed_list, dim=0))
print("function defined")

In [None]:
# configure model
model.to(device)
model.eval()

# extract embeddings
train_img_embeds = embed_imgs(model, train_loader)
test_img_embeds = embed_imgs(model, test_loader)
val_img_embeds = embed_imgs(model, val_loader)

print('done')

## Examine the Embeddings
Let's get some examples of the most similar images to our random query images

In [None]:
def find_similar_images(query_img, query_z, key_embeds, K=8):
   # Compute distances
    dist = torch.cdist(query_z[None, :], key_embeds[1], p=2)
    dist = dist.squeeze(dim=0)
    dist, indices = torch.sort(dist)

    # Ensure indices are on the same device as key_embeds[0]
    indices = indices.to(key_embeds[0].device)

    # Select closest images
    imgs_to_display = torch.cat([query_img[None], key_embeds[0][indices[:K]]], dim=0)

    # Create and display grid
    grid = torchvision.utils.make_grid(imgs_to_display, nrow=K+1, normalize=True)
    grid = grid.permute(1, 2, 0)
    plt.figure(figsize=(12, 3))
    plt.imshow(grid)
    plt.axis('off')
    plt.show()

print('defined')

In [None]:
# Plot the closest images for the first N test images as example
for i in range(10):
    find_similar_images(test_img_embeds[0][i], test_img_embeds[1][i], key_embeds=test_img_embeds)

Nice!  For land based pixels we can see that the matching pixels are very similar to the query

## Visualizing Embeddings through Dimensionality Reduction

Our embeddings are in a high-dimensional space, so difficult to visualize in 2D or 3D. There are three common ways to reduce their dimensionality so that they can be visualized with pros and cons to each. [Here's an article](https://aurigait.com/blog/blog-easy-explanation-of-dimensionality-reduction-and-techniques/) that overviews these three techniques, t-SNE, UMAP, and PCA. Now let's use t-SNE to visualize the training image embeddings.

In [None]:
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
print('imported')

In [None]:
tsne = TSNE(n_components=2, verbose=1)

# Move tensor to CPU and convert to NumPy
tsne_proj = tsne.fit_transform(test_img_embeds[1].cpu().numpy())

In [None]:
# plot the embeddings
plt.figure(figsize=(10, 6))
plt.scatter(tsne_proj[:, 0], tsne_proj[:, 1], c=range(len(tsne_proj)), cmap='viridis')
plt.xlabel('t-SNE 1')
plt.ylabel('t-SNE 2')
plt.colorbar(label="Index")
plt.show()

What patterns do you see in the embeddings above? Are nearby areas more similar to each other?

## Compare with Future Climate Data
Now let's add the future climate data (images) to see how they compare.  These are from the GFDL model from CMIP6 SSP3.7

In [None]:
paths2 = '/kaggle/input/climate-autoencoder/gfdl_chips/gfdl_chips'
image_paths2 = glob.glob(paths2 + '/*.png')

dataset2 = MyDataset(image_paths2, transform)

train_dataset2, test_dataset2, valid_dataset2 = torch.utils.data.random_split(dataset2, [0.7, 0.15, 0.15])

batch_size=256
train_loader2 = DataLoader(train_dataset2, batch_size=batch_size, num_workers=2, shuffle=True, drop_last=True)
test_loader2 = DataLoader(test_dataset2, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=2)
val_loader2 = DataLoader(valid_dataset2, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=2)


print('data loaded')

In [None]:
def get_train_images2(num):
    return torch.stack([train_dataset2[i][0] for i in range(num)], dim=0)

def embed_imgs(model, data_loader):
    # Encode all images in the data_laoder using model, and return both images and encodings
    img_list, embed_list = [], []
    model.eval()
    for imgs, _ in tqdm(data_loader, desc="Encoding images", leave=False):
        with torch.no_grad():
            z = model.encoder(imgs.to(model.device))
        img_list.append(imgs)
        embed_list.append(z)
    return (torch.cat(img_list, dim=0), torch.cat(embed_list, dim=0))

print('functions defined')

In [None]:
# get embeddings from new images
train_img_embeds2 = embed_imgs(model, train_loader2)
test_img_embeds2 = embed_imgs(model, test_loader2)
val_img_embeds2 = embed_imgs(model,val_loader2)

print('embeddings extracted')

In [None]:
# reduce dimensions with t-SNE
tsne = TSNE(2, verbose=1)
tsne_proj3 = tsne.fit_transform(test_img_embeds2[1].cpu().numpy())

In [None]:
# plot embeddings
plt.figure(figsize=(10, 6))
plt.scatter(tsne_proj[:, 0], tsne_proj[:, 1], color="blue")
plt.scatter(tsne_proj3[:, 0], tsne_proj3[:, 1], color="red")
plt.xlabel('t-SNE 1')
plt.ylabel('t-SNE 2')
plt.show()

## Acknowledgement
Thank you very much to [Ginni Braich](https://www.linkedin.com/in/ginni-braich-41672337/?originalSubdomain=ca) of the [Better Planet Lab](https://betterplanetlab.com/about) for conceptualizing and putting together much of this notebook.

## Assignment
1. Knowing that the image patches are named `Image_Row#_Col#`, reconstruct an complete image of Earth from the image patches. Alternatively, pull the historical climate data directly from World Clim and visualize it. Does it match your intuition of cliamte across Earth? What does it mean for a region to appear red, green, or blue? What about white?
2. What can we learn by comparing the results of future climate scenarios to the historical ones? Can you name one region of the world that currently has the climate conditions that somewhere else will have in the future?
3. Why is it useful to convert images into embeddings? How could you find climate analogs without doing any deep learning? Would it be possible?

## Bonus Challenges
1. Use another dimensionality reduction tool (UMAP or PCA) to visualize the embeddings and compare it with t-SNE. How is it different? Which do you think is better and why?
2. What if we re-ran the pre-training, but reduced the dimensionality of the embedding space from 384 down to 64 or 16 dimensions. How do you think this would affect the accuracy of the decoder? How would it change the utility of the embeddings?

## Deliverables
Write out your answers to the above questions. Make your notebook public with Save Version > Share > Public > Copy Link and share it in Discord.