In [1]:
import pandas as pd
pd.set_option('display.max_columns', None)
import ast
import numpy as np
from scipy.integrate import odeint
from scipy.optimize import minimize
import math
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from statsmodels.tsa.seasonal import seasonal_decompose

# data prep

In [2]:
df=pd.read_pickle(r'/content/drive/Shareddrives/STUDENT-Capstone SS23/BTS_data/PropagateSummary.pkl')
print(df.shape)
df_sorted=df.sort_values(['Adj_Date_Str','Adj_Time'])
df_sorted.head(2)

(21815, 16)


Unnamed: 0,FlightDate,CRSDepTime,Adj_Date_Str,Adj_Time,OrigList,DestList,Tail_Number,Flights,List_of_WDelay_Minutes,List_of_PDelay_Minutes,List_of_Dep_Times,PropagateCount,WFlagList,PFlagList,DelayTypeList,LateAircraftFlag
131406,2010-01-01,705,2010-01-01,205,"[BWI, DFW, PHX, DFW]","[DFW, PHX, DFW, DCA]",N457AA,4.0,"[23.0, nan, nan, nan]","[0.0, nan, nan, nan]","[205, 720, 1150, 1355]",0,"[1, 0, 0, 0]","[0, 0, 0, 0]","[W, na, na, na]",0
147008,2010-01-01,650,2010-01-01,250,"[DFW, DCA, DFW, SNA]","[DCA, DFW, SNA, DFW]",N3CTAA,4.0,"[16.0, nan, nan, nan]","[0.0, nan, nan, nan]","[250, 620, 1045, 1450]",0,"[1, 0, 0, 0]","[0, 0, 0, 0]","[W, na, na, na]",0


In [3]:
df_sorted=df_sorted.reset_index(drop=True)
# df_sorted['List_of_WDelay_Minutes']=df_sorted.List_of_WDelay_Minutes.apply(lambda x: '['+x+']' if x[0]!='[' else x)
def change_str_to_lst(s):
  if 'nan' in s:
    s=s.replace('nan', "0") #need to replace nan othersie literal_eval fails
  else:
    s=s.replace('na','None')
  s=ast.literal_eval(s)
  return s
df_sorted['List_of_WDelay_Minutes']=df_sorted.List_of_WDelay_Minutes.apply(change_str_to_lst)
df_sorted['List_of_PDelay_Minutes']=df_sorted.List_of_PDelay_Minutes.apply(change_str_to_lst)
df_sorted['DelayTypeList']=df_sorted.DelayTypeList.apply(lambda x: x if isinstance(x, list) else list(x))

In [40]:
df_sorted

Unnamed: 0,FlightDate,CRSDepTime,Adj_Date_Str,Adj_Time,OrigList,DestList,Tail_Number,Flights,List_of_WDelay_Minutes,List_of_PDelay_Minutes,List_of_Dep_Times,PropagateCount,WFlagList,PFlagList,DelayTypeList,LateAircraftFlag
0,2010-01-01,705,2010-01-01,205,"[BWI, DFW, PHX, DFW]","[DFW, PHX, DFW, DCA]",N457AA,4.0,"[23.0, 0, 0, 0]","[0.0, 0, 0, 0]","[205, 720, 1150, 1355]",0,"[1, 0, 0, 0]","[0, 0, 0, 0]","[W, na, na, na]",0
1,2010-01-01,650,2010-01-01,250,"[DFW, DCA, DFW, SNA]","[DCA, DFW, SNA, DFW]",N3CTAA,4.0,"[16.0, 0, 0, 0]","[0.0, 0, 0, 0]","[250, 620, 1045, 1450]",0,"[1, 0, 0, 0]","[0, 0, 0, 0]","[W, na, na, na]",0
2,2010-01-01,650,2010-01-01,250,"[DFW, LAS, DFW, STL]","[LAS, DFW, STL, MIA]",N639AA,4.0,"[17.0, 0, 0.0, 0.0]","[0.0, 0, 0.0, 11.0]","[250, 635, 1015, 1240]",0,"[1, 0, 0, 0]","[0, 0, 0, 1]","[W, na, na, P]",0
3,2010-01-01,840,2010-01-01,340,"[BOS, DFW, IND]","[DFW, IND, DFW]",N4WBAA,3.0,"[15.0, 0.0, 0.0]","[0.0, 25.0, 23.0]","[340, 910, 1255]",2,"[1, 0, 0]","[0, 1, 1]","[W, P, P]",0
4,2010-01-02,600,2010-01-02,100,"[LGA, DFW, LGA]","[DFW, LGA, BNA]",N461AA,3.0,"[29.0, 0, 0]","[0.0, 0, 0]","[100, 605, 1035]",0,"[1, 0, 0]","[0, 0, 0]","[W, na, na]",0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
21810,2023-03-31,1847,2023-03-31,1447,"[TYR, DFW]","[DFW, AEX]",N745SK,2.0,"[134.0, 0.0]","[0.0, 62.0]","[1447, 1724]",1,"[1, 0]","[0, 1]","[W, P]",0
21811,2023-03-31,1855,2023-03-31,1455,"[ACT, DFW]","[DFW, SJT]",N614SK,2.0,"[106.0, 0]","[0.0, 0]","[1455, 1839]",0,"[1, 0]","[0, 0]","[W, na]",0
21812,2023-03-31,1900,2023-03-31,1500,DFW,COS,N708SK,1.0,52.0,0.0,1500,0,1,0,[W],0
21813,2023-03-31,1925,2023-03-31,1525,DFW,IDA,N724EV,1.0,35.0,0.0,1525,0,1,0,[W],0


