In [None]:
!python3 --version

In [None]:
import torch
torch.cuda.set_device(1)
if torch.cuda.is_available():
    current_gpu = torch.cuda.current_device()
    print(f"Current default GPU index: {current_gpu}")
    print(f"Current default GPU name: {torch.cuda.get_device_name(current_gpu)}")
else:
    print("No GPUs available.")

In [None]:
from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image
from diffusers.utils import load_image

import torch

pipe = AutoPipelineForImage2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
pipe.to("cuda")


In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

def display_images_side_by_side(image_paths, captions):
    # Number of images
    n = len(image_paths)
    
    # Set up the figure with subplots
    fig, axes = plt.subplots(1, n, figsize=(4, 2))  # Adjust figure size as needed
    
    # Loop through images, their axes, and captions
    for ax, img_path, caption in zip(axes, image_paths, captions):
        # Load and display the image
        img = mpimg.imread(img_path)
        ax.imshow(img)
        ax.axis('off')  # Turn off axis
        
        # Set the caption
        ax.set_title(caption, fontsize=10, pad=10)  # Adjust font size and padding as needed
        
    plt.tight_layout()
    plt.show()

## image generation with sdxl_turbo

In [None]:
# dataset
import torchvision

from avalanche.benchmarks import SplitMNIST, SplitCIFAR100
from avalanche.benchmarks.classic import SplitCIFAR100
from avalanche.benchmarks.classic import SplitCIFAR10
from avalanche.benchmarks.utils.data_loader import GroupBalancedDataLoader, ReplayDataLoader
from avalanche.benchmarks.generators import nc_benchmark, ni_benchmark
from avalanche.benchmarks.generators import filelist_benchmark, dataset_benchmark, \
                                            tensors_benchmark, paths_benchmark

from avalanche.logging import InteractiveLogger, TensorboardLogger, \
    WandBLogger, TextLogger, TensorboardLogger

from avalanche.training.plugins import EvaluationPlugin
from avalanche.evaluation.metrics import forgetting_metrics, accuracy_metrics, loss_metrics

from avalanche.training.plugins.checkpoint import CheckpointPlugin, \
    FileSystemCheckpointStorage
from avalanche.training.determinism.rng_manager import RNGManager
from avalanche.training import Naive, CWRStar, Replay, GDumb, \
    Cumulative, LwF, GEM, AGEM, EWC, AR1
from avalanche.models import SimpleMLP
from avalanche.training import Naive, CWRStar, Replay, GDumb, \
    Cumulative, LwF, GEM, AGEM, EWC, AR1
from avalanche.models import SimpleMLP
from avalanche.training.plugins import ReplayPlugin
from types import SimpleNamespace
from avalanche.training.storage_policy import ParametricBuffer, RandomExemplarsSelectionStrategy

# all imports

import torch
import os
from torch import cat, Tensor
from torch.nn import Module
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Subset, ConcatDataset, TensorDataset
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torchvision import datasets, transforms
import torch.optim.lr_scheduler # ?
from torchvision.transforms import Compose, ToTensor, Normalize, RandomCrop, CenterCrop, RandomHorizontalFlip, Resize
from torchvision.transforms.functional import center_crop
from torchvision.models import resnet18, ResNet18_Weights
from torchvision.utils import save_image
from torchvision.transforms.functional import pil_to_tensor


In [None]:
from sdxl_main import *

In [None]:

def count_txt_files(directory):
    """
    Counts the number of .txt files in the specified directory.

    Args:
    directory (str): The path to the directory to search for .txt files.

    Returns:
    int: The number of .txt files in the directory.
    """
    txt_count = 0
    # List all files and directories in the specified directory
    for entry in os.listdir(directory):
        # Construct full path
        full_path = os.path.join(directory, entry)
        # Check if it's a file with a .txt extension
        if os.path.isfile(full_path) and entry.endswith('.txt'):
            txt_count += 1
    
    return txt_count

