In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from glob import glob
import numpy as np
from numpy import matlib
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from src import plotting
from scipy import ndimage
from tqdm import tqdm
import warnings
import nibabel as nib

In [3]:
process = 'PlotWholeBrain'
top_path = '/Users/emcmaho7/Dropbox/projects/SI_EEG/SIEEG_analysis'
input_path = f'{top_path}/data/interim'
raw_path = f'{top_path}/data/raw'
out_path = f'{top_path}/data/interim/{process}'
figure_path = f'{top_path}/reports/figures/{process}'
Path(out_path).mkdir(parents=True, exist_ok=True)
Path(figure_path).mkdir(parents=True, exist_ok=True)

plotting_subj = 2
start, end = -100, 750

## fMRI Whole Brain

In [4]:
import nibabel as nib
import numpy as np
from nilearn import plotting, surface
import nibabel as nib
from src.tools import camera_switcher
import cv2
import os
import re
import warnings

os.environ["SUBJECTS_DIR"] = "/Users/emcmaho7/Dropbox/projects/SI_EEG/SIEEG_analysis/data/raw/freesurfer"
os.environ["FREESURFER_HOME"] = "/Applications/freesurfer"


def compute_surf_stats(prefix, sub, hemi):
    file = f'{prefix}_hemi-{hemi}.mgz'
    if not os.path.exists(file):
        cmd = '/Applications/freesurfer/bin/mri_vol2surf '
        cmd += f'--src {prefix}.nii.gz '
        cmd += f'--out {file} '
        cmd += f'--regheader sub-{sub} '
        cmd += f'--hemi {hemi} '
        cmd += '--projfrac 1 '
        cmd += '> /dev/null 2>&1'
        os.system(cmd)
    return surface.load_surf_data(file)


def load_surf_mesh(path, sub, hemi):
    return f'{path}/freesurfer/sub-{sub}/surf/{hemi}.inflated', \
            f'{path}/freesurfer/sub-{sub}/surf/{hemi}.sulc'


def plot_stats(surf_mesh, bg_map, surf_map, hemi_, figure_prefix,
               vmax=0.3, threshold=1e-6, 
               title=None, cmap_name='magma', 
               views=['ventral', 'medial', 'lateral']):
    cmap=sns.color_palette(cmap_name, as_cmap=True)
    hemi_name = 'left' if hemi_ == 'lh' else 'right'
    
    for view in views:
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", message="choosing both vmin and a threshold is not allowed; setting vmin to 0")
            fig = plotting.plot_surf_roi(surf_mesh=surf_mesh,
                                        roi_map=surf_map,
                                        bg_map=bg_map,
                                        vmax=vmax,
                                        vmin=0., 
                                        engine='plotly',
                                        colorbar=True,
                                        view=view,
                                        cmap=cmap,
                                        title=title,
                                        title_font_size=30,
                                        hemi=hemi_name,
                                        kwargs={'symmetric_cmap': False})
            fig.figure.update_layout(coloraxis_colorbar=dict(
                                    title='Explained variance ($r^2$)',
                                    tickvals=[0, vmax],  # Positions of the ticks
                                    ticktext=list(np.linspace(0, vmax, 4).round(2)),
                                    thickness=25,
                                    len=0.75,
                                    x=1.02),
                                     scene_camera=camera_switcher(hemi_, view))
            fig.figure.write_image(f'{figure_prefix}_view-{view}_hemi-{hemi_}.png')


def pngs_to_mp4(input_folder, output_file,
                file_pattern, fps=9, frame_size=None,
                start_frame=None, end_frame=None):
    """
    Converts a series of PNG images in a folder into an MP4 video file.

    Parameters:
    - input_folder: Path to the folder containing PNG images.
    - output_file: Path to save the MP4 video file.
    - fps: Frames per second in the output video.
    - frame_size: Tuple of (width, height) for the video frame size. If None, the size of the first image is used.
    """

    # Get all the PNG files in the folder
    images = glob(os.path.join(input_folder, file_pattern))
    images = sorted(images)  # Sort the images by name
    if start_frame is not None and end_frame is not None: 
        images = [image for image in images if start_frame < int(re.search("timepoint-(\d+)", image).group(1)) < end_frame]

    # Define the codec and create VideoWriter object
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # Codec used to compress the frames
    if not frame_size:
        # If frame size is not specified, use the first image to determine the size
        first_image = cv2.imread(images[0])
        frame_size = (first_image.shape[1], first_image.shape[0])
    out = cv2.VideoWriter(output_file, fourcc, fps, frame_size)

    for image_path in images:
        img = cv2.imread(image_path)
        # Resize the image to match the frame size, if necessary
        if img.shape[1] != frame_size[0] or img.shape[0] != frame_size[1]:
            img = cv2.resize(img, frame_size)
        out.write(img)

    # Release everything when the job is finished
    out.release()
    cv2.destroyAllWindows()

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
hemis = ['lh', 'rh']
views = ['ventral', 'medial', 'lateral']
times = np.arange(-200, 1000, 10)

subj_id = str(plotting_subj).zfill(2)
files = sorted(glob(f'{input_path}/fMRIWholeBrain/sub-{subj_id}_time*.nii.gz'))
for file, time in zip(files, times):
    if time >= start and time <= end: 
        stat_file = file.split('.')[0]
        plot_file_prefix = file.split('/')[-1].split('.')[0]
        plot_file = f'{out_path}/{plot_file_prefix}.png'
        for hemi in hemis:
            surf = compute_surf_stats(stat_file, subj_id, hemi)
            surf[surf < 0] = 0.01
            inflated, sulcus = load_surf_mesh(raw_path, subj_id, hemi)
            plot_stats(inflated, sulcus, surf, hemi, plot_file,
                    title=f'{time:.0f} ms', cmap_name='rocket',
                    views=views, vmax=.5)

KeyboardInterrupt: 

In [None]:
for hemi in hemis: 
    for view in views:
        output_file = f'{out_path}/sub-{str(plotting_subj).zfill(2)}_view-{view}_hemi-{hemi}.mp4'
        file_pattern = f'sub-{str(plotting_subj).zfill(2)}*view-{view}*hemi-{hemi}*.png'
        pngs_to_mp4(input_folder=f'{out_path}',
                    output_file=output_file, fps=6,
                    file_pattern=file_pattern, start_frame=6, end_frame=100)