In [None]:
import pandas as pd
from sklearn.model_selection import train_test_split
from datasets import load_dataset
from torchvision import transforms
from PIL import Image
from bioclip import TreeOfLifeClassifier


# Load the Hugging Face dataset in streaming mode
hf_dataset = load_dataset("imageomics/TreeOfLife-10M", split="train", streaming=True)

# Get the catalog.csv file URL directly from the dataset info
catalog_url = hf_dataset.info.splits["train"].data_files["catalog.csv"]

# Load catalog.csv into a pandas DataFrame
metadata = pd.read_csv(catalog_url)

# Filter the train split from the metadata
train_metadata = metadata[metadata['split'] == 'train']

# Create custom train, validation, and test splits
train_subset, test_subset = train_test_split(train_metadata, test_size=0.1, random_state=42)
train_subset, val_subset = train_test_split(train_subset, test_size=0.1, random_state=42)

# Save splits for reproducibility
train_subset.to_csv("train_subset.csv", index=False)
val_subset.to_csv("val_subset.csv", index=False)
test_subset.to_csv("test_subset.csv", index=False)

print(f"Training subset: {len(train_subset)} samples")
print(f"Validation subset: {len(val_subset)} samples")
print(f"Testing subset: {len(test_subset)} samples")

classifier = TreeOfLifeClassifier()


augmentation_pipeline = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.Resize((224, 224)),  # Resize to match model input size
    transforms.ToTensor(),  # Convert image to tensor
])


def filter_dataset(dataset, subset_metadata):
    """
    Filters the Hugging Face dataset using metadata identifiers.
    """
    subset_ids = set(subset_metadata['image_id'])  # Assuming 'image_id' is a unique key in metadata
    for sample in dataset:
        if sample['image_id'] in subset_ids:  # Match image IDs
            yield sample


def process_and_generate_soft_targets(dataset_split, batch_size=32):
    """
    Processes the dataset in batches, applies augmentations, and generates predictions using the BioCLIP classifier.
    """
    batch = []
    for sample in dataset_split:
        # Load the image
        image = Image.open(sample["image"])  # Load image from URL or local path

        # Apply augmentations
        augmented_image = augmentation_pipeline(image)

        # Add the augmented image to the batch
        batch.append(augmented_image.numpy())  # Convert back to numpy array if needed for BioCLIP

        if len(batch) == batch_size:
            # Pass the batch to the classifier for prediction
            batch_predictions = classifier.predict(batch)
            yield batch, batch_predictions
            batch = []  # Reset the batch

    # Handle remaining samples in the last batch
    if batch:
        batch_predictions = classifier.predict(batch)
        yield batch, batch_predictions


# Filter the Hugging Face dataset for each subset
train_dataset = filter_dataset(hf_dataset, train_subset)
val_dataset = filter_dataset(hf_dataset, val_subset)
test_dataset = filter_dataset(hf_dataset, test_subset)

# Process the training subset with augmentation
print("Processing Training Data with Augmentations:")
for i, (batch_images, batch_predictions) in enumerate(process_and_generate_soft_targets(train_dataset, batch_size=32)):
    print(f"Processed Training Batch {i + 1}")

# Process the validation subset without augmentation
print("Processing Validation Data:")
for i, (batch_images, batch_predictions) in enumerate(process_and_generate_soft_targets(val_dataset, batch_size=32)):
    print(f"Processed Validation Batch {i + 1}")

# Process the test subset without augmentation
print("Processing Testing Data:")
for i, (batch_images, batch_predictions) in enumerate(process_and_generate_soft_targets(test_dataset, batch_size=32)):
    print(f"Processed Testing Batch {i + 1}")
