Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix open encoding #22

Merged
merged 2 commits into from
Dec 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions litstudy/common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import re
import io
import locale
from codecs import BOM_UTF8, BOM_UTF16_BE, BOM_UTF16_LE
from unidecode import unidecode

try:
Expand Down Expand Up @@ -85,3 +88,48 @@ def get(self, name):
self.mapping[key] = key
self.unmapping[key] = nice_name
return nice_name


def robust_open(path, errors="replace"):
""" This function can be used as a drop-in replacement when using
`with open(path) as f:` to read a file. However, the normal `open` function
is fragile since it attempts to open the file using the default system
character encoding and fails immediately when a character cannot be
decoded. This function is more robust in that it attempts to figure out
the encoding of the given file and ignores decoding errors.
"""
if hasattr(path, "read"):
return path
elif isinstance(path, bytes):
content = path
else:
with open(path, "rb") as f:
content = f.read()

# use the following options:
# - UTF-8 BOM: decode as UTF-8
# - UTF-16 BE BOM: decode as UTF-16-BE
# - UTF-16 LE BOM: decode as UTF-16-LE
# - otherwise, decode as utf-8 with strict errors
# - if that fails, decode using default charset
# - if that fails, decode using utf-8 but ignore errors
if content.startswith(BOM_UTF8):
n = len(BOM_UTF8)
result = content[n:].decode(errors=errors)
elif content.startswith(BOM_UTF16_BE):
n = len(BOM_UTF16_BE)
result = content[n:].decode("utf_16_be", errors=errors)
elif content.startswith(BOM_UTF16_LE):
n = len(BOM_UTF16_LE)
result = content[n:].decode("utf_16_le", errors=errors)
else:
try:
result = content.decode("utf-8", errors="strict")
except UnicodeError:
try:
default_charset = locale.getpreferredencoding()
result = content.decode(default_charset, errors=errors)
except UnicodeError:
result = content.decode("utf-8", errors=errors)

return io.StringIO(result)
3 changes: 2 additions & 1 deletion litstudy/sources/bibtex.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ..types import Document, DocumentSet, DocumentIdentifier, Author
from ..common import robust_open
# from bibtexparser.customization import convert_to_unicode
from bibtexparser.latexenc import latex_to_unicode
import bibtexparser
Expand Down Expand Up @@ -188,7 +189,7 @@ def decode(entry):
parser = bibtexparser.bparser.BibTexParser(common_strings=True)
parser.customization = decode

with open(path) as f:
with robust_open(path) as f:
data = bibtexparser.load(f, parser=parser)

docs = [BibDocument(e) for e in data.entries if e.get('title')]
Expand Down
3 changes: 2 additions & 1 deletion litstudy/sources/ieee.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ..types import Document, Author, DocumentSet, DocumentIdentifier, \
Affiliation
from ..common import robust_open
import csv


Expand Down Expand Up @@ -99,7 +100,7 @@ def load_ieee_csv(path: str) -> DocumentSet:
""" Import CSV file exported from
`IEEE Xplore <https://ieeexplore.ieee.org/search/searchresult.jsp>`_.
"""
with open(path, newline='') as f:
with robust_open(path) as f:
lines = csv.DictReader(f)
docs = [IEEEDocument(line) for line in lines]
return DocumentSet(docs)
3 changes: 2 additions & 1 deletion litstudy/sources/ris.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ..types import Document, Author, DocumentSet, DocumentIdentifier
from ..common import robust_open
import logging


Expand Down Expand Up @@ -70,7 +71,7 @@ def load_ris_file(path: str) -> DocumentSet:
""" Load the RIS file at the given `path` as a `DocumentSet`. """
docs = []

with open(path, newline='') as f:
with robust_open(path) as f:
authors = []
keywords = []
attr = dict()
Expand Down
5 changes: 2 additions & 3 deletions litstudy/sources/springer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import csv

from ..types import Document, DocumentSet, DocumentIdentifier

from ..common import robust_open

class SpringerDocument(Document):
def __init__(self, entry):
Expand Down Expand Up @@ -42,7 +41,7 @@ def load_springer_csv(path: str) -> DocumentSet:
""" Load CSV file exported from
`Springer Link <https://link.springer.com/>`_.
"""
with open(path, newline='') as f:
with robust_open(path) as f:
lines = csv.DictReader(f)
docs = [SpringerDocument(line) for line in lines]
return DocumentSet(docs)
24 changes: 24 additions & 0 deletions tests/test_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import litstudy
import codecs

def test_robust_open():
f = litstudy.common.robust_open
expected = "ABC \U0001F600"

assert f(b'').read() == ""

content = expected.encode("utf8")
assert f(content).read() == expected

content = codecs.BOM_UTF8 + expected.encode("utf8")
assert f(content).read() == expected

content = codecs.BOM_UTF16_BE + expected.encode("utf_16_be")
assert f(content).read() == expected

content = codecs.BOM_UTF16_LE + expected.encode("utf_16_le")
assert f(content).read() == expected

# Contains some invalid UTF-8 character, should become U+FFFD
content = b'ABC \x9f\x98\x80 DEF'
assert f(content).read() == "ABC \ufffd\ufffd\ufffd DEF"