In [None]:
import os
from tqdm.notebook import tqdm
from collections import Counter

import json
import numpy as np
import pandas as pd

from PIL import Image
import seaborn as sns
from matplotlib import pyplot as plt

from sklearn.cluster import KMeans
from sklearn.metrics import precision_recall_fscore_support

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
import torchvision.models as models
import torchvision.transforms as transforms

import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform

from accelerate import Accelerator
from transformers import AutoTokenizer, AutoModel
from bitsandbytes.optim import Lion

In [None]:
DATA_PATH = "../kaggle_data/"

# Load Training metadata
train_landsat_data_path = f"{DATA_PATH}/TimeSeries-Cubes/TimeSeries-Cubes/GLC24-PA-train-landsat_time_series/"
train_bioclim_data_path = f"{DATA_PATH}/TimeSeries-Cubes/TimeSeries-Cubes/GLC24-PA-train-bioclimatic_monthly/"
train_sentinel_data_path=f"{DATA_PATH}/PA_Train_SatellitePatches_RGB/pa_train_patches_rgb/"
train_metadata_path = f"{DATA_PATH}/GLC24_PA_metadata_train.csv"
train_metadata = pd.read_csv(train_metadata_path)
train_metadata['speciesId'] = train_metadata['speciesId'].astype(int)

# Load Train PO metadata
train_po_landsat_data_path = f"../all_data/SatelliteTimeSeries/cubes/GLC24-PO-train-landsat_time_series/"
train_po_bioclim_data_path = f"../all_data/EnvironmentalRasters/Climate/Climatic_Monthly_2000-2019_cubes/GLC24_timeseries_csvs/cubes/GLC24-PO-train-bioclimatic_monthly/"
train_po_sentinel_data_path = f"../all_data/SatellitePatches/po_train_patches_rgb"
train_po_metadata_path = f"{DATA_PATH}/GLC24_P0_metadata_train.csv"
train_po_metadata = pd.read_csv(train_po_metadata_path)
train_po_metadata['speciesId'] = train_po_metadata['speciesId'].astype(int)

# Load Test metadata
test_landsat_data_path = f"{DATA_PATH}/TimeSeries-Cubes/TimeSeries-Cubes/GLC24-PA-test-landsat_time_series/"
test_bioclim_data_path = f"{DATA_PATH}/TimeSeries-Cubes/TimeSeries-Cubes/GLC24-PA-test-bioclimatic_monthly/"
test_sentinel_data_path = f"{DATA_PATH}/PA_Test_SatellitePatches_RGB/pa_test_patches_rgb/"
test_metadata_path = f"{DATA_PATH}/GLC24_PA_metadata_test.csv"
test_metadata = pd.read_csv(test_metadata_path)

In [None]:
train_pa_elevation = pd.read_csv(f'{DATA_PATH}/EnvironmentalRasters/EnvironmentalRasters/Elevation/GLC24-PA-train-elevation.csv')
train_pa_footprint = pd.read_csv(f'{DATA_PATH}/EnvironmentalRasters/EnvironmentalRasters/Human Footprint/GLC24-PA-train-human_footprint.csv')
train_pa_landcover = pd.read_csv(f'{DATA_PATH}/EnvironmentalRasters/EnvironmentalRasters/LandCover/GLC24-PA-train-landcover.csv')
train_pa_soilgrid = pd.read_csv(f'{DATA_PATH}/EnvironmentalRasters/EnvironmentalRasters/SoilGrids/GLC24-PA-train-soilgrids.csv')

train_po_elevation = pd.read_csv('../all_data/EnvironmentalRasters/Elevation/GLC24-PO-train-elevation.csv')
train_po_footprint = pd.read_csv('../all_data/EnvironmentalRasters/HumanFootprint/GLC24-PO-train-human-footprint.csv')
train_po_landcover = pd.read_csv('../all_data/EnvironmentalRasters/LandCover/GLC24-PO-train-landcover.csv')
train_po_soilgrid = pd.read_csv('../all_data/EnvironmentalRasters/Soilgrids/GLC24-PO-train-soilgrids.csv')