# create SIRS model

In [4]:
# Extract the Infected and Recovered counts for each timestep
def preprocess_delay_type_list(delay_list):
    delay_counts={
            'I' : 0, # count of infections W, P, WP
            'R' : 0, # count of the recovered
            'S': 0, # count of no infections
        }
    for i, delay in enumerate(delay_list):
        if isinstance(delay, str):  # sanity check to make sure everything is string type
            if delay in ['W', 'P', 'WP']:
                delay_counts['I'] += 1
            elif delay == 'na':
                if (i > 0 and delay_list[i - 1] in ['W', 'P', 'WP']) and delay=='na':
                    delay_counts['R'] += 1  # 'na' after an infection is considered as recovery
                else:
                    delay_counts['S'] += 1  # 'na' before an infection is considered as no infection

    return delay_counts

In [5]:
def sirs_model(y, t, beta, gamma, N):
    # SIRS model - basic
    S, I, R = y
    # rate of change for susceptible nodes
    dSdt = -beta * I * S / N
    # + gamma * R
    # rate of change for infectious nodes
    dIdt = beta * I * S / N - gamma * I
    # rate of change of recovered individuals
    dRdt = gamma * I

    return [dSdt, dIdt, dRdt]

def objective(params, N, infected_counts, recovered_counts):
    # objective function to fit the SIRS model to the data
    # ODEs describe how the number of susceptible, infectious, and recovered individuals change over time based on the infection and recovery rates.
    beta, gamma = params

    # Add a small constant to avoid division by zero
    epsilon = 1e-6
    infected_counts_smoothed = [count + epsilon for count in infected_counts]
    recovered_counts_smoothed = [count + epsilon for count in recovered_counts]

    initial_conditions = np.array([N - infected_counts_smoothed[0] - recovered_counts_smoothed[0],
                                   infected_counts_smoothed[0], recovered_counts_smoothed[0]])

    t = np.arange(len(infected_counts))
    sol = odeint(sirs_model, initial_conditions, t, args=(beta, gamma, N), rtol=1e-6, atol=1e-6)

    model_predictions = sol[:, 1] + sol[:, 2]  # Infected + Recovered
    return np.sum((np.array(infected_counts_smoothed) + np.array(recovered_counts_smoothed) - model_predictions) ** 2)

def fit_sirs_model(infected_counts, recovered_counts, N):
    # this is just a placeholder
    initial_guess = [0.1, 0.1]

    # minimize the objective function to find best-fit parameters
    result = minimize(objective, initial_guess, args=(N, infected_counts, recovered_counts), method='L-BFGS-B', bounds=[(0, 1), (0, 1)])

    return result.x  # return the best-fit parameters [beta, gamma]

