Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add the gdrive/gee methods #539

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 161 additions & 0 deletions sepal_ui/scripts/gdrive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
"""
Google Drive object providing a simple interface to interact with files from Gdrive.
"""

from typing import Optional, Union
import json
import io
from pathlib import Path

import ee
from apiclient import discovery
from google.oauth2.credentials import Credentials
from googleapiclient.http import MediaIoBaseDownload
from osgeo import gdal

import sepal_ui.scripts.decorator as sd

sd.init_ee()
class GDrive:

def __init__(self) -> None:

self.initialize = ee.Initialize()

# Access to sepal access token
self.access_token = json.loads(
(Path.home() / ".config/earthengine/credentials").read_text()
).get("access_token")

self.service = discovery.build(
serviceName="drive",
version="v3",
cache_discovery=False,
credentials=Credentials(self.access_token),
)

def get_all_items(self, mime_type: Optional[str]="image/tiff") -> list:
"""Get all the items in the Gdrive.

items will have 2 columns, 'name' and 'id'.
It excludes files that are contained in the trashbin.

Args:
mime_type (str, optional): the mime type to look for by default Tif images
folder (str): the id of the folder we want to look into

Return:
(list): the found items with 2 columns ('id' and 'name')
"""

# get list of files
return (
self.service.files()
.list(
q=f"mimeType='{mime_type}' and trashed = false",
pageSize=1000,
fields="nextPageToken, files(id, name)",
)
.execute()
.get("files", [])
)


def get_items(self, file_name:Union[str, Path], mime_type: str = "image/tiff") -> list:
"""Look for the file_name patern in user Gdrive files and retreive a list of Ids.

usually gee export your files using a tiling system so the file name provided
need to be the one from the export description.

Args:
file_name (str): the file name used during the exportation step
mime_type (str, optional): the mime type to look for by default Tif images

Return:
(list): the list of file id corresponding to the requested filename in your gdrive account
"""

return [i for i in self.get_all_items(mime_type) if i["name"].startswith(file_name)]


def delete_items(self, items: list) -> None:
"""
Delete the items from Gdrive

Args:
items (list): the list of item to delete as described in get_imes functions
"""

for i in items:
self.service.files().delete(fileId=i["id"]).execute()

return


def download_items(
self,
file_name: Union[str, Path],
local_path: Union[str, Path],
mime_type: str ="image/tiff",
delete: Optional[bool]=False
) -> Union[Path, None]:

"""Download from Gdrive all the file corresponding to an equivalent get_items request.

if the mime_type is "image/tiff" a vrt file will be created. The delete option will automatically delete files once they are dowloaded.

Args:
file_name (str): the file name used during the exportation step
local_path (pathlike object): the destination of the files
mime_type (str, optional): the mime type to look for by default Tif images
delete (bool, optional): either or not the file need to be deleted once the download is finished. default to :code:`False`

Return:
(pathlib.Path): the path to the download folder or the path to the vrt
"""

# cast as path
local_path = Path(local_path)

# get the items
items = self.get_items(file_name, mime_type)

# load them to the use workspace
local_files = []
for i in items:
request = self.service.files().get_media(fileId=i["id"])
fh = io.BytesIO()
downloader = MediaIoBaseDownload(fh, request)

# download in chunks
done = False
while done is False:
status, done = downloader.next_chunk()

# write to files
local_file = local_path / i["name"]
with local_file.open("wb") as fo:
fo.write(fh.getvalue())

local_files.append(local_file)

# delete the items ?
if delete:
self.delete_items(items)

# create a vrt ?
if mime_type == "image/tiff":
vrt_file = local_path / f"{file_name}.vrt"
ds = gdal.BuildVRT(str(vrt_file), [str(f) for f in local_files])

# if there is no cache to empty it means that one of the dataset was empty
try:
ds.FlushCache()
except AttributeError:
raise Exception("one of the dataset was empty")

# check that the file was effectively created (gdal doesn't raise errors)
if not vrt_file.is_file():
raise Exception(f"the vrt {vrt_file} was not created")

return vrt_file if mime_type == "image/tiff" else local_path
160 changes: 160 additions & 0 deletions tests/test_gdrive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import tempfile
from itertools import product
from pathlib import Path

