# Imports

In [1]:
import numpy as np

import sys

import matplotlib.pyplot as plt
import matplotlib.animation as animation
import matplotlib as mpl
from matplotlib.ticker import ScalarFormatter

from astropy import units as u
from astropy import constants as const
from astropy.table import QTable

from pathlib import Path
import os

sys.path.insert(0, '/home/emelie/Desktop/Codes/PETAR/PeTar-master/tools/')
from analysis import *

In [2]:
plt.rcParams.update({'xtick.labelsize':13, 'ytick.labelsize':13, 'axes.titlesize':16, 
                     'axes.grid':True, 'axes.labelsize':15, 'legend.fontsize':13})

# Functions

## Plotting

### Plotting Means

In [3]:
def means_plots(which, header_array, mean, t_max, run, fig_width=14, fig_height=10, save=True):
    """
    Parameters:
    -----------

    which: str
            If the plots should contain velocities or positions
            
    header_array: array
            Array with header values: id, N, t, x, y, z, vx, vy, vz
            
    mean: float
            Mean value
            
    t_max: int or float
            Maximum time of simulation
            
    run: str
            Name of run
            
    fig_width: int or float
            Width of figure
            
    fig_height: int or float
            Height of figure
            
    save: str
            If figure should be saved
    
    """
    text = t_max/10
    tmin = 0-text
    tmax = t_max+text
    
    
    if which=='vel':
        fig, ax = plt.subplots(2, 2, figsize=(fig_width, fig_height))

        ax[0,0].plot(header_array[1, :], mean, color='b', marker='o')

        ax[0,0].set_xlabel('Time [Myr]')
        ax[0,0].set_ylabel(r'$v_{mean}$ [km/s]')
        ax[0,0].set_title('Mean velocity per timestep')
        ax[0,0].set_xlim(tmin, tmax)



        ax[0,1].plot(header_array[1, :], header_array[-3, :], color='b', marker='o')

        ax[0,1].set_xlabel('Time [Myr]')
        ax[0,1].set_ylabel(r'$v_{mean, x}$ [km/s]')
        ax[0,1].set_title('Mean velocity in x per timestep')
        ax[0,1].axhline(0, color='k', zorder=0)
        ax[0,1].set_xlim(tmin, tmax)


        ax[1,0].plot(header_array[1, :], header_array[-2, :], color='b', marker='o')

        ax[1,0].set_xlabel('Time [Myr]')
        ax[1,0].set_ylabel(r'$v_{mean, y}$ [km/s]')
        ax[1,0].set_title('Mean velocity in y per timestep')
        ax[1,0].axhline(0, color='k', zorder=0)
        ax[1,0].set_xlim(tmin, tmax)



        ax[1,1].plot(header_array[1, :], header_array[-1, :], color='b', marker='o')

        ax[1,1].set_xlabel('Time [Myr]')
        ax[1,1].set_ylabel(r'$v_{mean, z}$ [km/s]')
        ax[1,1].set_title('Mean velocity in z per timestep')
        ax[1,1].axhline(0, color='k', zorder=0)
        ax[1,1].set_xlim(tmin, tmax)

        plt.tight_layout()
        
        if save:
            plt.savefig(f'Velocity_plots_{run}.png', bbox_inches='tight')
            
        plt.show()
        
        
        
        
    if which=='pos':
        fig, ax = plt.subplots(2, 2, figsize=(fig_width, fig_height))

        ax[0,0].plot(header_array[1, :], mean, color='b', marker='o')

        ax[0,0].set_xlabel('Time [Myr]')
        ax[0,0].set_ylabel(r'$r_{mean}$ [kpc]')
        ax[0,0].set_title('Mean r per timestep')
        ax[0,0].set_xlim(tmin, tmax)
        #ax[0,0].set_ylim(ymin=-1000, ymax=23000)



        ax[0,1].plot(header_array[1, :], header_array[2, :], color='b', marker='o')

        ax[0,1].set_xlabel('Time [Myr]')
        ax[0,1].set_ylabel(r'$x_{mean}$ [kpc]')
        ax[0,1].set_title('Mean x per timestep')
        ax[0,1].axhline(0, color='k', zorder=0)
        ax[0,1].set_xlim(tmin, tmax)
        #ax[0,1].set_ylim(ymin=-21000, ymax=21000)


        ax[1,0].plot(header_array[1, :], header_array[3, :], color='b', marker='o')

        ax[1,0].set_xlabel('Time [Myr]')
        ax[1,0].set_ylabel(r'$y_{mean}$ [kpc]')
        ax[1,0].set_title('Mean y per timestep')
        ax[1,0].axhline(0, color='k', zorder=0)
        ax[1,0].set_xlim(tmin, tmax)
        #ax[1,0].set_ylim(ymin=-21000, ymax=21000)



        ax[1,1].plot(header_array[1, :], header_array[4, :], color='b', marker='o')

        ax[1,1].set_xlabel('Time [Myr]')
        ax[1,1].set_ylabel(r'$z_{mean}$ [kpc]')
        ax[1,1].set_title('Mean z per timestep')
        ax[1,1].axhline(0, color='k', zorder=0)
        ax[1,1].set_xlim(tmin, tmax)
        #ax[1,1].set_ylim(ymin=-21000, ymax=21000)

        plt.tight_layout()
        
        if save:
            plt.savefig(f'Position_plots_{run}.png', bbox_inches='tight')
        
        plt.show()
        

