Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions airflow/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
'http_hook': ['HttpHook'],
'druid_hook': ['DruidHook'],
'jdbc_hook': ['JdbcHook'],
'dbapi_hook': ['DbApiHook'],
}

_import_module_attrs(globals(), _hooks)
Expand Down
139 changes: 139 additions & 0 deletions airflow/hooks/dbapi_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from datetime import datetime
import numpy
import logging

from airflow.hooks.base_hook import BaseHook
from airflow.utils import AirflowException


class DbApiHook(BaseHook):
"""
Abstract base class for sql hooks.
"""

"""
Override to provide the connection name.
"""
conn_name_attr = None

"""
Override to have a default connection id for a particular dbHook
"""
default_conn_name = 'default_conn_id'

"""
Override if this db supports autocommit.
"""
supports_autocommit = False

def __init__(self, **kwargs):
try:
self.conn_id_name = kwargs[self.conn_name_attr]
except NameError:
raise AirflowException("conn_name_attr is not defined")
except KeyError:
raise AirflowException(
self.conn_name_attr + " was not passed in the kwargs")

def get_pandas_df(self, sql, parameters=None):
'''
Executes the sql and returns a pandas dataframe
'''
import pandas.io.sql as psql
conn = self.get_conn()
df = psql.read_sql(sql, con=conn)
conn.close()
return df

def get_records(self, sql, parameters=None):
'''
Executes the sql and returns a set of records.
'''
conn = self.get_conn()
cur = self.get_cursor()
cur.execute(sql)
rows = cur.fetchall()
cur.close()
conn.close()
return rows

def get_first(self, sql, parameters=None):
'''
Executes the sql and returns a set of records.
'''
conn = self.get_conn()
cur = conn.cursor()
cur.execute(sql)
rows = cur.fetchone()
cur.close()
conn.close()
return rows

def run(self, sql, autocommit=False, parameters=None):
"""
Runs a command
"""
conn = self.get_conn()
if self.supports_autocommit:
conn.autocommit = autocommit
cur = conn.cursor()
cur.execute(sql)
conn.commit()
cur.close()
conn.close()

def get_cursor(self):
"""Returns a cursor"""
return self.get_conn().cursor()

def insert_rows(self, table, rows, target_fields=None, commit_every=1000):
"""
A generic way to insert a set of tuples into a table,
the whole set of inserts is treated as one transaction
"""
if target_fields:
target_fields = ", ".join(target_fields)
target_fields = "({})".format(target_fields)
else:
target_fields = ''
conn = self.get_conn()
cur = conn.cursor()
if self.supports_autocommit:
cur.execute('SET autocommit = 0')
conn.commit()
i = 0
for row in rows:
i += 1
l = []
for cell in row:
if isinstance(cell, basestring):
l.append("'" + str(cell).replace("'", "''") + "'")
elif cell is None:
l.append('NULL')
elif isinstance(cell, numpy.datetime64):
l.append("'" + str(cell) + "'")
elif isinstance(cell, datetime):
l.append("'" + cell.isoformat() + "'")
else:
l.append(str(cell))
values = tuple(l)
sql = "INSERT INTO {0} {1} VALUES ({2});".format(
table,
target_fields,
",".join(values))
cur.execute(sql)
if i % commit_every == 0:
conn.commit()
logging.info(
"Loaded {i} into {table} rows so far".format(**locals()))
conn.commit()
cur.close()
conn.close()
logging.info(
"Done loading. Loaded a total of {i} rows".format(**locals()))

def get_conn(self):
"""
Retuns a sql connection that can be used to retrieve a cursor.
"""
raise NotImplemented()
91 changes: 6 additions & 85 deletions airflow/hooks/mysql_hook.py
Original file line number Diff line number Diff line change
@@ -1,104 +1,25 @@
from datetime import datetime
import numpy
import logging

import MySQLdb

from airflow.hooks.base_hook import BaseHook
from airflow.hooks.dbapi_hook import DbApiHook


class MySqlHook(BaseHook):
class MySqlHook(DbApiHook):
'''
Interact with MySQL.
'''

def __init__(
self, mysql_conn_id='mysql_default'):
self.mysql_conn_id = mysql_conn_id
conn_name_attr = 'mysql_conn_id'
default_conn_name = 'mysql_default'
supports_autocommit = True

