In [None]:
import os
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm

from constants import DatasetPath

In [None]:
#Utility constants

#Make sure the Dataset path has been specified in the 'constants.py' file
DATASET_PATH = DatasetPath.effectivePath

def DATASET_DIRS():
	return os.listdir(DATASET_PATH)

#CYCLE_GAN and PRO_GAN contains both real and fake images

REAL_DIRS = ["afhq", "celebahq", "coco", "ffhq", "imagenet", "landscape", "lsun", "metfaces", "cycle_gan", "pro_gan"]

FAKE_DIRS = ["big_gan", "cips", "cycle_gan", "ddpm", "denoising_diffusion_gan", "diffusion_gan", "face_synthetics", 
				 "gansformer", "gau_gan", "generative_inpainting", "glide", "lama", "latent_diffusion", "mat", "palette", 
				 "pro_gan", "projected_gan", "sfhq", "stable_diffusion", "star_gan", "stylegan1", "stylegan2", "stylegan3",
				 "taming_transformer", "vq_diffusion"]

csv_columns_name = ['filename', 'image_path', 'target', 'category']
REAL_CSV_PATH = DATASET_PATH + "real.csv"
FAKE_CSV_PATH = DATASET_PATH + "fake.csv"

In [None]:
def checkDatasetSync():
	"""
	Checks wether data loss occured during the download and unzip process.
	"""
	sync = True
	
	for folder in tqdm(DATASET_DIRS()):
		if folder not in REAL_DIRS + FAKE_DIRS:
			# exclude other files 
			if not (".csv" in folder or folder == "fourier"):
				print("Folder " + folder + " does not exist.")
				sync = False

	if sync: print("Dataset correctly synchronized.")

In [None]:
checkDatasetSync()

In [None]:
def real_fake_csv_split():
	"""
	Calls 'create_csv()' to perform a metadata split creating different csv file for both real and fake images.
		
	"""
	create_csv("real")
	create_csv("fake")

def create_csv(target):
	"""
	Performs a metadata split creating different csv file for both real and fake images.

	Parametres
	----------
		target (str): 'real' or 'fake', used to identify the image source.   
	"""
	csv = target + ".csv"
	dir_group = REAL_DIRS if target == "real" else FAKE_DIRS
	csv_path = REAL_CSV_PATH if target == "real" else FAKE_CSV_PATH
	
	if csv in DATASET_DIRS():
		print(csv + " already exists.")
		
		return
	
	csv_df = pd.DataFrame(columns = csv_columns_name)

	# Images collection process
	for dir in tqdm(dir_group, desc="Collecting " + target + " images"):
		csv_df = collect_metadatas(csv_df, dir, 0)

	# Update filenames
	for index, row in tqdm(csv_df.iterrows(), total=csv_df.shape[0], desc="Updating filenames to avoid duplicates"):
		csv_df.at[index, "filename"] = row["image_path"].replace("/","+")

	# DataFrame-to-csv conversion process
	df_to_csv(csv_df, csv, csv_path)

def collect_metadatas(df, dir, mode, size=None): 
	"""
	Collects metadatas from a given directory into a Pandas DataFrame.

	Parametres
	----------
		df (pd.DataFrame): the DataFrame to collect metadatas into.
		dir (str): the directory where the metadata file is stored.
		mode (int): how the DataFrame is built
			> 0: for real-fake split.
			> 1: for balanced dataset partitioning (contains both fake and real images).
		size (int, optional): optional parameter, used within mode 1. Default: None.

	Returns
	-------
		pd.DataFrame: the updated Dataframe.
	"""
	with open(DATASET_PATH + dir + "\\metadata.csv", mode='r', newline='') as current_csv:

		current_csv = pd.read_csv(current_csv)
		image_path = dir + "/" + current_csv["image_path"]
		current_csv["image_path"] = image_path.replace("/","\\")
		
		match mode:
			case 0:
				df = pd.concat([df,current_csv[current_csv['target'] == 0]], ignore_index=True) if dir in REAL_DIRS else pd.concat([df,current_csv[current_csv['target'] != 0]], ignore_index=True)
				return df
			
			case 1:
				#In this mode, the number of sampled images is also returned
				length = len(current_csv)
				if size > length:
					print("Error: sampling size cannot exceed the number of tuples in the dataframe.")
					print("Only " + length + " tuples were sampled.")
					df = pd.concat([df,current_csv.sample(length)], ignore_index=True)
					return df, length
				else:
					df = pd.concat([df,current_csv.sample(size)], ignore_index=True)
					return df, size
				
			case _:
				return "Unkwon mode. Consult function doc for recognised modes."

