In [21]:
import os
import pandas as pd
import numpy as np
from PIL import Image

# Define constants for group names and labels
LABELS = ['Landbird', 'Waterbird']

class Waterbirds:
    def __init__(self, root, captions_root, split='train'):
        self.root = os.path.expanduser(root)
        self.split = split
        self.captions_root = captions_root
        # Load metadata
        self.metadata_df = pd.read_csv(os.path.join(self.root, 'metadata.csv'))

        # Get the labels and corresponding filenames
        self.labels = self.metadata_df['y'].values
        self.filenames = self.metadata_df['img_filename'].values



    def add_captions_to_text_files(self):
        text_folder = self.captions_root  # Adjust as necessary

        # Iterate through the filenames and labels
        for filename, label in zip(self.filenames, self.labels):
            #print(filename)
            image_name = filename.replace('.jpg', '.txt')  # Assuming .jpg to .txt conversion
            #print(image_name)
            text_file_path = os.path.join(text_folder, image_name)

            # Define prompts based on label
            if label == 0:  # Landbird
                prompts = ["an image of a bird", "a photo of a bird", "an image of a landbird", "a photo of a landbird"]
            else:  # Waterbird
                prompts = ["an image of a bird", "a photo of a bird", "an image of a waterbird", "a photo of a waterbird"]

            # Append the prompts to the corresponding text file
            if os.path.exists(text_file_path):
                # Read existing lines in the text file
                with open(text_file_path, 'r') as f:
                    existing_prompts = f.read().splitlines()

                # Check if any of the prompts are already in the file
                if not any(prompt in existing_prompts for prompt in prompts):
                    with open(text_file_path, 'a') as f:
                        f.write('\n'.join(prompts) + '\n')  # Append prompts
                    print(f"Prompts added to {text_file_path}")
                else:
                    print(f"Prompts already exist in {text_file_path}")
            else:
                print(f"Text file not found: {text_file_path}")
            

# Usage

dataset = Waterbirds('/mimer/NOBACKUP/groups/ulio_inverse/ds-project/GALS/data/waterbird_1.0_forest2water2', '/mimer/NOBACKUP/groups/ulio_inverse/ds-project/ProbVLM/Datasets/text_c10', split='train')  # Specify the path to your dataset
dataset.add_captions_to_text_files()