test_pa_elevation = pd.read_csv(f'{DATA_PATH}/EnvironmentalRasters/EnvironmentalRasters/Elevation/GLC24-PA-test-elevation.csv')
test_pa_footprint = pd.read_csv(f'{DATA_PATH}/EnvironmentalRasters/EnvironmentalRasters/Human Footprint/GLC24-PA-test-human_footprint.csv')
test_pa_landcover = pd.read_csv(f'{DATA_PATH}/EnvironmentalRasters/EnvironmentalRasters/LandCover/GLC24-PA-test-landcover.csv')
test_pa_soilgrid = pd.read_csv(f'{DATA_PATH}/EnvironmentalRasters/EnvironmentalRasters/SoilGrids/GLC24-PA-test-soilgrids.csv')

In [None]:
with open("../final_run/valid_species.txt", "r") as f:
    valid_species = json.load(f)

species_id_to_index = {species_id: i for i, species_id in enumerate(valid_species)}

NUM_CLASSES = len(valid_species)
NUM_CLASSES

In [None]:
train_metadata = train_metadata[train_metadata['speciesId'].isin(species_id_to_index)].reset_index(drop=True)
train_po_metadata = train_po_metadata[train_po_metadata['speciesId'].isin(species_id_to_index)].reset_index(drop=True)

train_metadata['speciesId'] = train_metadata['speciesId'].map(species_id_to_index)
train_po_metadata['speciesId'] = train_po_metadata['speciesId'].map(species_id_to_index)

In [None]:
train_metadata['speciesId'].max(), train_po_metadata['speciesId'].max()

In [None]:
train_pa_rasters = pd.concat([
    train_pa_elevation,
    train_pa_footprint.drop(columns=['surveyId']),
    train_pa_landcover.drop(columns=['surveyId']),
    train_pa_soilgrid.drop(columns=['surveyId'])
], axis=1)

train_po_rasters = pd.concat([
    train_po_elevation,
    train_po_footprint.drop(columns=['surveyId']),
    train_po_landcover.drop(columns=['surveyId']),
    train_po_soilgrid.drop(columns=['surveyId'])
], axis=1)

test_pa_rasters = pd.concat([
    test_pa_elevation,
    test_pa_footprint.drop(columns=['surveyId']),
    test_pa_landcover.drop(columns=['surveyId']),
    test_pa_soilgrid.drop(columns=['surveyId'])
], axis=1)

In [None]:
# train_pa_rasters = train_pa_rasters.fillna(train_pa_rasters.mean())
# train_po_rasters = train_po_rasters.fillna(train_pa_rasters.mean())
# test_pa_rasters = test_pa_rasters.fillna(train_pa_rasters.mean())

In [None]:
train_pa_dd = train_metadata.drop_duplicates(subset=['surveyId'], ignore_index=True)
train_pa_dd = train_pa_dd.drop(columns=['speciesId'])

train_pa_species = train_metadata.groupby('surveyId')['speciesId'].apply(list).reset_index()
train_pa_dd = pd.merge(train_pa_dd, train_pa_species, on='surveyId', how='left')
train_pa_dd.head()

In [None]:
err_cols = [
    'HumanFootprint-Built1994',
    'HumanFootprint-Built2009',
    'HumanFootprint-croplands1992',
    'HumanFootprint-croplands2005',
    'HumanFootprint-Lights1994',
    'HumanFootprint-Lights2009',
    'HumanFootprint-NavWater1994',
    'HumanFootprint-NavWater2009',
    'HumanFootprint-Pasture1993',
    'HumanFootprint-Pasture2009',
    'HumanFootprint-Popdensity1990',
    'HumanFootprint-Popdensity2010',
    'HumanFootprint-Railways',
    'HumanFootprint-Roads',
]

train_pa_rasters = train_pa_rasters.replace([np.inf, -np.inf], np.nan)
train_pa_rasters[err_cols] = train_pa_rasters[err_cols].map(lambda x: np.nan if x > 10 or x < 0 else x)
train_pa_dd = pd.merge(train_pa_dd, train_pa_rasters, on='surveyId', how='left')

In [None]:
lat_min, lat_max = train_pa_dd['lat'].min(), train_pa_dd['lat'].max()
lon_min, lon_max = train_pa_dd['lon'].min(), train_pa_dd['lon'].max()

train_po_metadata = train_po_metadata[(train_po_metadata['lat'] >= lat_min) & (train_po_metadata['lat'] <= lat_max)]
train_po_metadata = train_po_metadata[(train_po_metadata['lon'] >= lon_min) & (train_po_metadata['lon'] <= lon_max)]
train_po_metadata = train_po_metadata.drop_duplicates(subset=['surveyId', 'speciesId'], ignore_index=True)
train_po_metadata = train_po_metadata.reset_index(drop=True)