def calculate_sir_rates(infected_counts, recovered_counts, N, beta, gamma):
    # calculate the SIR rates for the given parameters
    t = np.arange(len(infected_counts))
    initial_conditions = np.array([N - infected_counts[0] - recovered_counts[0], infected_counts[0], recovered_counts[0]])
    sol = odeint(sirs_model, initial_conditions, t, args=(beta, gamma, N))
    S = sol[:, 0]
    I = sol[:, 1]
    R = sol[:, 2]
    return S, I, R


all_dates = list(df_sorted.Adj_Date_Str.unique())
SIRS_rates={}
for date in all_dates:
    df_day = df_sorted[df_sorted.Adj_Date_Str == date]
    data = df_day.DelayTypeList.apply(preprocess_delay_type_list)
    # convert the pandas Series of dictionaries to a list of dictionaries
    data_list = data.tolist()
    # Create lists of infected and recovered counts for each timestep within the day
    susceptible_counts_list = [d['S'] for d in data_list]
    infected_counts_list = [d['I'] for d in data_list]
    recovered_counts_list = [d['R'] for d in data_list]
    # Calculate the total population size (N) for the day
    N = sum(susceptible_counts_list) + sum(infected_counts_list) + sum(recovered_counts_list)
    # # Fit the SIRS model to the daily data
    best_fit_params = fit_sirs_model(infected_counts_list, recovered_counts_list, N)
    beta, gamma = best_fit_params
    # if gamma==0:
    #   print(date)
    R0=beta/gamma
    # # Calculate the SIR rates using the best-fit parameters
    # S, I, R = calculate_sir_rates(infected_counts_list, recovered_counts_list, N, beta, gamma)
    SIRS_rates[date]=[beta, gamma,R0]

  R0=beta/gamma


In [6]:
df_sir=pd.DataFrame.from_dict(SIRS_rates,orient='index')
df_sir.reset_index(inplace=True)
df_sir.columns=['Date','Beta','Gamma','R0']
df_sir_sorted=df_sir.sort_values('Beta',ascending=False)

In [96]:
display(df_sorted.shape)
df_merged=pd.merge(df_sorted,df_sir_sorted,left_on='Adj_Date_Str',right_on='Date',how='left')
df_merged[['Adj_Date_Str','DelayTypeList','Beta','Gamma']].head(10)
df_merged['DelayTypeCount']=df_merged.DelayTypeList.apply(preprocess_delay_type_list)
df_merged[['I','R','S']]=df_merged.DelayTypeCount.apply(pd.Series)
df_merged['Adj_Date_Str']=pd.to_datetime(df_merged.Adj_Date_Str)
df_merged.head(2)

(21815, 16)

Unnamed: 0,FlightDate,CRSDepTime,Adj_Date_Str,Adj_Time,OrigList,DestList,Tail_Number,Flights,List_of_WDelay_Minutes,List_of_PDelay_Minutes,List_of_Dep_Times,PropagateCount,WFlagList,PFlagList,DelayTypeList,LateAircraftFlag,Date,Beta,Gamma,R0,DelayTypeCount,I,R,S
0,2010-01-01,705,2010-01-01,205,"[BWI, DFW, PHX, DFW]","[DFW, PHX, DFW, DCA]",N457AA,4.0,"[23.0, 0, 0, 0]","[0.0, 0, 0, 0]","[205, 720, 1150, 1355]",0,"[1, 0, 0, 0]","[0, 0, 0, 0]","[W, na, na, na]",0,2010-01-01,0.355747,0.161944,2.196729,"{'I': 1, 'R': 1, 'S': 2}",1,1,2
1,2010-01-01,650,2010-01-01,250,"[DFW, DCA, DFW, SNA]","[DCA, DFW, SNA, DFW]",N3CTAA,4.0,"[16.0, 0, 0, 0]","[0.0, 0, 0, 0]","[250, 620, 1045, 1450]",0,"[1, 0, 0, 0]","[0, 0, 0, 0]","[W, na, na, na]",0,2010-01-01,0.355747,0.161944,2.196729,"{'I': 1, 'R': 1, 'S': 2}",1,1,2


