In [None]:
import asyncio
import pandas as pd
import ccxt.async_support as ccxt
import aiohttp
import time
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import re
from datetime import datetime, timedelta
from typing import List, Dict, Any, Optional, Union, Tuple, Set
from dataclasses import dataclass
import nest_asyncio
nest_asyncio.apply()

In [None]:
# ==================== CONSTANTS ====================
COINGECKO_API_URL = "https://api.coingecko.com/api/v3"
STABLECOINS = {
    'USDT', 'USDC', 'BUSD', 'DAI', 'TUSD', 'USDP', 'USDD', 'SUSD0', 'PYUSD', 'USDX', 'USR', 'HONEY', 'USDD'
    'FDUSD', 'USDJ', 'USDD', 'GUSD', 'FRAX', 'USDS', 'USDE', 'USDX', 'SUSDS', 'USDC.E', 'USDB', 'BUIDL', 'USDY', 'USDA', 'GHO',
}
CLONECOINS = {
    'STETH', 'WBTC', 'WSTETH', 'WETH', 'WEETH', 'CBBTC', 'LBTC', 'RSETH', 'SOLVBTC', 'RETH', 'METH', 'BNSOL', 
    'SOLVBTC.BBN', 'EZETH', 'WBNB', 'MSOL', 'CMETH', 'JUPSOL'
}

# ==================== RATE LIMITER ====================
class RateLimiter:
    def __init__(self, rate_limit: int = 100):
        self.rate_limit = rate_limit
        self.last_request_time = 0
    
    async def wait(self) -> None:
        current_time = time.time() * 1000
        time_since_last = current_time - self.last_request_time
        wait_time = max(0, (1000 / self.rate_limit) - time_since_last)
        
        if wait_time > 0:
            await asyncio.sleep(wait_time / 1000)
            
        self.last_request_time = time.time() * 1000

# ==================== COINGECKO API ====================
class CoinGeckoAPI:
    def __init__(self, api_key: Optional[str] = None):
        self.api_url = COINGECKO_API_URL
        self.api_key = api_key
        self.rate_limiter = RateLimiter(10)
        self.session = None
        
    async def __aenter__(self):
        self.session = aiohttp.ClientSession()
        return self
        
    async def __aexit__(self, exc_type, exc_val, exc_tb):
        if self.session is not None:
            await self.session.close()
            
    async def get_coins_markets(self, params: Dict[str, Any]) -> List[Dict[str, Any]]:
        await self.rate_limiter.wait()
        
        url = f"{self.api_url}/coins/markets"
        query_params = {
            'vs_currency': 'usd',
            'order': 'market_cap_desc',
            'per_page': params.get('per_page', 250),
            'page': params.get('page', 1),
            'sparkline': 'false',
        }
        
        headers = {}
        if self.api_key:
            headers['x-cg-pro-api-key'] = self.api_key
            
        try:
            async with self.session.get(url, params=query_params, headers=headers) as response:
                response.raise_for_status()
                data = await response.json()
                return data
        except Exception as e:
            raise Exception(f"Failed to fetch data from CoinGecko: {str(e)}")