### Plotting snapshots

In [4]:
def snapshots_plot(data, snapshots, xlimits, ylimits, fsize, marker_size, run, save=False):
    nrows = len(snapshots)
    fig, ax = plt.subplots(nrows, 2, figsize=fsize)

    for i, t in enumerate(snapshots):
        ellipse = mpl.patches.Ellipse((0, 0), width=30.000, height=0.300, color='b', alpha=0.25) # Disc (x, z) view 30000pc and 300pc
        circle1 = mpl.patches.Circle((0, 0), 1.500, color='b', alpha=0.25) # Bulge (x, z) view 1500pc
        circle2 = mpl.patches.Circle((0, 0), 15.000, color='b', alpha=0.25) # Disc (x, y) view 15000pc

        ax[i, 0].scatter(data[:, 1, t], data[:, 3, t], s=marker_size)
        ax[i, 0].scatter(0, 0, s=20, marker='x', color='black')
        ax[i, 0].set_xlabel('x [kpc]')
        ax[i, 0].set_ylabel('z [kpc]')
        ax[i, 0].set_xlim(xmin=-xlimits, xmax=xlimits)
        ax[i, 0].set_ylim(ymin=-ylimits, ymax=ylimits)
        ax[i, 0].add_patch(circle1)
        ax[i, 0].add_patch(ellipse)
        ax[i, 0].xaxis.set_major_formatter(ScalarFormatter(useMathText=True))
        ax[i, 0].xaxis.get_major_formatter().set_powerlimits((0, 0))
        ax[i, 0].yaxis.set_major_formatter(ScalarFormatter(useMathText=True))
        ax[i, 0].yaxis.get_major_formatter().set_powerlimits((0, 0))

        ax[i, 1].scatter(data[:, 1, t], data[:, 2, t], s=marker_size)
        ax[i, 1].scatter(0, 0, s=20, marker='x', color='black')
        ax[i, 1].set_xlabel('x [kpc]')
        ax[i, 1].set_ylabel('y [kpc]')
        ax[i, 1].set_xlim(xmin=-xlimits, xmax=xlimits)
        ax[i, 1].set_ylim(ymin=-ylimits, ymax=ylimits)
        ax[i, 1].add_patch(circle2)
        ax[i, 1].xaxis.set_major_formatter(ScalarFormatter(useMathText=True))
        ax[i, 1].xaxis.get_major_formatter().set_powerlimits((0, 0))
        ax[i, 1].yaxis.set_major_formatter(ScalarFormatter(useMathText=True))
        ax[i, 1].yaxis.get_major_formatter().set_powerlimits((0, 0))
    
    
    plt.tight_layout()
    if save:
        plt.savefig(f'/Plots/Snapshots_{run}.png')
    plt.show()

