diff --git a/providers/imap/src/airflow/providers/imap/hooks/imap.py b/providers/imap/src/airflow/providers/imap/hooks/imap.py index ecb99eb2061b8..c3124f21bee0c 100644 --- a/providers/imap/src/airflow/providers/imap/hooks/imap.py +++ b/providers/imap/src/airflow/providers/imap/hooks/imap.py @@ -192,6 +192,7 @@ def download_mail_attachments( max_mails: int | None = None, mail_folder: str = "INBOX", mail_filter: str = "All", + overwrite_file: bool = True, not_found_mode: str = "raise", ) -> None: """ @@ -206,6 +207,12 @@ def download_mail_attachments( :param mail_folder: The mail folder where to look at. :param mail_filter: If set other than 'All' only specific mails will be checked. See :py:meth:`imaplib.IMAP4.search` for details. + :param overwrite_file: Specify what should happen if file already exists on disk. + If set to True - file is overwritten. + If set to False - new file with suffix _1, _2, etc. is created. + Suffix is inserted before the last extension, so files with multiple + extensions, like .tar.gz, will be transformed into .tar_1.gz. + Defaults to True to preserve existing behavior. :param not_found_mode: Specify what should happen if no attachment has been found. Supported values are 'raise', 'warn' and 'ignore'. If it is set to 'raise' it will raise an exception, @@ -219,7 +226,7 @@ def download_mail_attachments( if not mail_attachments: self._handle_not_found_mode(not_found_mode) - self._create_files(mail_attachments, local_output_directory) + self._create_files(mail_attachments, local_output_directory, overwrite_file) def _handle_not_found_mode(self, not_found_mode: str) -> None: if not_found_mode not in ("raise", "warn", "ignore"): @@ -287,14 +294,16 @@ def _check_mail_body( return mail.get_attachments_by_name(name, check_regex, find_first=latest_only) return [] - def _create_files(self, mail_attachments: list, local_output_directory: str) -> None: + def _create_files( + self, mail_attachments: list, local_output_directory: str, overwrite_file: bool + ) -> None: for name, payload in mail_attachments: if self._is_symlink(name): self.log.error("Can not create file because it is a symlink!") elif self._is_escaping_current_directory(name): self.log.error("Can not create file because it is escaping the current directory!") else: - self._create_file(name, payload, local_output_directory) + self._create_file(name, payload, local_output_directory, overwrite_file) def _is_symlink(self, name: str) -> bool: # IMPORTANT NOTE: os.path.islink is not working for windows symlinks @@ -311,11 +320,24 @@ def _correct_path(self, name: str, local_output_directory: str) -> str: else local_output_directory + os.sep + name ) - def _create_file(self, name: str, payload: Any, local_output_directory: str) -> None: - file_path = self._correct_path(name, local_output_directory) - - with open(file_path, "wb") as file: - file.write(payload) + def _create_file( + self, name: str, payload: Any, local_output_directory: str, overwrite_file: bool + ) -> None: + if overwrite_file: + method = "wb" + else: + method = "xb" + filename, extension = os.path.splitext(name) + counter = 1 + while True: + file_path = self._correct_path(name, local_output_directory) + try: + with open(file_path, method) as file: + file.write(payload) + break + except FileExistsError: + name = f"{filename}_{counter}{extension}" + counter += 1 class Mail(LoggingMixin): diff --git a/providers/imap/tests/unit/imap/hooks/test_imap.py b/providers/imap/tests/unit/imap/hooks/test_imap.py index 54537aa29af8d..906122300610e 100644 --- a/providers/imap/tests/unit/imap/hooks/test_imap.py +++ b/providers/imap/tests/unit/imap/hooks/test_imap.py @@ -33,7 +33,9 @@ open_string = "airflow.providers.imap.hooks.imap.open" -def _create_fake_imap(mock_imaplib, with_mail=False, attachment_name="test1.csv", use_ssl=True): +def _create_fake_imap( + mock_imaplib, with_mail=False, attachment_name="test1.csv", use_ssl=True, payload="SWQsTmFtZQoxLEZlbGl4" +): if use_ssl: mock_conn = Mock(spec=imaplib.IMAP4_SSL) mock_imaplib.IMAP4_SSL.return_value = mock_conn @@ -51,7 +53,7 @@ def _create_fake_imap(mock_imaplib, with_mail=False, attachment_name="test1.csv" f"boundary=123\r\n--123\r\n" f"Content-Disposition: attachment; " f'filename="{attachment_name}";' - f"Content-Transfer-Encoding: base64\r\nSWQsTmFtZQoxLEZlbGl4\r\n--123--" + f"Content-Transfer-Encoding: base64\r\n{payload}\r\n--123--" ) mock_conn.fetch.return_value = ("OK", [(b"", mail_string.encode("utf-8"))]) mock_conn.close.return_value = ("OK", []) @@ -509,3 +511,62 @@ def test_has_mail_attachment_with_max_mails(self, mock_imaplib): assert result is True assert 1 <= mock_conn.fetch.call_count <= 2 + + @patch(open_string, new_callable=mock_open) + @patch(imaplib_string) + def test_download_mail_attachments_without_overwrite(self, mock_imaplib, mock_open_method): + _create_fake_imap(mock_imaplib, with_mail=True) + + with ImapHook() as imap_hook: + imap_hook.download_mail_attachments("test1.csv", "test_directory", overwrite_file=False) + + mock_open_method.assert_called_once_with("test_directory/test1.csv", "xb") + mock_open_method.return_value.write.assert_called_once_with(b"SWQsTmFtZQoxLEZlbGl4") + + @pytest.mark.parametrize( + ("attachment_name", "unexpected_copies"), + [ + ("test1.csv", ["test1_1.csv", "test1_2.csv"]), + ("README", ["README_1", "README_2"]), + ("test.tar.gz", ["test.tar_1.gz", "test.tar_2.gz"]), + ], + ) + @patch(imaplib_string) + def test_download_mail_attachments_with_overwrite( + self, mock_imaplib, tmp_path, attachment_name, unexpected_copies + ): + payloads = ["test1", "test2", "test3"] + for payload in payloads: + _create_fake_imap(mock_imaplib, with_mail=True, attachment_name=attachment_name, payload=payload) + with ImapHook() as imap_hook: + imap_hook.download_mail_attachments(attachment_name, str(tmp_path), overwrite_file=True) + + assert (tmp_path / attachment_name).exists() + for copy_name in unexpected_copies: + assert not (tmp_path / copy_name).exists() + assert (tmp_path / attachment_name).read_bytes() == b"test3" + assert len(list(tmp_path.iterdir())) == 1 + + @pytest.mark.parametrize( + ("attachment_name", "expected_copies"), + [ + ("test1.csv", ["test1.csv", "test1_1.csv", "test1_2.csv"]), + ("README", ["README", "README_1", "README_2"]), + ("test.tar.gz", ["test.tar.gz", "test.tar_1.gz", "test.tar_2.gz"]), + ], + ) + @patch(imaplib_string) + def test_download_mail_attachments_without_overwrite_creates_copy( + self, mock_imaplib, tmp_path, attachment_name, expected_copies + ): + _create_fake_imap(mock_imaplib, with_mail=True, attachment_name=attachment_name) + payload = b"SWQsTmFtZQoxLEZlbGl4" + + with ImapHook() as imap_hook: + for _ in range(3): + imap_hook.download_mail_attachments(attachment_name, str(tmp_path), overwrite_file=False) + + for copy_name in expected_copies: + assert (tmp_path / copy_name).exists() + assert (tmp_path / copy_name).read_bytes() == payload + assert len(list(tmp_path.iterdir())) == 3