# ==================== EXCHANGE WRAPPER ====================
class CcxtExchange:
    def __init__(self, exchange_name: str, **configs: Dict[str, Any]):
        exchange_class = getattr(ccxt, exchange_name)
        
        self._client = exchange_class({
            'apiKey': configs.get("api_key"),
            'secret': configs.get("api_secret"),
            'enableRateLimit': True,
            **(configs.get("options", {}))
        })
        
        self._exchange_name = exchange_name
        self._limit_size = configs.get("limit_size", 1000)
        self._rate_limiter = RateLimiter(configs.get("rate_limit", 100))
        self._markets = None
        self._supported_pairs = set()
        self._timeframe_ms = {
            "1m": 60 * 1000,
            "5m": 5 * 60 * 1000,
            "15m": 15 * 60 * 1000,
            "30m": 30 * 60 * 1000,
            "1h": 60 * 60 * 1000,
            "4h": 4 * 60 * 60 * 1000,
            "1d": 24 * 60 * 60 * 1000,
            "1w": 7 * 24 * 60 * 60 * 1000,
        }
    
    @property
    def exchange_name(self) -> str:
        return self._exchange_name
    
    async def __aenter__(self):
        await self._ensure_markets_loaded()
        return self
    
    async def __aexit__(self, exc_type, exc_val, exc_tb):
        await self.close()
    
    async def _ensure_markets_loaded(self) -> None:
        try:
            if not self._markets:
                await self._rate_limiter.wait()
                self._markets = await self._client.load_markets()
                self._supported_pairs = set(self._markets.keys())
        except Exception as e:
            raise Exception(f"Error loading markets for {self._exchange_name}: {str(e)}")
    
    async def fetch_ohlcv(self, pair: str, timeframe: str, since: int, limit: int) -> List[List[Any]]:
        try:
            await self._ensure_markets_loaded()
            
            if pair not in self._supported_pairs:
                raise Exception(f"Pair {pair} not available on {self._exchange_name}")
                
            await self._rate_limiter.wait()
            return await self._client.fetch_ohlcv(pair, timeframe, since, limit)
        except Exception as e:
            raise Exception(f"Error fetching OHLCV data for {pair} - {timeframe}: {str(e)}")
    
    def get_supported_pairs(self) -> set:
        return self._supported_pairs
    
    def get_timeframe_milliseconds(self, timeframe: str) -> int:
        return self._timeframe_ms.get(timeframe, 24 * 60 * 60 * 1000)
    
    def get_ohlcv_request_limit(self) -> int:
        return self._limit_size
    
    async def close(self) -> None:
        await self._client.close()

# ==================== OHLCV DOWNLOADER ====================
@dataclass(frozen=True)
class OhlcvDataPoint:
    pair: str
    timeframe: str
    start_date: Optional[datetime] = None
    end_date: Optional[datetime] = None

class OhlcvDownloader:
    def __init__(self, exchange: CcxtExchange):
        self._exchange = exchange
    
    async def download_pair(self, request: OhlcvDataPoint) -> pd.DataFrame:
        try:
            interval_ms = self._exchange.get_timeframe_milliseconds(request.timeframe)
            chunk_size = self._exchange.get_ohlcv_request_limit()
            
            end_ts = int(datetime.now().timestamp() * 1000) if request.end_date is None else int(request.end_date.timestamp() * 1000)
            
            if request.start_date is None:
                start_ts = end_ts - (30 * 24 * 60 * 60 * 1000)
            else:
                start_ts = int(request.start_date.timestamp() * 1000)
                
            print(f"Downloading {request.pair} - {request.timeframe}...")
            return await self._download_data(request, start_ts, end_ts, chunk_size, interval_ms)
        except Exception as e:
            print(f"Error downloading {request.pair} - {request.timeframe}: {str(e)}")
            return pd.DataFrame()
            
    async def _download_data(self, request: OhlcvDataPoint, start_ts: int, end_ts: int, 
                           chunk_size: int, interval_ms: int) -> pd.DataFrame:
        tasks = []
        current_ts = start_ts
        
        while current_ts < end_ts:
            task = asyncio.create_task(
                self._fetch_chunk(request, current_ts, chunk_size)
            )
            tasks.append(task)
            current_ts += interval_ms * chunk_size
        
        chunks = await asyncio.gather(*tasks, return_exceptions=True)
        valid_chunks = [chunk for chunk in chunks if isinstance(chunk, pd.DataFrame) and not chunk.empty]
        
        if not valid_chunks:
            return pd.DataFrame()
            
        data = pd.concat(valid_chunks)
        data['date'] = pd.to_datetime(data['timestamp'], unit='ms')
        data = data.set_index('date')
        data = data.drop(columns=['timestamp'])
        data = data[~data.index.duplicated(keep='last')].sort_index()
        
        if request.start_date:
            data = data[data.index >= pd.Timestamp(request.start_date)]
        if request.end_date:
            data = data[data.index <= pd.Timestamp(request.end_date)]
            
        return data

    async def _fetch_chunk(self, request: OhlcvDataPoint, timestamp: int, chunk_size: int) -> pd.DataFrame:
        try:
            data = await self._exchange.fetch_ohlcv(
                pair=request.pair,
                timeframe=request.timeframe,
                since=timestamp,
                limit=chunk_size,
            )
            return pd.DataFrame(
                data, columns=["timestamp", "open", "high", "low", "close", "volume"]
            )
        except Exception as e:
            print(f"Error downloading {request.pair} - {request.timeframe} at {timestamp}: {e}")
            return pd.DataFrame()

