# 01 â€“ Data Exploration
Download price data, examine returns, and visualise the correlation structure.

In [None]:
import sys; sys.path.insert(0, '..')
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
from src.data_handler import load_data

TICKERS    = ['AAPL', 'MSFT', 'GOOGL', 'AMZN', 'JPM', 'SPY']
START      = '2015-01-01'
END        = '2024-12-31'
TRAIN_END  = '2023-12-31'

data = load_data(TICKERS, start=START, end=END, train_end=TRAIN_END)
prices       = data['prices']
log_returns  = data['log_returns']
train_ret    = data['train_returns']
mu, cov      = data['mu'], data['cov']

print('Price shape:', prices.shape)
prices.tail()

In [None]:
# Normalised cumulative prices (base = 1)
norm = prices / prices.iloc[0]
fig = px.line(norm, title='Normalised Price Performance (base=1)', labels={'value': 'Growth', 'index': 'Date'})
fig.update_layout(hovermode='x unified')
fig.show()

In [None]:
# Daily log-return distributions
fig, axes = plt.subplots(2, 3, figsize=(14, 7), tight_layout=True)
for ax, col in zip(axes.flat, TICKERS):
    log_returns[col].plot.hist(bins=80, ax=ax, color='steelblue', edgecolor='white', alpha=0.8)
    ax.set_title(col); ax.set_xlabel('Log Return')
plt.suptitle('Daily Log-Return Distributions', y=1.01, fontsize=14)
plt.show()

In [None]:
# Annualised statistics table
stats = pd.DataFrame({
    'Ann. Return (%)':  (mu * 100).round(2),
    'Ann. Vol (%)':     (np.sqrt(np.diag(cov.values)) * 100).round(2),
    'Sharpe (rf=4%)':   ((mu - 0.04) / np.sqrt(np.diag(cov.values))).round(3),
})
stats

In [None]:
# Correlation heatmap
fig, ax = plt.subplots(figsize=(8, 6))
corr = train_ret.corr()
mask = np.triu(np.ones_like(corr, dtype=bool))
sns.heatmap(corr, mask=mask, annot=True, fmt='.2f', cmap='RdBu_r', vmin=-1, vmax=1,
            linewidths=0.5, ax=ax, square=True)
ax.set_title('Pairwise Correlation (in-sample)')
plt.tight_layout(); plt.show()

In [None]:
# Rolling 6-month volatility
roll_vol = log_returns.rolling(126).std() * np.sqrt(252)
fig = px.line(roll_vol, title='Rolling 6-Month Annualised Volatility', labels={'value': 'Volatility', 'index': 'Date'})
fig.update_layout(hovermode='x unified')
fig.show()