def df_to_csv(df, filename, path):
	"""
	Splits the DataFrame in chunks to enable tqdm progress visualization while converting the DataFrame into a '.csv' file.

	Parametres
	----------
		df (pd.DataFrame): the DataFrame to convert.
		filename (str): the desired file name (comprehensive of '.csv' extension).
		path (str): the path where the '.csv' will be stored.
	"""
	chunks = np.array_split(df.index, 100)
	for chunck, subset in enumerate(tqdm(chunks, desc="Creating \'" + filename + "\' file")):
		if chunck == 0: # first row
			df.loc[subset].to_csv(path, mode='w', index=False)
		else:
			 df.loc[subset].to_csv(path, header=None, mode='a', index=False)
	

	print("\'" + filename + "\' has been successfully created.")

In [None]:
real_fake_csv_split()

In [None]:
def create_dataset_partition(size, real_dirs, fake_dirs):
	"""
	Creates a dataset partition of the given size with a 1:1 ratio between Real and Fake images taken from the given directories

	Parametres
	----------
		size (int): the requested number of tuples in the partition.
		real_dirs (list of str): names of the directories containing real images
		fake_dirs (list of str): names of the directories containing fake images
	"""
	# Checks if a dataset partition has already been created
	if "dataset_partition.csv" in DATASET_DIRS():
		print("A Dataset Partition already exists, if you want to create a new one make sure to delete the old one first.")
		
		return

	df = pd.DataFrame(columns = csv_columns_name)
	
	max_r_size = max_f_size = round(size / 2)
	min_r_size = round(max_r_size / len(real_dirs))

	remaining_size = 0
	for dir in tqdm(real_dirs, desc="Collecting metadatas from Real Directories"):
		#Tries sampling an exact amount of tuples
		#case 1: sampling size is equal or smaller than the number of images in the directory
			# > No actions required
		#case 2: sampling size is greater than the number of images in the directory
			# > Errors are handled in the 'collect_metadatas' function
			# > The following code ensure total sampling size is reached
		
		#sampled_size is the number of tuples actually sampled from the metadata.csv file in the currect directory
		df, sampled_size = collect_metadatas(df, dir, 1, min_r_size + remaining_size)
		
		#if sampled_size is smaller than the what it should be
		if(sampled_size < min_r_size):
			#calculate the remaining tuples so that the function tries to sample it from the next directory
			remaining_size = min_r_size - sampled_size
		#if sampled_size is correct than there are no remaining tuples to sample
		else:
			remaining_size = 0
	
	if remaining_size > 0: 
		print("Correct sampling size could not be reached from the given real_dirs")
		print("Actual sampled size: " + max_r_size-remaining_size)
		max_f_size = max_f_size - remaining_size

	min_f_size = round(max_f_size / len(fake_dirs))
	remaining_size = 0
	
	for dir in tqdm(fake_dirs, desc="Collecting_metadatas from Fake Directories"):

		df, sampled_size = collect_metadatas(df, dir, 1, min_f_size + remaining_size)

		if(sampled_size < min_f_size):
			remaining_size = min_f_size - sampled_size
		else:
			remaining_size = 0

	if remaining_size > 0: 
		print("Correct sampling size could not be reached from the given fake_dirs")
		print("Actual sampled size: " + max_f_size-remaining_size)
		max_f_size = max_f_size - remaining_size

	
	for index, row in tqdm(df.iterrows(), total=df.shape[0], desc="Updating filenames to avoid duplicates"):
		df.at[index, "filename"] = row["image_path"].replace("/","+")


	df_to_csv(df, "dataset_partition.csv", DATASET_PATH + "dataset_partition.csv")


In [None]:
#!!The following code it's only for convenience purpose!!

print(REAL_DIRS)
print(FAKE_DIRS)

#CYCLE_GAN and PRO_GAN contains both real and fake images

In [None]:
#dirs name above
real_dirs = ['coco', 'lsun', 'imagenet']
fake_dirs = ['big_gan', 'latent_diffusion', 'taming_transformer']

create_dataset_partition(60000, real_dirs, fake_dirs)