In [None]:
import plotly.graph_objects as go
import torch
import numpy as np
from torch.utils.data import DataLoader
from pathlib import Path

from bfm_finetune.dataloaders.geolifeclef_species import utils as geolife_utils
from bfm_finetune.plots_v2 import plot_eval
from bfm_finetune.dataloaders.geolifeclef_species.dataloader import (
    GeoLifeCLEFSpeciesDataset,
)
from bfm_finetune.dataloaders.dataloader_utils import custom_collate_fn

In [None]:
fpath = geolife_utils.aurorashape_species_location / "train" / "yearly_species_2017-2018.pt"
data = torch.load(fpath, map_location="cpu", weights_only=True)

In [None]:
species_distribution = data["species_distribution"].numpy()
print(species_distribution.shape)
print(data["metadata"]["lat"])
print(data["metadata"]["lon"])

In [None]:
# take one timestep and one species
matrix = species_distribution[1, 0, :, :]
# raw heatmap, no roll
# just flip north above
matrix = np.flip(matrix, axis=0)
fig = go.Figure(data=go.Heatmap(
                    z=matrix))
fig.show()

# negative_lon_mode

In [None]:
train_dataset = GeoLifeCLEFSpeciesDataset(num_species=500, mode="train", negative_lon_mode="ignore")
train_dataloader = DataLoader(
        train_dataset,
        batch_size=1,
        shuffle=False,
        collate_fn=custom_collate_fn,
        num_workers=1,
    )
batch = list(train_dataloader)[2]["batch"]
plot_eval(batch, None, Path("."), n_species_to_plot=1, save=False)

In [None]:
train_dataset = GeoLifeCLEFSpeciesDataset(num_species=500, mode="train", negative_lon_mode="exclude")
train_dataloader = DataLoader(
        train_dataset,
        batch_size=1,
        shuffle=False,
        collate_fn=custom_collate_fn,
        num_workers=1,
    )
batch = list(train_dataloader)[2]["batch"]
plot_eval(batch, None, Path("."), n_species_to_plot=1, save=False)

In [None]:
train_dataset = GeoLifeCLEFSpeciesDataset(num_species=500, mode="train", negative_lon_mode="roll")
train_dataloader = DataLoader(
        train_dataset,
        batch_size=1,
        shuffle=False,
        collate_fn=custom_collate_fn,
        num_workers=1,
    )
batch = list(train_dataloader)[2]["batch"]
plot_eval(batch, None, Path("."), n_species_to_plot=1, save=False)

In [None]:
train_dataset = GeoLifeCLEFSpeciesDataset(num_species=500, mode="train", negative_lon_mode="translate")
train_dataloader = DataLoader(
        train_dataset,
        batch_size=1,
        shuffle=False,
        collate_fn=custom_collate_fn,
        num_workers=1,
    )
batch = list(train_dataloader)[2]["batch"]
plot_eval(batch, None, Path("."), n_species_to_plot=1, save=False)