In [None]:
train_po_dd = train_po_metadata.drop_duplicates(subset=['surveyId'], ignore_index=True)
train_po_dd = train_po_dd.drop(columns=['speciesId'])

train_po_species = train_po_metadata.groupby('surveyId')['speciesId'].apply(lambda x: list(x)).reset_index()
train_po_dd = pd.merge(train_po_dd, train_po_species, on='surveyId', how='left')
train_po_dd.head()

In [None]:
train_po_rasters = train_po_rasters.replace([np.inf, -np.inf], np.nan)
train_po_rasters[err_cols] = train_po_rasters[err_cols].map(lambda x: np.nan if x > 10 or x < 0 else x)
train_po_dd = pd.merge(train_po_dd, train_po_rasters, on='surveyId', how='left')

In [None]:
len(train_pa_dd), len(train_po_dd)

In [None]:
from sklearn.cluster import KMeans

# 假設 'latitude' 和 'longitude' 是你數據中的經緯度列
coordinates = train_pa_dd[['lat', 'lon']]

# 設定群組數
N = 6000  # 你可以根據需要調整這個值

# 執行 K-means 分群
cluster = KMeans(n_clusters=N, max_iter=1, random_state=0).fit(coordinates)
# cluster = ward_tree(coordinates, n_clusters=N)
# cluster = OPTICS(min_samples=50, xi=.05, min_cluster_size=.05).fit(coordinates)

# 將分群結果新增為一個新的列
train_pa_dd['i_fold'] = cluster.labels_ % 5

# 可以打印看看分群結果的一部分
train_pa_dd.head()

In [None]:
train_pa_dd['areaInM2'] = train_pa_dd['areaInM2'].replace([np.inf, -np.inf], np.nan)

In [None]:
VALUE_COLUMNS = [
    'lon', 'lat', 'year', 'geoUncertaintyInM', 'Elevation',

    'HumanFootprint-Built1994',
    'HumanFootprint-Built2009',
    'HumanFootprint-Popdensity2010',
    'HumanFootprint-Roads',

    'HumanFootprint-croplands1992', 'HumanFootprint-croplands2005',
    'HumanFootprint-Lights1994', 'HumanFootprint-Lights2009',
    'HumanFootprint-NavWater1994', 'HumanFootprint-NavWater2009',
    'HumanFootprint-Pasture1993', 'HumanFootprint-Pasture2009',
    'HumanFootprint-Popdensity1990',
    
    'HumanFootprint-Railways', 
    'HumanFootprint-HFP1993', 'HumanFootprint-HFP2009', 'LandCover',
    'Soilgrid-bdod', 'Soilgrid-cec', 'Soilgrid-cfvo', 'Soilgrid-clay',
    'Soilgrid-nitrogen', 'Soilgrid-phh2o', 'Soilgrid-sand', 'Soilgrid-silt',
    'Soilgrid-soc',
]

In [None]:
train_pa_dd[VALUE_COLUMNS] = train_pa_dd[VALUE_COLUMNS].astype(np.float32)
train_po_dd[VALUE_COLUMNS] = train_po_dd[VALUE_COLUMNS].astype(np.float32)

In [None]:
features_mean = train_pa_dd[VALUE_COLUMNS].mean()
features_std = train_pa_dd[VALUE_COLUMNS].std()

train_pa_dd[VALUE_COLUMNS] = (train_pa_dd[VALUE_COLUMNS] - features_mean) / features_std
train_pa_dd[VALUE_COLUMNS] = train_pa_dd[VALUE_COLUMNS].fillna(0)

train_po_dd[VALUE_COLUMNS] = (train_po_dd[VALUE_COLUMNS] - features_mean) / features_std
train_po_dd[VALUE_COLUMNS] = train_po_dd[VALUE_COLUMNS].fillna(0)

In [None]:
def construct_patch_path(data_path, survey_id):
    """Construct the patch file path based on plot_id as './CD/AB/XXXXABCD.jpeg'"""
    path = data_path
    for d in (str(survey_id)[-2:], str(survey_id)[-4:-2]):
        path = os.path.join(path, d)

    path = os.path.join(path, f"{survey_id}.jpeg")

    return path
    
