Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,8 @@ jobs:
matrix:
python-version: [ "3.9","3.10" ]
runs-on: ubuntu-latest
env:
NLTK_DATA: ${{ github.workspace }}/nltk_data
needs: [ setup_ingest, lint ]
steps:
# actions/checkout MUST come before auth
Expand Down
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
## 0.14.10-dev13
## 0.14.10

### Enhancements

Expand All @@ -14,6 +14,7 @@

* **Fix counting false negatives and false positives in table structure evaluation**
* **Fix Slack CI test** Change channel that Slack test is pointing to because previous test bot expired
* **Remove NLTK download** Removes `nltk.download` in favor of downloading from an S3 bucket we host to mitigate CVE-2024-39705

## 0.14.9

Expand Down
5 changes: 2 additions & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM quay.io/unstructured-io/base-images:wolfi-base-d46498e@sha256:3db0544df1d8d9989cd3c3b28670d8b81351dfdc1d9129004c71ff05996fd51e as base
FROM quay.io/unstructured-io/base-images:wolfi-base-e48da6b@sha256:8ad3479e5dc87a86e4794350cca6385c01c6d110902c5b292d1a62e231be711b as base

USER root

Expand All @@ -18,8 +18,7 @@ USER notebook-user

RUN find requirements/ -type f -name "*.txt" -exec pip3.11 install --no-cache-dir --user -r '{}' ';' && \
pip3.11 install unstructured.paddlepaddle && \
python3.11 -c "import nltk; nltk.download('punkt')" && \
python3.11 -c "import nltk; nltk.download('averaged_perceptron_tagger')" && \
python3.11 -c "from unstructured.nlp.tokenize import download_nltk_packages; download_nltk_packages()" && \
python3.11 -c "from unstructured.partition.model_init import initialize; initialize()" && \
python3.11 -c "from unstructured_inference.models.tables import UnstructuredTableTransformerModel; model = UnstructuredTableTransformerModel(); model.initialize('microsoft/table-transformer-structure-recognition')"

Expand Down
14 changes: 10 additions & 4 deletions test_unstructured/nlp/test_tokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,28 @@
from unittest.mock import patch

import nltk
import pytest

from test_unstructured.nlp.mock_nltk import mock_sent_tokenize, mock_word_tokenize
from unstructured.nlp import tokenize


def test_error_raised_on_nltk_download():
with pytest.raises(ValueError):
tokenize.nltk.download("tokenizers/punkt")


def test_nltk_packages_download_if_not_present():
with patch.object(nltk, "find", side_effect=LookupError):
with patch.object(nltk, "download") as mock_download:
tokenize._download_nltk_package_if_not_present("fake_package", "tokenizers")
with patch.object(tokenize, "download_nltk_packages") as mock_download:
tokenize._download_nltk_packages_if_not_present()

mock_download.assert_called_with("fake_package")
mock_download.assert_called_once()


def test_nltk_packages_do_not_download_if():
with patch.object(nltk, "find"), patch.object(nltk, "download") as mock_download:
tokenize._download_nltk_package_if_not_present("fake_package", "tokenizers")
tokenize._download_nltk_packages_if_not_present()

mock_download.assert_not_called()

Expand Down
17 changes: 17 additions & 0 deletions typings/nltk/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from __future__ import annotations

from nltk import data, internals
from nltk.data import find
from nltk.downloader import download
from nltk.tag import pos_tag
from nltk.tokenize import sent_tokenize, word_tokenize

__all__ = [
"data",
"download",
"find",
"internals",
"pos_tag",
"sent_tokenize",
"word_tokenize",
]
7 changes: 7 additions & 0 deletions typings/nltk/data.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from __future__ import annotations

from typing import Sequence

path: list[str]

def find(resource_name: str, paths: Sequence[str] | None = None) -> str: ...
5 changes: 5 additions & 0 deletions typings/nltk/downloader.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from __future__ import annotations

from typing import Callable

download: Callable[..., bool]
3 changes: 3 additions & 0 deletions typings/nltk/internals.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from __future__ import annotations

def is_writable(path: str) -> bool: ...
5 changes: 5 additions & 0 deletions typings/nltk/tag.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from __future__ import annotations

def pos_tag(
tokens: list[str], tagset: str | None = None, lang: str = "eng"
) -> list[tuple[str, str]]: ...
4 changes: 4 additions & 0 deletions typings/nltk/tokenize.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from __future__ import annotations

def sent_tokenize(text: str, language: str = ...) -> list[str]: ...
def word_tokenize(text: str, language: str = ..., preserve_line: bool = ...) -> list[str]: ...
2 changes: 1 addition & 1 deletion unstructured/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.14.10-dev13" # pragma: no cover
__version__ = "0.14.10" # pragma: no cover
139 changes: 121 additions & 18 deletions unstructured/nlp/tokenize.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

import hashlib
import os
import sys
import tarfile
import tempfile
import urllib.request
from functools import lru_cache
from typing import List, Tuple

if sys.version_info < (3, 8):
from typing_extensions import Final # pragma: no cover
else:
from typing import Final
from typing import Any, Final, List, Tuple

import nltk
from nltk import pos_tag as _pos_tag
Expand All @@ -14,42 +16,143 @@

CACHE_MAX_SIZE: Final[int] = 128

NLTK_DATA_URL = "https://utic-public-cf.s3.amazonaws.com/nltk_data.tgz"
NLTK_DATA_SHA256 = "126faf671cd255a062c436b3d0f2d311dfeefcd92ffa43f7c3ab677309404d61"


