In [None]:
#default_exp datasets

In [None]:
#export
import os
import shutil
import requests
import tarfile
import torch
from dl4to.datasets import TopoDataset, CSVConverter

# CSV dataset

In [None]:
#export

class CSVDataset(TopoDataset):
    """
    A class for downloading, generating and importing datasets from CSV files. An inheriting class needs to override the method `_get_gz_file_paths_dict` for correct paths to the `.tar.gz` files that contain the `csv` files (see e.g. SELTO datasets).
    """
    def __init__(
        self,
        root:str, # The root directory in which the datasets should be downloaded, generated and accessed.
        name:str, # The name of the dataset that should be downloaded.
        train:bool=True, # Whether the training or validation dataset should be generated.
        size:int=-1, # The size of the dataset. If `size=-1`, then the whole dataset is imported. Useful if only subsets of the original dataset are needed.
        download:bool=True, # Whether the dataset should be downloaded, if needed.
        verbose:bool=True, # Whether to give the user feedback on the progress.
        dtype:torch.dtype=torch.float32, # The datatype into which the values from the csv files are converted.
        pde_solver:"dl4to.pde.PDESolver"=None, # The PDE solver that is used to solve the PDE for linear elasticity. Only has an effect if either `solve_pde_for_trivial_solution=True` or `solve_pde_for_gt_solution=True`.
        solve_pde_for_trivial_solution:bool=False, # Whether to solve the PDE for each trivial solution and save the displacements in the solution object. These can later be accessed via `problem.trivial_solution.u`. This is useful if PDE preprocessing is used. Requires a PDE solver.
        solve_pde_for_gt_solution:bool=False # Whether to solve the PDE for each ground truth and save the displacements in the solution object. These can later be accessed via `gt_solution.u`. Requires a PDE solver.
    ):

        dataset_name = self._get_dataset_name(name, train)
        self._dtype = dtype
        super().__init__(
            name=dataset_name,
            verbose=verbose
        )
        self._size = size
        self._pt_dir_path = self._get_pt_dir_path(train, root, name)
        self._create_dirs(root, name)
        self.pt_file_paths = self._get_pt_file_paths()

        if len(self.pt_file_paths) == 0:
            self._generate_dataset(
                dataset_name=dataset_name, 
                download=download, 
                dtype=dtype, 
                verbose=verbose,
                pde_solver=pde_solver,
                solve_pde_for_trivial_solution=solve_pde_for_trivial_solution,
                solve_pde_for_gt_solution=solve_pde_for_gt_solution
            )
            self.pt_file_paths = self._get_pt_file_paths()

        self._load_dataset()


    @property
    def pt_dir_path(self):
        return self._pt_dir_path


    @property
    def dtype(self):
        return self._dtype


    def _generate_dataset(self, dataset_name, download, dtype, verbose, pde_solver, 
                          solve_pde_for_trivial_solution, solve_pde_for_gt_solution):
        gz_file_paths = [f'{self.pt_dir_path}/{file_name}' for file_name in os.listdir(self.pt_dir_path) if file_name[-2:] == 'gz']

        if len(gz_file_paths) == 0:
            if download:
                gz_file_path = self._download_gz_file(dataset_name)
            else:
                raise AttributeError('Dataset cannot be constructed with `download=False`.')
        else:
            if len(gz_file_paths) != 1:
                raise AttributeError('Directory contains more than one `.gz` file.')
            gz_file_path = gz_file_paths[0]

        csv_dir_path = self._extract_gz_file(gz_file_path, dataset_name)

        csv_converter = CSVConverter(
            csv_dir_path=csv_dir_path,
            dtype=dtype,
            verbose=verbose,
            pde_solver=pde_solver,
            solve_pde_for_trivial_solution=solve_pde_for_trivial_solution,
            solve_pde_for_gt_solution=solve_pde_for_gt_solution
        )

        csv_converter(self.pt_dir_path)
        shutil.rmtree(csv_dir_path)


    def _get_pt_file_paths(self):
        file_names = os.listdir(self.pt_dir_path)
        pt_file_paths = [f"{self.pt_dir_path}/{name}" for name in file_names if name[-2:] == 'pt']

        if self.size == -1 or self.size == 0:
            self._size = len(pt_file_paths)

        pt_file_paths = pt_file_paths[:self.size]

        if self.verbose:
            print(f"Found {len(pt_file_paths)} files.")

        return pt_file_paths


    def _load_dataset(self):
        self.dataset = []

        if self.verbose:
            print('importing dataset...')
            pt_file_paths = tqdm(self.pt_file_paths)

        for pt_file_path in pt_file_paths:
            self.dataset.append(torch.load(pt_file_path))

        if self.verbose:
            print('done!')


    def _get_pt_dir_path(self, train, root, name):
        if train:
            return f'{root}/{name}/train'
        return f'{root}/{name}/test'


    def _get_dataset_name(self, name, train):
        if train:
            return f'{name}_train'
        return f'{name}_test'


    def _create_dirs(self, root, name):
        if not os.path.exists(root):
            os.mkdir(root)

        if not os.path.exists(f'{root}/{name}'):
            os.mkdir(f'{root}/{name}')

        if not os.path.exists(self.pt_dir_path):
            os.mkdir(self.pt_dir_path)


    def _get_gz_file_paths_dict(self):
        raise NotImplementedError("Must be overridden.")


    def _download_gz_file(self, dataset_name):
        gz_file_path = f'{self.pt_dir_path}/{dataset_name}.tar.gz'
        resources = self._get_gz_file_paths_dict()
        url = resources[dataset_name]

        if self.verbose:
            print(f"Downloading {dataset_name}...")

        with open(gz_file_path, 'wb') as f:
            r = requests.get(url)
            f.write(r.content)

        if self.verbose:
            print("Done!")
        return gz_file_path


    def _extract_gz_file(self, gz_file_path, dataset_name):
        tar = tarfile.open(gz_file_path, 'r:gz')

        csv_dir_path = f'{self.pt_dir_path}/csv'

        if os.path.exists(csv_dir_path):
            shutil.rmtree(csv_dir_path)

        os.mkdir(csv_dir_path)

        if self.verbose:
            print(f"Extracting {dataset_name}...")

        tar.extractall(path=csv_dir_path)
        tar.close()

        if self.verbose:
            print("Done!")
        return csv_dir_path