# ==================== HELPER FUNCTIONS ====================
def calculate_pct_returns(df: pd.DataFrame, price_columns: List[str]) -> pd.DataFrame:
    returns_df = pd.DataFrame(index=df.index)
    
    for col in price_columns:
        if col in df.columns:
            returns_df[col] = df[col].pct_change()
    
    returns_df = returns_df.dropna()
    return returns_df

def compute_correlation_matrix(df: pd.DataFrame) -> pd.DataFrame:
    df = df.dropna(axis=1, how='all')
    return df.corr(method='pearson')

def simplify_column_names(df: pd.DataFrame) -> pd.DataFrame:
    df_copy = df.copy()
    rename_map = {}
    for col in df_copy.columns:
        match = re.match(r'([^/]+)/', col)
        if match:
            rename_map[col] = match.group(1)
    
    df_copy.rename(columns=rename_map, inplace=True)
    return df_copy

def plot_heatmap(corr_matrix: pd.DataFrame, title: str = "Correlation Matrix", 
                 figsize: Tuple[int, int] = (12, 10), cmap: str = "coolwarm",
                 annot: bool = True, mask_upper: bool = True,
                 simplify_names: bool = True) -> None:
    if simplify_names:
        corr_matrix = corr_matrix.copy()
        rename_map = {}
        for col in corr_matrix.columns:
            match = re.match(r'([^/]+)/', col)
            if match:
                rename_map[col] = match.group(1)
        
        corr_matrix = corr_matrix.rename(columns=rename_map, index=rename_map)
    
    mask = None
    if mask_upper:
        mask = np.triu(np.ones_like(corr_matrix, dtype=bool))
    
    plt.figure(figsize=figsize)
    
    sns.heatmap(corr_matrix, mask=mask, cmap=cmap, vmax=1.0, vmin=-1.0,
               center=0, annot=annot, fmt=".2f", square=True, linewidths=.5)
    
    plt.title(title, fontsize=16)
    plt.tight_layout()
    plt.show()

def get_sorted_correlations(corr_matrix: pd.DataFrame, threshold: float = 0.0,
                           direction: str = "above", exclude_self: bool = True,
                           simplify_names: bool = True) -> pd.DataFrame:
    if simplify_names:
        corr_matrix = simplify_column_names(corr_matrix)
    
    corr_pairs = corr_matrix.unstack()
    
    if simplify_names:
        corr_pairs.index = [f"{i[0]} - {i[1]}" for i in corr_pairs.index]
    else:
        corr_pairs.index = [
            f"{re.match(r'([^/]+)/', i[0]).group(1) if re.match(r'([^/]+)/', i[0]) else i[0]} - "
            f"{re.match(r'([^/]+)/', i[1]).group(1) if re.match(r'([^/]+)/', i[1]) else i[1]}" 
            for i in corr_pairs.index
        ]
    
    corr_df = pd.DataFrame(corr_pairs, columns=["correlation"])
    
    if threshold > 0:
        if direction.lower() == "above":
            corr_df = corr_df[corr_df["correlation"] >= threshold]
        elif direction.lower() == "below":
            corr_df = corr_df[corr_df["correlation"] <= threshold]
        else:
            corr_df = corr_df[abs(corr_df["correlation"]) >= threshold]
    
    if exclude_self:
        corr_df = corr_df[abs(corr_df["correlation"]) < 1.0]
    
    if direction.lower() == "above":
        return corr_df.sort_values(by="correlation", ascending=False)
    elif direction.lower() == "below":
        return corr_df.sort_values(by="correlation", ascending=True)
    else:
        return corr_df.sort_values(by="correlation", key=abs, ascending=False)

