In [None]:
%load_ext autoreload
%autoreload 2
%cd ..

In [None]:
import os

from tqdm import tqdm
import pandas as pd
import numpy as np
import torch

from src.dataset_utils import (
    read_satellite_image,
    get_bioclimatic_time_series_cube,
    get_satellite_time_series_landsat_cube,
    read_environmental_values,
    get_environmental_values_tensor
)

from transformers import CLIPProcessor, CLIPModel

train_data = pd.read_csv("data/GLC25_PA_metadata_train.csv")
test_data = pd.read_csv("data/GLC25_PA_metadata_test.csv")

def preprocessing(data):

    environmental_data = read_environmental_values()
    environmental_data = environmental_data.drop(columns=["surveyId"])
    # change NaN and inf to median of the column
    for column in environmental_data.columns:
        median = environmental_data[column].median()
        environmental_data.loc[environmental_data[column].isna(), column] = median
        environmental_data.loc[environmental_data[column] == np.inf, column] = median

    # normalize each column of the environmental_data dataframe
    for column in environmental_data.columns:
        environmental_data[column] = (environmental_data[column] - environmental_data[column].mean()) / environmental_data[column].std()

    table_data_meta_columns = ['lon', 'lat', 'year', 'geoUncertaintyInM', 'areaInM2', 'region', 'country']
    def get_text(row):
        text = ""
        for column in table_data_meta_columns:
            text += f"{column}: {row[column]}\n"
        return text

    data["text"] = data.apply(get_text, axis=1)
    return environmental_data, data

environmental_train_data, train_data = preprocessing(train_data)
environmental_test_data, test_data = preprocessing(test_data)
train_data["speciesId"] = train_data["speciesId"].astype(int)

In [None]:
train_unique_texts = list(train_data.text.unique())
test_unique_texts = list(test_data.text.unique())

device = "mps"

model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14", local_files_only=True).to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14", local_files_only=True, use_fast=True)

# Get text embeddings via CLIP

In [None]:
def get_embeddings(unique_texts, split: str, batch_inference_size = 128):

    fp_save_path = f"data/{split}_text_embeddings.pt"

    if not os.path.exists(fp_save_path):
        text_embeddings = []
        for i in tqdm(range(0, len(unique_texts), batch_inference_size)):
            batch_texts = unique_texts[i:i+batch_inference_size]
            inputs = processor(text=batch_texts, return_tensors="pt", padding=True)
            inputs = {k: v.to(device) for k, v in inputs.items()}
            text_features = model.get_text_features(**inputs)
            text_embeddings.append(text_features.detach().cpu())

        embeddings_torch = torch.cat(text_embeddings, dim=0)
        torch.save(embeddings_torch, fp_save_path)

    else:
        embeddings_torch = torch.load(fp_save_path)
    return embeddings_torch

train_text_embeddings = get_embeddings(train_unique_texts, "train")
test_text_embeddings = get_embeddings(test_unique_texts, "test")

In [None]:
# For each surveyId get a list of speciesIds
survey_id2species_ids = train_data.groupby('surveyId')['speciesId'].apply(list).to_dict()

species_ids = train_data.speciesId.unique()
species_id2label = {species_id.item(): i for i, species_id in enumerate(species_ids)}
label2species_id = {i: species_id.item() for i, species_id in enumerate(species_ids)}

n_classes = len(species_ids)
n_classes  # 5016

In [None]:
len(train_data.speciesId.unique())

# Get image embeddings via CLIP

In [None]:
# image_embeddings = []
# for survey_id in tqdm(train_data.surveyId.unique()): # TEST MODE
#     image = read_satellite_image(survey_id) 
#     inputs = processor(images=[image[:3, :, :]], return_tensors="pt")
#     inputs = {k: v.to(device) for k, v in inputs.items()}
#     outputs = model.get_image_features(**inputs)
#     image_embeddings.append(outputs.detach().cpu())

In [46]:
# Task: multi-label prediction
# Input:  (4, 64, 64) + (4, 19, 12) + (6, 4, 21) + (65,)
# Output: (n_classes,)
# Loss: BCEwithLogitsLoss + assymetric loss