In [54]:
df_merged['Month']=df_merged.Adj_Date_Str.dt.month
df_merged['Year']=df_merged.Adj_Date_Str.dt.year
df_grouped=df_merged.groupby('Month')[['I','R','S']].sum()
trace_I = go.Scatter(x=df_grouped.index, y=df_grouped['I'], mode='lines+markers', name='I')
trace_R = go.Scatter(x=df_grouped.index, y=df_grouped['R'], mode='lines+markers', name='R')

layout = go.Layout(title='Infected and Recovered counts by month', xaxis_title='Month', yaxis_title='Count')

fig = go.Figure(data=[trace_I, trace_R], layout=layout)

fig.show()


In [56]:
df_merged['Month']=df_merged.Adj_Date_Str.dt.month
df_merged['Year']=df_merged.Adj_Date_Str.dt.year
df_grouped=df_merged.groupby('Year')[['I','R','S']].sum()
trace_I = go.Scatter(x=df_grouped.index, y=df_grouped['I'], mode='lines+markers', name='I')
trace_R = go.Scatter(x=df_grouped.index, y=df_grouped['R'], mode='lines+markers', name='R')

layout = go.Layout(title='Infected and Recovered counts by year', xaxis_title='Month', yaxis_title='Count')

fig = go.Figure(data=[trace_R,trace_I], layout=layout)

fig.show()


In [97]:
df_merged['Adj_Date_Str']=pd.to_datetime(df_merged.Adj_Date_Str)
df_merged['Month']=df_merged.Adj_Date_Str.dt.month
df_merged['Year']=df_merged.Adj_Date_Str.dt.year
df_merged.head()

Unnamed: 0,FlightDate,CRSDepTime,Adj_Date_Str,Adj_Time,OrigList,DestList,Tail_Number,Flights,List_of_WDelay_Minutes,List_of_PDelay_Minutes,List_of_Dep_Times,PropagateCount,WFlagList,PFlagList,DelayTypeList,LateAircraftFlag,Date,Beta,Gamma,R0,DelayTypeCount,I,R,S,Month,Year
0,2010-01-01,705,2010-01-01,205,"[BWI, DFW, PHX, DFW]","[DFW, PHX, DFW, DCA]",N457AA,4.0,"[23.0, 0, 0, 0]","[0.0, 0, 0, 0]","[205, 720, 1150, 1355]",0,"[1, 0, 0, 0]","[0, 0, 0, 0]","[W, na, na, na]",0,2010-01-01,0.355747,0.161944,2.196729,"{'I': 1, 'R': 1, 'S': 2}",1,1,2,1,2010
1,2010-01-01,650,2010-01-01,250,"[DFW, DCA, DFW, SNA]","[DCA, DFW, SNA, DFW]",N3CTAA,4.0,"[16.0, 0, 0, 0]","[0.0, 0, 0, 0]","[250, 620, 1045, 1450]",0,"[1, 0, 0, 0]","[0, 0, 0, 0]","[W, na, na, na]",0,2010-01-01,0.355747,0.161944,2.196729,"{'I': 1, 'R': 1, 'S': 2}",1,1,2,1,2010
2,2010-01-01,650,2010-01-01,250,"[DFW, LAS, DFW, STL]","[LAS, DFW, STL, MIA]",N639AA,4.0,"[17.0, 0, 0.0, 0.0]","[0.0, 0, 0.0, 11.0]","[250, 635, 1015, 1240]",0,"[1, 0, 0, 0]","[0, 0, 0, 1]","[W, na, na, P]",0,2010-01-01,0.355747,0.161944,2.196729,"{'I': 2, 'R': 1, 'S': 1}",2,1,1,1,2010
3,2010-01-01,840,2010-01-01,340,"[BOS, DFW, IND]","[DFW, IND, DFW]",N4WBAA,3.0,"[15.0, 0.0, 0.0]","[0.0, 25.0, 23.0]","[340, 910, 1255]",2,"[1, 0, 0]","[0, 1, 1]","[W, P, P]",0,2010-01-01,0.355747,0.161944,2.196729,"{'I': 3, 'R': 0, 'S': 0}",3,0,0,1,2010
4,2010-01-02,600,2010-01-02,100,"[LGA, DFW, LGA]","[DFW, LGA, BNA]",N461AA,3.0,"[29.0, 0, 0]","[0.0, 0, 0]","[100, 605, 1035]",0,"[1, 0, 0]","[0, 0, 0]","[W, na, na]",0,2010-01-02,0.0,1.0,0.0,"{'I': 1, 'R': 1, 'S': 1}",1,1,1,1,2010


