In [None]:
from rich.progress import Progress
from datetime import datetime, timedelta 
import pandas as pd
import numpy as np
import torch

def get_time_list(df: pd.DataFrame, interval: int):
    """Get sampled timestamps at given interval (in seconds)"""
    times = df['trade_time'].unique()
    times = sorted(times)
    # Sample every 'interval' seconds
    return [times[i] for i in range(0, len(times), interval)]

def get_x_y(instrument_id: str, timestamp: str, look_back: int, look_forward: int, df: pd.DataFrame, 
            min_return_threshold: float = 0.0001):  # Added minimum return threshold
    """Get features (X) and target (y) for given timestamp"""
    past_data = df[df['trade_time'] <= timestamp].sort_values('trade_time', ascending=False)
    future_data = df[df['trade_time'] > timestamp].sort_values('trade_time')
    
    if (past_data.shape[0] >= look_back) and (future_data.shape[0] >= look_forward):
        # Extract tick data features
        features = past_data.iloc[0:look_back][[
            'last_price', 'highest_price', 'lowest_price', 
            'cum_volume', 'cum_turnover',
            'bid_price1', 'bid_volume1', 'ask_price1', 'ask_volume1'
        ]].fillna(0)
        
        # Calculate target return
        # Look for significant price move within a larger window
        max_window = min(look_forward * 3, len(future_data))  # Look up to 3x forward
        future_prices = future_data['last_price'].iloc[0:max_window]
        
        if len(future_prices) >= look_forward:
            # Calculate returns for different horizons
            returns = []
            volumes = []
            for i in range(look_forward, max_window):
                ret = future_prices.iloc[i]/future_prices.iloc[0] - 1
                vol_change = future_data['cum_volume'].iloc[i] - future_data['cum_volume'].iloc[0]
                returns.append((abs(ret), ret, i, vol_change))
            
            if returns:
                # Find the first significant price move
                for abs_ret, ret, idx, vol_change in returns:
                    if abs_ret >= min_return_threshold or vol_change >= 1000:  # Added volume condition
                        return features.iloc[::-1].T.values, ret
                
                # If no significant move found, use the original horizon
                ret = future_prices.iloc[look_forward-1]/future_prices.iloc[0] - 1
                return features.iloc[::-1].T.values, ret
            
    return None, None

def get_dataset(interval: int, look_back: int, look_forward: int, df: pd.DataFrame):
    """Build dataset from tick data"""
    X_train = []
    y_train = []
    
    # Sample timestamps at interval
    times = df['trade_time'].unique()
    times = sorted(times)
    time_list = [times[i] for i in range(0, len(times), interval)]
    
    with Progress() as progress:
        task = progress.add_task("[red]Processing...", total=len(time_list))
        
        for timestamp in time_list:
            progress.update(task, advance=1)
            instrument_id = df['instrument_id'].iloc[0]
            
            # Calculate current volatility for adaptive threshold
            current_data = df[df['trade_time'] <= timestamp].tail(100)
            if len(current_data) >= 100:
                current_vol = current_data['last_price'].pct_change().std()
                min_return_threshold = max(0.0001, 0.0001 * (1 + current_vol * 100))
            else:
                min_return_threshold = 0.0001
                
            x, y = get_x_y(
                instrument_id=instrument_id, 
                timestamp=timestamp, 
                look_back=look_back, 
                look_forward=look_forward, 
                df=df,
                min_return_threshold=min_return_threshold
            )
            
            try:
                if (x.shape[0] == 9) & (x.shape[1] == look_back):
                    X_train.append(x)
                    y_train.append(y)
            except:
                continue
                    
    return X_train, y_train

def balance_samples_optimized(X_samples, y_samples, min_samples_per_class: int = None, random_state: int = 42):
    """
    Optimized version of balance_samples
    """
    np.random.seed(random_state)
    
    # Convert to numpy arrays if not already
    X_samples = np.array(X_samples)
    y_samples = np.array(y_samples)
    
    # Fast vectorized class conversion
    y_classes = np.zeros_like(y_samples, dtype=int)
    y_classes[y_samples > 0.002] = 2
    y_classes[y_samples < -0.002] = 0
    
    # Calculate class distribution
    class_counts = np.bincount(y_classes)
    print("Original class distribution:", class_counts)
    
    if min_samples_per_class is None:
        min_samples_per_class = min(class_counts)
    
    # Pre-allocate arrays
    total_samples = min_samples_per_class * len(class_counts)
    balanced_X = np.zeros((total_samples,) + X_samples.shape[1:])
    balanced_y = np.zeros(total_samples)
    
    current_idx = 0
    for class_label in range(len(class_counts)):
        class_mask = (y_classes == class_label)
        class_indices = np.where(class_mask)[0]
        
        if len(class_indices) > min_samples_per_class:
            # Undersample
            selected_indices = np.random.choice(
                class_indices, 
                size=min_samples_per_class, 
                replace=False
            )
        else:
            # Oversample
            selected_indices = np.random.choice(
                class_indices, 
                size=min_samples_per_class, 
                replace=True
            )
        
        end_idx = current_idx + min_samples_per_class
        balanced_X[current_idx:end_idx] = X_samples[selected_indices]
        balanced_y[current_idx:end_idx] = y_samples[selected_indices]
        current_idx = end_idx
    
    # Shuffle in-place
    shuffle_idx = np.random.permutation(len(balanced_y))
    balanced_X = balanced_X[shuffle_idx]
    balanced_y = balanced_y[shuffle_idx]
    
    print("Balanced class distribution:", 
          np.bincount(np.where(balanced_y > 0.001, 2, np.where(balanced_y < -0.001, 0, 1))))
    
    return balanced_X, balanced_y

# Read tick data CSV
# Read tick data
df = pd.read_csv('real_tick.csv')

# Convert timestamps
df['trading_day'] = pd.to_datetime(df['trading_day'])
df['trade_time'] = pd.to_datetime(df['trade_time'])

# Split train/test
split_point = int(len(df) * 0.8)
df_train = df.iloc[:split_point]
df_test = df.iloc[split_point:]

# Get training and test datasets
X_train, y_train = get_dataset(
    interval=10,  # Sample every 10 ticks
    look_back=30, # Use 30 past ticks 
    look_forward=100, # Predict return over next 100 ticks
    df=df_train)

X_test, y_test = get_dataset(
    interval=10,
    look_back=30, 
    look_forward=100,
    df=df_test)

X_train_balanced, y_train_balanced = balance_samples(
    np.array(X_train), 
    np.array(y_train),
    min_samples_per_class=1000  # 可以调整每个类别的样本数
)

X_test_balanced, y_test_balanced = balance_samples(
    np.array(X_test), 
    np.array(y_test),
    min_samples_per_class=500  # 测试集可以使用更少的样本
)

# Convert to tensors
trainx = torch.from_numpy(np.array(X_train)).reshape(
    len(X_train), 1, 9, 30)  # N x 1 x 9 x 30
trainy = torch.from_numpy(np.array(y_train)).reshape(
    len(y_train), 1)  # N x 1



print("there are in total", len(X_train), "training samples")
print("there are in total", len(X_test), "testing samples")

# Save arrays
if len(X_train_balanced) > 0 and len(X_test_balanced) > 0:
    np.save('./X_train_balanced.npy', X_train_balanced)
    np.save('./y_train_balanced.npy', y_train_balanced)
    np.save('./X_test_balanced.npy', X_test_balanced)
    np.save('./y_test_balanced.npy', y_test_balanced)
    print("\nBalanced datasets saved successfully")
else:
    print("No valid samples were generated")