# CLIP   -> v1 (512, )  -| cross_attention(CLIP, LSTM) -> (512, )  Q=CLIP, K=LSTM, V=LSTM
# LSTM 1 -> v2 (512, )  -| cross attention(CLIP, LSTM) -> (512, )  Q=LSTM, K=CLIP, V=CLIP
# LSTM 2 -> v3 (512, )
# MLP    -> v4 (512, )

# v1 | v2 | v3 | v4 -> (2048, )
# MLP 2048 -> (num_classes,)

def get_X_y(data, environmental_values, text_embeddings, split_folder):
    unique_survey_ids = data.surveyId.unique()

    # survey_id -> geo -> [labels]
    X_all, y_all = [], []

    for idx, survey_id in tqdm(enumerate(unique_survey_ids), total=len(unique_survey_ids)): # TEST MODE

        X, y = [], []
        # IMAGE --> CNN / ViT on the first 3 channels + custom CNN on the last channel (NIR)
        image = read_satellite_image(survey_id, split_folder)                                       # shape: (4, 64, 64)  [torch.int16]

        # TIME SERIES --> LSTM
        # 4 variables
        # 19 years, 12 months
        # 4 times series: 19 * 12 = 228
        bioclimatic_time_series_cube = get_bioclimatic_time_series_cube(survey_id, split_folder)                   # shape: (4, 19, 12)  [torch.float32]

        # TIME SERIES --> LSTM  (but there are some RGB channels involved -- maybe it should be treated as a picture)
        # 6 variables
        # 4 = four times per year (after each season)
        # 21 years
        # 6 time series: 4 * 21 = 84
        satellite_time_series_landsat_cube = get_satellite_time_series_landsat_cube(survey_id, split_folder)       # shape: (6, 4, 21)   [torch.float32]
        # set all nan and inf to 0
        satellite_time_series_landsat_cube[satellite_time_series_landsat_cube != satellite_time_series_landsat_cube] = 0

        # MLP ???
        # 65 variables
        # 1 variable
        # 65 time series: 1 * 65 = 65
        # environmental_values = get_environmental_values_tensor(survey_id, environmental_data)  # shape: (1, 65)      [torch.float64]

        X = {
            "satellite_image_RGB": image[:3, :, :],
            "satellite_image_NIR": torch.tensor(image[3, :, :]),  # shape (1, 64, 64)

            "bioclimatic_time_series_cube": bioclimatic_time_series_cube.reshape(4 * 19, -1),
            "satellite_time_series_landsat_cube": satellite_time_series_landsat_cube.reshape(6 * 4, -1),
            "environmental_values": torch.tensor(environmental_values[idx], dtype=torch.float32),
            "text_embedding": text_embeddings[idx]
        }
        X_all.append(X)

        if split_folder == "PA-train":
            y = survey_id2species_ids[survey_id]
            y_all.append(y)

    return X_all, y_all

X_test, _ = get_X_y(test_data, environmental_test_data.values, test_text_embeddings, "PA-test")
X_train, y_train = get_X_y(train_data, environmental_train_data.values, train_text_embeddings, "PA-train")

100%|██████████| 14784/14784 [00:51<00:00, 284.56it/s]
100%|██████████| 88987/88987 [04:08<00:00, 358.80it/s]


In [47]:
# Normalize bioclimatic_time_series_cube: this is 4 time series each of 19 * 12 = 228 values
def normalize_bioclimatic_time_series(X_all):
    # Stack all cubes for joint normalization
    cubes = [x["bioclimatic_time_series_cube"] for x in X_all]
    cubes_stacked = torch.stack(cubes)  # shape: (num_samples, 4, 228)
    # Compute mean and std over all samples, per-channel (over axis 0)
    mean = cubes_stacked.mean(dim=0, keepdim=True)  # shape: (1, 4, 228)
    std = cubes_stacked.std(dim=0, keepdim=True) + 1e-8  # to avoid div by zero
    # Apply normalization
    for idx, x in enumerate(X_all):
        x["bioclimatic_time_series_cube"] = (x["bioclimatic_time_series_cube"] - mean.squeeze(0)) / std.squeeze(0)
    return X_all