## model validation with correlation between the beta/gamma vs the Infected, Recovered, and Susceptible counts

In [98]:
df_merged_grouped=df_merged.groupby('Adj_Date_Str').agg({'Beta':'mean', 'Gamma':'mean', 'I':'sum', 'R':'sum', 'S':'sum'}).reset_index()
df_merged_grouped.head()

Unnamed: 0,Adj_Date_Str,Beta,Gamma,I,R,S
0,2010-01-01,0.355747,0.161944,7,3,5
1,2010-01-02,0.0,1.0,11,2,1
2,2010-01-03,0.099003,1.0,23,5,6
3,2010-01-04,0.0,0.34422,3,1,0
4,2010-01-05,0.0,0.173712,2,1,4


In [29]:
correlation_matrix = df_merged_grouped[['Beta', 'Gamma', 'I', 'R', 'S']].corr()

# Create a heatmap
fig = go.Figure(data=go.Heatmap(
    z=correlation_matrix,
    x=['Beta', 'Gamma', 'I', 'R', 'S'],
    y=['Beta', 'Gamma', 'I', 'R', 'S'],
    colorscale='Viridis'
))

# Update layout
fig.update_layout(
    title='Correlation between Beta/Gamma and I, R, S Counts',
    xaxis_title='Parameter',
    yaxis_title='Parameter',
)

# Show the plot
fig.show()


In [118]:
fig = go.Figure(data=go.Heatmap(
    z=df_corr.values,
    x=['Beta', 'Gamma', 'I', 'R', 'S'],
    y=['Beta', 'Gamma', 'I', 'R', 'S'],
    colorscale='Viridis',
    showscale=False  # Turn off the color scale
))

# Add annotations
for i in range(len(df_corr)):
    for j in range(len(df_corr)):
        fig.add_annotation(
            x=df_corr.columns[i],
            y=df_corr.columns[j],
            text=f"{df_corr.iloc[i, j]:.2f}",
            showarrow=False,
            font=dict(color='white')
        )

# Update layout
fig.update_layout(
    height=700,
    width=700
    # title='Correlation between Beta/Gamma and I, R, S Counts',
    # xaxis_title='Parameter',
    # yaxis_title='Parameter',
)

# Show the plot
fig.show()

In [4]:
data = {
    'Name': ['Alice', 'Bob', 'Charlie'],
    'Age': [25, 30, 28],
    'Country': ['USA', 'Canada', 'UK']
}

df = pd.DataFrame(data)

# Apply styling using Pandas Styler
styled_df = df.style

# Render the styled DataFrame
styled_df.render()

# You can also chain style methods for more customization
styled_df = df.style\
    .set_properties(**{'text-align': 'center'})\
    .set_table_styles([{'selector': 'th', 'props': [('text-align', 'center')]}])

# Render the styled DataFrame
styled_df.render()

  styled_df.render()
  styled_df.render()


'<style type="text/css">\n#T_386fb th {\n  text-align: center;\n}\n#T_386fb_row0_col0, #T_386fb_row0_col1, #T_386fb_row0_col2, #T_386fb_row1_col0, #T_386fb_row1_col1, #T_386fb_row1_col2, #T_386fb_row2_col0, #T_386fb_row2_col1, #T_386fb_row2_col2 {\n  text-align: center;\n}\n</style>\n<table id="T_386fb">\n  <thead>\n    <tr>\n      <th class="blank level0" >&nbsp;</th>\n      <th id="T_386fb_level0_col0" class="col_heading level0 col0" >Name</th>\n      <th id="T_386fb_level0_col1" class="col_heading level0 col1" >Age</th>\n      <th id="T_386fb_level0_col2" class="col_heading level0 col2" >Country</th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <th id="T_386fb_level0_row0" class="row_heading level0 row0" >0</th>\n      <td id="T_386fb_row0_col0" class="data row0 col0" >Alice</td>\n      <td id="T_386fb_row0_col1" class="data row0 col1" >25</td>\n      <td id="T_386fb_row0_col2" class="data row0 col2" >USA</td>\n    </tr>\n    <tr>\n      <th id="T_386fb_level0_row1" class="row

