Skip to content

Commit

Permalink
Add SMTP timeout and retry limit for SMTP email backend. (#12801)
Browse files Browse the repository at this point in the history
  • Loading branch information
siddartha-ravichandran authored Dec 4, 2020
1 parent 1bd98cd commit 88aa174
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 28 deletions.
12 changes: 12 additions & 0 deletions airflow/config_templates/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1255,6 +1255,18 @@
type: string
example: ~
default: "airflow@example.com"
- name: smtp_timeout
description: ~
version_added: ~
type: int
example: ~
default: "30"
- name: smtp_retry_limit
description: ~
version_added: ~
type: int
example: ~
default: "5"
- name: sentry
description: |
Sentry (https://docs.sentry.io) integration. Here you can supply
Expand Down
2 changes: 2 additions & 0 deletions airflow/config_templates/default_airflow.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,8 @@ smtp_ssl = False
# smtp_password =
smtp_port = 25
smtp_mail_from = airflow@example.com
smtp_timeout = 30
smtp_retry_limit = 5

[sentry]

Expand Down
2 changes: 2 additions & 0 deletions airflow/config_templates/default_test.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ smtp_user = airflow
smtp_port = 25
smtp_password = airflow
smtp_mail_from = airflow@example.com
smtp_retry_limit = 5
smtp_timeout = 30

[celery]
celery_app_name = airflow.executors.celery_executor
Expand Down
35 changes: 27 additions & 8 deletions airflow/utils/email.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ def send_mime_email(e_from: str, e_to: List[str], mime_msg: MIMEMultipart, dryru
smtp_port = conf.getint('smtp', 'SMTP_PORT')
smtp_starttls = conf.getboolean('smtp', 'SMTP_STARTTLS')
smtp_ssl = conf.getboolean('smtp', 'SMTP_SSL')
smtp_retry_limit = conf.getint('smtp', 'SMTP_RETRY_LIMIT')
smtp_timeout = conf.getint('smtp', 'SMTP_TIMEOUT')
smtp_user = None
smtp_password = None

Expand All @@ -178,14 +180,23 @@ def send_mime_email(e_from: str, e_to: List[str], mime_msg: MIMEMultipart, dryru
log.debug("No user/password found for SMTP, so logging in with no authentication.")

if not dryrun:
conn = smtplib.SMTP_SSL(smtp_host, smtp_port) if smtp_ssl else smtplib.SMTP(smtp_host, smtp_port)
if smtp_starttls:
conn.starttls()
if smtp_user and smtp_password:
conn.login(smtp_user, smtp_password)
log.info("Sent an alert email to %s", e_to)
conn.sendmail(e_from, e_to, mime_msg.as_string())
conn.quit()
for attempt in range(1, smtp_retry_limit + 1):
log.info("Email alerting: attempt %s", str(attempt))
try:
conn = _get_smtp_connection(smtp_host, smtp_port, smtp_timeout, smtp_ssl)
except smtplib.SMTPServerDisconnected:
if attempt < smtp_retry_limit:
continue
raise

if smtp_starttls:
conn.starttls()
if smtp_user and smtp_password:
conn.login(smtp_user, smtp_password)
log.info("Sent an alert email to %s", e_to)
conn.sendmail(e_from, e_to, mime_msg.as_string())
conn.quit()
break


def get_email_address_list(addresses: Union[str, Iterable[str]]) -> List[str]:
Expand All @@ -202,6 +213,14 @@ def get_email_address_list(addresses: Union[str, Iterable[str]]) -> List[str]:
raise TypeError(f"Unexpected argument type: Received '{received_type}'.")


def _get_smtp_connection(host: str, port: int, timeout: int, with_ssl: bool) -> smtplib.SMTP:
return (
smtplib.SMTP_SSL(host=host, port=port, timeout=timeout)
if with_ssl
else smtplib.SMTP(host=host, port=port, timeout=timeout)
)


def _get_email_list_from_str(addresses: str) -> List[str]:
delimiters = [",", ";"]
for delimiter in delimiters:
Expand Down
127 changes: 107 additions & 20 deletions tests/utils/test_email.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from email.mime.application import MIMEApplication
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from smtplib import SMTPServerDisconnected
from unittest import mock

from airflow import utils
Expand Down Expand Up @@ -118,7 +119,6 @@ def test_build_mime_message(self):
self.assertEqual(msg['To'], ','.join(recipients))


@conf_vars({('smtp', 'SMTP_SSL'): 'False'})
class TestEmailSmtp(unittest.TestCase):
@mock.patch('airflow.utils.email.send_mime_email')
def test_send_smtp(self, mock_send_mime):
Expand All @@ -127,10 +127,10 @@ def test_send_smtp(self, mock_send_mime):
attachment.seek(0)
utils.email.send_email_smtp('to', 'subject', 'content', files=[attachment.name])
self.assertTrue(mock_send_mime.called)
call_args = mock_send_mime.call_args[0]
self.assertEqual(conf.get('smtp', 'SMTP_MAIL_FROM'), call_args[0])
self.assertEqual(['to'], call_args[1])
msg = call_args[2]
_, call_args = mock_send_mime.call_args
self.assertEqual(conf.get('smtp', 'SMTP_MAIL_FROM'), call_args['e_from'])
self.assertEqual(['to'], call_args['e_to'])
msg = call_args['mime_msg']
self.assertEqual('subject', msg['Subject'])
self.assertEqual(conf.get('smtp', 'SMTP_MAIL_FROM'), msg['From'])
self.assertEqual(2, len(msg.get_payload()))
Expand All @@ -143,8 +143,8 @@ def test_send_smtp(self, mock_send_mime):
def test_send_smtp_with_multibyte_content(self, mock_send_mime):
utils.email.send_email_smtp('to', 'subject', '🔥', mime_charset='utf-8')
self.assertTrue(mock_send_mime.called)
call_args = mock_send_mime.call_args[0]
msg = call_args[2]
_, call_args = mock_send_mime.call_args
msg = call_args['mime_msg']
mimetext = MIMEText('🔥', 'mixed', 'utf-8')
self.assertEqual(mimetext.get_payload(), msg.get_payload()[0].get_payload())

Expand All @@ -155,10 +155,10 @@ def test_send_bcc_smtp(self, mock_send_mime):
attachment.seek(0)
utils.email.send_email_smtp('to', 'subject', 'content', files=[attachment.name], cc='cc', bcc='bcc')
self.assertTrue(mock_send_mime.called)
call_args = mock_send_mime.call_args[0]
self.assertEqual(conf.get('smtp', 'SMTP_MAIL_FROM'), call_args[0])
self.assertEqual(['to', 'cc', 'bcc'], call_args[1])
msg = call_args[2]
_, call_args = mock_send_mime.call_args
self.assertEqual(conf.get('smtp', 'SMTP_MAIL_FROM'), call_args['e_from'])
self.assertEqual(['to', 'cc', 'bcc'], call_args['e_to'])
msg = call_args['mime_msg']
self.assertEqual('subject', msg['Subject'])
self.assertEqual(conf.get('smtp', 'SMTP_MAIL_FROM'), msg['From'])
self.assertEqual(2, len(msg.get_payload()))
Expand All @@ -173,13 +173,14 @@ def test_send_bcc_smtp(self, mock_send_mime):
@mock.patch('smtplib.SMTP')
def test_send_mime(self, mock_smtp, mock_smtp_ssl):
mock_smtp.return_value = mock.Mock()
mock_smtp_ssl.return_value = mock.Mock()
msg = MIMEMultipart()
utils.email.send_mime_email('from', 'to', msg, dryrun=False)
mock_smtp.assert_called_once_with(
conf.get('smtp', 'SMTP_HOST'),
conf.getint('smtp', 'SMTP_PORT'),
host=conf.get('smtp', 'SMTP_HOST'),
port=conf.getint('smtp', 'SMTP_PORT'),
timeout=conf.getint('smtp', 'SMTP_TIMEOUT'),
)
self.assertFalse(mock_smtp_ssl.called)
self.assertTrue(mock_smtp.return_value.starttls.called)
mock_smtp.return_value.login.assert_called_once_with(
conf.get('smtp', 'SMTP_USER'),
Expand All @@ -191,21 +192,20 @@ def test_send_mime(self, mock_smtp, mock_smtp_ssl):
@mock.patch('smtplib.SMTP_SSL')
@mock.patch('smtplib.SMTP')
def test_send_mime_ssl(self, mock_smtp, mock_smtp_ssl):
mock_smtp.return_value = mock.Mock()
mock_smtp_ssl.return_value = mock.Mock()
with conf_vars({('smtp', 'smtp_ssl'): 'True'}):
utils.email.send_mime_email('from', 'to', MIMEMultipart(), dryrun=False)
self.assertFalse(mock_smtp.called)
mock_smtp_ssl.assert_called_once_with(
conf.get('smtp', 'SMTP_HOST'),
conf.getint('smtp', 'SMTP_PORT'),
host=conf.get('smtp', 'SMTP_HOST'),
port=conf.getint('smtp', 'SMTP_PORT'),
timeout=conf.getint('smtp', 'SMTP_TIMEOUT'),
)

@mock.patch('smtplib.SMTP_SSL')
@mock.patch('smtplib.SMTP')
def test_send_mime_noauth(self, mock_smtp, mock_smtp_ssl):
mock_smtp.return_value = mock.Mock()
mock_smtp_ssl.return_value = mock.Mock()
with conf_vars(
{
('smtp', 'smtp_user'): None,
Expand All @@ -215,8 +215,9 @@ def test_send_mime_noauth(self, mock_smtp, mock_smtp_ssl):
utils.email.send_mime_email('from', 'to', MIMEMultipart(), dryrun=False)
self.assertFalse(mock_smtp_ssl.called)
mock_smtp.assert_called_once_with(
conf.get('smtp', 'SMTP_HOST'),
conf.getint('smtp', 'SMTP_PORT'),
host=conf.get('smtp', 'SMTP_HOST'),
port=conf.getint('smtp', 'SMTP_PORT'),
timeout=conf.getint('smtp', 'SMTP_TIMEOUT'),
)
self.assertFalse(mock_smtp.login.called)

Expand All @@ -226,3 +227,89 @@ def test_send_mime_dryrun(self, mock_smtp, mock_smtp_ssl):
utils.email.send_mime_email('from', 'to', MIMEMultipart(), dryrun=True)
self.assertFalse(mock_smtp.called)
self.assertFalse(mock_smtp_ssl.called)

@mock.patch('smtplib.SMTP_SSL')
@mock.patch('smtplib.SMTP')
def test_send_mime_complete_failure(self, mock_smtp: mock, mock_smtp_ssl):
mock_smtp.side_effect = SMTPServerDisconnected()
msg = MIMEMultipart()
with self.assertRaises(SMTPServerDisconnected):
utils.email.send_mime_email('from', 'to', msg, dryrun=False)

mock_smtp.assert_any_call(
host=conf.get('smtp', 'SMTP_HOST'),
port=conf.getint('smtp', 'SMTP_PORT'),
timeout=conf.getint('smtp', 'SMTP_TIMEOUT'),
)
self.assertEqual(mock_smtp.call_count, conf.getint('smtp', 'SMTP_RETRY_LIMIT'))
self.assertFalse(mock_smtp_ssl.called)
self.assertFalse(mock_smtp.return_value.starttls.called)
self.assertFalse(mock_smtp.return_value.login.called)
self.assertFalse(mock_smtp.return_value.sendmail.called)
self.assertFalse(mock_smtp.return_value.quit.called)

@mock.patch('smtplib.SMTP_SSL')
@mock.patch('smtplib.SMTP')
def test_send_mime_ssl_complete_failure(self, mock_smtp, mock_smtp_ssl):
mock_smtp_ssl.side_effect = SMTPServerDisconnected()
msg = MIMEMultipart()
with conf_vars({('smtp', 'smtp_ssl'): 'True'}):
with self.assertRaises(SMTPServerDisconnected):
utils.email.send_mime_email('from', 'to', msg, dryrun=False)

mock_smtp_ssl.assert_any_call(
host=conf.get('smtp', 'SMTP_HOST'),
port=conf.getint('smtp', 'SMTP_PORT'),
timeout=conf.getint('smtp', 'SMTP_TIMEOUT'),
)
self.assertEqual(mock_smtp_ssl.call_count, conf.getint('smtp', 'SMTP_RETRY_LIMIT'))
self.assertFalse(mock_smtp.called)
self.assertFalse(mock_smtp_ssl.return_value.starttls.called)
self.assertFalse(mock_smtp_ssl.return_value.login.called)
self.assertFalse(mock_smtp_ssl.return_value.sendmail.called)
self.assertFalse(mock_smtp_ssl.return_value.quit.called)

@mock.patch('smtplib.SMTP_SSL')
@mock.patch('smtplib.SMTP')
def test_send_mime_custom_timeout_retrylimit(self, mock_smtp, mock_smtp_ssl):
mock_smtp.side_effect = SMTPServerDisconnected()
msg = MIMEMultipart()

custom_retry_limit = 10
custom_timeout = 60

with conf_vars(
{
('smtp', 'smtp_retry_limit'): str(custom_retry_limit),
('smtp', 'smtp_timeout'): str(custom_timeout),
}
):
with self.assertRaises(SMTPServerDisconnected):
utils.email.send_mime_email('from', 'to', msg, dryrun=False)

mock_smtp.assert_any_call(
host=conf.get('smtp', 'SMTP_HOST'), port=conf.getint('smtp', 'SMTP_PORT'), timeout=custom_timeout
)
self.assertFalse(mock_smtp_ssl.called)
self.assertEqual(mock_smtp.call_count, 10)

@mock.patch('smtplib.SMTP_SSL')
@mock.patch('smtplib.SMTP')
def test_send_mime_partial_failure(self, mock_smtp, mock_smtp_ssl):
final_mock = mock.Mock()
side_effects = [SMTPServerDisconnected(), SMTPServerDisconnected(), final_mock]
mock_smtp.side_effect = side_effects
msg = MIMEMultipart()

utils.email.send_mime_email('from', 'to', msg, dryrun=False)

mock_smtp.assert_any_call(
host=conf.get('smtp', 'SMTP_HOST'),
port=conf.getint('smtp', 'SMTP_PORT'),
timeout=conf.getint('smtp', 'SMTP_TIMEOUT'),
)
self.assertEqual(mock_smtp.call_count, side_effects.index(final_mock) + 1)
self.assertFalse(mock_smtp_ssl.called)
self.assertTrue(final_mock.starttls.called)
final_mock.sendmail.assert_called_once_with('from', 'to', msg.as_string())
self.assertTrue(final_mock.quit.called)

0 comments on commit 88aa174

Please sign in to comment.