X_train = normalize_bioclimatic_time_series(X_train)
X_test = normalize_bioclimatic_time_series(X_test)

def normalize_satellite_time_series(X_all):
    cubes = [x["satellite_time_series_landsat_cube"] for x in X_all]
    cubes_stacked = torch.stack(cubes)  # shape: (num_samples, 6 * 4, 21)

    mean = cubes_stacked.mean(dim=0, keepdim=True)  # shape: (1, 6 * 4, 21)
    std = cubes_stacked.std(dim=0, keepdim=True) + 1e-8  # to avoid div by zero
    for idx, x in enumerate(X_all):
        x["satellite_time_series_landsat_cube"] = (x["satellite_time_series_landsat_cube"] - mean.squeeze(0)) / std.squeeze(0)
    return X_all

X_train = normalize_satellite_time_series(X_train)
X_test = normalize_satellite_time_series(X_test)

In [48]:
y_train_multi_hot_encoded = []

for y_sample in tqdm(y_train):
    y_sample_class_labels = [
        species_id2label[l] for l in y_sample
    ]
    multi_hot_encoded = [0 for _ in range(n_classes)]
    for l in y_sample_class_labels:
        multi_hot_encoded[l] = 1
    
    y_train_multi_hot_encoded.append(multi_hot_encoded)

100%|██████████| 88987/88987 [00:07<00:00, 11822.56it/s]


In [71]:
import torch.nn as nn

class SimpleCNN(nn.Module):
    def __init__(self, in_channels=1, out_features=32):
        super(SimpleCNN, self).__init__()
        # Accept variable input channels for CNN
        self.cnn = nn.Sequential(
            nn.Conv2d(in_channels, 8, kernel_size=3, stride=2, padding=1),   # (batch, 8, 32, 32)
            nn.BatchNorm2d(8),
            nn.ReLU(),
            nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=1),  # (batch, 16, 16, 16)
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # (batch, 32, 8, 8)
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))                           # (batch, 32, 1, 1)
        )
        self.out_features = out_features
        self.projector = nn.Linear(32, out_features)

    def forward(self, x):
        x = x.to(dtype=torch.float32)
        # Accept only (batch, C, H, W), do not squeeze channels unless necessary.
        # If x is (batch, H, W), unsqueeze to get 1 channel
        if x.dim() == 3:
            x = x.unsqueeze(1)
        x = self.cnn(x)
        x = x.view(x.size(0), -1)
        x = self.projector(x)
        return x