class TrainDataset(Dataset):
    def __init__(self, metadata, bioclimate_path, landsat_path, sentinel_path, split='pa_train'):
        self.metadata = metadata
        self.bioclimate_path = bioclimate_path
        self.landsat_path = landsat_path
        self.sentinel_path = sentinel_path

        self.transform = transforms.Compose([
            transforms.ToTensor(),
        ])
        self.sentinel_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((128, 128)),
            transforms.Normalize(mean=[0.485, 0.456, 0.406, 0.485], std=[0.229, 0.224, 0.225, 0.229])
        ])

        self.split = split

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        survey_id = self.metadata.surveyId[idx]
        
        if self.split == 'pa_train':
            landsat_sample = torch.nan_to_num(torch.load(os.path.join(self.landsat_path, f"GLC24-PA-train-landsat-time-series_{survey_id}_cube.pt")))
            bioclim_sample = torch.nan_to_num(torch.load(os.path.join(self.bioclimate_path, f"GLC24-PA-train-bioclimatic_monthly_{survey_id}_cube.pt")))
        elif self.split == 'pa_test':
            landsat_sample = torch.nan_to_num(torch.load(os.path.join(self.landsat_path, f"GLC24-PA-test-landsat_time_series_{survey_id}_cube.pt")))
            bioclim_sample = torch.nan_to_num(torch.load(os.path.join(self.bioclimate_path, f"GLC24-PA-test-bioclimatic_monthly_{survey_id}_cube.pt")))
        elif self.split == 'po_train':
            landsat_sample = torch.nan_to_num(torch.load(os.path.join(self.landsat_path, f"GLC24-PO-train-landsat_time_series_{survey_id}_cube.pt")))
            bioclim_sample = torch.nan_to_num(torch.load(os.path.join(self.bioclimate_path, f"GLC24-PO-train-bioclimatic_monthly_{survey_id}_cube.pt")))
        else:
            raise ValueError(f"Invalid split: {self.split}")

        rgb_sample = np.array(Image.open(construct_patch_path(self.sentinel_path, survey_id)))
        nir_sample = np.array(Image.open(construct_patch_path(self.sentinel_path.replace("rgb", "nir").replace("RGB", "NIR"), survey_id)))
        sentinel_sample = np.concatenate((rgb_sample, nir_sample[...,None]), axis=2)
        
        if isinstance(landsat_sample, torch.Tensor):
            landsat_sample = landsat_sample.permute(1, 2, 0)  # Change tensor shape from (C, H, W) to (H, W, C)
            landsat_sample = landsat_sample.numpy()  # Convert tensor to numpy array
            
        if isinstance(bioclim_sample, torch.Tensor):
            bioclim_sample = bioclim_sample.permute(1, 2, 0)  # Change tensor shape from (C, H, W) to (H, W, C)
            bioclim_sample = bioclim_sample.numpy()  # Convert tensor to numpy array   

        value_features = self.metadata.iloc[idx][VALUE_COLUMNS].values.astype(np.float32)

        if self.transform:
            landsat_sample = self.transform(landsat_sample)
            bioclim_sample = self.transform(bioclim_sample)
        
        if self.sentinel_transform:
            sentinel_sample = self.sentinel_transform(sentinel_sample)

        label = []
        if 'speciesId' in self.metadata.columns:
            species_ids = self.metadata.speciesId[idx]

            if self.split == 'pa_train':
                label = torch.zeros(NUM_CLASSES)
                for species_id in species_ids:
                    label_id = species_id
                    label[label_id] = 1  # Set the corresponding class index to 1 for each species
            elif self.split == 'po_train':
                label = torch.full((NUM_CLASSES,), -100.0)
                for species_id in species_ids:
                    label_id = species_id
                    label[label_id] = 1
            else:
                raise ValueError(f"Invalid split: {self.split}")

        return value_features, landsat_sample, bioclim_sample, sentinel_sample, label, survey_id

In [None]:
train_pa_dataset = TrainDataset(train_pa_dd, train_bioclim_data_path, train_landsat_data_path, train_sentinel_data_path)
train_po_dataset = TrainDataset(train_po_dd, train_po_bioclim_data_path, train_po_landsat_data_path, train_po_sentinel_data_path, split='po_train')

In [None]:
[x.shape for x in train_po_dataset[0]]

In [None]:
import torch.nn.functional as F
from torchgeo.models import ResNet18_Weights, ResNet50_Weights