## Animations

### Disc animations

In [5]:
def anim_disc(data_input, data_header, tstep, nsteps, lims, fsize, tx, ty, format_type, fps, run):
    """
    
    Parameters:
    -----------
    
    format_type: str
            Format of the output animation, 'mp4' or 'gif'
    
    """
    
    len_data = len(data_input)
    f_anim = np.arange(0, nsteps, 1)
    
    pos_in = np.round(data_input[0, 1:4, 0], 3)
    vel_in = np.round(data_input[0, 4:7, 0], 3)

    fig_anim, ax_anim = plt.subplots(figsize=fsize) 

    scatter1, = ax_anim.plot([], [], c='b', marker='.', linestyle=' ')

    circle = mpl.patches.Circle((0, 0), 15.000, color='b', alpha=0.25) 

    data = [scatter1]

    my_text = ax_anim.text(tx, ty, '', fontsize=20) # tx=-25000, ty=24000

    def init_anim(): 
        ax_anim.scatter(0, 0, s=30, marker='x', color='black')
    
        ax_anim.set_title(f'N:{len_data} Pos: {pos_in} pc,' + '\n' + f' Vel: {vel_in} km/s')
        ax_anim.set_xlabel('x [kpc]')
        ax_anim.set_ylabel('y [kpc]')
    
        ax_anim.set_xlim(xmin=-lims, xmax=lims)
        ax_anim.set_ylim(ymin=-lims, ymax=lims)
    
        ax_anim.add_patch(circle)
    
        plt.show()
    
        return data, 

    def update_anim(frame):
        my_text.set_text(f't={frame*tstep}')
         
        data[0].set_data(np.array([data_input[:, 1, frame], data_input[:, 2, frame]]))
    
        return data, 



    animation_plots = animation.FuncAnimation(fig_anim, update_anim, 
                                                        frames=f_anim, init_func=init_anim)

    
    if format_type=='mp4':
        writervideo = animation.FFMpegWriter(fps=fps)
        animation_plots.save(f'./Animations/Anim_{run}.{format_type}', writer=writervideo)
        
    elif format_type=='gif':
        writervideo = animation.PillowWriter(fps=fps)
        animation_plots.save(f'./Animations/Anim_{run}.{format_type}', writer=writervideo)

### 3D animations