def _raise_on_nltk_download(*args: Any, **kwargs: Any):
raise ValueError("NLTK download disabled. See CVE-2024-39705")


nltk.download = _raise_on_nltk_download


# NOTE(robinson) - mimic default dir logic from NLTK
# https://github.com/nltk/nltk/
# blob/8c233dc585b91c7a0c58f96a9d99244a379740d5/nltk/downloader.py#L1046
def get_nltk_data_dir() -> str | None:
"""Locates the directory the nltk data will be saved too. The directory
set by the NLTK environment variable takes highest precedence. Otherwise
the default is determined by the rules indicated below. Returns None when
the directory is not writable.

On Windows, the default download directory is
``PYTHONHOME/lib/nltk``, where *PYTHONHOME* is the
directory containing Python, e.g. ``C:\\Python311``.

On all other platforms, the default directory is the first of
the following which exists or which can be created with write
permission: ``/usr/share/nltk_data``, ``/usr/local/share/nltk_data``,
``/usr/lib/nltk_data``, ``/usr/local/lib/nltk_data``, ``~/nltk_data``.
"""
# Check if we are on GAE where we cannot write into filesystem.
if "APPENGINE_RUNTIME" in os.environ:
return

# Check if we have sufficient permissions to install in a
# variety of system-wide locations.
for nltkdir in nltk.data.path:
if os.path.exists(nltkdir) and nltk.internals.is_writable(nltkdir):
return nltkdir

# On Windows, use %APPDATA%
if sys.platform == "win32" and "APPDATA" in os.environ:
homedir = os.environ["APPDATA"]

# Otherwise, install in the user's home directory.
else:
homedir = os.path.expanduser("~/")
if homedir == "~/":
raise ValueError("Could not find a default download directory")

# NOTE(robinson) - NLTK appends nltk_data to the homedir. That's already
# present in the tar file so we don't have to do that here.
return homedir


def download_nltk_packages():
nltk_data_dir = get_nltk_data_dir()

if nltk_data_dir is None:
raise OSError("NLTK data directory does not exist or is not writable.")

def sha256_checksum(filename: str, block_size: int = 65536):
sha256 = hashlib.sha256()
with open(filename, "rb") as f:
for block in iter(lambda: f.read(block_size), b""):
sha256.update(block)
return sha256.hexdigest()

with tempfile.NamedTemporaryFile() as tmp_file:
tgz_file = tmp_file.name
urllib.request.urlretrieve(NLTK_DATA_URL, tgz_file)

file_hash = sha256_checksum(tgz_file)
if file_hash != NLTK_DATA_SHA256:
os.remove(tgz_file)
raise ValueError(f"SHA-256 mismatch: expected {NLTK_DATA_SHA256}, got {file_hash}")

# Extract the contents
if not os.path.exists(nltk_data_dir):
os.makedirs(nltk_data_dir)

with tarfile.open(tgz_file, "r:gz") as tar:
tar.extractall(path=nltk_data_dir)


def check_for_nltk_package(package_name: str, package_category: str) -> bool:
"""Checks to see if the specified NLTK package exists on the file system"""
paths: list[str] = []
for path in nltk.data.path:
if not path.endswith("nltk_data"):
path = os.path.join(path, "nltk_data")
paths.append(path)

def _download_nltk_package_if_not_present(package_name: str, package_category: str):
"""If the required nlt package is not present, download it."""
try:
nltk.find(f"{package_category}/{package_name}")
nltk.find(f"{package_category}/{package_name}", paths=paths)
return True
except LookupError:
nltk.download(package_name)
return False


def _download_nltk_packages_if_not_present():
"""If required NLTK packages are not available, download them."""

tagger_available = check_for_nltk_package(
package_category="taggers",
package_name="averaged_perceptron_tagger",
)
tokenizer_available = check_for_nltk_package(
package_category="tokenizers", package_name="punkt"
)

if not (tokenizer_available and tagger_available):
download_nltk_packages()


@lru_cache(maxsize=CACHE_MAX_SIZE)
def sent_tokenize(text: str) -> List[str]:
"""A wrapper around the NLTK sentence tokenizer with LRU caching enabled."""
_download_nltk_package_if_not_present(package_category="tokenizers", package_name="punkt")
_download_nltk_packages_if_not_present()
return _sent_tokenize(text)


@lru_cache(maxsize=CACHE_MAX_SIZE)
def word_tokenize(text: str) -> List[str]:
"""A wrapper around the NLTK word tokenizer with LRU caching enabled."""
_download_nltk_package_if_not_present(package_category="tokenizers", package_name="punkt")
_download_nltk_packages_if_not_present()
return _word_tokenize(text)


@lru_cache(maxsize=CACHE_MAX_SIZE)
def pos_tag(text: str) -> List[Tuple[str, str]]:
"""A wrapper around the NLTK POS tagger with LRU caching enabled."""
_download_nltk_package_if_not_present(package_category="tokenizers", package_name="punkt")
_download_nltk_package_if_not_present(
package_category="taggers",
package_name="averaged_perceptron_tagger",
)
_download_nltk_packages_if_not_present()
# NOTE(robinson) - Splitting into sentences before tokenizing. The helps with
# situations like "ITEM 1A. PROPERTIES" where "PROPERTIES" can be mistaken
# for a verb because it looks like it's in verb form an "ITEM 1A." looks like the subject.
sentences = _sent_tokenize(text)
parts_of_speech = []
parts_of_speech: list[tuple[str, str]] = []
for sentence in sentences:
tokens = _word_tokenize(sentence)
parts_of_speech.extend(_pos_tag(tokens))
Expand Down