class MultimodalEnsemble(nn.Module):
    def __init__(self, num_classes):
        super(MultimodalEnsemble, self).__init__()

        self.value_mlp = nn.Sequential(
            nn.Linear(len(VALUE_COLUMNS), 1000),
            nn.BatchNorm1d(1000),
            nn.ReLU(),
            nn.Linear(1000, 1000),
            nn.BatchNorm1d(1000),
            nn.ReLU(),
            nn.Linear(1000, 1000),
            nn.BatchNorm1d(1000),
            nn.ReLU(),
            nn.Linear(1000, 512),
        )
        
        self.landsat_norm = nn.LayerNorm([6,4,21])
        self.landsat_model = models.resnet18(weights=None, num_classes=0)
        # Modify the first convolutional layer to accept 6 channels instead of 3
        self.landsat_model.conv1 = nn.Conv2d(6, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.landsat_model.maxpool = nn.Identity()
        self.landsat_model.fc = nn.Identity()
        
        self.bioclim_norm = nn.LayerNorm([4,19,12])
        self.bioclim_model = models.resnet18(weights=None, num_classes=0)  
        # Modify the first convolutional layer to accept 4 channels instead of 3
        self.bioclim_model.conv1 = nn.Conv2d(4, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bioclim_model.maxpool = nn.Identity()
        self.bioclim_model.fc = nn.Identity()

        sentinel_weights = ResNet18_Weights.SENTINEL2_RGB_MOCO
        self.sentinel_model = timm.create_model("resnet18", in_chans=sentinel_weights.meta["in_chans"], num_classes=0)
        self.sentinel_model.load_state_dict(sentinel_weights.get_state_dict(progress=True), strict=False)
        with torch.no_grad():
            in_conv_w = self.sentinel_model.conv1.weight
            in_conv_w_new = torch.cat([in_conv_w, in_conv_w[:, [0]]], dim=1)
            self.sentinel_model.conv1 = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
            self.sentinel_model.conv1.weight = nn.Parameter(in_conv_w_new)

        # sentinel_weights = ResNet50_Weights.SENTINEL2_RGB_SECO
        # self.sentinel_model = timm.create_model("resnet50", in_chans=sentinel_weights.meta["in_chans"], num_classes=0)
        # self.sentinel_model.load_state_dict(sentinel_weights.get_state_dict(progress=True), strict=False)
        
        self.ln_landsat = nn.LayerNorm(512)
        self.ln_bioclim = nn.LayerNorm(512)

        self.mlp = nn.Sequential(
            nn.Linear(512+512+512+512, 1024),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(1024, num_classes)
        )
        
    def forward(self, value_features, landsat, bioclim, sentinel):
        
        value_features = self.value_mlp(value_features)
        
        landsat = self.landsat_norm(landsat)
        landsat = self.landsat_model(landsat)
        landsat = self.ln_landsat(landsat)
        
        bioclim = self.bioclim_norm(bioclim)
        bioclim = self.bioclim_model(bioclim)
        bioclim = self.ln_bioclim(bioclim)
        
        sentinel = self.sentinel_model(sentinel)
        
        x = torch.cat([value_features, landsat, bioclim, sentinel], dim=1)
        x = self.mlp(x)
        
        return x

In [None]:
train_loader = DataLoader(train_pa_dataset, batch_size=32, shuffle=True, num_workers=8, pin_memory=True)

In [None]:
model = MultimodalEnsemble(NUM_CLASSES)

In [None]:
model

In [None]:
for batch in train_loader:
    value_features, landsat, bioclim, sentinel, label, survey_id = batch
    with torch.no_grad():
        output = model(value_features, landsat, bioclim, sentinel)
    break

In [None]:
accelerator = Accelerator(
    mixed_precision='bf16',
)

In [None]:
def f1_score_multilabel(preds, targets):
    """
    計算 multi-label 的 F1 score (sample wise 平均) 並回傳 recall 和 precision
    
    Args:
        preds (torch.Tensor): 預測值，形狀為 (batch_size, num_classes)，整數類型
        targets (torch.Tensor): 實際標籤，形狀為 (batch_size, num_classes)，整數類型
        
    Returns:
        tuple: sample wise 平均的 F1 score, recall, precision
    """

    preds = preds.float()
    targets = targets.float()
    
    # 計算 TP, FP, FN
    tp = (preds * targets).sum(dim=1).float()
    fp = (preds * (1 - targets)).sum(dim=1).float()
    fn = ((1 - preds) * targets).sum(dim=1).float()
    
    # 計算 precision 和 recall
    precision = tp / (tp + fp + 1e-7)
    recall = tp / (tp + fn + 1e-7)
    
    # 計算 F1 score
    f1 = 2 * (precision * recall) / (precision + recall + 1e-7)
    
    # 計算 sample wise 平均的 F1 score, recall, precision
    average_f1 = f1.mean().item()
    average_precision = precision.mean().item()
    average_recall = recall.mean().item()
    
    return average_f1, average_precision, average_recall

In [None]:
pa_species_counts = train_metadata['speciesId'].value_counts()
pa_rare_species = pa_species_counts[pa_species_counts < 20].index

train_po_metadata['is_rare'] = train_po_metadata['speciesId'].isin(pa_rare_species)
train_po_sid_have_rare = train_po_metadata.groupby('surveyId')['is_rare'].any()

In [None]:
# for batch in fold_train_loader:
#     break

In [None]:
thresholds = np.arange(0.1, 0.99, 0.05)
batch_size = 64

for i_fold in range(5):
    print(f"Training fold {i_fold+1}")

    fold_train_metadata = train_pa_dd[train_pa_dd['i_fold'] != i_fold].reset_index(drop=True)
    fold_train_dataset = TrainDataset(fold_train_metadata, train_bioclim_data_path, train_landsat_data_path, train_sentinel_data_path)
    print(f'PA Training data size: {len(fold_train_dataset)}')

    # fold_po_metadata = train_po_dd.sample(100000, random_state=i_fold).reset_index(drop=True)
    # fold_po_metadata = pd.merge(fold_po_metadata, train_po_sid_have_rare, on='surveyId', how='left')
    # fold_po_metadata = fold_po_metadata[fold_po_metadata['is_rare']].reset_index(drop=True)
    # train_po_dataset = TrainDataset(fold_po_metadata, train_po_bioclim_data_path, train_po_landsat_data_path, train_po_sentinel_data_path, split='po_train')
    # print(f'PO Training data size: {len(train_po_dataset)}')

    # fold_train_dataset = torch.utils.data.ConcatDataset([fold_train_dataset, train_po_dataset])
    # print(f'Training data size: {len(fold_train_dataset)}')
    
    fold_train_loader = DataLoader(fold_train_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True, drop_last=True)

    fold_val_metadata = train_pa_dd[train_pa_dd['i_fold'] == i_fold].reset_index(drop=True)
    fold_val_dataset = TrainDataset(fold_val_metadata, train_bioclim_data_path, train_landsat_data_path, train_sentinel_data_path)
    fold_val_loader = DataLoader(fold_val_dataset, batch_size=batch_size*4, shuffle=False, num_workers=8, pin_memory=True)

    # Check if cuda is available
    device = torch.device("cpu")

    if torch.cuda.is_available():
        device = torch.device("cuda")
        tqdm.write("DEVICE = CUDA")

    num_classes = NUM_CLASSES # Number of all unique classes within the PO and PA data.
    model = MultimodalEnsemble(num_classes).to(device)

    # for name, param in model.named_parameters():
    #     if "sentinel" in name:
    #         param.requires_grad = False

    # Hyperparameters
    learning_rate = 0.00025
    num_epochs = 10
    positive_weigh_factor = 1.0
    fold_best_f1 = 0.0

    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
    scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, verbose=True)

    model, optimizer, scheduler = accelerator.prepare(
        model, optimizer, scheduler
    )

    fold_train_loader, fold_val_loader = accelerator.prepare(
        fold_train_loader, fold_val_loader
    )

    tqdm.write(f"Training for {num_epochs} epochs started.")

    
    for epoch in tqdm(range(num_epochs)):
        model.train()
        train_f1_scores = {th:[] for th in thresholds}

        for batch_idx, (value_features, landsat, bioclim, sentinel, targets, _) in enumerate(tqdm(fold_train_loader, leave=False)):

            value_features = value_features.to(device)
            landsat = landsat.to(device)
            bioclim = bioclim.to(device)
            sentinel = sentinel.to(device)
            targets = targets.to(device)

            optimizer.zero_grad()
            outputs = model(value_features, landsat, bioclim, sentinel)

            # label smoothing
            # pos_weight = targets*positive_weigh_factor
            # criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
            
            # targets[targets == -100] = 0
            # criterion = torch.nn.BCEWithLogitsLoss()
            # loss = criterion(outputs, targets)
            
            output = outputs.reshape(-1)
            target = targets.reshape(-1)
            loss = F.binary_cross_entropy_with_logits(
                output[target != -100], target[target != -100],
                # pos_weight=torch.tensor(3.0).to(device)
            )

            loss.backward()
            optimizer.step()

            if batch_idx % 500 == 0:
                tqdm.write(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(fold_train_loader)}, Loss: {loss.item()}")

            for threshold in thresholds:
                predictions = (outputs.sigmoid() > threshold)
                f1, _, _ = f1_score_multilabel(predictions, (targets==1))
                train_f1_scores[threshold].append(f1)

        best_threshold, best_f1 = max([(k, np.mean(v)) for k, v in train_f1_scores.items()], key=lambda x: x[1])
        tqdm.write(f"Epoch {epoch+1}. Training best F1 score: {best_f1:.5f} with threshold: {best_threshold:.3f}")

        scheduler.step()
        print("Scheduler:",scheduler.state_dict())

        # evaluation
        model.eval()
        val_f1_scores = {th:[] for th in thresholds}

        with torch.no_grad():
            all_targets = []
            all_outputs = []
            for value_features, landsat, bioclim, sentinel, targets, _ in tqdm(fold_val_loader, leave=False):
                value_features = value_features.to(device)
                landsat = landsat.to(device)
                bioclim = bioclim.to(device)
                sentinel = sentinel.to(device)
                targets = targets.to(device)

                outputs = model(value_features, landsat, bioclim, sentinel)

                for threshold in thresholds:
                    predictions = (outputs.sigmoid() > threshold)
                    f1, _, _ = f1_score_multilabel(predictions, targets)
                    val_f1_scores[threshold].append(f1)

        best_threshold, best_f1 = max([(k, np.mean(v)) for k, v in val_f1_scores.items()], key=lambda x: x[1])
        tqdm.write(f"Epoch {epoch+1}. Validation best F1 score: {best_f1:.5f} with threshold: {best_threshold:.3f}")

        # if best_f1 > fold_best_f1:
        #     model.eval()
        #     if not os.path.exists("all_new_v2_dl"):
        #         os.makedirs("all_new_v2_dl", exist_ok=True)
        #     torch.save(model.state_dict(), f"all_new_v2_dl/model_fold_{i_fold}.pt")
        #     tqdm.write(f"Saved model with F1 score: {best_f1:.5f}")

        #     fold_best_f1 = best_f1

In [None]:
class_freq = train_metadata['speciesId'].value_counts().sort_index().values


In [None]:
with torch.no_grad():
    all_targets = []
    all_outputs = []
    for value_features, landsat, bioclim, sentinel, targets, _ in tqdm(fold_val_loader, leave=False):
        value_features = value_features.to(device)
        landsat = landsat.to(device)
        bioclim = bioclim.to(device)
        sentinel = sentinel.to(device)
        targets = targets.to(device)

        outputs = model(value_features, landsat, bioclim, sentinel)

        all_targets.append(targets)
        all_outputs.append(outputs)

all_targets = torch.cat(all_targets, dim=0)
all_outputs = torch.cat(all_outputs, dim=0)

In [None]:
class_freq.max() * 0.000005

In [None]:
np.log(class_freq).min()

In [None]:
best_f1 = 0.0
best_threshold = 0.0
best_freq_coef = 0.0
for threshold in np.arange(0.0, 0.3, 0.01):
    for freq_coef in np.arange(0.0, 0.3, 0.01):
        real_th = torch.tensor(threshold + freq_coef * np.log(class_freq)).to(device)
        predictions = (all_outputs.sigmoid() > real_th)
        f1, r, p = f1_score_multilabel(predictions, all_targets)
        print(f"Threshold: {threshold:.3f}, Freq Coef: {freq_coef:.6f}, F1: {f1:.5f}, Recall: {r:.5f}, Precision: {p:.5f}")

        if f1 > best_f1:
            best_f1 = f1
            best_threshold = threshold
            best_freq_coef = freq_coef
    print()

print(f"Best F1: {best_f1:.5f}, Best Threshold: {best_threshold:.3f}, Best Freq Coef: {best_freq_coef:.6f}")

In [None]:
best_f1 = 0.0
best_threshold = 0.0
best_freq_coef = 0.0
for threshold in np.arange(0.0, 0.3, 0.02):
    for freq_coef in [0, 0.000001, 0.000002, 0.000003, 0.000004, 0.000005, 0.000007, 0.00001, 0.000015, 0.00002, 0.00003]:
        real_th = torch.tensor(threshold + freq_coef * class_freq).to(device)
        predictions = (all_outputs.sigmoid() > real_th)
        f1, r, p = f1_score_multilabel(predictions, all_targets)
        print(f"Threshold: {threshold:.3f}, Freq Coef: {freq_coef:.6f}, F1: {f1:.5f}, Recall: {r:.5f}, Precision: {p:.5f}")

        if f1 > best_f1:
            best_f1 = f1
            best_threshold = threshold
            best_freq_coef = freq_coef
    print()

print(f"Best F1: {best_f1:.5f}, Best Threshold: {best_threshold:.3f}, Best Freq Coef: {best_freq_coef:.6f}")

In [None]:
test_pa_rasters = test_pa_rasters.replace([np.inf, -np.inf], np.nan)
test_pa_rasters[err_cols] = test_pa_rasters[err_cols].map(lambda x: np.nan if x > 10 or x < 0 else x)
test_pa_metadata = pd.merge(test_metadata, test_pa_rasters, on='surveyId', how='left')
test_pa_metadata['areaInM2'] = test_pa_metadata['areaInM2'].replace([np.inf, -np.inf], np.nan)
test_pa_metadata[VALUE_COLUMNS] = (test_pa_metadata[VALUE_COLUMNS] - features_mean) / features_std
test_pa_metadata[VALUE_COLUMNS] = test_pa_metadata[VALUE_COLUMNS].fillna(0)

In [None]:
test_dataset = TrainDataset(
    test_pa_metadata, test_bioclim_data_path,
    test_landsat_data_path, test_sentinel_data_path, split='pa_test',
)

test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=8, pin_memory=True)