def count_png_files(folder_path):
    # Ensure the folder exists
    if not os.path.exists(folder_path):
        print("The specified folder does not exist.")
        return 0
    
    # List all files in the directory
    files = os.listdir(folder_path)
    
    # Filter and count files that end with .txt
    txt_files_count = sum(1 for file in files if file.endswith('.png'))
    
    return txt_files_count

def count_jpeg_files(folder_path):
    # Ensure the folder exists
    if not os.path.exists(folder_path):
        print("The specified folder does not exist.")
        return 0
    
    # List all files in the directory
    files = os.listdir(folder_path)
    
    # Filter and count files that end with .txt
    txt_files_count = sum(1 for file in files if file.endswith('.JPEG'))
    
    return txt_files_count

# count_png_files('/storage3/enbo/saved_data/imageget_sdxl_llava_i2i_synfromreal_s8g2')

In [None]:
benchmark = SplitCIFAR100(n_experiences=20,
                          seed = 41,             
                          )

orders = benchmark.classes_order
order_list = [orders[x:x+5] for x in range(0, len(orders), 5)]
print(order_list)

# order_sample = [order[3:] for order in order_list]


In [None]:
import json
import os

def load_json(file_path):
    """
    Load a JSON file and return the data.

    :param file_path: Path to the JSON file.
    :return: Parsed JSON data.
    """
    with open(file_path, 'r') as file:
        data = json.load(file)
    return data

def process_labels(json_data):
    """
    Process the labels from the JSON data.

    :param json_data: The JSON data containing labels and class names.
    """
    for identifier, class_name in json_data.items():
        print(f"ID: {identifier}, Class Name: {class_name}")

def replace_spaces_with_underscores(class_names):
    modified_names = []
    for name in class_names:
        # Replace spaces with underscores
        modified_name = name.replace(' ', '_')
        modified_names.append(modified_name)
    return modified_names



In [None]:

from avalanche.benchmarks import SplitMNIST, SplitCIFAR100