| Name    |   Age | Country   |
|:--------|------:|:----------|
| Alice   |    25 | USA       |
| Bob     |    30 | Canada    |
| Charlie |    28 | UK        |

the lower correlation between the SIR counts and the beta/gamma indicate that the model is not capturing the dynamics of infection and recovery accurately. The SIRS model assumes that the rates of infection and recovery are constant over time, but if the actual data shows little correlation between these rates and the counts of infected, recovered, and susceptible individuals, it may indicate that the model is not adequately representing the real-world dynamics of the infection.

## visualize these parameters

In [None]:
def find_inf(df, column):
  """Finds all rows in the specified column that contain inf."""
  is_not_inf = ~np.isinf(df[column])
  return df[is_not_inf]

find_inf(df_sir_sorted,'R0').head()

In [7]:
df_sir['Date'] = pd.to_datetime(df_sir['Date'])
df_sir.set_index('Date', inplace=True)

In [39]:
monthly_avg = df_sir.resample('M').mean()
monthly_avg.fillna(0, inplace=True)
# monthly_avg_smoothed=monthly_avg
window_size = 60
monthly_avg_smoothed = monthly_avg.rolling(window=window_size, min_periods=12, center=True).mean()
monthly_avg_smoothed = df_sir.rolling(window=window_size, min_periods=12, center=True).mean()

fig = go.Figure()

fig.add_trace(go.Scatter(x=monthly_avg_smoothed.index, y=monthly_avg_smoothed['Beta'], mode='lines+markers', name='Average Beta'))
fig.add_trace(go.Scatter(x=monthly_avg_smoothed.index, y=monthly_avg_smoothed['Gamma'], mode='lines+markers', name='Average Gamma'))

fig.update_layout(title='Moving Average of SIR Model Parameters',
                  xaxis_title='Month',
                  yaxis_title='Average Value')
years = pd.date_range(start=monthly_avg_smoothed.index[0], end=monthly_avg_smoothed.index[-1], freq='YS')
fig.update_layout(xaxis=dict(tickvals=years, ticktext=years.strftime("%Y")))


fig.show()

In [None]:
df_sir.reset_index().to_csv(r'SIR_rates_from_flights.csv',index=False)
monthly_avg.reset_index().to_csv(r'SIR_rates_perMonth_from_flights.csv',index=False)
monthly_avg_smoothed.reset_index().to_csv(r'SIR_rates_perMonth_rolling_from_flights.csv',index=False)

In [34]:
from statsmodels.tsa.seasonal import seasonal_decompose
df_sir=pd.DataFrame.from_dict(SIRS_rates,orient='index')
df_sir.reset_index(inplace=True)
df_sir.columns=['Date','Beta','Gamma','R0']
df_sir['Date'] = pd.to_datetime(df_sir['Date'])
df_sir.set_index('Date', inplace=True)

## seasonal decomp on beta and gamma

**Trend:**
The trend component represents the long-term movement or behavior of the data.

Seasonal:
The seasonal component represents the periodic fluctuations or patterns that repeat at fixed intervals, such as daily, weekly, monthly, or yearly cycles.
 Seasonality often occurs due to external factors like weather, holidays, or business cycles. In seasonal decomposition, the seasonal component isolates these repeating patterns, allowing us to understand the regular variations within the data. Seasonal patterns can help identify specific time frames when certain events or phenomena tend to occur.

Residual (or Error):
The residual component represents the noise or random fluctuations that cannot be explained by the trend or seasonal patterns. It is the difference between the observed data and the predicted values based on the trend and seasonal components. The residual component is essentially the "leftover" variation after accounting for the trend and seasonal patterns. Ideally, the residuals should be random and have no specific pattern. Large and systematic residuals may indicate that the trend or seasonal decomposition is not capturing all the underlying patterns in the data.

In [80]:
df_sir_monthly = df_sir.resample('M').mean().fillna(0)
# seasonal decomposition - additive
decomposition_beta = seasonal_decompose(df_sir_monthly['Beta'], model='additive', period=12)
decomposition_gamma = seasonal_decompose(df_sir_monthly['Gamma'], model='additive', period=12)