class MultiModalModel(nn.Module):
    def __init__(self, num_classes):
        super(MultiModalModel, self).__init__()
        self.num_classes = num_classes

        self.bioclim_lstm_input_size = 12  # Each time series is of length 12
        self.bioclim_lstm_hidden_size = 32  # Reduced hidden size
        self.bioclim_lstm_num_layers = 4

        # Smaller LSTM for the bioclimatic time series
        self.bioclim_lstm = nn.LSTM(
            input_size=self.bioclim_lstm_input_size,
            hidden_size=self.bioclim_lstm_hidden_size,
            num_layers=self.bioclim_lstm_num_layers,
            batch_first=True
        )

        self.satellite_lstm_input_size = 21 
        self.satellite_lstm_hidden_size = 32
        self.satellite_lstm_num_layers = 2

        self.satellite_lstm = nn.LSTM(
            input_size=self.satellite_lstm_input_size,
            hidden_size=self.satellite_lstm_hidden_size,
            num_layers=self.satellite_lstm_num_layers,
            batch_first=True
        )

        self.text_embeddings_projector = nn.Linear(768, 32)
        self.environmental_values_projector = nn.Linear(64, 32)

        # Now we will use CNN output in combined features too
        self.combined_features = nn.Linear(32 + 32 + self.bioclim_lstm_hidden_size + self.satellite_lstm_hidden_size + 32, 64)
        self.classifier = nn.Linear(64, num_classes)

        # Determine the correct number of channels for satellite image input at runtime in forward
        self.satellite_image_cnn = None  # Lazy initialization in forward

    def forward(self, x):
        text_embeddings = torch.stack([
            x_i["text_embedding"] for x_i in x
        ]).to(device, dtype=torch.float32)
        text_embeddings_projected = self.text_embeddings_projector(text_embeddings)

        environmental_values = torch.stack([
            x_i["environmental_values"] for x_i in x
        ]).to(device, dtype=torch.float32)
        environmental_values_projected = self.environmental_values_projector(environmental_values)

        bioclimatic_time_series_cube = torch.stack([
            x_i["bioclimatic_time_series_cube"]
            for x_i in x
        ]).to(device, dtype=torch.float32)  # shape: (batch_size, 4, 228)

        satellite_time_series_landsat_cube = torch.stack([
            x_i["satellite_time_series_landsat_cube"]
            for x_i in x
        ]).to(device, dtype=torch.float32)  # shape: (batch_size, 6 * 4, 21)

        # Pass through LSTM
        lstm_out, (h_n, c_n) = self.satellite_lstm(satellite_time_series_landsat_cube)
        satellite_lstm_features = lstm_out[:, -1, :]  # shape (batch, satellite_lstm_hidden_size)

        # Pass through LSTM
        lstm_out, (h_n, c_n) = self.bioclim_lstm(bioclimatic_time_series_cube)
        bioclim_lstm_features = h_n[-1]  # shape (batch, bioclim_lstm_hidden_size)

        # Stack all satellite images. Accept multiple channels if present.
        satellite_images = torch.stack([
            x_i["satellite_image_NIR"]
            for x_i in x
        ]).to(device, dtype=torch.float32)  # shape: (batch_size, C, 64, 64) or (batch_size, 1, 64, 64)
        # If satellite_images is (batch, 64, 64), add channel dim
        if satellite_images.dim() == 3:
            satellite_images = satellite_images.unsqueeze(1)
        in_channels = satellite_images.size(1)

        # Lazy initialize CNN to handle correct input channels
        if (self.satellite_image_cnn is None) or \
           (hasattr(self.satellite_image_cnn, 'cnn') and getattr(self.satellite_image_cnn.cnn[0], 'in_channels', 1) != in_channels):
            self.satellite_image_cnn = SimpleCNN(in_channels=in_channels, out_features=32).to(device)

        satellite_image_features = self.satellite_image_cnn(satellite_images)  # shape: (batch_size, 32)

        # Concatenate features with CNN features included
        concat = torch.cat([
                text_embeddings_projected, 
                environmental_values_projected,
                satellite_lstm_features, 
                bioclim_lstm_features,
                satellite_image_features
            ], dim=1)
        features = self.combined_features(concat)
        logits = self.classifier(features)

        return logits

In [72]:
from torch.utils.data import DataLoader, TensorDataset
from torch.optim import Adam

criterion = nn.BCEWithLogitsLoss()
# criterion = AsymmetricLoss()
mmm = MultiModalModel(num_classes=len(species_ids)).to(device)
optimizer = Adam(mmm.parameters(), lr=0.001)

In [73]:
X_train_split, y_train_multi_hot_encoded_split, y_train_split = X_train[:88_000], y_train_multi_hot_encoded[:88_000], y_train[:88_000]
X_val_split, y_val_multi_hot_encoded_split, y_val_split = X_train[88_000:], y_train_multi_hot_encoded[88_000:], y_train[88_000:]

In [74]:
# EVAL
def make_predictions(torch_model, subset):

    test_inference_batch_size = 256

    predictions = []
    with torch.no_grad():
        n_test = len(subset)
        for step in range(0, n_test, test_inference_batch_size):
            batch_X = subset[step : step + test_inference_batch_size]
            batch_logits = torch_model(batch_X)
            batch_probs = torch.sigmoid(batch_logits)
            # take top 25 by probability
            # batch_preds = [
            #     torch.topk(probs, 25).indices.tolist()
            #     for probs in batch_probs
            # ]
            batch_preds = [
                torch.where(probs > 0.1)[0].tolist() or torch.topk(probs, 25).indices.tolist()
                for probs in batch_probs
            ]
            predictions.extend(batch_preds)

    for idx, p in enumerate(predictions):
        predictions[idx] = [
            label2species_id[label] for label in p
        ]

    return predictions

