# <center>MySQL to Cloud Spanner Migration (or Bulk Load)

In [None]:
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

#### References

- [DataprocPySparkBatchOp reference](https://google-cloud-pipeline-components.readthedocs.io/en/google-cloud-pipeline-components-1.0.0/google_cloud_pipeline_components.experimental.dataproc.html)
- [Kubeflow SDK Overview](https://www.kubeflow.org/docs/components/pipelines/sdk/sdk-overview/)
- [Dataproc Serverless in Vertex AI Pipelines tutorial](https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/ml_ops/stage3/get_started_with_dataproc_serverless_pipeline_components.ipynb)
- [Build a Vertex AI Pipeline](https://cloud.google.com/vertex-ai/docs/pipelines/build-pipeline)

This notebook is built to run a Vertex AI User-Managed Notebook using the default Compute Engine Service Account.  
Check the Dataproc Serverless in Vertex AI Pipelines tutorial linked above to learn how to setup a different Service Account.  

#### Permissions

Make sure that the service account used to run the notebook has the following roles:

- roles/aiplatform.serviceAgent
- roles/aiplatform.customCodeServiceAgent
- roles/storage.objectCreator
- roles/storage.objectViewer
- roles/dataproc.editor
- roles/dataproc.worker

## Step 1: Install Libraries
#### Run Step 1 one time for each new notebook instance

In [None]:
!pip3 install pymysql SQLAlchemy
!pip3 install --upgrade google-cloud-pipeline-components kfp --user -q

In [None]:
!sudo apt-get update -y
!sudo apt-get install default-jdk -y
!sudo apt-get install maven -y

#### Once you've installed the additional packages, you may need to restart the notebook kernel so it can find the packages.

Uncomment & Run this cell if you have installed anything from above commands

In [None]:
# import os
# import IPython
# if not os.getenv("IS_TESTING"):
#     app = IPython.Application.instance()
#     app.kernel.do_shutdown(True)

## Step 2: Import Libraries

In [None]:
import sqlalchemy
import math
import pymysql
import google.cloud.aiplatform as aiplatform
from kfp import dsl
from kfp.v2 import compiler
from datetime import datetime
import time
import copy
import json
import pandas as pd
from google_cloud_pipeline_components.experimental.dataproc import DataprocSparkBatchOp
from pathlib import Path
import os

## Step 3: Assign Parameters

### Step 3.1 Common Parameters
 
- PROJECT : GCP project-id
- REGION : GCP region (us-central1)
- GCS_STAGING_LOCATION : GCS staging location to be used for this notebook to store artifacts 
- SUBNET : VPC subnet
- JARS : list of jars. For this notebook mysql connector and avro jar is required in addition with the dataproc template jars
- MAX_PARALLELISM : Parameter for number of jobs to run in parallel default value is 2

In [None]:
PROJECT = ""
REGION = "" # eg: us-central1 (any valid GCP region)
GCS_STAGING_LOCATION = "" # eg: gs://my-staging-bucket/sub-folder
SUBNET = "projects/{project}/regions/{region}/subnetworks/{subnet}"
MAX_PARALLELISM = 5 # max number of tables which will migrated parallelly 

# Do not change this parameter unless you want to refer below JARS from new location
JARS = [GCS_STAGING_LOCATION + "/jars/mysql-connector-java-8.0.29.jar","file:///usr/lib/spark/external/spark-avro.jar"]

### Step 3.2 MYSQL to Spanner Parameters
- MYSQL_HOST : MYSQL instance ip address
- MYSQL_PORT : MySQL instance port
- MYSQL_USERNAME : MYSQL username
- MYSQL_PASSWORD : MYSQL password
- MYSQL_DATABASE : name of database that you want to migrate
- MYSQLTABLE_LIST : list of tables you want to migrate eg: ['table1','table2'] else provide an empty list for migration whole database eg : [] 
- MYSQL_OUTPUT_SPANNER_MODE : output mode for MYSQL data one of (overwrite|append). Use append if schema already exists in Spanner
- SPANNER_INSTANCE : cloud spanner instance name
- SPANNER_DATABASE : cloud spanner database name

Spanner requires primary key for each table
- SPANNER_TABLE_PRIMARY_KEYS : provide dictionary of format {"table_name":"primary_key"} for tables which do not have primary key in MYSQL

In [None]:
MYSQL_HOST = ""
MYSQL_PORT = "3306"
MYSQL_USERNAME = ""
MYSQL_PASSWORD = ""
MYSQL_DATABASE = ""
MYSQLTABLE_LIST = [] # leave list empty for migrating complete database else provide tables as ['table1','table2']
MYSQL_OUTPUT_SPANNER_MODE = "overwrite" # one of overwrite|append (Use append when schema already exists in Spanner)
SPANNER_INSTANCE = ""
SPANNER_DATABASE = ""
SPANNER_TABLE_PRIMARY_KEYS = {} # provide table & pk column which do not have PK in MYSQL {"table_name":"primary_key"}

### Step 3.3 Notebook Configuration Parameters
Below variables shoulld not be changed unless required

In [None]:
cur_path = Path(os.getcwd())
WORKING_DIRECTORY = os.path.join(cur_path.parent.parent ,'python')

# If the above code doesn't fetches the correct path please
# provide complete path to python folder in your dataproc 
# template repo which you cloned 

# WORKING_DIRECTORY = "/home/jupyter/dataproc-templates/python/"
print(WORKING_DIRECTORY)

In [None]:
PYMYSQL_DRIVER = "mysql+pymysql"
JDBC_DRIVER = "com.mysql.cj.jdbc.Driver"
JDBC_URL = "jdbc:mysql://{}:{}/{}?user={}&password={}".format(MYSQL_HOST,MYSQL_PORT,MYSQL_DATABASE,MYSQL_USERNAME,MYSQL_PASSWORD)
MAIN_CLASS = "com.google.cloud.dataproc.templates.main.DataProcTemplate"
JAR_FILE = "dataproc-templates-1.0-SNAPSHOT.jar"
GRPC_JAR_PATH = "./grpc_lb/io/grpc/grpc-grpclb/1.40.1"
GRPC_JAR = "grpc-grpclb-1.40.1.jar"
LOG4J_PROPERTIES_PATH = "./src/test/resources"
LOG4J_PROPERTIES = "log4j-spark-driver-template.properties"
PIPELINE_ROOT = GCS_STAGING_LOCATION + "/pipeline_root/dataproc_pyspark"

# adding dataproc template JAR and grpc jar
JARS.append(GCS_STAGING_LOCATION + "/" + GRPC_JAR)
JARS.append(GCS_STAGING_LOCATION + "/" + JAR_FILE)

## Step 4: Generate MySQL Table List
This step creates list of tables for migration. If MYSQLTABLE_LIST is kept empty all the tables in the MYSQL_DATABASE are listed for migration otherwise the provided list is used

In [None]:
if len(MYSQLTABLE_LIST) == 0:
    DB = sqlalchemy.create_engine(
            sqlalchemy.engine.url.URL.create(
                drivername=PYMYSQL_DRIVER,
                username=MYSQL_USERNAME,
                password=MYSQL_PASSWORD,
                database=MYSQL_DATABASE,
                host=MYSQL_HOST,
                port=MYSQL_PORT
              )
            )
    with DB.connect() as conn:
        print("connected to database")
        results = DB.execute('show tables;').fetchall()
        print("Total Tables = ", len(results))
        for row in results:
            MYSQLTABLE_LIST.append(row[0])

print("list of tables for migration :")
print(MYSQLTABLE_LIST)

## Step 5: Get Primary Keys for tables not present in SPANNER_TABLE_PRIMARY_KEYS
For tables which do not have primary key provided in dictionary SPANNER_TABLE_PRIMARY_KEYS this step fetches primary key from MYSQL_DATABASE

In [None]:
DB = sqlalchemy.create_engine(
            sqlalchemy.engine.url.URL.create(
                drivername=PYMYSQL_DRIVER,
                username=MYSQL_USERNAME,
                password=MYSQL_PASSWORD,
                database=MYSQL_DATABASE,
                host=MYSQL_HOST,
                port=MYSQL_PORT
              )
            )
with DB.connect() as conn:
    for table in MYSQLTABLE_LIST:
        primary_keys = []
        if table not in SPANNER_TABLE_PRIMARY_KEYS:
            results = DB.execute("SHOW KEYS FROM {} WHERE Key_name = 'PRIMARY'".format(table)).fetchall()
            for row in results:
                primary_keys.append(row[4])
            if primary_keys:
                SPANNER_TABLE_PRIMARY_KEYS[table] = ",".join(primary_keys)
            else:
                SPANNER_TABLE_PRIMARY_KEYS[table] = ""

In [None]:
pkDF = pd.DataFrame({"table" : MYSQLTABLE_LIST, "primary_keys": list(SPANNER_TABLE_PRIMARY_KEYS.values())})
print("Below are identified primary keys for migrating mysql table to spanner:")
pkDF

## Step 6 Get Row Count of Tables and identify read partition column
This step uses PARTITION_THRESHOLD(default value is 1 million) parameter and any table having rows greater than PARTITION_THRESHOLD will be partitioned based on Primary Key

Get Primary keys for all tables to be migrated and find an integer column to partition on

In [None]:
PARTITION_THRESHOLD = 200000 #Number of rows fetched per spark executor
CHECK_PARTITION_COLUMN_LIST={}

In [None]:
with DB.connect() as conn:
    for table in MYSQLTABLE_LIST:
        qry = "SELECT TABLE_ROWS FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = '{}' and TABLE_NAME = '{}'".format(MYSQL_DATABASE,table)
        results = DB.execute(qry).fetchall()
        if results[0][0]>int(PARTITION_THRESHOLD):
            column_list=SPANNER_TABLE_PRIMARY_KEYS.get(table).split(",")
            if len(column_list) > 1:
                continue
            else:
                column = column_list[0];
                datatype = DB.execute("SELECT DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME   = '{0}' AND COLUMN_NAME  = '{1}'".format(table,column)).fetchall()[0][0]      
                if (datatype=="int" or datatype=="bigint" or datatype=="mediumint"):
                    lowerbound = DB.execute("SELECT min({0}) from {1}".format(column,table)).fetchall()
                    upperbound = DB.execute("SELECT max({0}) from {1}".format(column,table)).fetchall()
                    numberPartitions = math.ceil((upperbound[0][0]-lowerbound[0][0])/PARTITION_THRESHOLD)
                    CHECK_PARTITION_COLUMN_LIST[table]=[column,lowerbound[0][0],upperbound[0][0],numberPartitions]
                
                
prtDF = pd.DataFrame.from_dict(CHECK_PARTITION_COLUMN_LIST, orient='index', columns=['PK', 'Min Val', 'Max Val', 'Num Partitions'])
print("Below are identified partitioning scheme, which will be used to read tables exceeding PARTITION_THRESHOLD of {}".format(PARTITION_THRESHOLD))
prtDF

## Step 7: Calculate Parallel Jobs for MySQL to Cloud Spanner
This step uses MAX_PARALLELISM parameter to calculate number of parallel jobs to run

In [None]:
# calculate parallel jobs:
COMPLETE_LIST = copy.deepcopy(MYSQLTABLE_LIST)
PARALLEL_JOBS = len(MYSQLTABLE_LIST)//MAX_PARALLELISM
JOB_LIST = []
while len(COMPLETE_LIST) > 0:
    SUB_LIST = []
    for i in range(MAX_PARALLELISM):
        if len(COMPLETE_LIST)>0 :
            SUB_LIST.append(COMPLETE_LIST[0])
            COMPLETE_LIST.pop(0)
        else:
            break
    JOB_LIST.append(SUB_LIST)
print("list of tables for execution : ")
print(JOB_LIST)

## Step 8: Create JAR files and Upload to GCS
#### Run Step 8 one time for each new notebook instance

In [None]:
%cd $WORKING_DIRECTORY

#### Setting PATH variables for JDK and Maven and executing MAVEN build

In [None]:
!wget https://downloads.mysql.com/archives/get/p/3/file/mysql-connector-java-8.0.29.tar.gz
!tar -xf mysql-connector-java-8.0.29.tar.gz
!mvn clean spotless:apply install -DskipTests 
!mvn dependency:get -Dartifact=io.grpc:grpc-grpclb:1.40.1 -Dmaven.repo.local=./grpc_lb 

#### copying JARS files to GCS_STAGING_LOCATION

In [None]:
!gsutil cp target/$JAR_FILE $GCS_STAGING_LOCATION/$JAR_FILE
!gsutil cp $GRPC_JAR_PATH/$GRPC_JAR $GCS_STAGING_LOCATION/$GRPC_JAR
!gsutil cp $LOG4J_PROPERTIES_PATH/$LOG4J_PROPERTIES $GCS_STAGING_LOCATION/$LOG4J_PROPERTIES
!gsutil cp mysql-connector-java-8.0.29/mysql-connector-java-8.0.29.jar $GCS_STAGING_LOCATION/jars/mysql-connector-java-8.0.29.jar

## Step 9: Execute Pipeline to Migrate tables from MySQL to Spanner

In [None]:
mysql_to_spanner_jobs = []

In [None]:
def migrate_mysql_to_spanner(EXECUTION_LIST):
    EXECUTION_LIST = EXECUTION_LIST
    aiplatform.init(project=PROJECT,staging_bucket=GCS_STAGING_LOCATION)
    
    @dsl.pipeline(
        name="java-mysql-to-spanner-spark",
        description="Pipeline to get data from mysql to spanner",
    )
    def pipeline(
        PROJECT_ID: str = PROJECT,
        LOCATION: str = REGION,
        MAIN_CLASS: str = MAIN_CLASS,
        JAR_FILE_URIS: list = JARS,
        SUBNETWORK_URI: str = SUBNET,
        FILE_URIS: list = [GCS_STAGING_LOCATION + "/" + LOG4J_PROPERTIES]
    ):
        for table in EXECUTION_LIST:
            BATCH_ID = "mysql2spanner-{}-{}".format(table,datetime.now().strftime("%s")).replace('_','-').lower()
            mysql_to_spanner_jobs.append(BATCH_ID)
            if table in CHECK_PARTITION_COLUMN_LIST.keys():
                TEMPLATE_SPARK_ARGS = [
                "--template=JDBCTOSPANNER",
                "--templateProperty", "project.id={}".format(PROJECT),
                "--templateProperty", "jdbctospanner.jdbc.url={}".format(JDBC_URL),
                "--templateProperty", "jdbctospanner.jdbc.driver.class.name={}".format(JDBC_DRIVER),
                "--templateProperty", "jdbctospanner.sql=select * from {}".format(table),
                "--templateProperty", "jdbctospanner.output.instance={}".format(SPANNER_INSTANCE),
                "--templateProperty", "jdbctospanner.output.database={}".format(SPANNER_DATABASE),
                "--templateProperty", "jdbctospanner.output.table={}".format(table),
                "--templateProperty", "jdbctospanner.output.saveMode={}".format(MYSQL_OUTPUT_SPANNER_MODE.capitalize()),
                "--templateProperty", "jdbctospanner.output.primaryKey={}".format(SPANNER_TABLE_PRIMARY_KEYS[table]),
                "--templateProperty", "jdbctospanner.output.batchInsertSize=200",
                "--templateProperty", "jdbctospanner.sql.partitionColumn={}".format(CHECK_PARTITION_COLUMN_LIST[table][0]),
                "--templateProperty", "jdbctospanner.sql.lowerBound={}".format(CHECK_PARTITION_COLUMN_LIST[table][1]),
                "--templateProperty", "jdbctospanner.sql.upperBound={}".format(CHECK_PARTITION_COLUMN_LIST[table][2]),
                "--templateProperty", "jdbctospanner.sql.numPartitions={}".format(CHECK_PARTITION_COLUMN_LIST[table][3]),
                ]
            else:
                TEMPLATE_SPARK_ARGS = [
                "--template=JDBCTOSPANNER",
                "--templateProperty", "project.id={}".format(PROJECT),
                "--templateProperty", "jdbctospanner.jdbc.url={}".format(JDBC_URL),
                "--templateProperty", "jdbctospanner.jdbc.driver.class.name={}".format(JDBC_DRIVER),
                "--templateProperty", "jdbctospanner.sql=select * from {}".format(table),
                "--templateProperty", "jdbctospanner.output.instance={}".format(SPANNER_INSTANCE),
                "--templateProperty", "jdbctospanner.output.database={}".format(SPANNER_DATABASE),
                "--templateProperty", "jdbctospanner.output.table={}".format(table),
                "--templateProperty", "jdbctospanner.output.saveMode={}".format(MYSQL_OUTPUT_SPANNER_MODE.capitalize()),
                "--templateProperty", "jdbctospanner.output.primaryKey={}".format(SPANNER_TABLE_PRIMARY_KEYS[table]),
                "--templateProperty", "jdbctospanner.output.batchInsertSize=200",
                ]
            _ = DataprocSparkBatchOp(
                project=PROJECT_ID,
                location=LOCATION,
                batch_id=BATCH_ID,
                main_class=MAIN_CLASS,
                jar_file_uris=JAR_FILE_URIS,
                file_uris=FILE_URIS,
                subnetwork_uri=SUBNETWORK_URI,
                runtime_config_version="1.1", # issue 665
                args=TEMPLATE_SPARK_ARGS
            )
            time.sleep(1)

    compiler.Compiler().compile(pipeline_func=pipeline, package_path="pipeline.json")

    pipeline = aiplatform.PipelineJob(
            display_name="pipeline",
        template_path="pipeline.json",
        pipeline_root=PIPELINE_ROOT,
        enable_caching=False,
        )
    pipeline.run()

In [None]:
for execution_list in JOB_LIST:
    print(execution_list)
    migrate_mysql_to_spanner(execution_list)

## Step 10: Get status for tables migrated from MySql to Spanner

In [None]:
def get_bearer_token():
    
    try:
        #Defining Scope
        CREDENTIAL_SCOPES = ["https://www.googleapis.com/auth/cloud-platform"]

        #Assining credentials and project value
        credentials, project_id = google.auth.default(scopes=CREDENTIAL_SCOPES)

        #Refreshing credentials data
        credentials.refresh(requests.Request())

        #Get refreshed token
        token = credentials.token
        if token:
            return (token,200)
        else:
            return "Bearer token not generated"
    except Exception as error:
        return ("Bearer token not generated. Error : {}".format(error),500)

In [None]:
from google.auth.transport import requests
import google
token = get_bearer_token()
if token[1] == 200:
    print("Bearer token generated")
else:
    print(token)

In [None]:
import requests

mysql_to_spanner_status = []
job_status_url = "https://dataproc.googleapis.com/v1/projects/{}/locations/{}/batches/{}"
for job in mysql_to_spanner_jobs:
    auth = "Bearer " + token[0]
    url = job_status_url.format(PROJECT,REGION,job)
    headers = {
      'Content-Type': 'application/json; charset=UTF-8',
      'Authorization': auth 
    }
    response = requests.get(url, headers=headers)
    mysql_to_spanner_status.append(response.json()['state'])

In [None]:
statusDF = pd.DataFrame({"table" : MYSQLTABLE_LIST,"mysql_to_spanner_job" : mysql_to_spanner_jobs, "mysql_to_spanner_status" : mysql_to_spanner_status})
statusDF

## Step 11: Validate row counts of migrated tables from MySQL to Cloud Spanner

In [None]:
# get mysql table counts
mysql_row_count = []
DB = sqlalchemy.create_engine(
            sqlalchemy.engine.url.URL.create(
                drivername=PYMYSQL_DRIVER,
                username=MYSQL_USERNAME,
                password=MYSQL_PASSWORD,
                database=MYSQL_DATABASE,
                host=MYSQL_HOST,
                port=MYSQL_PORT
              )
            )
with DB.connect() as conn:
    for table in MYSQLTABLE_LIST:
        results = DB.execute("select count(1) from {}".format(table)).fetchall()
        for row in results:
            mysql_row_count.append(row[0])

In [None]:
# get spanner table counts
spanner_row_count = []
from google.cloud import spanner

spanner_client = spanner.Client()
instance = spanner_client.instance(SPANNER_INSTANCE)
database = instance.database(SPANNER_DATABASE)

for table in MYSQLTABLE_LIST:
    with database.snapshot() as snapshot:
        qry = "@{{USE_ADDITIONAL_PARALLELISM=true}} select count(1) from {}".format(table)
        results = snapshot.execute_sql(qry)
        for row in results:
            spanner_row_count.append(row[0])

In [None]:
statusDF['mysql_row_count'] = mysql_row_count 
statusDF['spanner_row_count'] = spanner_row_count 
statusDF

## Post data loading activities
- You may create relationships (FKs), constraints and indexes (as needed).
- You may configure countinuous replication with [DataStream](https://cloud.google.com/datastream/docs/configure-your-source-mysql-database) or any other 3rd party tools.