trend_beta = decomposition_beta.trend
seasonal_beta = decomposition_beta.seasonal

trend_gamma = decomposition_gamma.trend
seasonal_gamma = decomposition_gamma.seasonal

In [94]:
fig_beta = make_subplots(rows=2, cols=1, shared_xaxes=False, subplot_titles=['Beta Trend', 'Beta Seasonality'], vertical_spacing=0.15, horizontal_spacing=0.01)
ticktext = [str(year)[:4] for year in df_sir_monthly.index[df_sir_monthly.index.month == 1]]
# plotting the beta trend and seasonality
beta_color='rgb(214, 39, 40)'
fig_beta.add_trace(go.Scatter(x=df_sir_monthly.index, y=trend_beta, mode='lines', name='Trend', line=dict(color=beta_color)), row=1, col=1)
fig_beta.add_trace(go.Scatter(x=df_sir_monthly.index, y=seasonal_beta, mode='lines', name='Seasonal', line=dict(color=beta_color)), row=2, col=1)

fig_beta.update_xaxes(tickvals=df_sir_monthly.index[df_sir_monthly.index.month == 1], ticktext=ticktext, showgrid=False, row=1, col=1, showline=False)
fig_beta.update_yaxes(visible=False, showgrid=False, row=1, col=1, showline=False)

fig_beta.update_xaxes(tickvals=df_sir_monthly.index[df_sir_monthly.index.month == 1], ticktext=ticktext, showgrid=False, row=2, col=1, showline=False)
fig_beta.update_yaxes(visible=False, showgrid=False, row=2, col=1, showline=False)

fig_beta.update_layout(height=500,showlegend=False)
fig_beta.show()

fig_gamma = make_subplots(rows=2, cols=1, shared_xaxes=False, subplot_titles=['Gamma Trend', 'Gamma Seasonality'], vertical_spacing=0.15, horizontal_spacing=0.01)
# plotting the gamma trend and seasonality
gamma_color = 'rgb(44, 160, 44)'
fig_gamma.add_trace(go.Scatter(x=df_sir_monthly.index, y=trend_gamma, mode='lines', name='Trend', line=dict(color=gamma_color)), row=1, col=1)
fig_gamma.add_trace(go.Scatter(x=df_sir_monthly.index, y=seasonal_gamma, mode='lines', name='Seasonal', line=dict(color=gamma_color)), row=2, col=1)

fig_gamma.update_xaxes(tickvals=df_sir_monthly.index[df_sir_monthly.index.month == 1], ticktext=ticktext, showgrid=False, row=1, col=1, showline=False)
fig_gamma.update_yaxes(visible=False, showgrid=False, row=1, col=1, showline=False)

fig_gamma.update_xaxes(tickvals=df_sir_monthly.index[df_sir_monthly.index.month == 1], ticktext=ticktext, showgrid=False, row=2, col=1, showline=False)
fig_gamma.update_yaxes(visible=False, showgrid=False, row=2, col=1, showline=False)

fig_gamma.update_layout(height=500,showlegend=False)
fig_gamma.show()

In [83]:
fig_gamma.add_trace(go.Scatter(x=df_sir_monthly.index, y=trend_gamma, mode='lines', name='Trend'), row=1, col=1)
fig_gamma.add_trace(go.Scatter(x=df_sir_monthly.index, y=seasonal_gamma, mode='lines', name='Seasonal'), row=2, col=1)

fig_gamma.update_xaxes(tickvals=df_sir_monthly.index[df_sir_monthly.index.month == 1], ticktext=ticktext, showgrid=False, row=1, col=1, showline=False)
fig_gamma.update_yaxes(visible=False, showgrid=False, row=1, col=1, showline=False)

fig_gamma.update_xaxes(tickvals=df_sir_monthly.index[df_sir_monthly.index.month == 1], ticktext=ticktext, showgrid=False, row=2, col=1, showline=False)
fig_gamma.update_yaxes(visible=False, showgrid=False, row=2, col=1, showline=False)

fig_gamma.update_layout(height=500,showlegend=False)
fig_gamma.show()