In [6]:
def anim_3d(data_input, data_header, tstep, nsteps, lims, fsize, marker_size, tx, ty, tz, format_type, fps, run):
    """
    
    Parameters:
    -----------
    
    format_type: str
            Format of the output animation, 'mp4' or 'gif'
    
    """
    len_data = len(data_input)
    f_anim_3d = np.arange(0, nsteps, 1)
    
    pos_in = np.round(data_header[2:5, 0], 3)
    vel_in = np.round(data_header[5:, 0], 3)


    fig_anim_3d = plt.figure(figsize=fsize) 
    ax_anim_3d = fig_anim_3d.add_subplot(projection='3d')

    scatter1 = ax_anim_3d.scatter([], [], [], c='b', s=marker_size)

    my_text = ax_anim_3d.text(tx, ty, tz, '', fontsize=20) #-30500, 35500, 35500


    u = np.linspace(0, 2 * np.pi, 100)
    v = np.linspace(0, np.pi, 100)

    x_bulge = 1.500 * np.outer(np.cos(u), np.sin(v))
    y_bulge = 1.500 * np.outer(np.sin(u), np.sin(v))
    z_bulge = 1.500 * np.outer(np.ones(np.size(u)), np.cos(v))

    x_disc = 14.490 * np.outer(np.cos(u), np.sin(v))
    y_disc = 14.490 * np.outer(np.sin(u), np.sin(v))
    z_disc = 0.300 * np.outer(np.ones(np.size(u)), np.cos(v))


    min_lim = -lims
    max_lim = lims

    def init_anim_3d(): 
        ax_anim_3d.set_title(f'N: {len_data}, Pos: {pos_in} pc,'+ '\n' + f' Vel: {vel_in} km/s', fontsize=20, y=0.85)
        ax_anim_3d.set_xlabel('x [kpc]', labelpad = 15)
        ax_anim_3d.set_ylabel('y [kpc]', labelpad = 15)
        ax_anim_3d.set_zlabel('z [kpc]', labelpad = 5) 
    
        ax_anim_3d.set_xlim(xmin=min_lim, xmax=max_lim)
        ax_anim_3d.set_ylim(ymin=min_lim, ymax=max_lim)
        ax_anim_3d.set_zlim(zmin=min_lim, zmax=max_lim)
    
        # Fixing the ticklabels
        ax_anim_3d.xaxis.set_major_formatter(ScalarFormatter(useMathText=True))
        ax_anim_3d.xaxis.get_major_formatter().set_powerlimits((0, 0))
        ax_anim_3d.yaxis.set_major_formatter(ScalarFormatter(useMathText=True))
        ax_anim_3d.yaxis.get_major_formatter().set_powerlimits((0, 0))
        ax_anim_3d.zaxis.set_major_formatter(ScalarFormatter(useMathText=True))
        ax_anim_3d.zaxis.get_major_formatter().set_powerlimits((0, 0))
    
        # Plots Milky Way
        ax_anim_3d.plot_surface(x_bulge, y_bulge, z_bulge, color='grey', alpha=0.5)
        ax_anim_3d.plot_surface(x_disc, y_disc, z_disc, color='grey', alpha=0.5)
    
        ax_anim_3d.view_init(4, 45)
        ax_anim_3d.dist=12
    
        plt.tight_layout()
        plt.show()
    
        return scatter1, 

    def update_anim_3d(frame):
        my_text.set_text(f't={frame*tstep} Myr')
    
        scatter1._offsets3d = (data_input[:, 1, frame], data_input[:, 2, frame], 
                               data_input[:, 3, frame])
    
        return scatter1, 



    animation_plots_3d = animation.FuncAnimation(fig_anim_3d, update_anim_3d, 
                                                    frames=f_anim_3d, init_func=init_anim_3d)

    
    if format_type=='mp4':
        writervideo = animation.FFMpegWriter(fps=fps)
        animation_plots_3d.save(f'./Animations/Anim_{run}_3d.{format_type}', writer=writervideo)
        
    elif format_type=='gif':
        writervideo = animation.PillowWriter(fps=fps) 
        animation_plots_3d.save(f'./Animations/Anim_{run}_3d.{format_type}', writer=writervideo)

### Compare to trace-back animation

