In [56]:
import yfinance as yf

import numpy as np 
import pandas as pd 



In [49]:
ticker = 'MSFT'

data = yf.download(ticker)

[*********************100%***********************]  1 of 1 completed


Log Returns

In [50]:
data

Price,Close,High,Low,Open,Volume
Ticker,MSFT,MSFT,MSFT,MSFT,MSFT
Date,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2
1986-03-13,0.059707,0.062373,0.054376,0.054376,1031788800
1986-03-14,0.061839,0.062906,0.059707,0.059707,308160000
1986-03-17,0.062906,0.063439,0.061839,0.061839,133171200
1986-03-18,0.061306,0.063439,0.060773,0.062906,67766400
1986-03-19,0.060240,0.061839,0.059707,0.061306,47894400
...,...,...,...,...,...
2025-03-24,393.079987,395.399994,389.809998,395.399994,21004500
2025-03-25,395.160004,396.359985,392.640015,393.920013,15775000
2025-03-26,389.970001,395.309998,388.570007,395.000000,16108400
2025-03-27,390.579987,392.239990,387.399994,390.130005,13766800


In [51]:
data['Returns'] = data['Close'].pct_change()
data['log_return'] = np.log(1 + data['Returns'])   

In [52]:
data = data.dropna()
data

Price,Close,High,Low,Open,Volume,Returns,log_return
Ticker,MSFT,MSFT,MSFT,MSFT,MSFT,Unnamed: 6_level_1,Unnamed: 7_level_1
Date,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2
1986-03-14,0.061839,0.062906,0.059707,0.059707,308160000,0.035712,0.035089
1986-03-17,0.062906,0.063439,0.061839,0.061839,133171200,0.017251,0.017104
1986-03-18,0.061306,0.063439,0.060773,0.062906,67766400,-0.025432,-0.025761
1986-03-19,0.060240,0.061839,0.059707,0.061306,47894400,-0.017390,-0.017543
1986-03-20,0.058641,0.060240,0.058108,0.060240,58435200,-0.026547,-0.026906
...,...,...,...,...,...,...,...
2025-03-24,393.079987,395.399994,389.809998,395.399994,21004500,0.004652,0.004641
2025-03-25,395.160004,396.359985,392.640015,393.920013,15775000,0.005292,0.005278
2025-03-26,389.970001,395.309998,388.570007,395.000000,16108400,-0.013134,-0.013221
2025-03-27,390.579987,392.239990,387.399994,390.130005,13766800,0.001564,0.001563


Lifting Transformation (Sliding Window)

In [53]:
h1 = 30
h2 = 5

windows = []
for i in range(0, len(data) - h1 + 1, h2):
    window = data['log_return'].iloc[i:i+h1].values.tolist()
    windows.append(window)

windows_df = pd.DataFrame(windows)

Empirical Scrutiny

In [54]:
sorted_windows = np.sort(windows_df.values, axis=1)

Implement Wasserstein kmeans

In [60]:
from sklearn.cluster import KMeans
X = sorted_windows

kmeans = KMeans(n_clusters=2, random_state=2) 
kmeans.fit(X)

cluster_labels = kmeans.labels_

Assign Labels to original data

In [66]:
data.loc[:, 'regime'] = np.nan

# ... previous code ...
for i, label in enumerate(cluster_labels):
    start = i * h2
    midpoint = start + h1//2
    if midpoint < len(data):
        data.iloc[midpoint, data.columns.get_loc('regime')] = label

data.loc[:, 'regime'] = data['regime'].bfill().ffill()

In [67]:
data

Price,Close,High,Low,Open,Volume,Returns,log_return,regime
Ticker,MSFT,MSFT,MSFT,MSFT,MSFT,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
Date,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2
1986-03-14,0.061839,0.062906,0.059707,0.059707,308160000,0.035712,0.035089,0.0
1986-03-17,0.062906,0.063439,0.061839,0.061839,133171200,0.017251,0.017104,0.0
1986-03-18,0.061306,0.063439,0.060773,0.062906,67766400,-0.025432,-0.025761,0.0
1986-03-19,0.060240,0.061839,0.059707,0.061306,47894400,-0.017390,-0.017543,0.0
1986-03-20,0.058641,0.060240,0.058108,0.060240,58435200,-0.026547,-0.026906,0.0
...,...,...,...,...,...,...,...,...
2025-03-24,393.079987,395.399994,389.809998,395.399994,21004500,0.004652,0.004641,1.0
2025-03-25,395.160004,396.359985,392.640015,393.920013,15775000,0.005292,0.005278,1.0
2025-03-26,389.970001,395.309998,388.570007,395.000000,16108400,-0.013134,-0.013221,1.0
2025-03-27,390.579987,392.239990,387.399994,390.130005,13766800,0.001564,0.001563,1.0


Visualize Regimes

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(12,6))
plt.plot(data['Close'], label='Price', alpha=0.5)
plt.scatter(data.index, data['Close'], c=data['regime'], cmap='viridis', label='Regime')
