In [None]:
import pandas as pd
import os
import io
from typing import Tuple, Callable

class DataFrameLoader:
    def __init__(
        self,
        filename: str,
        *data: [pd.DataFrame],
        features=[],
        targets=[],
        xl: pd.ExcelFile = None,
        date_parser: Callable[[pd.DataFrame], pd.DataFrame] = None
    ) -> None:
        self.filename = filename
        self._xl: pd.ExcelFile = xl
        self._sheets_cache: Dict[str, pd.DataFrame] = { str(i): df for i, df in enumerate(data) }
        self._columns = pd.DataFrame(columns = ['sheet_name', 'column_name', 'is_feature', 'is_target'])
        self._columns = self._columns.set_index(['sheet_name', 'column_name'])
        self._date_parser = date_parser

    def reset(self) -> None:
        self._columns.drop(df.index, inplace=True)

    def remove_sheet(self, sheet_name: str = None) -> None:
        if sheet_name not in self._columns.index:
            return
        column_names = self._columns.loc[(sheet_name), :].index
        for column_name in column_names:
            self.remove_features(column_name, sheet_name=sheet_name)
            self.remove_targets(column_name, sheet_name=sheet_name)

    def remove_features(self, *column_names: [str], sheet_name: str = None) -> None:
        for column_name in column_names:
            self._set_column(column_name, sheet_name=sheet_name, is_feature=False)

    def remove_targets(self, *column_names: [str], sheet_name: str = None) -> None:
        for column_name in column_names:
            self._set_column(column_name, sheet_name=sheet_name, is_target=False)

    def add_features(self, *column_names: [str], sheet_name: str = None) -> None:
        for column_name in column_names:
            self._set_column(column_name, sheet_name=sheet_name, is_feature=True)

    def add_targets(self, *column_names: [str], sheet_name: str = None) -> None:
        for column_name in column_names:
            self._set_column(column_name, sheet_name=sheet_name, is_target=True)

    def _set_column(self, column_name: str, sheet_name: str = None, is_feature: bool = None, is_target: bool = None) -> None:
        assert sheet_name is not None or len(self.sheet_names) == 1, 'Please provide a sheet_name'
        sheet_name = sheet_name or self.sheet_names[0]
        index = (sheet_name, column_name)
        if self._is_sheet_cached(sheet_name) is False:
            self._load_sheet(sheet_name)
        if index not in self._columns.index:
            column = pd.Series({ 'is_feature': False, 'is_target': False }, name=index)
            self._columns = self._columns.append(column, sort=True)
        column = self._columns.loc[index]
        column.is_feature = is_feature if is_feature is not None else column.is_feature
        column.is_target = is_target if is_target is not None else column.is_target
        if column.is_feature is False and column.is_target is False:
            self._columns.drop(index, inplace=True)

    @property
    def features(self) -> [str]:
        indices = list(self._columns.loc[self._columns['is_feature']].index)
        return [ f'{sheet}_{column}' for sheet, column in indices ]

    @property
    def targets(self) -> [str]:
        indices = list(self._columns.loc[self._columns['is_target']].index)
        return [ f'{sheet}_{column}' for sheet, column in indices ]

    @property
    def features_index(self) -> [Tuple[str, str]]:
        return list(self._columns.loc[self._columns['is_feature']].index)

    @property
    def targets_index(self) -> [Tuple[str, str]]:
        return list(self._columns.loc[self._columns['is_target']].index)

    @property
    def df(self) -> pd.DataFrame:
        df = pd.DataFrame()
        if len(self._columns) == 0:
            return df
        sheet_names = pd.Series([ s for s, c in self._columns.index ]).unique()
        for sheet_name in sheet_names:
            column_names = self._columns.loc[(sheet_name), :].index
            sheet_df = self._sheets_cache[sheet_name][column_names]
            if len(self.sheet_names) > 1:
                sheet_df = sheet_df.add_prefix(f'{sheet_name}_')
            df = pd.merge(df, sheet_df, left_index=True, right_index=True, how='outer')
        return df

    @property
    def sheet_names(self) -> [str]:
        if self._xl is not None:
            return self._xl.sheet_names
        else:
            return list(self._sheets_cache.keys())

    @property
    def sheets(self) -> 'SheetLoader':
        class SheetLoader:
            def __init__(self, loader: DataFrameLoader) -> None:
                self._loader = loader
            
            def __getitem__(self, sheet_name: str) -> pd.DataFrame:
                if self._loader._is_sheet_cached(sheet_name) is False:
                    self._loader._load_sheet(sheet_name)
                return self._loader._sheets_cache[sheet_name]

        return SheetLoader(self)

    def _load_sheet(self, sheet_name: str) -> None:
        assert self._is_sheet_available(sheet_name), f'Sheet {sheet_name} is not available'
        if self._is_sheet_cached(sheet_name) is False:
            sheet_df = pd.read_excel(self._xl, sheet_name=sheet_name)
            if self._date_parser is not None:
                sheet_df = self._date_parser(sheet_df)
            self._sheets_cache[sheet_name] = sheet_df

    def _is_sheet_cached(self, sheet_name: str) -> bool:
        return sheet_name in self._sheets_cache.keys()

    def _is_sheet_available(self, sheet_name: str) -> bool:
        if self._is_sheet_cached(sheet_name):
            return True
        if self._xl is not None and sheet_name in self._xl.sheet_names:
            return True
        return False

    @staticmethod
    def from_file(filename: str, *args, **kwargs) -> 'DataFrameLoader':
        buffer = DataFrameLoader._read_file(filename)
        return DataFrameLoader.from_buffer(filename, buffer, *args, **kwargs)

    @staticmethod
    def from_buffer(filename: str, buffer: [bytes], *args, **kwargs) -> 'DataFrameLoader':
        date_parser = kwargs['date_parser'] if 'date_parser' in kwargs else None
        ext = os.path.splitext(filename)[1][1:]
        if ext == 'csv':
            csv = DataFrameLoader._parse_csv(buffer, date_parser)
            return DataFrameLoader(filename, csv, *args, **kwargs)
        elif ext == 'xlsx':
            xl = DataFrameLoader._parse_xlsx(buffer)
            return DataFrameLoader(filename, xl=xl, *args, **kwargs)
        else:
            raise Exception(f'Unsupported file extension: {ext}')

    @staticmethod
    def _read_file(filename: str) -> [bytes]:
        with open(filename, 'rb') as file:
            return file.read()

    @staticmethod
    def _parse_csv(buffer: [bytes], date_parser: Callable[[pd.DataFrame], pd.DataFrame] = None) -> pd.DataFrame:
        data = io.StringIO(buffer.decode('utf-8'))
        sheet_df = pd.read_csv(data)
        if date_parser is not None:
            sheet_df = date_parser(sheet_df)
        return sheet_df

    @staticmethod
    def _parse_xlsx(buffer: [bytes]) -> pd.ExcelFile:
        data = io.BytesIO(buffer)
        xl = pd.ExcelFile(data, engine='openpyxl')
        return xl