def eval_haccard(predictions, targets):
    score = 0
    for pred, ground_truth in zip(predictions, targets):
        score += len(set(pred) & set(ground_truth)) / len(set(pred) | set(ground_truth))
    score /= len(predictions)
    return score

def eval_f1_micro(predictions, targets):
    # predictions[i] is a list of labels
    # targets[i] is a list of labels
    # we need to compute the TP, FP, FN for each pair (pred, target) and then compute the f1 micro score
    TP = []
    FP = []
    FN = []
    for pred, target in zip(predictions, targets):
        TP.append(len(set(pred) & set(target)))
        FP.append(len(set(pred) - set(target)))
        FN.append(len(set(target) - set(pred)))
    return 1 / len(predictions) * sum(TP[i] / (TP[i] + (FP[i] + FN[i]) / 2) for i in range(len(predictions)))

In [75]:
n_epochs = 1
batch_size = 128
eval_each_n_steps = 600
for epoch in range(n_epochs):
    for step in range(len(X_train_split) // batch_size):
        batch_X = X_train_split[step * batch_size:(step + 1) * batch_size]
        batch_y = torch.tensor(y_train_multi_hot_encoded_split[step * batch_size:(step + 1) * batch_size], dtype=torch.float32).to(device)

        logits = mmm(batch_X)
        loss = criterion(logits, batch_y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Epoch {epoch + 1} | Step {step + 1}/{len(X_train) // batch_size} | Loss: {loss.item()}")
        if step % eval_each_n_steps == 0:
            mmm.eval()
            train_preds = make_predictions(mmm, X_train_split[:2000])
            train_haccard = eval_haccard(train_preds, y_train_split[:2000])
            train_f1_micro = eval_f1_micro(train_preds, y_train_split[:2000])

            val_preds = make_predictions(mmm, X_val_split)
            val_haccard = eval_haccard(val_preds, y_val_split)
            val_f1_micro = eval_f1_micro(val_preds, y_val_split)
            print(f"Train HACCARD: {train_haccard} | Train F1 Micro: {train_f1_micro} | Val HACCARD: {val_haccard} | Val F1 Micro: {val_f1_micro}")
            mmm.train()

Epoch 1 | Step 1/695 | Loss: 0.6948239207267761
Train HACCARD: 0.0031851076555024565 | Train F1 Micro: 0.006342824775877274 | Val HACCARD: 0.0032164550641594716 | Val F1 Micro: 0.0064050267009411535
Epoch 1 | Step 2/695 | Loss: 0.6913779377937317
Epoch 1 | Step 3/695 | Loss: 0.6859808564186096
Epoch 1 | Step 4/695 | Loss: 0.678429365158081
Epoch 1 | Step 5/695 | Loss: 0.6673648953437805
Epoch 1 | Step 6/695 | Loss: 0.652824878692627
Epoch 1 | Step 7/695 | Loss: 0.634341299533844
Epoch 1 | Step 8/695 | Loss: 0.6099648475646973
Epoch 1 | Step 9/695 | Loss: 0.5807234048843384
Epoch 1 | Step 10/695 | Loss: 0.5469526052474976
Epoch 1 | Step 11/695 | Loss: 0.5094784498214722
Epoch 1 | Step 12/695 | Loss: 0.467572420835495
Epoch 1 | Step 13/695 | Loss: 0.42099037766456604
Epoch 1 | Step 14/695 | Loss: 0.37070780992507935
Epoch 1 | Step 15/695 | Loss: 0.32141199707984924
Epoch 1 | Step 16/695 | Loss: 0.2745509743690491
Epoch 1 | Step 17/695 | Loss: 0.22781406342983246
Epoch 1 | Step 18/695 | L

# Inference

In [None]:
test_predictions = make_predictions(mmm, X_test)

lengths = [
    len(xx) for xx in test_predictions
]
print("Median:", np.median(lengths))
print("Mean:", np.mean(lengths))
print("Max:", np.max(lengths))
print("Min:", np.min(lengths))


x_test_predictions_list_str = [
    " ".join(map(str, x))
    for x in test_predictions
]

submission_dataframe = pd.DataFrame({
    "surveyId": test_data.surveyId,
    "predictions": x_test_predictions_list_str
})

submission_dataframe.to_csv("submission.csv", index=False)