## Loading all csv files from a data store to mysql

In [90]:
from pyspark.sql import *
from os import listdir, environ, path

### Function to get all the files to be loaded

In [91]:
def get_files(folder_path:str)->list:
    '''
    :param folder_path - path for all the files

    returns list of files
    '''
    files = listdir(folder_path)
    return ["file:///"+path.join(folder_path,file) for file in files]

### Function to return the dataframe

In [92]:
def get_df(file:str , format:str=None)->DataFrame:
    '''
    :param file - file name
    :param format - format of the file

    returns dataframe object
    '''
    if not format:
        format = path.basename(file).rsplit(".",1)[1]
        
    # print(f"Creating dataframe for {file} of format {format}")

    return spark.read \
            .format(format) \
            .option("path", file) \
            .option("header", "true").load()

### Function to drop the dataframe into the mysql table

In [93]:
def load_df_mysql(df:DataFrame, 
            host:str, 
            port:int,
            db:str,
            table:str,
            user:str,
            pwd:str,
            mode:str="append"
            )->None:
    '''
    Loads the dataframe into the mysql table.
    
    :param df : required dataframe
    :param db : database name
    :param table : mysql target table 
    :param user : user name
    :param pwd : password
    :param mode : write method <append, overwrite, ignore>
    '''
    # print(f"Writing the data for table {table}")
    
    df.write.format('jdbc') \
        .options(
            url = f"jdbc:mysql://{host}:{port}/{db}",
            driver = "com.mysql.jdbc.Driver",
            dbtable = table,
            user = user,
            password = pwd
        ).mode(mode).save()

### Main method

In [94]:
def main():
    for file in files:
        # if not 'emp_dept' in file: continue
        load_df_mysql(
            df = get_df(file),
            db=dbname,
            table=path.basename(file).rsplit(".")[0],
                host=host,
                port=port,
                user=user,
                pwd=password,
                mode=write_mode
            )
    else:
        print("Successfully loaded all the tables")

### Create spark session and required variables

In [95]:
spark = SparkSession.builder.appName("csv_loader").getOrCreate()
data_folder = path.join(environ['DATA_LAKE'],"sql_files")
files = get_files(data_folder)
host = 'localhost'
port = 3306
dbname = 'spark_tables'
user = 'debezium'
password = 'debezium'
write_mode = "overwrite"

#### Run main

In [96]:
if __name__ == '__main__':
    main()
    spark.stop()

Successfully loaded all the tables
