In [1]:
import pandas as pd
import akshare as ak
import asyncio
import aiohttp
import time
import random
from functools import lru_cache
from typing import List, Tuple, Optional

In [2]:
# Cache the industry names with TTL-like behavior
@lru_cache(maxsize=1)
def get_industry_names():
    """Get industry names (cached to avoid repeated API calls)"""
    print("Fetching industry names from API...")
    return ak.stock_board_industry_name_em()["板块名称"]

# Cache the reference data to avoid repeated API calls
@lru_cache(maxsize=1)
def get_reference_dates(days=29):
    """Get first and last dates from reference data (cached)"""
    ref_data = ak.stock_sector_fund_flow_hist(symbol="证券")
    first_date = ref_data.iloc[-days]["日期"]
    last_date = ref_data.iloc[-1]["日期"]
    return first_date, last_date

async def async_api_call(func, *args, **kwargs):
    """Wrapper to run sync akshare functions in async context"""
    loop = asyncio.get_event_loop()
    # Use functools.partial to handle keyword arguments
    from functools import partial
    partial_func = partial(func, *args, **kwargs)
    return await loop.run_in_executor(None, partial_func)

async def process_single_industry_async(semaphore: asyncio.Semaphore, 
                                      industry_name: str, 
                                      days: int, 
                                      first_date_str: str, 
                                      last_date_str: str) -> List:
    """Process a single industry asynchronously"""
    async with semaphore:
        # Add random delay to avoid hammering the API
        await asyncio.sleep(random.uniform(0.1, 0.3))
        
        try:
            # Get fund flow data asynchronously
            fund_flow_df = await async_api_call(
                ak.stock_sector_fund_flow_hist, 
                symbol=industry_name
            )
            main_net_flow = fund_flow_df.iloc[-days:]["主力净流入-净额"].sum()
            
            # Get historical price data for change percentage (parallel execution)
            hist_task = async_api_call(
                ak.stock_board_industry_hist_em,
                symbol=industry_name,
                start_date=first_date_str,
                end_date=first_date_str,
                period="日k",
                adjust=""
            )
            
            today_task = async_api_call(
                ak.stock_board_industry_hist_em,
                symbol=industry_name,
                start_date=last_date_str,
                end_date=last_date_str,
                period="日k",
                adjust=""
            )
            
            # Wait for both price queries to complete
            hist_result, today_result = await asyncio.gather(hist_task, today_task)
            
            hist_index = hist_result.iloc[0]["收盘"]
            today_index = today_result.iloc[0]["收盘"]
            
            change_percentage = round((today_index - hist_index) / hist_index * 100, 2)
            
            print(f"{industry_name}: {main_net_flow}, {change_percentage}%")
            return [industry_name, main_net_flow, change_percentage]
            
        except Exception as e:
            print(f"Error processing {industry_name}: {str(e)}")
            # Handle rate limiting with exponential backoff
            if "429" in str(e) or "rate limit" in str(e).lower():
                await asyncio.sleep(random.uniform(1, 3))
            return [industry_name, 0, 0]

async def industry_info_async(days: int = 29, max_concurrent: int = 5) -> pd.DataFrame:
    """
    Async version: Extracts the main net flow and change percentage of each industry.
    
    :param days: Number of days to consider for net flow calculation.
    :param max_concurrent: Maximum number of concurrent operations
    :return: DataFrame with industry names, their main net flow, and change percentage.
    """
    
    # Get industry names from cache (much faster on subsequent calls)
    industry_names = get_industry_names()  # or industry_cache.get_industry_names()
    
    # Get reference dates (cached)
    first_date, last_date = get_reference_dates(days)
    first_date_str = first_date.strftime("%Y%m%d")
    last_date_str = last_date.strftime("%Y%m%d")
    
    # Create semaphore to limit concurrent operations
    semaphore = asyncio.Semaphore(max_concurrent)
    
    # Create tasks for all industries
    tasks = [
        process_single_industry_async(
            semaphore, industry_name, days, first_date_str, last_date_str
        )
        for industry_name in industry_names
    ]
    
    # Execute all tasks concurrently
    print(f"Processing {len(tasks)} industries with {max_concurrent} concurrent operations...")
    results = await asyncio.gather(*tasks, return_exceptions=True)
    
    # Filter out exceptions and convert to DataFrame
    valid_results = [r for r in results if not isinstance(r, Exception)]
    
    # Handle exceptions
    exceptions = [r for r in results if isinstance(r, Exception)]
    if exceptions:
        print(f"Encountered {len(exceptions)} exceptions during processing")
    
    df = pd.DataFrame(valid_results, columns=["行业名称", "主力净流入-总净额", "行业涨跌幅"])
    
    return df

# Batch processing version with cached industry names
async def process_industry_batch(semaphore: asyncio.Semaphore,
                               industries: List[str],
                               days: int,
                               first_date_str: str,
                               last_date_str: str) -> List[List]:
    """Process a batch of industries together"""
    async with semaphore:
        results = []
        for industry_name in industries:
            try:
                # Small delay between items in batch
                await asyncio.sleep(0.05)
                
                # Get fund flow data
                fund_flow_df = await async_api_call(
                    ak.stock_sector_fund_flow_hist, 
                    symbol=industry_name
                )
                main_net_flow = fund_flow_df.iloc[-days:]["主力净流入-净额"].sum()
                
                # Get price data
                hist_result = await async_api_call(
                    ak.stock_board_industry_hist_em,
                    symbol=industry_name,
                    start_date=first_date_str,
                    end_date=first_date_str,
                    period="日k",
                    adjust=""
                )
                
                today_result = await async_api_call(
                    ak.stock_board_industry_hist_em,
                    symbol=industry_name,
                    start_date=last_date_str,
                    end_date=last_date_str,
                    period="日k",
                    adjust=""
                )
                
                hist_index = hist_result.iloc[0]["收盘"]
                today_index = today_result.iloc[0]["收盘"]
                change_percentage = round((today_index - hist_index) / hist_index * 100, 2)
                
                print(f"{industry_name}: {main_net_flow}, {change_percentage}%")
                results.append([industry_name, main_net_flow, change_percentage])
                
            except Exception as e:
                print(f"Error processing {industry_name}: {str(e)}")
                results.append([industry_name, 0, 0])
        
        return results

