In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.patches import Rectangle, Polygon, PathPatch
import mpl_toolkits.mplot3d.art3d as art3d
import matplotlib.path as mpath
import matplotlib.patches as mpatches
%matplotlib widget
    
    
    
"""
Usage example:
marker_ids_to_connect = test_positions_gt.marker_ids_to_connect_in_3D_plot
"""
 
class Visualization():

    def __init__(self, frame_idx: int, filepath_to_3D_csv: Path, marker_ids_to_connect: List[Tuple[str]], paradigm: Optional[str] = None, return_frame: bool = False):
        self._read_csv()
        self._show_3D_plot(frame_idx = frame_idx, marker_ids_to_connect = test_positions_gt.marker_ids_to_connect_in_3D_plot, paradigm = paradigm, return_frame = return_frame)
    
    
    def self._read_csv(self):
        self.df = pd.read_csv(filepath_to_3D_csv)
        self.bodyparts = []
        for key in self.df.keys():
            bodypart = key.split('_')[0]
            if bodypart not in self.bodyparts.keys() and bodypart not in set (['M', 'center', 'fnum']):
                self.bodyparts.append(bodypart)
    
    def _show_3D_plot(self, frame_idx: int, marker_ids_to_connect: List[Tuple[str]], paradigm: Optional[str] = None, return_frame: bool) -> None:
        p3d = self.df[frame_idx]
        fig = plt.figure(figsize=(8, 6))
        ax = fig.add_subplot(111, projection='3d')
        ax.scatter([-25, -25, 55, 55, -25, -25, 55, 55], [-25, 55, 55, -25, -25, 55, 55, -25], [-25, -25, -25, -25, 55, 55, 55, 55], s=100, c='white', alpha=0)
        # the line above fixes axes
        ax.scatter(p3d[:,0], p3d[:,1], p3d[:,2], c='black', s=100)
        # replace bodyparts by list of keys of df
        self._connect_all_marker_ids(ax = ax, points = p3d, scheme = marker_ids_to_connect)
        for i in range(len(self.bodyparts)):
            ax.text(p3d[i,0], p3d[i,1] + 0.01, p3d[i,2], self.bodyparts[i], size = 9)
        self._add_maze_shape(paradigm = paradigm)
        
        if return_frame:
            #https://stackoverflow.com/questions/35355930/matplotlib-figure-to-image-as-a-numpy-array
            ax.axis('off')
            fig.tight_layout(pad=0)
            ax.margins(0)
            fig.canvas.draw()
            frame = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
            frame = frame.reshape(fig.canvas.get_width_height()[::-1] + (3,))

            return frame
        else:
            plt.show()
    
    def _connect_all_marker_ids(self, ax: plt.Figure, points: np.ndarray, scheme: List[Tuple[str]]) -> List[plt.Figure]:
        # ToDo: correct type hints
        cmap = plt.get_cmap('tab10')
        bp_dict = dict(zip(self.bodyparts, range(len(self.bodyparts))))
        lines = []
        for i, bps in enumerate(scheme):
            line = self._connect_one_set_of_marker_ids(ax = ax, points = points, bps = bps, bp_dict = bp_dict, color = cmap(i)[:3])
            lines.append(line)
        return lines # return neccessary?
    
    def _connect_one_set_of_marker_ids(self, ax: plt.Figure, points: np.ndarray, bps: List[str], bp_dict: Dict, color: np.ndarray) -> plt.Figure:
        # ToDo: correct type hints
        ixs = [bp_dict[bp] for bp in bps]
        return ax.plot(points[ixs, 0], points[ixs, 1], points[ixs, 2], color=color)        
    
    def _add_maze_shape(self, paradigm: Optional[str] = None) -> None:
        if paradigm == 'OTR':
            sideright = Rectangle((0, 0), 35, 30, color='red', alpha=0.4)
            sideleft = Rectangle((0, 0), 35, 30, color='red', alpha=0.4)

        if paradigm == 'OTT':
            sideright = Polygon(np.array([[0, 0], [0, 30], [30, 0]]), closed=True, color='red', alpha=0.4)
            sideleft = Polygon(np.array([[0, 0], [0, 30], [30, 0]]), closed=True, color='red', alpha=0.4)

        if paradigm == 'OTE':
            Path = mpath.Path
            path_data = [
            (Path.MOVETO, (0, 0)),
            (Path.LINETO, (0, 30)),
            #(Path.CURVE3, (1.3, 27)),
            (Path.CURVE4, (13, 11.0)),
            (Path.CURVE4, (33.8, 2.1)), 
            (Path.CURVE4, (35, 1)),
            (Path.LINETO, (35, 0)),
            (Path.LINETO, (0, 0))]
            codes, verts = zip(*path_data)
            path = mpath.Path(verts, codes)
            sideright = mpatches.PathPatch(path, color='red', fill=True, alpha=0.4)
            sideleft = mpatches.PathPatch(path, color='red', fill=True, alpha=0.4)

        if paradigm != None:
            ax.add_patch(sideright)
            art3d.pathpatch_2d_to_3d(sideright, z=0, zdir='y')
            ax.add_patch(sideleft)
            art3d.pathpatch_2d_to_3d(sideleft, z=5, zdir='y')

        base = Rectangle((0, 0), 50, 5, color='gray', alpha=0.1)
        ax.add_patch(base)
        art3d.pathpatch_2d_to_3d(base, z=0, zdir='z')
        sideback = Rectangle((0, 0), 5, 30, color='gray', alpha=1)
        ax.add_patch(sideback)
        art3d.pathpatch_2d_to_3d(sideback, z=0, zdir='x')
   


    