Skip to content

Commit

Permalink
feat: type hints in utils
Browse files Browse the repository at this point in the history
  • Loading branch information
bhosale2 committed Dec 6, 2022
1 parent ba70e21 commit 16b2e35
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 41 deletions.
55 changes: 29 additions & 26 deletions sopht/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import h5py
import numpy as np
from elastica.rod.cosserat_rod import CosseratRod
from typing import Dict, Type, List


class IO:
Expand Down Expand Up @@ -29,7 +30,7 @@ class IO:
(Fields) (Fields) (Fields) (Fields)
"""

def __init__(self, dim, real_dtype=np.float64):
def __init__(self, dim: int, real_dtype: Type = np.float64) -> None:
"""Class initializer."""
self.dim = dim
assert self.dim == 2 or self.dim == 3, "Invalid dimension (only 2D and 3D)"
Expand All @@ -39,23 +40,23 @@ def __init__(self, dim, real_dtype=np.float64):
# Initialize dictionaries for fields for IO and their
# corresponding field_type ('Scalar' or 'Vector') Eulerian grid
self.eulerian_grid_defined = False
self.eulerian_fields = {}
self.eulerian_fields_type = {}
self.eulerian_fields: Dict = {}
self.eulerian_fields_type: Dict = {}

# Lagrangian grid
self.lagrangian_fields = {}
self.lagrangian_fields_type = {}
self.lagrangian_grids = {}
self.lagrangian_fields_with_grid_name = {}
self.lagrangian_fields: Dict = {}
self.lagrangian_fields_type: Dict = {}
self.lagrangian_grids: Dict = {}
self.lagrangian_fields_with_grid_name: Dict = {}
self.lagrangian_grid_count = 0
self.lagrangian_grid_connection = {}
self.lagrangian_grid_connection: Dict = {}

def define_eulerian_grid(
self,
origin,
dx,
grid_size,
):
origin: np.ndarray,
dx: np.ndarray,
grid_size: np.ndarray,
) -> None:
"""
Define the Eulerian grid mesh.
Expand All @@ -79,7 +80,7 @@ def define_eulerian_grid(
self.eulerian_grid_size = grid_size # z,y,x
self.eulerian_grid_defined = True

def add_as_eulerian_fields_for_io(self, **fields_for_io):
def add_as_eulerian_fields_for_io(self, **fields_for_io) -> None:
"""Add Eulerian fields to be saved/loaded.
Eulerian grid needs to be defined first using `define_eulerian_grid(...)` call.
Expand Down Expand Up @@ -112,11 +113,11 @@ def add_as_eulerian_fields_for_io(self, **fields_for_io):

def add_as_lagrangian_fields_for_io(
self,
lagrangian_grid,
lagrangian_grid_name=None,
lagrangian_grid_connect=False,
lagrangian_grid: np.ndarray,
lagrangian_grid_name: str = None,
lagrangian_grid_connect: bool = False,
**fields_for_io,
):
) -> None:
"""
Add lagrangian fields to be saved/loaded.
Expand Down Expand Up @@ -175,7 +176,7 @@ def add_as_lagrangian_fields_for_io(
f"(scalar / vector) based on field dimension {field.shape}"
)

def save(self, h5_file_name, time=0.0): # noqa: C901
def save(self, h5_file_name: str, time: float = 0.0) -> None: # noqa: C901
"""
This is a wrapper function to call _save function.
Expand All @@ -189,7 +190,7 @@ def save(self, h5_file_name, time=0.0): # noqa: C901

self._save(h5_file_name, time)

def _save(self, h5_file_name, time=0.0): # noqa: C901
def _save(self, h5_file_name: str, time: float = 0.0) -> None: # noqa: C901
"""
Save added fields to hdf5 file.
Expand Down Expand Up @@ -301,7 +302,7 @@ def _save(self, h5_file_name, time=0.0): # noqa: C901
if self.lagrangian_fields:
self.generate_xdmf_lagrangian(h5_file_name=h5_file_name, time=time)

def load(self, h5_file_name): # noqa: C901
def load(self, h5_file_name: str) -> None: # noqa: C901
"""Load fields from hdf5 file.
Field arrays need to be allocated and added to `eulerian_fields` and/or
Expand All @@ -313,7 +314,7 @@ def load(self, h5_file_name): # noqa: C901
String containing name of the hdf5 file.
"""
with h5py.File(h5_file_name, "r") as f:
keys = []
keys: List = []
f.visit(keys.append)

# Load time
Expand Down Expand Up @@ -419,7 +420,7 @@ def load(self, h5_file_name): # noqa: C901

return time

def generate_xdmf_eulerian(self, h5_file_name, time=0.0):
def generate_xdmf_eulerian(self, h5_file_name: str, time: float = 0.0) -> None:
"""Generate XDMF description file for Eulerian fields.
Currently, the XDMF file is generated for Paraview visualization only.
Expand Down Expand Up @@ -515,7 +516,7 @@ def generate_field_entry(file_name, field_name, field_type):
with open(h5_file_name.replace(".h5", "_eulerian.xmf"), "w") as f:
f.write(xdmffile)

def generate_xdmf_lagrangian(self, h5_file_name, time):
def generate_xdmf_lagrangian(self, h5_file_name: str, time: float) -> None:
"""Generate XDMF description file for Lagrangian fields.
Currently, the XDMF file is generated for Paraview visualization only.
Expand Down Expand Up @@ -620,7 +621,9 @@ class CosseratRodIO(IO):
Derived IO class for Cosserat rod IO.
"""

def __init__(self, cosserat_rod: CosseratRod, dim, real_dtype=np.float64):
def __init__(
self, cosserat_rod: CosseratRod, dim: int, real_dtype: Type = np.float64
) -> None:
super().__init__(dim, real_dtype)
self.cosserat_rod = cosserat_rod

Expand All @@ -636,11 +639,11 @@ def __init__(self, cosserat_rod: CosseratRod, dim, real_dtype=np.float64):
lagrangian_grid_connect=True,
)

def save(self, h5_file_name, time=0.0):
def save(self, h5_file_name: str, time: float = 0.0) -> None:
self._update_rod_element_position()
self._save(h5_file_name=h5_file_name, time=time)

def _update_rod_element_position(self):
def _update_rod_element_position(self) -> None:
self.rod_element_position[...] = 0.5 * (
self.cosserat_rod.position_collection[: self.dim, 1:]
+ self.cosserat_rod.position_collection[: self.dim, :-1]
Expand Down
2 changes: 1 addition & 1 deletion sopht/utils/lab_cmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from matplotlib import cm


def get_lab_cmap(res: int = 256):
def get_lab_cmap(res: int = 256) -> ListedColormap:
"""Returns Custom map resembling Orange-Blue scheme"""
top = cm.get_cmap("Oranges", res)
bottom = cm.get_cmap("Blues", res)
Expand Down
13 changes: 10 additions & 3 deletions sopht/utils/plot_field.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
def create_figure_and_axes(fig_aspect_ratio=1.0):
from typing import Tuple
import matplotlib.pyplot as plt


def create_figure_and_axes(
fig_aspect_ratio: float = 1.0,
) -> Tuple[plt.Figure, plt.Axes]:
"""Creates figure and axes for plotting contour fields"""
import matplotlib.pyplot as plt

plt.style.use("seaborn")
fig = plt.figure(frameon=True, dpi=150)
Expand All @@ -12,7 +17,9 @@ def create_figure_and_axes(fig_aspect_ratio=1.0):
return fig, ax


def save_and_clear_fig(fig, ax, cbar=None, file_name=""):
def save_and_clear_fig(
fig: plt.Figure, ax: plt.Axes, cbar: plt.Figure.colorbar = None, file_name: str = ""
) -> None:
"""Save figure and clear for next iteration"""
fig.savefig(
file_name,
Expand Down
4 changes: 2 additions & 2 deletions sopht/utils/post_process.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
def make_video_from_image_series(
video_name: str, image_series_name: str, frame_rate: int
):
) -> None:
"""Makes a video using ffmpeg from series of images"""
import os

Expand All @@ -17,7 +17,7 @@ def make_video_from_image_series(
os.system(f"rm -f {image_series_name}*.png")


def make_dir_and_transfer_h5_data(dir_name: str, clean_dir: bool = True):
def make_dir_and_transfer_h5_data(dir_name: str, clean_dir: bool = True) -> None:
"""Makes a new directory and transfers h5 flow data files to the directory"""
import os

Expand Down
2 changes: 1 addition & 1 deletion sopht/utils/pyst_kernel_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def get_pyst_dtype(real_t: Type) -> str:

def get_pyst_kernel_config(
real_t: Type, num_threads: int, iteration_slice: Tuple = None
):
) -> ps.CreateKernelConfig:
"""Returns the pystencils kernel config based on the data
dtype and number of threads"""
pyst_dtype = get_pyst_dtype(real_t)
Expand Down
18 changes: 10 additions & 8 deletions sopht/utils/rod_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@
from tqdm import tqdm
from matplotlib.patches import Circle
import matplotlib.animation as animation
from typing import Dict, Sequence
from typing import Dict, Sequence, Tuple
from sopht.utils.field import VectorField


def plot_video_of_rod_surface( # noqa C901
rods_history: Sequence[Dict],
video_name="video.mp4",
fps=60,
step=1,
fps: int = 60,
step: int = 1,
**kwargs,
):
) -> None:
plt.rcParams.update({"font.size": 22})
folder_name = kwargs.get("folder_name", "")
# 2d case <always 2d case for now>
Expand All @@ -24,14 +24,14 @@ def plot_video_of_rod_surface( # noqa C901
n_visualized_rods = len(rods_history) # should be one for now

# Rod info
def rod_history_unpacker(rod_idx, t_idx):
def rod_history_unpacker(rod_idx: int, t_idx: int) -> Tuple[np.ndarray, np.ndarray]:
return (
rods_history[rod_idx]["position"][t_idx],
rods_history[rod_idx]["radius"][t_idx],
)

# Rod center of mass
def com_history_unpacker(rod_idx):
def com_history_unpacker(rod_idx: int) -> np.ndarray:
return rods_history[rod_idx]["com"][time_idx]

# Generate target sphere data
Expand All @@ -41,7 +41,9 @@ def com_history_unpacker(rod_idx):
sphere_history = kwargs["sphere_history"]
n_visualized_spheres = len(sphere_history) # should be one for now

def sphere_history_unpacker(sph_idx, t_idx):
def sphere_history_unpacker(
sph_idx: int, t_idx: int
) -> Tuple[np.ndarray, np.ndarray]:
return (
sphere_history[sph_idx]["position"][t_idx],
sphere_history[sph_idx]["radius"][t_idx],
Expand All @@ -60,7 +62,7 @@ def sphere_history_unpacker(sph_idx, t_idx):
ylim = kwargs.get("y_limits", (-1.0, 1.0))
zlim = kwargs.get("z_limits", (-0.05, 1.0))

def difference(x):
def difference(x: Tuple[float, float]) -> float:
return x[1] - x[0]

max_axis_length = max(difference(xlim), difference(ylim))
Expand Down

0 comments on commit 16b2e35

Please sign in to comment.