In [None]:
device = torch.device("cuda")

In [None]:
train_po_dataset = TrainDataset(train_po_dd, train_po_bioclim_data_path, train_po_landsat_data_path, train_po_sentinel_data_path, split='po_train')
train_po_loader = DataLoader(train_po_dataset, batch_size=512, shuffle=False, num_workers=24, pin_memory=True)
model = MultimodalEnsemble(NUM_CLASSES).to(device)

model, train_po_loader = accelerator.prepare(model, train_po_loader)

all_predictions = []
for i_fold in tqdm(range(5)):
    model.load_state_dict(torch.load(f"all_new_v2_dl/model_fold_{i_fold}.pt"))
    model.eval()
    with torch.no_grad():
        survey_ids = []
        fold_predictions = []
        for batch_idx, (value_features, landsat, bioclim, sentinel, _, survey_id) in enumerate(tqdm(train_po_loader)):
            value_features = value_features.to(device)
            landsat = landsat.to(device)
            bioclim = bioclim.to(device)
            sentinel = sentinel.to(device)

            outputs = model(value_features, landsat, bioclim, sentinel)
            predictions = outputs.sigmoid()
            fold_predictions.append(predictions.cpu())
            survey_ids.extend(survey_id)

            if len(fold_predictions) == 200:
                fold_predictions = torch.cat(fold_predictions)
                torch.save(fold_predictions, f"all_new_v2_dl/po_predictions_fold_{i_fold}_batch_{batch_idx+1}.pt")
                fold_predictions = []

                print(f"Saved fold {i_fold} batch {batch_idx} predictions")

        fold_predictions = torch.cat(fold_predictions)
        torch.save(fold_predictions, f"all_new_v2_dl/po_predictions_fold_{i_fold}_batch_{batch_idx}.pt")
        print(f"Saved fold {i_fold} batch {batch_idx} predictions")

In [None]:
all_batch_indices = list(range(200, 6001, 200)) + [6090]
all_pseudo_labels = []

for i_batch in tqdm(all_batch_indices):
    batch_preds = []
    for i_fold in range(5):
        pred = torch.load(f"all_new_v2_dl/po_predictions_fold_{i_fold}_batch_{i_batch}.pt")
        batch_preds.append(pred)
    mean_batch_preds = torch.stack(batch_preds, dim=0).mean(dim=0)
    
    for probs in mean_batch_preds:
        pos_indices = torch.where(probs > 0.4)[0].tolist()
        ignore_indices = torch.where((probs > 0.05) & (probs <= 0.4))[0].tolist()
        all_pseudo_labels.append((pos_indices, ignore_indices))

In [None]:
pseudo_label_df = pd.DataFrame()
pseudo_label_df['surveyId'] = train_po_dd['surveyId']
pseudo_label_df[['pos_indices', 'ignore_indices']] = pd.DataFrame(all_pseudo_labels)

In [None]:
pseudo_label_df.to_csv("all_new_v2_dl/po_pseudo_labels.csv", index=False)