In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from astropy.table import Table
import os
from astropy.stats import sigma_clip
from stella import YoungStars
from altaipony.flarelc import FlareLightCurve

from lightkurve.search import search_targetpixelfile
from lightkurve.targetpixelfile import TessTargetPixelFile as TTPF
from lightkurve.lightcurve import LightCurve as LC

from tess_stars2px import tess_stars2px_function_entry as tess_stars2px

directory = '/Users/AdinaFeinstein/Documents/young_stars'

  from ._conv import register_converters as _register_converters


In [None]:
f = pd.read_csv('flare_parameters.csv')
tics = np.unique(f['TIC'])

result = search_targetpixelfile(25118964, mission='TESS')

lk_collection = result.download_all()

In [None]:
short_time, short_lc = np.array([]), np.array([])
short_err = np.array([])
quality_flags = np.array([])

for i in range(len(lk_collection)):
    j = lk_collection[i].to_lightcurve()#.flatten(window_length=15)
    if (len(j.time) != len(j.flux)):
        print('bad sector')
        
    short_time    = np.append(short_time, j.time)
    short_lc      = np.append(short_lc, j.flux/np.nanmedian(j.flux))
    short_err     = np.append(short_err, j.flux_err/np.nanmedian(j.flux))
    quality_flags = np.append(quality_flags, j.quality)


short_time, short_lc   = zip(*sorted(zip(short_time, short_lc)))
short_time, short_err  = zip(*sorted(zip(short_time, short_err)))
short_time, short_flux = np.array(short_time), np.array(short_lc)
short_err = np.array(short_err)

In [None]:
files = os.listdir(directory)
files = [i for i in files if str(25118964) in i]
ys = YoungStars(fn=files, fn_dir=directory)
ys.savitsky_golay(window_length=15)
ys.identify_flares(method="savitsky-golay")

In [None]:
plt.rcParams['font.size'] = 15

plt.figure(figsize=(14,8))
plt.plot(ys.time, ys.norm_flux, 'k', linewidth=1, alpha=0.8, label='You')

plt.plot(short_time, short_flux+0.05, c='darkorange', linewidth=1, label='The guy she tells you not to worry about')
plt.ylim(0.9,1.2)
plt.legend()
plt.ylabel('Normalized Flux')
plt.xlabel('Time (BJD-2457000)')
plt.xlim(1350,1365);

In [None]:
lk = LC(short_time, short_flux, flux_err=short_err).flatten(window_length=255)

flc = FlareLightCurve(time=short_time, flux=short_flux, flux_err=short_err,
                      detrended_flux=lk.flux, detrended_flux_err=lk.flux_err)
flc_result = flc.find_flares(N1=3, N2=1, N3=2)
#result = result.characterize_flares(N1=3, N2=1, N3=2)
#result.flares

In [None]:
def plot_flares(time, flux, flares, mask=None, flare_mask=None):
    if mask is None:
        mask = np.ones(len(flux), dtype=bool)
    if flare_mask is None:
        flare_mask = np.ones(len(flares), dtype=bool)
        
    subflares = flares[flare_mask]
        
    plt.figure(figsize=(20,8))
    plt.plot(time[mask], flux[mask], 'k', alpha=0.8, linewidth=3)

    y = np.linspace(0,2,10)
    
    for t in range(len(subflares)):
        f = ((time[mask] >= subflares.tstart[t]) & (time[mask] <= subflares.tstop[t]))
        
        if len(np.where(f==True)[0]) > 0.0:
            midpoint = np.sum(time[mask][f])/len(time[mask][f])
            plt.plot( np.full(len(y), midpoint), y, c='xkcd:forest green', alpha=0.3, linewidth=3)  
        else:
            f = ((time[mask] >= subflares.tstart[t]-0.2) & (time[mask] <= subflares.tstop[t]+0.2))
            midpoint = np.sum(time[mask][f])/len(time[mask][f])
            plt.plot( np.full(len(y), midpoint), y, c='xkcd:lavender', alpha=0.5, linewidth=3)
            
    plt.ylim(0.98, 1.06)
    plt.show()

In [None]:
def my_flares(flux, median, error, sigma=2.5, N3=2):
    isflare = np.zeros_like(flux, dtype=bool)

    peaks = sigma_clip(flux, sigma=sigma)
    mask  = peaks.mask
    
    passed  = np.zeros_like(flux, dtype=int)
    isflare = np.where((mask==True) & ((flux-median) > 0.) & (flux > (np.std(flux)+median)) )[0]
    passed[isflare] = 1
    return isflare

In [None]:
time_mask = ((flc_result.time > 1340.) & (flc_result.time < 1350.))

sigma_cut = sigma_clip(flc_result.detrended_flux[time_mask], sigma=2.5, maxiters=5)
sigma_mask = sigma_cut.mask

time, flux, flux_err = flc_result.time[time_mask], flc_result.detrended_flux[time_mask], lk.flux_err[time_mask]
median     = np.full(len(time), np.nanmedian(flux[sigma_mask]))

myflur = my_flares(flux, median, flux_err)

plt.figure(figsize=(20,8))
plt.plot(time, flux, 'k')

plt.plot(time[myflur], flux[myflur], 'r.', ms=15)
plt.show()

In [None]:
# Look around breaks in data
# Look around momentum dumps
#

flc_result.flares