# **Mounting the google drive**

---



In [None]:
# you need to uncomment this line when you want to run this code on google colab and load you data from google drive
from google.colab import drive
# Mount Google Drive as a local file system
drive.mount('/content/drive')

# **Installing the needed libraries**

---



In [None]:
!pip install diffusers torch accelerate transformers

# **Importing the needed libraries**

---



In [None]:
import torch
from diffusers import AmusedPipeline
import pandas as pd
from google.colab import drive
import os
import re
import sys
from tqdm import tqdm
import concurrent.futures

# **Double check**

In [None]:
# Install the accelerate library
!pip install --upgrade accelerate

# Import the accelerate library
from accelerate import cpu_offload

# Mount Google Drive
try:
    drive.mount('/content/drive')
except Exception as e:
    print(f"Error mounting Google Drive: {e}")
    print("Attempting to forcibly remount...")
    drive.mount('/content/drive', force_remount=True)

# **Set up the environment**

---



In [None]:
# Set the CUDA allocation configuration to use expandable segments
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

# Define global variables for input and output file paths, you can change this if you have your custom dataset
INPUT_FILE_PATH = '/content/drive/MyDrive/datacolab_dataset/booksummaries.txt'
OUTPUT_DIR = '/content/drive/MyDrive/datacolab_dataset/image_outputs'

# **Generate the images using paralleliztion**

---



In [None]:
class ImageGenerator:
    """
    A class responsible for generating images from text prompts using the Amused Diffusion model.
    """
    def __init__(self, device: torch.device):
        self.device = device
        self.pipe = self.load_amused_pipeline()
        self.existing_images = self.get_existing_image_filenames()

    def load_amused_pipeline(self) -> AmusedPipeline:
        """
        Load the pre-trained Amused Diffusion model and move it to the specified device.
        """
        pipe = AmusedPipeline.from_pretrained("amused/amused-256", device_map="auto", low_cpu_mem_usage=True)
        return pipe.to(self.device)

    def get_existing_image_filenames(self) -> set:
        """
        Get a set of existing image filenames in the output directory.
        """
        if not os.path.exists(OUTPUT_DIR):
            os.makedirs(OUTPUT_DIR, exist_ok=True)
        return set([f.split('.')[0] for f in os.listdir(OUTPUT_DIR) if f.endswith('.png')])

    def generate_image(self, freebase_id: str, prompt: str) -> None:
        """
        Generate an image from the given text prompt and save it with the freebaseID as the filename.

        If the image for the current freebaseID already exists, print a message and return.
        """
        # Replace forward slashes with underscores in the freebaseID
        cleaned_freebase_id = re.sub(r'[/]', '_', freebase_id)

        # Check if the image with the current freebaseID already exists
        if cleaned_freebase_id in self.existing_images:
            print(f"Image for {cleaned_freebase_id} already exists, skipping...")
            return

        try:
            # Generate the image using the Amused Diffusion model
            image = self.pipe(prompt, negative_prompt="low quality, ugly", generator=torch.manual_seed(0)).images[0]
            image_path = os.path.join(OUTPUT_DIR, f"{cleaned_freebase_id}.png")
            image.save(image_path)
            print(f"Image saved: {image_path}")
            sys.stdout.flush()
        except Exception as e:
            print(f"Error generating image for {cleaned_freebase_id}: {e}")

class DataManager:
    """
    A class responsible for managing the input data and the image generation process.
    """
    def __init__(self, input_file_path: str):
        self.input_file_path = input_file_path

    def load_input_data(self) -> pd.DataFrame:
        """
        Load the input data from the specified file path.

        Returns:
            pd.DataFrame: The input DataFrame containing the text prompts and freebaseIDs.
        """
        column_names = ["length", "freebase_id", "book_name", "author_name", "date", "freebase_id_json", "summary"]
        data = pd.read_csv(self.input_file_path, sep="\t", header=None, names=column_names)
        return pd.DataFrame(data)

def main():
    """
    The main entry point of the application.

    1. Create a DataManager instance to handle the input data
    2. Load the input data
    3. Create an ImageGenerator instance and start the parallel image generation process
    """
    try:
        manager = DataManager(INPUT_FILE_PATH)
        df = manager.load_input_data()

        # Create an ImageGenerator instance and start the parallel image generation process
        generator = ImageGenerator(device='cuda' if torch.cuda.is_available() else 'cpu')

        with concurrent.futures.ThreadPoolExecutor() as executor:
            futures = [executor.submit(generator.generate_image, row['freebase_id'], row['summary']) for _, row in tqdm(df.iterrows(), total=len(df), desc="Generating images")]
            concurrent.futures.wait(futures)
    except Exception as e:
        print(f"An error occurred: {e}")
        sys.exit(1)

if __name__ == '__main__':
    main()