# ==================== CRYPTO ANALYSIS MODULE ====================
class CryptoCorrelationAnalysis:
    def __init__(self, exchange_name: str = "binance", api_key: Optional[str] = None):
        self.exchange_name = exchange_name
        self.api_key = api_key
        self.selected_pairs = []
        self.data = {}
    
    async def select_pairs(self, top_n: int = 20, quote_currency: str = "USDT",
                         remove_stablecoins: bool = True, remove_clones: bool = True,
                         add_pairs: Optional[List[str]] = None,
                         remove_pairs: Optional[List[str]] = None) -> List[str]:
        async with CoinGeckoAPI(api_key=self.api_key) as client:
            params = {
                'per_page': 250,
                'page': 1,
                'order': 'market_cap_desc'
            }
            data = await client.get_coins_markets(params)
            
            df = pd.DataFrame([{
                'id': coin['id'],
                'name': coin['name'],
                'symbol': coin['symbol'].upper(),
                'market_cap': coin['market_cap'],
                'market_cap_rank': coin['market_cap_rank']
            } for coin in data])
            
            if remove_stablecoins:
                df = df[~df['symbol'].isin(STABLECOINS)]
            
            if remove_clones:
                df = df[~df['symbol'].isin(CLONECOINS)]
                
            df = df.sort_values('market_cap_rank', ascending=True)
            df = df.head(top_n)
            
            pairs = [f"{row['symbol'].upper()}/{quote_currency}" for _, row in df.iterrows()]
        
        pairs_set = set(pairs)
        
        if add_pairs:
            for pair in add_pairs:
                pairs_set.add(pair)
                
        if remove_pairs:
            for pair in remove_pairs:
                if pair in pairs_set:
                    pairs_set.remove(pair)
                    
        pairs = sorted(list(pairs_set))
        
        self.selected_pairs = pairs
        return pairs
    
    async def download_ohlcv(self, pairs: Optional[List[str]] = None,
                            timeframes: Union[str, List[str]] = "1d",
                            start_date: Optional[Union[str, datetime]] = None,
                            end_date: Optional[Union[str, datetime]] = None) -> Dict[str, Dict[str, pd.DataFrame]]:
        if pairs is None:
            if not self.selected_pairs:
                raise ValueError("No pairs selected. Call select_pairs() first or provide pairs explicitly.")
            pairs = self.selected_pairs
        
        if isinstance(timeframes, str):
            timeframes = [timeframes]
            
        if isinstance(start_date, str):
            start_date = datetime.fromisoformat(start_date.replace('Z', '+00:00'))
        if isinstance(end_date, str):
            end_date = datetime.fromisoformat(end_date.replace('Z', '+00:00'))
            
        exchange = CcxtExchange(self.exchange_name)
        downloader = OhlcvDownloader(exchange)
        
        result_data = {tf: {} for tf in timeframes}
        
        async with exchange:
            for timeframe in timeframes:
                for pair in pairs:
                    request = OhlcvDataPoint(
                        pair=pair,
                        timeframe=timeframe,
                        start_date=start_date,
                        end_date=end_date
                    )
                    
                    df = await downloader.download_pair(request)
                    
                    if not df.empty:
                        result_data[timeframe][pair] = df
                
                if timeframe in result_data and result_data[timeframe]:
                    self.data[timeframe] = result_data[timeframe]
        
        return result_data
    
    async def load_ohlcv(self, pairs: Optional[List[str]] = None, timeframe: str = "1d",
                       start_date: Optional[Union[str, datetime]] = None,
                       end_date: Optional[Union[str, datetime]] = None,
                       price_type: Union[str, List[str]] = "close",
                       calculate_returns: bool = False) -> Dict[str, pd.DataFrame]:
        if pairs is None:
            if not self.selected_pairs:
                raise ValueError("No pairs selected. Call select_pairs() first or provide pairs explicitly.")
            pairs = self.selected_pairs
            
        if timeframe not in self.data:
            raise ValueError(f"No data available for timeframe {timeframe}. Call download_ohlcv() first.")
        
        if isinstance(price_type, str):
            price_type = [price_type]
        
        if isinstance(start_date, str):
            start_date = datetime.fromisoformat(start_date.replace('Z', '+00:00'))
        if isinstance(end_date, str):
            end_date = datetime.fromisoformat(end_date.replace('Z', '+00:00'))
            
        all_data = {}
        price_columns = []
        
        for pair in pairs:
            if pair in self.data[timeframe]:
                df = self.data[timeframe][pair]
                
                if start_date:
                    df = df[df.index >= pd.Timestamp(start_date)]
                if end_date:
                    df = df[df.index <= pd.Timestamp(end_date)]
                
                for col in price_type:
                    if col in df.columns:
                        col_name = f"{pair}_{col}"
                        all_data[col_name] = df[col]
                        price_columns.append(col_name)
        
        result = {}
        if all_data:
            prices_df = pd.DataFrame(all_data)
            result["prices"] = prices_df
            
            if calculate_returns and not prices_df.empty:
                result["returns"] = calculate_pct_returns(prices_df, price_columns)
        else:
            result["prices"] = pd.DataFrame()
            if calculate_returns:
                result["returns"] = pd.DataFrame()
                
        return result
    
    async def get_data(self, top_n: int = 20, quote_currency: str = "USDT",
                     timeframe: str = "1d", start_date: Optional[Union[str, datetime]] = None,
                     end_date: Optional[Union[str, datetime]] = None,
                     remove_stablecoins: bool = True, remove_clones: bool = True,
                     price_type: Union[str, List[str]] = "close",
                     calculate_returns: bool = True,
                     custom_pairs: Optional[List[str]] = None,
                     add_pairs: Optional[List[str]] = None,
                     remove_pairs: Optional[List[str]] = None) -> Dict[str, pd.DataFrame]:
        if custom_pairs:
            self.selected_pairs = custom_pairs
        else:
            await self.select_pairs(
                top_n=top_n,
                quote_currency=quote_currency,
                remove_stablecoins=remove_stablecoins,
                remove_clones=remove_clones,
                add_pairs=add_pairs,
                remove_pairs=remove_pairs
            )
        
        await self.download_ohlcv(
            timeframes=timeframe,
            start_date=start_date,
            end_date=end_date
        )
        
        return await self.load_ohlcv(
            timeframe=timeframe,
            start_date=start_date,
            end_date=end_date,
            price_type=price_type,
            calculate_returns=calculate_returns
        )
    
    def analyse(
        self, 
        data: pd.DataFrame, 
        threshold: float = 0.0,
        direction: str = "above",
        plot: bool = True,
        title: str = "Correlation Matrix",
        figsize: Tuple[int, int] = (12, 10),
        simplify_names: bool = True
    ) -> Tuple[pd.DataFrame, pd.DataFrame]:
        corr_matrix = compute_correlation_matrix(data)
        sorted_corrs = get_sorted_correlations(
            corr_matrix, 
            threshold=threshold,
            direction=direction,
            simplify_names=simplify_names
        )
            
       
        if plot:
            plot_heatmap(
                corr_matrix,
                title=title, 
                figsize=figsize, 
                simplify_names=simplify_names
            )
        
        return corr_matrix, sorted_corrs

In [None]:
analysis = CryptoCorrelationAnalysis(exchange_name="binance") # L'exchange sur lequel on va récupérer les données
result = await analysis.get_data(
    top_n=15,                                      # Nombre de cryptomonnaies à sélectionner par capitalisation boursière
    timeframe="1d",                                # Intervalle de temps pour les données (1 heure)
    start_date="2025-01-01",                       # Date de début pour la collecte des données
    end_date="2025-03-01",                         # Date de fin pour la collecte des données
    add_pairs=[],                                  # Paires supplémentaires à ajouter (liste vide ici)
    remove_pairs=["BGB/USDT", "HYPE/USDT", "LEO/USDT", "PI/USDT", "ONDO/USDT"],  # Paires à exclure
)

print(f"Nombre de paires sélectionnées: {len(analysis.selected_pairs)}")
print(f"Paires sélectionnées: {analysis.selected_pairs}")
print(f"\nAperçu des prix:")
display(result["prices"].head())
print(f"\nAperçu des rendements:")
display(result["returns"].head())

In [None]:
corr_matrix, sorted_corrs = analysis.analyse(
    result["returns"],
)