stats = ((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

transform = transform_train = Compose([
    # Resize(224),
    # Resize(384),
    # RandomHorizontalFlip(),
    ToTensor(),
    # Normalize(*stats,inplace=True)
])

# Load the CIFAR-100 training set
trainset = torchvision.datasets.CIFAR100(root='data', train=True,download=True, transform=transform)

name_list = trainset.classes
integer_to_name = {i: name_list[i] for i in range(100)}
import json

def read_json(json_path):
# Step 1: Open the JSON file
    with open(json_path, 'r') as file:
    # Step 2: Load the JSON data
        data = json.load(file)
    return data


name_to_integer = {name_list[i]: i for i in range(100)}

# sythnthesis classes
benchmark = SplitCIFAR100(n_experiences=20,
                          seed = 41,             
                          )

orders = benchmark.classes_order
order_list = [orders[x:x+5] for x in range(0, len(orders), 5)]

syn_class_list = [order[3:] for order in order_list]
syn_classes = [item for lists in syn_class_list for item in lists]

real_classes = list(set([i for i in range(100)]) - set(syn_classes))
print(len(real_classes))



# order_sample = [order[3:] for order in order_list]
classname_list = []
label_list = []
classname_list_sep = []
for order_l in syn_class_list:
    label_list.append(order_l)
    cur_classname = [integer_to_name[i] for i in order_l]
    classname_list.append(cur_classname)
classname_list_sep = [item for lists in classname_list for item in lists]
label_list_sep = [item for lists in label_list for item in lists]
print(label_list_sep)

# label_list_sep = [order[3] for order in order_list]
# print(label_list_sep)

In [None]:

real_to_syn_id = read_json('/storage3/enbo/saved_data/cifar100_dict_40synfrom60real_id.json')
real_to_syn_name = read_json('/storage3/enbo/saved_data/cifar100_dict_40synfrom60real_name.json')


In [None]:
real_to_syn_id = {int(key): int(item) for key, item in real_to_syn_id.items()}

In [None]:
specific_integer_to_name = {key: integer_to_name[key] for key in label_list_sep if key in integer_to_name}

In [None]:
print(specific_integer_to_name)

In [None]:
real_list = list(set([i for i in range(100)]) - set(label_list_sep))
real_integer_to_name = {key: integer_to_name[key] for key in real_list if key in integer_to_name}

In [None]:
print(real_integer_to_name)

In [None]:
import re

def split_words(word):
    # Check if the word contains an underscore
    if '_' in word:
        # Split the compound word into individual words
        words = word.split('_')
        # Include the original compound word in the list
        words.append(word)
    else:
        # If it's a single word, just return it as a list
        words = [word]
    return words


def replace_words(text, word, new_word):
    # Extract words to replace from the input word
    words_to_replace = split_words(word)
    
    # Create a regular expression pattern to match the words
    pattern = re.compile(r'\b(?:' + '|'.join(re.escape(word) for word in words_to_replace) + r')\b', re.IGNORECASE)
    
    # Replace the matched words with the new word
    new_text = pattern.sub(new_word, text)
    
    return new_text

# Example usage
text = "The image features a chambered nautilus, a spiral-shaped shell, with a silver centerpiece. The nautilus is prominently displayed, occupying a significant portion of the image. The overall mood of the image is serene and captivating, as the nautilus's unique shape and the contrasting colors of the shell and"
new_word = "new_word"
word = "chambered_nautilus"
# Replace the words
result_text = replace_words(text, word, new_word)
print(result_text)


In [None]:
random.seed(42)
        indices1 = [random.randint(0, len(current_prompt_list)-1) for _ in range(num_image_replay)]
        current_prompt_list = [current_prompt_list[item] for item in indices1]

        random.seed(41)
        indices2 = [random.randint(0, len(startimage_file_paths)-1) for _ in range(num_image_replay)]
        current_init_image_list = [startimage_file_paths[item] for item in indices2]

In [None]:
# !rm -rf '/storage3/enbo/saved_data/imageget_sdxl_llava_i2i_synfromreal_s8g2'

In [None]:
                
def sdxl_img2img_matching_image_prompt_real2syn(pipe,
                                                class_dict, 
                                                image_size, 
                                                prompt_file_dict, 
                                                startimage_file_path, 
                                                integer_to_name, 
                                                dict_syn_to_real_id,
                                                generator_seed, 
                                                num_image_replay=50,
                                                folder_name="/content/sd_images"):
    
    """
    this is specified to matching pair of lavva prompt and images
    """

    if not os.path.exists(folder_name):
        os.makedirs(folder_name)

    for id, class_name in class_dict.items():
        # id, class_name are synthetic id and classname
        print("Generating images for class " + str(id) + ": " + class_name)

        class_file_path = os.path.join(folder_name, f"class{id}.txt")
        existing_images_count = 0
        
        num_images_to_generate = num_image_replay
        
        real_id = dict_syn_to_real_id[id] # this is the real class id that bert thinks is most similar to the current syn class
        real_class_name = integer_to_name[real_id]
        prompt_file_path = os.path.join(prompt_file_dict, f"class{real_id}.txt")
        current_path_list, current_prompt_list = extract_prompts_fromtxts_real2syn(prompt_file_path)

        # Generate the new images
        with open(class_file_path, "w") as file:
            for j, prompt in enumerate(current_prompt_list):
                # print(prompt)
                # syn_prompt = prompt.replace(real_class_name, class_name)
                syn_prompt = replace_words(prompt, real_class_name, class_name)
                print(syn_prompt)
                # syn_image_name = class_name + f"{j}.png"
                
                real_image_name = os.path.basename(current_path_list[j])
                syn_image_name = real_image_name.replace(real_class_name, class_name)

                starting_image_path = current_path_list[j]
                init_image = load_image(starting_image_path).resize((512, 512))

                # generated image is synthetic name
                generated_image_path = os.path.join(folder_name, syn_image_name)
                # int(num_inference_steps * strength)
                new_image = pipe(prompt=syn_prompt,image = init_image, 
                                strength = 0.8, 
                                num_inference_steps=25, 
                                generator = generator_seed, guidance_scale=2).images[0]
                resized_image = new_image.resize(image_size)

                resized_image.save(generated_image_path)

                file.write(f"{generated_image_path} {id}\n")

                print(f"Generated image {syn_image_name} for class {class_name}")
                
def sdxl_img2img_matching_image_prompt(class_dict, image_size, prompt_file_dict, startimage_file_path,
                                                          generator_seed, num_image_replay=50, folder_name="/content/sd_images"):
    
    """
    this is specified to matching pair of lavva prompt and images
    """
    # Create the folder if it doesn't exist
    random.seed(42)
    if not os.path.exists(folder_name):
        os.makedirs(folder_name)
    classes_to_process = list(class_dict.items())
    classes_processed = []

    while classes_to_process:
        id, class_name = classes_to_process.pop(0)
        print("Generating images for class " + str(id) + ": " + class_name)

        class_file_path = os.path.join(folder_name, f"class{id}.txt")

        prompt_file_path = os.path.join(prompt_file_dict, f"class{id}.txt")
        current_path_list, current_prompt_list = extract_prompts_fromtxts_real2syn(prompt_file_path)
        
        if len(current_prompt_list) < num_image_replay:
            print(f"Not enough prompts for class {id}. Skipping and will revisit later.")
            classes_to_process.append((id, class_name))
            continue

        # Generate the new images
        with open(class_file_path, "w") as file:
            for j, prompt in enumerate(current_prompt_list):
                # print(prompt)
                
                starting_image_path = current_path_list[j]
                image_name = os.path.basename(starting_image_path)

                init_image = load_image(starting_image_path).resize((512, 512))

            
                image_path = os.path.join(folder_name, image_name)
                new_image = pipe(prompt=prompt,image = init_image, 
                                strength = 0.8, 
                                num_inference_steps=20, 
                                generator = generator_seed, guidance_scale=2).images[0]
                resized_image = new_image.resize(image_size)

                resized_image.save(image_path)

                file.write(f"{image_path} {id}\n")

                print(f"Generated image {image_name} for class {class_name}")
        classes_processed.append((id, class_name))
        
        # Check if all classes are processed and the while loop should be terminated
        if len(classes_processed) == len(class_dict):
            break

In [None]:
def sdxl_img2img_moreimage_lessprompt_real2syn(pipe,
                                                class_dict, 
                                                image_size, 
                                                prompt_file_dict, 
                                                startimage_file_path, 
                                                integer_to_name, 
                                                dict_syn_to_real_id,
                                                generator_seed, 
                                                num_image_replay=50,
                                                folder_name="/content/sd_images"):
    
    """
    this is specified to matching pair of lavva prompt and images
    """

    if not os.path.exists(folder_name):
        os.makedirs(folder_name)

    for id, class_name in class_dict.items():
        # id, class_name are synthetic id and classname
        print("Generating images for class " + str(id) + ": " + class_name)

        class_file_path = os.path.join(folder_name, f"class{id}.txt")
        existing_images_count = 0
        
        num_images_to_generate = num_image_replay
        
        real_id = dict_syn_to_real_id[id] # this is the real class id that bert thinks is most similar to the current syn class
        real_class_name = integer_to_name[real_id]
        prompt_file_path = os.path.join(prompt_file_dict, f"class{real_id}.txt")
        current_path_list, current_prompt_list = extract_prompts_fromtxts_real2syn(prompt_file_path)
        
        current_prompt_list = current_prompt_list[:int(num_images_to_generate*0.4)]
        print('num of prompts: ', len(current_prompt_list))
        random.seed(42)
        indices1 = [random.randint(0, len(current_prompt_list)-1) for _ in range(num_image_replay)]
        current_prompt_list = [current_prompt_list[item] for item in indices1]
        print('num of prompts: ', len(current_prompt_list))
        

        # Generate the new images
        with open(class_file_path, "w") as file:
            for j, prompt in enumerate(current_prompt_list):
                # print(prompt)
                # syn_prompt = prompt.replace(real_class_name, class_name)
                syn_prompt = replace_words(prompt, real_class_name, class_name)
                print(syn_prompt)
                # syn_image_name = class_name + f"{j}.png"
                
                real_image_name = os.path.basename(current_path_list[j])
                syn_image_name = real_image_name.replace(real_class_name, class_name)

                starting_image_path = current_path_list[j]
                init_image = load_image(starting_image_path).resize((512, 512))

                # generated image is synthetic name
                generated_image_path = os.path.join(folder_name, syn_image_name)
                # int(num_inference_steps * strength)
                new_image = pipe(prompt=syn_prompt,image = init_image, 
                                strength = 0.8, 
                                num_inference_steps=25, 
                                generator = generator_seed, guidance_scale=2).images[0]
                resized_image = new_image.resize(image_size)

                resized_image.save(generated_image_path)

                file.write(f"{generated_image_path} {id}\n")

                print(f"Generated image {syn_image_name} for class {class_name}")
                                    
                    

# more image less prompt 40 synthetic generation, each 500 images

In [None]:
print(specific_integer_to_name)

In [None]:
generator1 = torch.Generator(device="cuda").manual_seed(42)
# create_sdxl_data_fixed_prompts_randommultiple(integer_to_name, image_size = (32,32), prompt_file_dict = 'saved_data/llava_saved_data_310/llava_prompt_50/', generator_seed = generator1, num_image_replay=50, folder_name='saved_data/sd_turbo_500images_llava_firstthreeclasses/')

sdxl_img2img_moreimage_lessprompt_real2syn(pipe, 
                                            specific_integer_to_name, image_size = (32,32),
                                            prompt_file_dict = 'saved_data/llava_cifar100_real60_500/',
                                startimage_file_path = 'saved_data/cifar_train_all_fortest',
                                integer_to_name = integer_to_name,
                                dict_syn_to_real_id = real_to_syn_id,
                                generator_seed = generator1, num_image_replay=500, 
                                folder_name='/storage3/enbo/saved_data/cifar100_sdxl_llava_i2i_40synfrom60real_i2i_step20_40percentpromptallimg_500') # starting image and prompt from the similar class within the same experience


In [None]:
generator1 = torch.Generator(device="cuda").manual_seed(42)
# create_sdxl_data_fixed_prompts_randommultiple(integer_to_name, image_size = (32,32), prompt_file_dict = 'saved_data/llava_saved_data_310/llava_prompt_50/', generator_seed = generator1, num_image_replay=50, folder_name='saved_data/sd_turbo_500images_llava_firstthreeclasses/')

sdxl_img2img_matching_image_prompt_real2syn(pipe, 
                                            specific_integer_to_name, image_size = (32,32),
                                            prompt_file_dict = 'saved_data/llava_cifar100_real60_500/',
                                startimage_file_path = 'saved_data/cifar_train_all_fortest',
                                integer_to_name = integer_to_name,
                                dict_syn_to_real_id = real_to_syn_id,
                                generator_seed = generator1, num_image_replay=500, 
                                folder_name='/storage3/enbo/saved_data/cifar100_sdxl_llava_i2i_40synfrom60real_i2i_step20') # starting image and prompt from the similar class within the same experience

# sdxl_img2img_matching_image_prompt(specific_integer_to_name, 
#                                    image_size = (32,32), 
#                                    prompt_file_dict = '/storage3/enbo/saved_data/llava_cifar/cifar100_long_syn20', 
#                                    startimage_file_path = 'saved_data/cifar_train_all_fortest',
#                                     generator_seed = generator1, 
#                                    num_image_replay=500, 
#                                 folder_name='/storage3/enbo/saved_data/cifar100_sdxl_llava_i2i_20moresyntheticreal') # starting image and prompt from the similar class within the same experience


# generate real 60 class from real image and prompt

In [None]:
generator1 = torch.Generator(device="cuda").manual_seed(42)
create_sdxl_data_img2img_matching_image_prompt(real_integer_to_name, image_size = (224,224),
                                               prompt_file_dict = '/scratch/local/ssd/enbo/saved_data/llava_imagenet/imagenet_long_60real',
                                               startimage_file_path = '/scratch/local/ssd/enbo/saved_data/imagenet_train_data_allimages',
                                               generator_seed = generator1, num_image_replay=1300, 
                                               folder_name='/storage3/enbo/saved_data/imageget_sdxl_llava_i2i_allimageprompt_s8g2')
    

## syn-real from all real images, 10% long prompts

In [None]:
# generator1 = torch.Generator(device="cuda").manual_seed(42)
# create_sdxl_data_moreimages_lessprompts_img2img(real_integer_to_name, image_size = (32,32),
#                                                prompt_file_dict = '/storage3/enbo/saved_data/realimages_10persent_longprompts/',
#                                                startimage_file_path = 'saved_data/cifar_train_all_fortest/',
#                                                generator_seed = generator1, num_image_replay=500, 
#                                                folder_name='/storage3/enbo/saved_data/sdxl_llava_i2i_allimage10percentprompt_60real')

generator1 = torch.Generator(device="cuda").manual_seed(42)
# imagenet 224
create_sdxl_data_moreimages_lessprompts_img2img(real_integer_to_name, image_size = (224,224),
                                               prompt_file_dict = '/scratch/local/ssd/enbo/saved_data/llava_imagenet/imagenet_long_60real_10percent',
                                               startimage_file_path = '/scratch/local/ssd/enbo/saved_data/imagenet_train_data_allimages',
                                               generator_seed = generator1, num_image_replay=1300, 
                                               folder_name='/storage3/enbo/saved_data/imageget_sdxl_llava_i2i_allimage10percentprompt_60real')

## syn-real from 10% real images, prompts long

In [None]:
generator1 = torch.Generator(device="cuda").manual_seed(42)

create_sdxl_data_fixed_prompts_randommultiple_img2img_update(real_integer_to_name, image_size = (32,32),
                                               prompt_file_dict = '/storage3/enbo/saved_data/realimages_10persent_longprompts/',
                                               startimage_file_path = 'saved_data/llava_saved_data_310/cifar_50/',
                                               generator_seed = generator1, num_image_replay=500, 
                                               folder_name='/storage3/enbo/saved_data/sdxl_llava_i2i_10percentimageprompt_60real')

In [None]:
generator1 = torch.Generator(device="cuda").manual_seed(42)

create_sdxl_data_fixed_prompts_randommultiple_img2img(specific_integer_to_name, image_size = (32,32),
                                               prompt_file_dict = 'saved_data/llava_saved_data_310/llava_prompt_50/',
                                               startimage_file_path = 'saved_data/llava_saved_data_310/cifar_50/',
                                               generator_seed = generator1, num_image_replay=500, 
                                               folder_name='/scratch/local/ssd/enbo/saved_data/sdxl_llava_i2i_10percentimageprompt')

## mixed data llava multiple generation image2image 500 real2syn

In [None]:
generator1 = torch.Generator(device="cuda").manual_seed(42)
# create_sdxl_data_fixed_prompts_randommultiple(integer_to_name, image_size = (32,32), prompt_file_dict = 'saved_data/llava_saved_data_310/llava_prompt_50/', generator_seed = generator1, num_image_replay=50, folder_name='saved_data/sd_turbo_500images_llava_firstthreeclasses/')

create_sdxl_data_img2img_matching_image_prompt_real2syn(specific_integer_to_name, image_size = (32,32),
                                               prompt_file_dict = 'saved_data/llava_cifar100_real60_500/',
                                startimage_file_path = 'saved_data/cifar_train_all_fortest',
                                integer_to_name = integer_to_name,
                                dict_syn_to_real_id = real_to_syn_id,
                                generator_seed = generator1, num_image_replay=500, 
                                folder_name='saved_data/sdxl_llava_synfromreal_s8g2') # starting image and prompt from the similar class within the same experience

## llava prompt multiple generation image2image imagenet

In [None]:
generator1 = torch.Generator(device="cuda").manual_seed(42)

# create_sdxl_data_fixed_prompts_randommultiple_img2img(specific_integer_to_name, image_size = (224,224),
                                               prompt_file_dict = 'saved_data/ImageNet/prompt_130_1/',
                                               startimage_file_path = 'saved_data/ImageNet/ImageNet_train_renamed_10percent/',
                                               generator_seed = generator1, num_image_replay=1300, 
                                               folder_name='saved_data/ImageNet/ImageNet_sdxl_llavaprompt_3real_i2i')

## mixed data llava multiple generation image2image 

In [None]:
generator1 = torch.Generator(device="cuda").manual_seed(42)
# create_sdxl_data_fixed_prompts_randommultiple(integer_to_name, image_size = (32,32), prompt_file_dict = 'saved_data/llava_saved_data_310/llava_prompt_50/', generator_seed = generator1, num_image_replay=50, folder_name='saved_data/sd_turbo_500images_llava_firstthreeclasses/')

create_sdxl_data_img2img_matching_image_prompt(integer_to_name, image_size = (32,32),
                                               prompt_file_dict = 'saved_data/llava_saved_data_310/llava_prompt_50/',
                                               startimage_file_path = 'saved_data/llava_saved_data_310/cifar_50/',
                                               generator_seed = generator1, num_image_replay=50, 
                                               folder_name='saved_data/sd_turbo_i2i_50all')

## mixed data llava multiple generation

In [None]:
generator1 = torch.Generator(device="cuda").manual_seed(42)
create_sdxl_data_fixed_prompts_randommultiple(specific_integer_to_name, image_size = (32,32), prompt_file_dict = 'saved_data/llava_saved_data_310/llava_prompt_50/', generator_seed = generator1, num_image_replay=500, folder_name='saved_data/sd_turbo_500images_llava_firstthreeclasses/')

## diverse prompts

In [None]:
generator1 = torch.Generator(device="cuda").manual_seed(41)
create_sdxl_data(specific_integer_to_name, image_size = (32,32), generator_seed = generator1, num_image_replay=500, folder_name='saved_data/sd_turbo_500images/')


## base prompt

In [None]:
generator1 = torch.Generator(device="cuda").manual_seed(41)
create_sdxl_data_baseprompt(specific_integer_to_name, image_size = (32,32), generator_seed = generator1, num_image_replay=500, folder_name='saved_data/sd_turbo_500images_baseprompt/')


In [None]:
import os

def count_txt_files(folder_path):
    # Ensure the folder exists
    if not os.path.exists(folder_path):
        print("The specified folder does not exist.")
        return 0
    
    # List all files in the directory
    files = os.listdir(folder_path)
    
    # Filter and count files that end with .txt
    txt_files_count = sum(1 for file in files if file.endswith('.txt'))
    
    return txt_files_count

# Specify the folder path here
folder_path = 'saved_data/cifar_train500_2syn_i2i_step10/'

# Call the function and print the result
count = count_txt_files(folder_path)
print(f"There are {count} .txt files in the folder.")


In [None]:
def read_and_print_file(file_path):
    try:
        with open(file_path, 'r') as file:
            lines = file.readlines()  # Read all lines in the file
            print(len(lines))
            for line in lines:
                print(line.strip())  # Print each line, stripping newline characters
    except FileNotFoundError:
        print(f"Error: The file at {file_path} does not exist.")
    except Exception as e:
        print(f"An error occurred: {e}")

# Example usage:
file_path = 'saved_data/cifar_train500/class36.txt'
read_and_print_file(file_path)