import pytest
import rasterio as rio
from google_drive_downloader import GoogleDriveDownloader as gdd
from googleapiclient.http import MediaFileUpload
from rasterio import windows

from sepal_ui.scripts import gdrive
from sepal_ui.scripts import utils as su


class TestGdrive:
def test_get_all_items(self, tmp_dem, gdrive_folder):

# extract name and folder
tmp_dir, test_file = tmp_dem

list_items = gdrive.get_all_items()

# at least the one I added manually
assert len(list_items) >= 9

return

def test_get_items(self, tmp_dem, gdrive_folder):

# extract name and folder
tmp_dir, test_file = tmp_dem

list_items = gdrive.get_items(test_file.stem)

assert len(list_items) == 9

return

def test_download_items(self, tmp_dem, gdrive_folder):

# extract name and folder
tmp_dir, test_file = tmp_dem

# extract all the files from the folder
with tempfile.TemporaryDirectory() as loc_tmp_dir:

gdrive.download_items(test_file.stem, loc_tmp_dir)

loc_tmp_dir = Path(loc_tmp_dir)
assert len([f for f in loc_tmp_dir.glob("*.tif")]) == 9
assert len([f for f in loc_tmp_dir.glob("*.vrt")]) == 1

return

def test_delete_items(self, tmp_dem, gdrive_folder):

# extract name and folder
tmp_dir, test_file = tmp_dem

gdrive.delete_items(gdrive.get_items(test_file.stem))

# assert
assert gdrive.get_items(test_file.stem) == []

return

@pytest.fixture(scope="class")
def gdrive_folder(self, tmp_dem):
"""create a fake folder in my gdrive and run the test over it"""

# extract name and folder
tmp_dir, test_file = tmp_dem

# create a gdrive folder
body = {
"name": "test_sepal_ui",
"mimeType": "application/vnd.google-apps.folder",
}
gdrive_folder = gdrive.SERVICE.files().create(body=body).execute()

# send all the tile files to the gdrive folder
files = [f for f in tmp_dir.glob("*.tif") if not f.name.endswith("dem.tif")]
for f in files:
file_metadata = {"name": f.name, "parents": [gdrive_folder["id"]]}
media = MediaFileUpload(f, mimetype="image/tiff")
(
gdrive.SERVICE.files()
.create(body=file_metadata, media_body=media)
.execute()
)

yield gdrive_folder

# delete the folder
gdrive.SERVICE.files().delete(fileId=gdrive_folder["id"]).execute()

return

@pytest.fixture(scope="class")
def tmp_dem(self):
"""the tmp dir containing the dem"""

# create a tmp directory and save the DEM file inside
with tempfile.TemporaryDirectory() as tmp_dir:

tmp_dir = Path(tmp_dir)

# save the file
test_file = tmp_dir / f"{su.random_string(8)}_dem.tif"
test_id = "1vRkAWQYsLWCi6vcTMk8vLxoXMFbdMFn8"
gdd.download_file_from_google_drive(test_id, test_file, True, True)

# cut the image in pieces
with rio.open(test_file) as src:

tile_width = int(src.meta["width"] / 2)
tile_height = int(src.meta["height"] / 2)
meta = src.meta.copy()

for window, transform in self.get_tiles(src, tile_width, tile_height):

meta["transform"] = transform
meta["width"], meta["height"] = window.width, window.height
outpath = (
tmp_dir
/ f"{test_file.stem}_{window.col_off}_{window.row_off}.tif"
)
with rio.open(outpath, "w", **meta) as dst:
dst.write(src.read(window=window))

yield tmp_dir, test_file

# add this empty line before return to make sure that the file is destroyed
return

@staticmethod
def get_tiles(ds, width, height):
"""
Cut an image in pieces according to the specified width and height

Args:
ds: dataset
width: the width of the tile
height; the height of the tile

Yield:
(window, transform): the tuple of the window characteristics corresponding to each tile
"""
ncols, nrows = ds.meta["width"], ds.meta["height"]

offsets = product(range(0, ncols, width), range(0, nrows, height))
big_window = windows.Window(col_off=0, row_off=0, width=ncols, height=nrows)
for col_off, row_off in offsets:
window = windows.Window(
col_off=col_off, row_off=row_off, width=width, height=height
).intersection(big_window)
transform = windows.transform(window, ds.transform)
yield window, transform

return