In [None]:
import numpy as np
import pandas as pd

In [None]:
from google.colab import files
uploaded = files.upload()

In [None]:
df = pd.read_excel("Britannia_training.xlsx")

In [None]:
# Import necessary libraries
import pandas as pd
import numpy as np
from google.colab import files  # Specific to Google Colab environment
uploaded = files.upload()      # For file upload in Colab

# Read Excel data file
df = pd.read_excel("Britannia_training.xlsx")

# Standardize column names to lowercase
df.columns = df.columns.str.lower()
print(df.columns)  # Verify column names

# Machine Learning imports
from sklearn.tree import DecisionTreeClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
import seaborn as sns
import matplotlib.pyplot as plt

class CWRV_ML_Model:
    def __init__(self):
        """Initialize ML model and scaler"""
        self.model = DecisionTreeClassifier()  # Decision Tree classifier
        self.scaler = StandardScaler()        # Feature scaler
        self.feature_names = []               # Store feature names for interpretation

    def create_features(self, df):
        """Feature engineering: Convert raw price data into technical features"""
        features = []
        if len(df) < 5:  # Minimum 5 candles required for 4-candle window + target
            return pd.DataFrame(), np.array([])

        # Slide 4-candle window through data
        for i in range(len(df) - 4):
            window = df.iloc[i:i+4]  # 4-candle window
            next_close = df.iloc[i+4]['close']  # Target candle's close
            current_close = df.iloc[i+3]['close']  # Last candle in window's close

            # Create binary label (1 if price rises, 0 otherwise)
            label = 1 if next_close > current_close else 0

            # Feature engineering block
            row_features = {
                # Potential issue: These counts use window-closing prices instead of individual candle patterns
                'buyer_count': sum(1 for c in window.itertuples() if current_close < next_close),
                'seller_count': 4 - sum(1 for c in window.itertuples() if current_close > next_close),

                # Volume trend analysis (0=down, 1=neutral, 2=up)
                'volume_trend': self._get_volume_trend(window['volume']),

                # Average price movement magnitude
                'avg_candle_size': np.mean(np.abs(window['close'] - window['open'])),

                # Wick analysis (ratio of upper wick to body size)
                'wick_ratio': np.mean(
                    (window['high'] - window[['close', 'open']].max(axis=1)) /
                    (window[['close', 'open']].max(axis=1) - window[['close', 'open']].min(axis=1) + 1e-8)
                ),

                # Correlation between price changes and volume
                'body_volume_corr': window['close'].diff().corr(
                    pd.to_numeric(window['volume'].astype(str).replace(',', ''), errors='coerce'),
                    min_periods=1),

                # Momentum of last candle in window
                'last_momentum': 1 if window.iloc[-1]['close'] > window.iloc[-1]['open'] else 0
            }

            # Add individual candle metrics for each of 4 candles
            for j in range(4):
                candle = window.iloc[j]
                row_features.update({
                    f'candle{j+1}_body': candle['close'] - candle['open'],        # Body size
                    f'candle{j+1}_upper_wick': candle['high'] - max(candle['close'], candle['open']),  # Upper wick
                    f'candle{j+1}_lower_wick': min(candle['close'], candle['open']) - candle['low']   # Lower wick
                })

            features.append((row_features, label))

        # Store feature names and convert to DataFrame
        self.feature_names = list(features[0][0].keys())
        X = pd.DataFrame([f[0] for f in features])  # Feature matrix
        y = np.array([f[1] for f in features])      # Target vector
        return X, y

    def _get_volume_trend(self, volumes):
        """Calculate volume trend direction"""
        # Clean volume data (handle commas and missing values)
        volumes = pd.to_numeric(
            volumes.astype(str).str.replace(',', '', regex=True),
            errors='coerce'
        ).fillna(0)

        diffs = volumes.diff().dropna()
        if len(diffs) < 1:
            return 1  # Neutral if insufficient data
        if all(diffs > 0):
            return 2  # Strong upward trend
        elif all(diffs < 0):
            return 0  # Strong downward trend
        else:
            return 1  # Mixed trend

    def train(self, data_file):
        """Full training pipeline"""
        df = pd.read_excel(data_file)
        df.columns = df.columns.str.lower()
        X, y = self.create_features(df)

        # Data cleaning
        X = X.replace([np.inf, -np.inf], np.nan).dropna()
        y = y[X.index]  # Align labels with cleaned features

        if X.empty:
            raise ValueError("Not enough valid data to train the model.")

        # Preprocess and split data
        X = self.scaler.fit_transform(X)  # Scale features
        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=0.2, random_state=90)

        # Model training
        self.model.fit(X_train, y_train)

        # Model evaluation
        preds = self.model.predict(X_test)
        accuracy = accuracy_score(y_test, preds)
        cm = confusion_matrix(y_test, preds)
        cr = classification_report(y_test, preds, target_names=['SELL', 'BUY'])

        print(f"Model Training Complete!\nAccuracy: {accuracy:.2%}")
        print("\nConfusion Matrix:\n", cm)
        print("\nClassification Report:\n", cr)

        # Additional diagnostics
        self._plot_feature_importance()
        self._plot_confusion_matrix(y_test, preds)

    def predict_next(self, recent_candles):
        """Make prediction for new data"""
        df = pd.DataFrame(recent_candles)
        df.columns = df.columns.str.lower()
        X, _ = self.create_features(df)

        if X.empty:
            raise ValueError("Insufficient data. Provide at least 5 candles.")

        X = self.scaler.transform(X)
        prediction = self.model.predict(X[-1].reshape(1, -1))[0]  # Predict last available sample
        return "BUY" if prediction == 1 else "SELL"

    # Visualization methods
    def _plot_feature_importance(self):
        """Visualize feature importance from decision tree"""
        importances = self.model.feature_importances_
        indices = np.argsort(importances)[::-1]
        plt.figure(figsize=(12, 6))
        plt.barh(range(len(indices)), importances[indices], align='center')
        plt.yticks(range(len(indices)), [self.feature_names[i] for i in indices])
        plt.title('Feature Importance Analysis')
        plt.xlabel('Relative Importance')
        plt.show()

    def _plot_confusion_matrix(self, y_true, y_pred):
        """Plot confusion matrix"""
        cm = confusion_matrix(y_true, y_pred)
        plt.figure(figsize=(6, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False)
        plt.title('Confusion Matrix: Actual vs Predicted')
        plt.xlabel('Predicted Labels')
        plt.ylabel('True Labels')
        plt.show()

# Main execution block
if __name__ == "__main__":
    # Initialize and train model
    ml_model = CWRV_ML_Model()
    ml_model.train('Britannia_training.xlsx')

    # Make prediction (example using training data)
    prediction = ml_model.predict_next(df)
    print(f"\nPredicted action: {prediction}")

    # Visualize prediction context
    plt.figure(figsize=(12, 6))
    plt.plot(df['close'], marker='o', linestyle='-', label='Price Trend')
    plt.scatter(len(df)-1, df.iloc[-1]['close'],
                color='red' if prediction == 'SELL' else 'green',
                s=200, label=f'Prediction: {prediction}')
    plt.title('Price Trend with Prediction Marker')
    plt.xlabel('Candle Index')
    plt.ylabel('Close Price')
    plt.legend()
    plt.grid(True)
    plt.show()