In [7]:
def anim_ComparetoTraceback(sim_data, trace_back_data, lims, run_file, run_title, fsize=(10, 10), fps=2, format_type='mp4'):
    """
    Summary:
    ------------------------------------------------------------------------------------------
    Parameters:
    -----------
    sim_data: nd array
        Data form simulation
        
    trace_back_data: nd array
        Data from trace-back run
        
    lims: 1d array
        Limits for figure axes
    
    run_file: str
        Name of the run to save in file
        
    run_title: str
        Name of the run to have in the title
        
    fsize: tuple
        Fize of figure
        
    fps: int
        Frames per second in animation
        
    format_type: str
        Format of amination, 'gif' or 'mp4'. 
    """
    # Which snapshots from traceback data to use
    trace_back_snapshots = np.linspace(0, 112, 8, dtype=int)
    # Frames to use from simulations
    f_anim_3d = np.arange(0, 8, 1)


    fig_anim_3d = plt.figure(figsize=fsize) 
    ax_anim_3d = fig_anim_3d.add_subplot(projection='3d')

    scatter1 = ax_anim_3d.scatter([], [], [], c='b', s=1)
    scatter2 = ax_anim_3d.scatter([], [], [], c='r', s=5)
    scatters = [scatter1, scatter2]
    
    # Printing time of snapshot in plot, numbers give the position of the text
    my_text = ax_anim_3d.text(-1.0e5, 2.0e5, 2.0e5, '', fontsize=20)

    # Setting up MW in figure
    u = np.linspace(0, 2 * np.pi, 100)
    v = np.linspace(0, np.pi, 100)

    x_bulge = 1.500 * np.outer(np.cos(u), np.sin(v))
    y_bulge = 1.500 * np.outer(np.sin(u), np.sin(v))
    z_bulge = 1.500 * np.outer(np.ones(np.size(u)), np.cos(v))

    x_disc = 14.490 * np.outer(np.cos(u), np.sin(v))
    y_disc = 14.490 * np.outer(np.sin(u), np.sin(v))
    z_disc = 0.300 * np.outer(np.ones(np.size(u)), np.cos(v))

    # Limits to axes
    min_lim = -lims
    max_lim = lims

    def init_anim_3d(): 
        ax_anim_3d.set_title(run_title, fontsize=20, y=0.85)
        ax_anim_3d.set_xlabel('x [kpc]', labelpad=15)
        ax_anim_3d.set_ylabel('y [kpc]', labelpad=15)
        ax_anim_3d.set_zlabel('z [kpc]', labelpad=5) 

        ax_anim_3d.set_xlim(xmin=min_lim, xmax=max_lim)
        ax_anim_3d.set_ylim(ymin=min_lim, ymax=max_lim)
        ax_anim_3d.set_zlim(zmin=min_lim, zmax=max_lim)

        # Fixing the ticklabels
        ax_anim_3d.xaxis.set_major_formatter(ScalarFormatter(useMathText=True))
        ax_anim_3d.xaxis.get_major_formatter().set_powerlimits((0, 0))
        ax_anim_3d.yaxis.set_major_formatter(ScalarFormatter(useMathText=True))
        ax_anim_3d.yaxis.get_major_formatter().set_powerlimits((0, 0))
        ax_anim_3d.zaxis.set_major_formatter(ScalarFormatter(useMathText=True))
        ax_anim_3d.zaxis.get_major_formatter().set_powerlimits((0, 0))

        # Plots Milky Way
        ax_anim_3d.plot_surface(x_bulge, y_bulge, z_bulge, color='grey', alpha=0.5)
        ax_anim_3d.plot_surface(x_disc, y_disc, z_disc, color='grey', alpha=0.5)
        
        # Alignment of plot
        ax_anim_3d.view_init(4, 45)
        ax_anim_3d.dist=12

        plt.tight_layout()
        plt.show()

        return scatters[0], scatters[1],

    
    
    def update_anim_3d(frame):
        # Time of snapshot
        my_text.set_text(f't={frame*128} Myr')
        
        # Simulation data
        scatters[0]._offsets3d = (sim_data[:, 1, frame], sim_data[:, 2, frame], 
                                   sim_data[:, 3, frame])

        # Frame and data for trace-back simulation
        frame2 = trace_back_snapshots[::-1][frame]
        scatters[1]._offsets3d = (trace_back_data[:, 1, frame2], trace_back_data[:, 2, frame2], 
                                   trace_back_data[:, 3, frame2])

        return scatters[0], scatters[1],



    animation_plots_3d = animation.FuncAnimation(fig_anim_3d, update_anim_3d, frames=f_anim_3d, 
                                                 init_func=init_anim_3d)

    if format_type=='mp4':
        writervideo = animation.FFMpegWriter(fps=fps)
        animation_plots_3d.save(f'./Animations/Anim_{run_file}_3d.mp4', writer=writervideo)

    elif format_type=='gif':
        writervideo = animation.PillowWriter(fps=fps) 
        animation_plots_3d.save(f'./Animations/Anim_{run_file}_3d.gif', writer=writervideo)

# Testing