diff --git a/lms/extractors/imagefile.py b/lms/extractors/imagefile.py index 0fe38096..f7a15e73 100644 --- a/lms/extractors/imagefile.py +++ b/lms/extractors/imagefile.py @@ -15,7 +15,7 @@ def __init__(self, **kwargs): ) def can_extract(self) -> bool: - return self.ext in ALLOWED_IMAGES_EXTENSIONS + return self.ext.lower() in ALLOWED_IMAGES_EXTENSIONS def get_exercise(self, to_extract: bytes) -> Tuple[int, List[File]]: exercise_id = 0 @@ -25,7 +25,7 @@ def get_exercise(self, to_extract: bytes) -> Tuple[int, List[File]]: raise BadUploadFile("Can't resolve exercise id.", self.filename) decoded = base64.b64encode(to_extract) - return (exercise_id, [File(f'/main.{self.ext}', decoded)]) + return (exercise_id, [File(f'/main.{self.ext.lower()}', decoded)]) def get_exercises(self) -> Iterator[Tuple[int, List[File]]]: exercise_id, files = self.get_exercise(self.file_content) diff --git a/lms/extractors/textfile.py b/lms/extractors/textfile.py index 5a2b1e23..227965e6 100644 --- a/lms/extractors/textfile.py +++ b/lms/extractors/textfile.py @@ -20,7 +20,7 @@ def __init__(self, **kwargs): ) def can_extract(self) -> bool: - if self.ext not in ALLOWED_EXTENSIONS: + if self.ext.lower() not in ALLOWED_EXTENSIONS: return False if isinstance(self.file_content, str): return True @@ -34,7 +34,7 @@ def get_exercise(self, to_extract: str) -> Tuple[int, List[File]]: if not exercise_id: raise BadUploadFile("Can't resolve exercise id.", self.filename) - return (exercise_id, [File(f'/main.{self.ext}', content)]) + return (exercise_id, [File(f'/main.{self.ext.lower()}', content)]) def get_exercises(self) -> Iterator[Tuple[int, List[File]]]: exercise_id, files = self.get_exercise(self.file_content) diff --git a/lms/extractors/ziparchive.py b/lms/extractors/ziparchive.py index 89507d3a..de1ff1ac 100644 --- a/lms/extractors/ziparchive.py +++ b/lms/extractors/ziparchive.py @@ -43,14 +43,14 @@ def _extract(archive: ZipFile, filename: str, dirname: str = '') -> File: with archive.open(filename) as current_file: log.debug(f'Extracting from archive: {filename}') code = current_file.read() - if filename.rpartition('.')[-1] in ALLOWED_IMAGES_EXTENSIONS: + if filename.rpartition('.')[-1].lower() in ALLOWED_IMAGES_EXTENSIONS: decoded = base64.b64encode(code) else: decoded = code.decode( 'utf-8', errors='replace', ).replace('\x00', '') filename = filename[len(dirname):] - return File(path=f'/{filename}', code=decoded) + return File(path=f'/{filename.lower()}', code=decoded) def get_files( self, archive: ZipFile, filenames: List[Text], dirname: str = '',