def get_conn(self):
"""
Returns a mysql connection object
"""
conn = self.get_connection(self.mysql_conn_id)
conn = self.get_connection(self.conn_id_name)
conn = MySQLdb.connect(
conn.host,
conn.login,
conn.password,
conn.schema)
return conn

def get_records(self, sql):
'''
Executes the sql and returns a set of records.
'''
conn = self.get_conn()
cur = conn.cursor()
cur.execute(sql)
rows = cur.fetchall()
cur.close()
conn.close()
return rows

def get_pandas_df(self, sql):
'''
Executes the sql and returns a pandas dataframe
'''
import pandas.io.sql as psql
conn = self.get_conn()
df = psql.read_sql(sql, con=conn)
conn.close()
return df

def run(self, sql):
conn = self.get_conn()
cur = conn.cursor()
cur.execute(sql)
conn.commit()
cur.close()
conn.close()

def insert_rows(self, table, rows, target_fields=None, commit_every=1000):
"""
A generic way to insert a set of tuples into a table,
the whole set of inserts is treated as one transaction
"""
if target_fields:
target_fields = ", ".join(target_fields)
target_fields = "({})".format(target_fields)
else:
target_fields = ''
conn = self.get_conn()
cur = conn.cursor()
cur.execute('SET autocommit = 0')
conn.commit()
i = 0
for row in rows:
i += 1
l = []
for cell in row:
if isinstance(cell, basestring):
l.append("'" + str(cell).replace("'", "''") + "'")
elif cell is None:
l.append('NULL')
elif isinstance(cell, numpy.datetime64):
l.append("'" + str(cell) + "'")
elif isinstance(cell, datetime):
l.append("'" + cell.isoformat() + "'")
else:
l.append(str(cell))
values = tuple(l)
sql = "INSERT INTO {0} {1} VALUES ({2});".format(
table,
target_fields,
",".join(values))
cur.execute(sql)
if i % commit_every == 0:
conn.commit()
logging.info(
"Loaded {i} into {table} rows so far".format(**locals()))
conn.commit()
cur.close()
conn.close()
logging.info(
"Done loading. Loaded a total of {i} rows".format(**locals()))
79 changes: 12 additions & 67 deletions airflow/hooks/postgres_hook.py
Original file line number Diff line number Diff line change
@@ -1,77 +1,22 @@
import psycopg2

from airflow import settings
from airflow.utils import AirflowException
from airflow.models import Connection
from airflow.hooks.dbapi_hook import DbApiHook


class PostgresHook(object):
class PostgresHook(DbApiHook):
'''
Interact with Postgres.
'''

def __init__(
self, host=None, login=None,
psw=None, db=None, port=None, postgres_conn_id=None):
if not postgres_conn_id:
self.host = host
self.login = login
self.psw = psw
self.db = db
self.port = port
else:
session = settings.Session()
db = session.query(
Connection).filter(
Connection.conn_id == postgres_conn_id)
if db.count() == 0:
raise AirflowException("The postgres_dbid you provided isn't defined")
else:
db = db.all()[0]
self.host = db.host
self.login = db.login
self.psw = db.password
self.db = db.schema
self.port = db.port
session.commit()
session.close()
conn_name_attr = 'postgres_conn_id'
default_conn_name = 'postgres_default'
supports_autocommit = True

def get_conn(self):
conn = psycopg2.connect(
host=self.host,
user=self.login,
password=self.psw,
dbname=self.db,
port=self.port)
return conn

def get_records(self, sql):
'''
Executes the sql and returns a set of records.
'''
conn = self.get_conn()
cur = conn.cursor()
cur.execute(sql)
rows = cur.fetchall()
cur.close()
conn.close()
return rows

def get_pandas_df(self, sql):
'''
Executes the sql and returns a pandas dataframe
'''
import pandas.io.sql as psql
conn = self.get_conn()
df = psql.read_sql(sql, con=conn)
conn.close()
return df

def run(self, sql, autocommit=False):
conn = self.get_conn()
conn.autocommit = autocommit
cur = conn.cursor()
cur.execute(sql)
conn.commit()
cur.close()
conn.close()
conn = self.get_connection(self.conn_id_name)
return psycopg2.connect(
host=conn.host,
user=conn.login,
password=conn.psw,
dbname=conn.db,
port=conn.port)
Loading