async def industry_info_batched(days: int = 29, 
                              max_concurrent: int = 3, 
                              batch_size: int = 5) -> pd.DataFrame:
    """
    Batched async version for even better performance and rate limit handling.
    
    :param days: Number of days to consider for net flow calculation.
    :param max_concurrent: Maximum number of concurrent batches
    :param batch_size: Number of industries per batch
    :return: DataFrame with industry names, their main net flow, and change percentage.
    """
    
    # Get industry names from cache
    industry_names = get_industry_names()
    
    # Get reference dates
    first_date, last_date = get_reference_dates(days)
    first_date_str = first_date.strftime("%Y%m%d")
    last_date_str = last_date.strftime("%Y%m%d")
    
    # Create batches of industries
    industries = list(industry_names)
    batches = [industries[i:i + batch_size] for i in range(0, len(industries), batch_size)]
    
    # Create semaphore for batches
    semaphore = asyncio.Semaphore(max_concurrent)
    
    # Process batches concurrently
    print(f"Processing {len(industries)} industries in {len(batches)} batches...")
    batch_tasks = [
        process_industry_batch(semaphore, batch, days, first_date_str, last_date_str)
        for batch in batches
    ]
    
    batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True)
    
    # Flatten results
    all_results = []
    for batch_result in batch_results:
        if isinstance(batch_result, Exception):
            print(f"Batch failed: {batch_result}")
        else:
            all_results.extend(batch_result)
    
    df = pd.DataFrame(all_results, columns=["行业名称", "主力净流入-总净额", "行业涨跌幅"])
    return df

# Utility functions for cache management
def clear_all_caches():
    """Clear all caches"""
    get_industry_names.cache_clear()
    get_reference_dates.cache_clear()
    industry_cache.invalidate()
    print("All caches cleared")

def get_cache_info():
    """Get information about cache usage"""
    return {
        "industry_names_cache": get_industry_names.cache_info(),
        "reference_dates_cache": get_reference_dates.cache_info(),
        "manual_cache_last_updated": industry_cache._last_updated
    }

# Usage examples
async def main():
    """Test the improved caching approach"""
    
    print("=== Testing Cached Industry Names ===")
    
    # First call - will fetch from API
    print("\n1. First call (fetching from API):")
    start_time = time.time()
    result1 = await industry_info_async(days=29, max_concurrent=3)
    end_time = time.time()
    print(f"First call - Time: {end_time - start_time:.2f}s, Results: {len(result1)}")
    
    # Second call - will use cached industry names
    print("\n2. Second call (using cached industry names):")
    start_time = time.time()
    result2 = await industry_info_async(days=29, max_concurrent=3)
    end_time = time.time()
    print(f"Second call - Time: {end_time - start_time:.2f}s, Results: {len(result2)}")
    
    # Show cache info
    print("\n3. Cache information:")
    cache_info = get_cache_info()
    for key, value in cache_info.items():
        print(f"{key}: {value}")
    
    return result1

In [None]:
# In Jupyter cell
result = await main()

=== Testing Cached Industry Names ===

1. First call (fetching from API):
Fetching industry names from API...


0it [00:00, ?it/s]

0it [00:00, ?it/s]

Processing 86 industries with 3 concurrent operations...


0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

证券: 5136258208.0, 13.09%
船舶制造: -6317680448.0, 17.97%
小金属: -6950252432.0, 18.94%
多元金融: -271065360.0, 25.99%
软件开发: -29420723504.0, 11.39%
贸易行业: -2648144433.0, 8.69%
互联网服务: -23248636080.0, 10.44%
有色金属: -8363104864.0, 11.6%
医疗服务: -5325389968.0, 8.26%
生物制品: -6716129936.0, 5.76%
汽车服务: -526485768.0, 4.92%
非金属材料: -824826147.0, 13.05%
贵金属: -4094884096.0, 5.33%
电机: -3298242879.0, 5.08%
化学制药: -21141113888.0, 7.53%
半导体: -22800459584.0, 6.65%
能源金属: -555705113.0, 14.03%
钢铁行业: 514943376.0, 13.08%
计算机设备: -12874643680.0, 8.8%
石油行业: -2427497440.0, 7.76%
专业服务: -1914674013.0, 10.5%
房地产开发: -2327435424.0, 9.88%
通用设备: -22690558416.0, 5.52%
专用设备: -26213457264.0, 7.63%
航天航空: -12842400384.0, 9.39%
综合行业: -1612794895.0, 10.26%
工程机械: -3586548080.0, 4.22%
汽车零部件: -30338997296.0, 3.58%
农牧饲渔: -6738783760.0, 3.11%
航空机场: -1062698812.0, -0.4%
通信设备: -20115830784.0, 11.51%
酿酒行业: -12738338288.0, -0.31%
仪器仪表: -3323753166.0, 6.98%
医疗器械: -4931589808.0, 5.07%
中药: -6477581744.0, 2.2%
化肥行业: -3269532793.0, 8.45%
工程建设: -3603747888.

In [None]:
result

In [None]:
# For regular Python script usage
if __name__ == "__main__":
    # To run the async version:
    result = asyncio.run(main())