In [0]:
import pandas as pd
import os
import numpy as np
import plotly.graph_objects as go
import matplotlib.pyplot as plt
import json

from datetime import datetime
from scipy.interpolate import CubicSpline
from scipy.integrate import simpson
from plotly.subplots import make_subplots
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
from scipy.signal import savgol_filter
from scipy import stats

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import MinMaxScaler




In [0]:
# Define input and output directories
input_dir = '/Workspace/Users/amahmud1@networkrail.co.uk/CT/CT_raw_data/'
output_dir = '/Workspace/Users/amahmud1@networkrail.co.uk/CT/CT_processed_data/'

# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Configuration variables
peak_amp = 3
cols_to_use = [
    'CT_Index', 'ELR', 'TrackId', '45_top_right', '45_top_left', '6_top_right',
    '6_top_left', '9_top_right', '9_top_left', '135_top_right', '135_top_left', 
    '18_top_right', '18_top_left', 'Location_Norm_m'
]

ct_channels = [
    '45_top_right', '45_top_left', '6_top_right', '6_top_left',
    '9_top_right', '9_top_left', '135_top_right', '135_top_left',
    '18_top_right', '18_top_left'
]

def sigmoid_value(val):
    """Placeholder function for sigmoid value calculation"""
    return 0.0

def extract_date_from_filename(filename):
    """Extract date from filename format 'OWW-2100-134404-222214-2021-08-12.csv'"""
    try:
        parts = filename.replace('.csv', '').split('-')
        year = parts[-3]
        month = parts[-2]
        day = parts[-1]
        date_str = f"{year}-{month}-{day}"
        datetime.strptime(date_str, '%Y-%m-%d')
        return date_str
    except Exception as e:
        print(f"Warning: Could not extract valid date from filename: {filename}. Error: {str(e)}")
        return None

def find_peak_islands(df, col):
    """Find islands of consecutive peaks and their triplet scores"""
    # Create a boolean mask for peaks
    peaks = df[col] >= peak_amp
    
    # Create groups of consecutive peaks
    peak_groups = (peaks != peaks.shift()).cumsum()[peaks]
    
    peak_data = []
    island_counter = 1  # Counter for island IDs
    
    for group_id in peak_groups.unique():
        # Get indices for this group of peaks
        group_indices = peak_groups[peak_groups == group_id].index
        
        # Only process groups with 3 or more consecutive peaks
        if len(group_indices) >= 3:
            group_df = df.loc[group_indices]
            
            # Create all possible triplets within this group
            for i in range(len(group_df) - 2):
                triplet_indices = group_df.index[i:i+3]
                triplet_data = group_df.loc[triplet_indices]
                
                triplet_score = sum(triplet_data[col].values + 
                                  [sigmoid_value(v) for v in triplet_data[col].values])
                
                # Store data for each point in the triplet
                for point in triplet_indices:
                    peak_data.append({
                        'CT_Index': df.loc[point, 'CT_Index'],
                        'Location_Norm_m': df.loc[point, 'Location_Norm_m'],
                        'Value': df.loc[point, col],
                        'triplet_score': triplet_score,
                        'island_id': island_counter,
                        'island_size': len(group_indices)
                    })
            
            island_counter += 1  # Increment island ID for next group
    
    if not peak_data:
        return pd.DataFrame()
    
    # Convert to DataFrame
    peaks_df = pd.DataFrame(peak_data)
    
    # For each CT_Index, keep only the highest triplet score
    peaks_df = peaks_df.loc[peaks_df.groupby('CT_Index')['triplet_score'].idxmax()]
    
    return peaks_df

def process_single_file(file_path):
    """Process a single CT data file and return results"""
    try:
        # Read the CSV file
        ct = pd.read_csv(file_path, usecols=cols_to_use)
        
        final_results = []
        
        # Process each channel
        for ct_channel in ct_channels:
            # Get data for this channel
            ct_channel_df = ct[['CT_Index', 'Location_Norm_m', ct_channel]].dropna(subset=[ct_channel])
            
            # Find peaks and their scores
            peaks_df = find_peak_islands(ct_channel_df, ct_channel)
            
            if not peaks_df.empty:
                # Add channel name
                peaks_df['Channel'] = ct_channel
                final_results.append(peaks_df)
        
        if final_results:
            # Combine all results
            return pd.concat(final_results, ignore_index=True)
        return pd.DataFrame()
    
    except Exception as e:
        print(f"Error processing file {file_path}: {str(e)}")
        return pd.DataFrame()

def main():
    """Main function to process all files"""
    for filename in os.listdir(input_dir):
        if not filename.endswith('.csv'):
            continue
            
        print(f"Processing file: {filename}")
        
        # Extract date from filename
        date = extract_date_from_filename(filename)
        if not date:
            continue
            
        # Process the file
        file_path = os.path.join(input_dir, filename)
        results_df = process_single_file(file_path)
        
        if not results_df.empty:
            # Add date column
            results_df['Date'] = date
            
            # Save to output file
            output_file = os.path.join(output_dir, f"{date}.csv")
            results_df.to_csv(output_file, index=False)
            print(f"Saved results to: {output_file}")

if __name__ == "__main__":
    main()

In [0]:
class CTFeatureExtractor:
    def __init__(self, data, location_col='Location_Norm_m', value_col='Value', 
                 island_id_col='island_id', interpolation_points=100):
        self.data = data
        self.location_col = location_col
        self.value_col = value_col
        self.island_id_col = island_id_col
        self.interpolation_points = interpolation_points
        self.features = {}

    def fit_spline(self, x, y):
        """Fit cubic spline to the data"""
        # Use keyword arguments for CubicSpline
        cs = CubicSpline(x=x, y=y)
        x_new = np.linspace(start=x.min(), stop=x.max(), num=self.interpolation_points)
        y_new = cs(x=x_new)
        return x_new, y_new, cs

    def calculate_features(self, island_id):
        """Calculate features for a specific island"""
        # Get data for this island
        island_data = self.data[self.data[self.island_id_col] == island_id].sort_values(self.location_col)
        x = island_data[self.location_col].values
        y = island_data[self.value_col].values
        
        # Store start and end locations
        start_location = x[0]
        end_location = x[-1]
        
        # Fit spline
        x_interp, y_interp, spline = self.fit_spline(x, y)
        
        # Calculate derivatives
        dy_dx = spline.derivative()(x_interp)
        
        # Calculate features
        features = {
            'island_id': island_id,
            'start_location': start_location,  # Add start location
            'end_location': end_location,      # Add end location
            'peak_amplitude': np.max(y_interp),
            'average_y': np.mean(y_interp),
            'pattern_width': x_interp[-1] - x_interp[0],
            'center_location': x_interp[np.argmax(y_interp)],
            'mean_slope': np.mean(np.abs(dy_dx)),
            'max_slope': np.max(np.abs(dy_dx)),
            'area_under_curve': simpson(y=y_interp, x=x_interp)
        }
        
        # Calculate distance difference
        scalar_distance = np.sqrt((x_interp[-1] - x_interp[0])**2 + 
                                (y_interp[-1] - y_interp[0])**2)
        vector_distance = np.sum(np.sqrt(np.diff(x_interp)**2 + 
                                       np.diff(y_interp)**2))
        features['distance_difference'] = vector_distance - scalar_distance
        
        # Store interpolated values for plotting
        features['x_interp'] = x_interp
        features['y_interp'] = y_interp
        features['dy_dx'] = dy_dx
        features['x_raw'] = x
        features['y_raw'] = y
        
        return features
           
    def extract_all_features(self):
        """Extract features for all islands"""
        for island_id in self.data[self.island_id_col].unique():
            self.features[island_id] = self.calculate_features(island_id)
        return self.features
    
    def plot_all_islands(self, date):
      """Create a single plot showing all islands for a given date"""
      fig = go.Figure()
      
      # Color scale for different risk scores
      colors = plt.cm.RdYlGn_r(np.linspace(0, 1, 100))  # Red (high risk) to Green (low risk)
      
      for island_id, features in self.features.items():
          risk_score = features.get('risk_score', 0)
          color_idx = int(risk_score * 99)  # Map 0-1 to 0-99
          color = f'rgb({colors[color_idx][0]*255},{colors[color_idx][1]*255},{colors[color_idx][2]*255})'
          
          # Plot original points
          fig.add_trace(
              go.Scatter(
                  x=features['x_raw'],
                  y=features['y_raw'],
                  mode='markers',
                  name=f'Island {island_id} (Risk: {risk_score:.2f})',
                  marker=dict(size=8, color=color)
              )
          )
          
          # Plot spline fit
          fig.add_trace(
              go.Scatter(
                  x=features['x_interp'],
                  y=features['y_interp'],
                  mode='lines',
                  name=f'Spline {island_id}',
                  line=dict(color=color),
                  showlegend=False
              )
          )
      
      fig.update_layout(
          title=f'CT Signal Analysis - {date}',
          xaxis_title='Location (m)',
          yaxis_title='Signal Value',
          showlegend=True,
          height=600,
          width=1200
      )
      
      return fig




In [0]:
def get_sorted_files(input_dir):
    """Get list of CSV files sorted by date in filename"""
    files = []
    for filename in os.listdir(input_dir):
        if filename.endswith('.csv'):
            # Extract date from filename (assuming format YYYY-MM-DD.csv)
            date_str = filename.replace('.csv', '')
            try:
                date = datetime.strptime(date_str, '%Y-%m-%d')
                files.append((date, filename))
            except ValueError:
                print(f"Warning: Couldn't parse date from filename: {filename}")
                continue
    
    # Sort files by date
    files.sort(key=lambda x: x[0])
    return files

In [0]:
class IslandIdentifier:
    def __init__(self, distance_threshold=5.0, overlap_threshold=0.3):
        """
        Initialize IslandIdentifier with configurable thresholds
        
        Parameters:
        distance_threshold: Maximum distance (meters) between island boundaries for matching
        overlap_threshold: Minimum overlap ratio required for potential splits/merges
        """
        self.distance_threshold = distance_threshold
        self.overlap_threshold = overlap_threshold
        self.registered_islands = pd.DataFrame(columns=[
            'global_id',
            'start_location',
            'end_location',
            'first_date',
            'last_date',
            'status',  # 'active', 'merged', 'split', 'disappeared'
            'related_islands',  # IDs of split/merged islands
            'confidence'  # confidence score for location matching
        ])
        self.next_id = 1
        self.history = []  # Track all changes to islands

    def calculate_overlap(self, island1, island2):
        """Calculate overlap ratio between two islands"""
        start = max(island1['start_location'], island2['start_location'])
        end = min(island1['end_location'], island2['end_location'])
        
        if end <= start:
            return 0.0
            
        overlap = end - start
        length1 = island1['end_location'] - island1['start_location']
        length2 = island2['end_location'] - island2['start_location']
        
        return overlap / min(length1, length2)

    def find_matches(self, island, current_date):
        """
        Find potential matches for an island
        Returns matches with confidence scores
        """
        active_islands = self.registered_islands[
            self.registered_islands['status'] == 'active'
        ]
        
        matches = []
        for _, registered in active_islands.iterrows():
            # Calculate different matching metrics
            start_diff = abs(registered['start_location'] - island['start_location'])
            end_diff = abs(registered['end_location'] - island['end_location'])
            overlap = self.calculate_overlap(island, registered)
            
            # Check if it's a potential match
            if (start_diff <= self.distance_threshold and 
                end_diff <= self.distance_threshold):
                
                # Calculate confidence score (0-1)
                confidence = 1.0 - max(
                    start_diff / self.distance_threshold,
                    end_diff / self.distance_threshold
                ) * (1 - overlap)
                
                matches.append({
                    'global_id': registered['global_id'],
                    'confidence': confidence,
                    'overlap': overlap
                })
        
        return sorted(matches, key=lambda x: x['confidence'], reverse=True)

    def register_new_island(self, island, date):
        """Register a new island and return its global ID"""
        global_id = f"IS_{self.next_id:04d}"
        self.next_id += 1
        
        new_island = pd.DataFrame({
            'global_id': [global_id],
            'start_location': [island['start_location']],
            'end_location': [island['end_location']],
            'first_date': [date],
            'last_date': [date],
            'status': ['active'],
            'related_islands': [[]],
            'confidence': [1.0]
        })
        
        self.registered_islands = pd.concat([self.registered_islands, new_island], 
                                          ignore_index=True)
        
        self.history.append({
            'date': date,
            'event': 'new',
            'global_id': global_id,
            'location': f"{island['start_location']:.1f}-{island['end_location']:.1f}m"
        })
        
        return global_id

    def update_island(self, global_id, island, date, confidence):
        """Update an existing island's record"""
        idx = self.registered_islands['global_id'] == global_id
        self.registered_islands.loc[idx, 'last_date'] = date
        self.registered_islands.loc[idx, 'confidence'] = min(
            self.registered_islands.loc[idx, 'confidence'].iloc[0],
            confidence
        )

    def check_splits_and_merges(self, current_islands, date):
        """Check for potential split or merged islands"""
        active_islands = self.registered_islands[
            self.registered_islands['status'] == 'active'
        ]
        
        # Group current islands by proximity
        grouped_current = []
        for island in current_islands:
            added = False
            for group in grouped_current:
                if any(self.calculate_overlap(island, existing) > self.overlap_threshold 
                      for existing in group):
                    group.append(island)
                    added = True
                    break
            if not added:
                grouped_current.append([island])
        
        # Check each group for splits/merges
        for group in grouped_current:
            if len(group) > 1:  # Potential split
                matches = []
                for island in group:
                    matches.extend(self.find_matches(island, date))
                
                if len(set(m['global_id'] for m in matches)) == 1:
                    # Split detected
                    parent_id = matches[0]['global_id']
                    self.mark_split(parent_id, group, date)
            
            elif len(group) == 1 and len(self.find_matches(group[0], date)) > 1:
                # Potential merge
                self.mark_merge(group[0], self.find_matches(group[0], date), date)

    def mark_split(self, parent_id, new_islands, date):
        """Mark an island as split and create new islands"""
        # Update parent island
        idx = self.registered_islands['global_id'] == parent_id
        self.registered_islands.loc[idx, 'status'] = 'split'
        
        # Create new islands
        new_ids = []
        for island in new_islands:
            new_id = self.register_new_island(island, date)
            new_ids.append(new_id)
        
        self.history.append({
            'date': date,
            'event': 'split',
            'parent_id': parent_id,
            'child_ids': new_ids
        })

    def mark_merge(self, new_island, parent_matches, date):
        """Mark islands as merged and create a new merged island"""
        # Update parent islands
        parent_ids = [m['global_id'] for m in parent_matches]
        for parent_id in parent_ids:
            idx = self.registered_islands['global_id'] == parent_id
            self.registered_islands.loc[idx, 'status'] = 'merged'
        
        # Create new merged island
        new_id = self.register_new_island(new_island, date)
        
        self.history.append({
            'date': date,
            'event': 'merge',
            'parent_ids': parent_ids,
            'new_id': new_id
        })

    def register_islands(self, df_date, date):
        """
        Register islands from a specific date and assign consistent global IDs
        """
        df = df_date.copy()
        df['global_id'] = None
        df['match_confidence'] = 1.0
        
        # Process each island in the current date
        current_islands = df.to_dict('records')
        
        # Check for splits and merges
        self.check_splits_and_merges(current_islands, date)
        
        # Process each island
        for idx, island in df.iterrows():
            matches = self.find_matches(island, date)
            
            if not matches:
                # New island
                global_id = self.register_new_island(island, date)
                df.loc[idx, 'global_id'] = global_id
                df.loc[idx, 'match_confidence'] = 1.0
            else:
                # Use best match
                best_match = matches[0]
                df.loc[idx, 'global_id'] = best_match['global_id']
                df.loc[idx, 'match_confidence'] = best_match['confidence']
                self.update_island(best_match['global_id'], island, date, 
                                 best_match['confidence'])
        
        return df
    
    def get_history_summary(self):
        """Return a summary of island history"""
        return pd.DataFrame(self.history)

In [0]:
class CTAnalyzer:
    def __init__(self, n_clusters=5):
        self.n_clusters = n_clusters
        self.scaler = StandardScaler()
        self.kmeans = KMeans(n_clusters=n_clusters, random_state=42)
        self.feature_cols = [
            'peak_amplitude', 'average_y', 'distance_difference',
            'pattern_width', 'area_under_curve', 'mean_slope', 'max_slope'
        ]
        # Add feature weights to emphasize important indicators
        self.feature_weights = {
            'peak_amplitude': 0.3,
            'average_y': 0.1,
            'distance_difference': 0.1,
            'pattern_width': 0.15,
            'area_under_curve': 0.15,
            'mean_slope': 0.1,
            'max_slope': 0.1
        }
        self.cluster_risks = None

    def fit(self, combined_summary):
        """Train the analyzer on historical data"""
        try:
            # Extract features
            X = combined_summary[self.feature_cols]
            
            # Scale features
            X_scaled = self.scaler.fit_transform(X)
            
            # Apply feature weights
            X_weighted = X_scaled * np.array([self.feature_weights[col] for col in self.feature_cols])
            
            # Fit KMeans
            self.kmeans.fit(X_weighted)
            
            # Calculate cluster risk scores
            self.cluster_risks = self._calculate_cluster_risks(X_weighted)
            
            return self
            
        except Exception as e:
            print(f"Error during fitting: {str(e)}")
            # Set default risk scores if fitting fails
            self.cluster_risks = np.linspace(0, 1, self.n_clusters)
            return self

    def _calculate_cluster_risks(self, X_scaled):
        """Calculate risk scores for each cluster using weighted features"""
        cluster_centers = self.kmeans.cluster_centers_
        feature_importance = np.array([self.feature_weights[col] for col in self.feature_cols])
        
        # Calculate weighted distances from ideal state
        weighted_distances = np.sum(np.abs(cluster_centers) * feature_importance, axis=1)
        
        # Normalize to 0-1 range but maintain sensitivity
        risks = (weighted_distances - weighted_distances.min()) / (weighted_distances.max() - weighted_distances.min())
        
        # Apply sigmoid transformation to spread out the middle range
        risks = 1 / (1 + np.exp(-5 * (risks - 0.5)))
        
        return risks

    def predict_risk(self, features_df):
        """Enhanced risk prediction with temporal smoothing"""
        try:
            X = features_df[self.feature_cols]
            X_scaled = self.scaler.transform(X)
            
            # Apply feature weights
            X_weighted = X_scaled * np.array([self.feature_weights[col] for col in self.feature_cols])
            
            clusters = self.kmeans.predict(X_weighted)
            
            # Calculate base risk scores
            risk_scores = []
            for i, (cluster, point) in enumerate(zip(clusters, X_weighted)):
                # Get base risk from cluster
                base_risk = self.cluster_risks[cluster]
                
                # Calculate point-specific adjustments
                feature_deviations = np.abs(point - self.kmeans.cluster_centers_[cluster])
                weighted_deviations = np.sum(feature_deviations)
                
                # Calculate dynamic adjustment factor
                adjustment = self._sigmoid(weighted_deviations - 0.5) * 0.2  # Max 20% adjustment
                
                # Apply temporal smoothing if we have previous scores
                risk = base_risk + adjustment
                if i > 0:
                    risk = 0.7 * risk + 0.3 * risk_scores[-1]  # Exponential smoothing
                
                risk_scores.append(min(max(risk, 0), 1))  # Ensure 0-1 bounds
            
            return np.array(risk_scores)
            
        except Exception as e:
            print(f"Error during risk prediction: {str(e)}")
            return self._fallback_risk_calculation(features_df)
    
    def _fallback_risk_calculation(self, features_df):
        """More nuanced fallback risk calculation"""
        risks = []
        for _, row in features_df.iterrows():
            # Combine multiple features for risk assessment
            peak_risk = min(row['peak_amplitude'] / 10, 1)
            width_risk = min(row['pattern_width'] / 20, 1)
            slope_risk = min(row['max_slope'] / 5, 1)
            
            # Weighted combination
            combined_risk = (0.5 * peak_risk + 
                           0.3 * width_risk + 
                           0.2 * slope_risk)
            risks.append(combined_risk)
        
        return np.array(risks)

    def _sigmoid(self, x):
        """Sigmoid function for smooth risk adjustments"""
        return 1 / (1 + np.exp(-x))

In [0]:
def process_file(input_file, output_dir, channel, date_str, analyzer=None, identifier=None):
    """
    Process a single file and save results
    
    Parameters:
    -----------
    input_file : str
        Path to input CSV file
    output_dir : str
        Path to output directory
    channel : str
        Channel to process
    date_str : str
        Date string for the file
    analyzer : CTAnalyzer, optional
        Analyzer for risk scoring
    identifier : IslandIdentifier, optional
        Island identifier for consistent IDs across dates
    """
    try:
        # Read and process data
        data = pd.read_csv(input_file)
        channel_data = data[data['Channel'] == channel].copy()
        
        if channel_data.empty:
            print(f"No data found for channel {channel} in file {input_file}")
            return None
        
        # Extract features
        extractor = CTFeatureExtractor(channel_data)
        features = extractor.extract_all_features()
        
        # Create summary DataFrame with location bounds
        summary_data = []
        for island_id, feature_dict in features.items():
            # Extract non-array features including start and end locations
            feature_summary = {k: v for k, v in feature_dict.items() 
                             if not isinstance(v, np.ndarray) and 
                             k in ['island_id', 'start_location', 'end_location',
                                  'peak_amplitude', 'average_y', 'pattern_width',
                                  'center_location', 'mean_slope', 'max_slope',
                                  'area_under_curve', 'distance_difference']}
            summary_data.append(feature_summary)
        
        summary_df = pd.DataFrame(summary_data)
        
        # Add date column
        summary_df['date'] = date_str
        
        # Apply island identification if provided
        if identifier is not None:
            # Register islands and get consistent global IDs
            summary_df = identifier.register_islands(summary_df, date_str)
            
            # Update features with global IDs
            for idx, row in summary_df.iterrows():
                old_id = row['island_id']
                features[old_id]['global_id'] = row['global_id']
                features[old_id]['match_confidence'] = row['match_confidence']
        
        # Calculate risk scores if analyzer is provided
        if analyzer is not None:
            summary_df['risk_score'] = analyzer.predict_risk(summary_df)
            
            # Update features with risk scores
            for idx, row in summary_df.iterrows():
                features[row['island_id']]['risk_score'] = row['risk_score']
        
        # Save summary data
        output_file = os.path.join(output_dir, 'summaries', f'{date_str}_summary.csv')
        summary_df.to_csv(output_file, index=False)
        
        # Generate and save plot
        if hasattr(extractor, 'plot_all_islands'):
            fig = extractor.plot_all_islands(date_str)
            plot_file = os.path.join(output_dir, 'plots', f'{date_str}_analysis.html')
            fig.write_html(plot_file)
        
        return summary_df
        
    except Exception as e:
        print(f"Error processing file {input_file}: {str(e)}")
        return None

In [0]:
class IslandSequenceDataset(Dataset):
    """Dataset for handling multiple island sequences"""
    def __init__(self, data, sequence_length=5):
        self.sequence_length = sequence_length
        self.scalers = {}
        self.sequences = []
        self.targets = []
        
        # Group data by global_id
        grouped = data.groupby('global_id')
        
        # Features to use for prediction
        self.features = [
            'peak_amplitude', 'pattern_width', 'area_under_curve',
            'mean_slope', 'max_slope', 'risk_score'
        ]
        
        # Process each island's data
        for island_id, island_data in grouped:
            if len(island_data) >= sequence_length + 1:  # Need at least sequence_length + 1 points
                # Sort by date
                island_data = island_data.sort_values('date')
                
                # Scale features for this island
                scaler = MinMaxScaler()
                scaled_data = scaler.fit_transform(island_data[self.features])
                self.scalers[island_id] = scaler
                
                # Create sequences
                for i in range(len(scaled_data) - sequence_length):
                    self.sequences.append(scaled_data[i:i+sequence_length])
                    self.targets.append(scaled_data[i+sequence_length, -1])  # risk_score is last feature
        
        self.sequences = torch.FloatTensor(np.array(self.sequences))
        self.targets = torch.FloatTensor(np.array(self.targets))

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        return self.sequences[idx], self.targets[idx]

class LSTMPredictor(nn.Module):
    """LSTM model for risk trajectory prediction"""
    def __init__(self, input_size, hidden_size=64, num_layers=2, dropout=0.2):
        super(LSTMPredictor, self).__init__()
        
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0
        )
        
        self.fc = nn.Sequential(
            nn.Linear(hidden_size, 32),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(32, 1)
        )

    def forward(self, x):
        # x shape: (batch_size, sequence_length, input_size)
        lstm_out, _ = self.lstm(x)
        
        # Use only the last output for prediction
        last_output = lstm_out[:, -1, :]
        
        # Predict risk score
        prediction = self.fc(last_output)
        return prediction

class RiskTrajectoryPredictor:
    """Manager class for risk trajectory prediction"""
    def __init__(self, sequence_length=5, hidden_size=64, num_layers=2, learning_rate=0.001):
        self.sequence_length = sequence_length
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.dataset = None
        self.model = None
        self.learning_rate = learning_rate
        self.hidden_size = hidden_size
        self.num_layers = num_layers

    def prepare_data(self, combined_summary_df):
        """Prepare dataset from combined summary DataFrame"""
        self.dataset = IslandSequenceDataset(
            combined_summary_df, 
            sequence_length=self.sequence_length
        )
        
        # Initialize model
        input_size = len(self.dataset.features)
        self.model = LSTMPredictor(
            input_size=input_size,
            hidden_size=self.hidden_size,
            num_layers=self.num_layers
        ).to(self.device)

    def train(self, num_epochs=50, batch_size=32):
        """Train the LSTM model"""
        if self.dataset is None or self.model is None:
            raise ValueError("Call prepare_data first")

        dataloader = DataLoader(
            self.dataset,
            batch_size=batch_size,
            shuffle=True
        )
        
        criterion = nn.MSELoss()
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
        
        self.model.train()
        for epoch in range(num_epochs):
            total_loss = 0
            for sequences, targets in dataloader:
                sequences = sequences.to(self.device)
                targets = targets.to(self.device)
                
                optimizer.zero_grad()
                outputs = self.model(sequences)
                loss = criterion(outputs.squeeze(), targets)
                
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
            
            if (epoch + 1) % 10 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(dataloader):.4f}')

    def predict_trajectory(self, island_data, future_steps=90):
        """
        Predict future risk trajectory for an island
        
        Parameters:
        -----------
        island_data : pd.DataFrame
            Historical data for the island
        future_steps : int, optional (default=90)
            Number of days to predict into the future
            
        Returns:
        --------
        np.ndarray
            Array of predicted risk scores
        """
        try:
            self.model.eval()
            
            # Sort by date and get recent sequence
            island_data = island_data.sort_values('date')
            recent_data = island_data.iloc[-self.sequence_length:]
            
            # Ensure numerical values
            for col in self.dataset.features:
                if col in recent_data.columns:
                    recent_data[col] = pd.to_numeric(recent_data[col], errors='coerce')
            
            # Scale the data
            scaler = MinMaxScaler()
            feature_data = recent_data[self.dataset.features].astype(float)
            scaled_data = scaler.fit_transform(feature_data)
            
            # Convert to tensor
            sequence = torch.FloatTensor(scaled_data).unsqueeze(0).to(self.device)
            
            predictions = []
            with torch.no_grad():
                current_sequence = sequence
                
                for _ in range(future_steps):
                    # Predict next risk score
                    output = self.model(current_sequence)
                    pred_risk = float(output.item())  # Convert to float explicitly
                    predictions.append(pred_risk)
                    
                    # Update sequence for next prediction
                    new_sequence = current_sequence.clone()
                    new_sequence = new_sequence[:, 1:, :]  # Remove oldest timestep
                    
                    # Create new feature vector using last known values and predicted risk
                    new_features = new_sequence[:, -1, :].clone()
                    new_features[:, -1] = pred_risk  # Update risk score
                    
                    # Add new timestep
                    new_sequence = torch.cat([
                        new_sequence,
                        new_features.unsqueeze(1)
                    ], dim=1)
                    
                    current_sequence = new_sequence
            
            # Convert predictions to numpy array
            predictions = np.array(predictions, dtype=np.float32)
            
            # Inverse transform predictions
            predictions_2d = predictions.reshape(-1, 1)
            padding = np.zeros((len(predictions), len(self.dataset.features)-1))
            padded_predictions = np.hstack([padding, predictions_2d])
            inverse_transformed = scaler.inverse_transform(padded_predictions)[:, -1]
            
            # Ensure predictions are within valid range [0, 1] and are float type
            return np.clip(inverse_transformed, 0, 1).astype(np.float32)
            
        except Exception as e:
            print(f"Error in prediction: {str(e)}")
            return np.array([], dtype=np.float32)

In [0]:
class TemporalEvolutionAnalyzer:

    def __init__(self):
        self.features_to_track = [
            'peak_amplitude',
            'pattern_width',
            'area_under_curve',
            'mean_slope',
            'max_slope',
            'risk_score'
        ]
        self.risk_predictor = None

    # New LSTM addition
    def initialize_risk_predictor(self, combined_summary_df):
        """Initialize and train the LSTM risk predictor"""
        self.risk_predictor = RiskTrajectoryPredictor()
        self.risk_predictor.prepare_data(combined_summary_df)
        self.risk_predictor.train()

    def analyze_island_evolution(self, island_data):
        """Analyze temporal evolution of a single island"""
        evolution_metrics = {}
        island_data = island_data.sort_values('date')
        
        # Calculate days since first observation
        island_data['days_since_start'] = (
            pd.to_datetime(island_data['date']) - 
            pd.to_datetime(island_data['date']).min()
        ).dt.total_seconds() / (24 * 3600)
        
        evolution_metrics = {}
        
        for feature in self.features_to_track:
            values = island_data[feature].values
            days = island_data['days_since_start'].values
            
            # Fit exponential growth model
            try:
                # log(y) = log(a) + bx
                # y = ae^(bx)
                log_values = np.log(values)
                slope, intercept, r_value, _, _ = stats.linregress(days, log_values)
                
                evolution_metrics[f'{feature}_growth_rate'] = slope
                evolution_metrics[f'{feature}_initial_value'] = np.exp(intercept)
                evolution_metrics[f'{feature}_r_squared'] = r_value**2
                
                # Project time to critical (assuming critical is 2x current max)
                if slope > 0:  # Only if growing
                    critical_value = 2 * values.max()
                    time_to_critical = (np.log(critical_value) - intercept) / slope
                    evolution_metrics[f'{feature}_days_to_critical'] = time_to_critical
                else:
                    evolution_metrics[f'{feature}_days_to_critical'] = np.inf
                
            except:
                evolution_metrics[f'{feature}_growth_rate'] = 0
                evolution_metrics[f'{feature}_initial_value'] = values[0]
                evolution_metrics[f'{feature}_r_squared'] = 0
                evolution_metrics[f'{feature}_days_to_critical'] = np.inf

                    # Add LSTM predictions if predictor is initialized
        if self.risk_predictor is not None and len(island_data) >= self.risk_predictor.sequence_length:
            future_predictions = self.risk_predictor.predict_trajectory(island_data)
            evolution_metrics['predicted_risk_trajectory'] = future_predictions.tolist()
            evolution_metrics['predicted_max_risk'] = max(future_predictions)
            evolution_metrics['predicted_risk_trend'] = (
                future_predictions[-1] - future_predictions[0]
            ) / len(future_predictions)
        
        return evolution_metrics
    
    

    def plot_spectral_heatmap(self, data, feature, output_dir):
        """Create a heatmap showing feature evolution with Island IDs"""
        # Prepare data
        pivot_data = data.pivot(
            index='date',
            columns='global_id',
            values=feature
        )
        
        # Get location information for each island
        island_info = data.groupby('global_id').agg({
            'start_location': 'mean',
            'end_location': 'mean'
        }).sort_values('start_location')
        
        # Sort islands by location
        pivot_data = pivot_data.reindex(columns=island_info.index)
        
        # Create x-axis labels with Island IDs and locations
        x_labels = [
            f"IS:{id}<br>{start:.0f}-{end:.0f}m" 
            for id, (start, end) in zip(
                island_info.index, 
                island_info[['start_location', 'end_location']].values
            )
        ]
        
        # Create heatmap
        fig = go.Figure(data=go.Heatmap(
            z=pivot_data.values,
            x=x_labels,
            y=pivot_data.index,
            colorscale='Viridis',
            colorbar=dict(title=feature)
        ))
        
        # Update layout for visual adjustments
        fig.update_layout(
            title=f'Temporal Evolution of {feature} Across Track',
            xaxis_title='Island ID and Location',
            yaxis_title='Date',
            height=800,
            xaxis=dict(
                tickangle=90,
                tickmode='array',
                ticktext=x_labels,
                tickvals=list(range(len(x_labels))),
                showgrid=False  # Disable x-axis gridlines
            ),
            yaxis=dict(
                showgrid=False  # Disable y-axis gridlines
            ),
            paper_bgcolor='black',  # Set the overall background color
            plot_bgcolor='black',  # Set the plot's background color
            font=dict(color='white')  # Adjust font color for visibility on black
        )
        
        fig.write_html(os.path.join(output_dir, f'spectral_evolution_{feature}.html'))

    def plot_island_evolution(self, island_data, global_id, output_dir):
        """Create evolution plot with comprehensive trend analysis and predictions"""
        try:
            island_data = island_data.copy()  # Make a copy to prevent modifications
            island_data = island_data.sort_values('date')
            dates = pd.to_datetime(island_data['date'])
            
            # Create subplots for each feature
            fig = make_subplots(
                rows=len(self.features_to_track), 
                cols=1,
                subplot_titles=[f'{feature.replace("_", " ").title()} Evolution' 
                            for feature in self.features_to_track],
                vertical_spacing=0.08
            )
            
            start_loc = float(island_data['start_location'].mean())  # Explicitly convert to float
            end_loc = float(island_data['end_location'].mean())  # Explicitly convert to float
            
            for i, feature in enumerate(self.features_to_track, 1):
                try:
                    # Convert values to float explicitly
                    values = island_data[feature].astype(float).values
                    
                    # Plot actual values
                    fig.add_trace(
                        go.Scatter(
                            x=dates,
                            y=values,
                            mode='markers+lines',
                            name=f'Actual {feature.replace("_", " ").title()}',
                            marker=dict(size=8),
                            line=dict(width=2)
                        ),
                        row=i, col=1
                    )
                    
                    # Linear regression for trend analysis
                    days = (dates - dates.min()).dt.total_seconds() / (24 * 3600)
                    days = days.astype(float)
                    
                    slope, intercept, r_value, p_value, std_err = stats.linregress(days, values)
                    r_squared = r_value ** 2
                    
                    # Generate trend line
                    future_days = np.linspace(0, days.max() * 2, 100)
                    trend_dates = dates.min() + pd.Timedelta(days=float(future_days.max()))
                    trend_values = slope * future_days + intercept
                    
                    # Add LSTM predictions specifically for risk_score
                    if (feature == 'risk_score' and 
                        self.risk_predictor is not None and 
                        len(island_data) >= self.risk_predictor.sequence_length):
                        try:
                            predictions = self.risk_predictor.predict_trajectory(island_data)
                            if len(predictions) > 0:
                                # Generate future dates for predictions
                                last_date = dates.max()
                                future_dates = pd.date_range(
                                    start=last_date + pd.Timedelta(days=1),
                                    periods=len(predictions),
                                    freq='D'
                                )
                                
                                # Add prediction line
                                fig.add_trace(
                                    go.Scatter(
                                        x=future_dates,
                                        y=predictions,
                                        mode='lines',
                                        name='LSTM Predictions',
                                        line=dict(
                                            dash='dash',
                                            color='red',
                                            width=2
                                        )
                                    ),
                                    row=i, col=1
                                )
                                
                                # Add uncertainty ribbon if available
                                if hasattr(self.risk_predictor, 'get_prediction_uncertainty'):
                                    lower_bound, upper_bound = self.risk_predictor.get_prediction_uncertainty(predictions)
                                    fig.add_trace(
                                        go.Scatter(
                                            x=future_dates,
                                            y=upper_bound,
                                            mode='lines',
                                            line=dict(width=0),
                                            showlegend=False
                                        ),
                                        row=i, col=1
                                    )
                                    fig.add_trace(
                                        go.Scatter(
                                            x=future_dates,
                                            y=lower_bound,
                                            mode='lines',
                                            line=dict(width=0),
                                            fillcolor='rgba(255, 0, 0, 0.2)',
                                            fill='tonexty',
                                            name='95% Confidence'
                                        ),
                                        row=i, col=1
                                    )
                        except Exception as e:
                            print(f"Error adding predictions for island {global_id}: {str(e)}")
                    
                    # Update axes labels
                    fig.update_xaxes(
                        title_text="Date",
                        tickformat="%Y-%m-%d",
                        tickangle=45,
                        row=i, col=1
                    )
                    fig.update_yaxes(
                        title_text=feature.replace("_", " ").title(),
                        row=i, col=1
                    )
                    
                except Exception as e:
                    print(f"Error processing {feature} for island {global_id}: {str(e)}")
                    continue
            
            # Update layout
            fig.update_layout(
                height=250 * len(self.features_to_track),
                title=f'Evolution of Island {global_id}<br>(Location: {start_loc:.1f}m - {end_loc:.1f}m)',
                showlegend=True,
                legend=dict(
                    yanchor="top",
                    y=0.99,
                    xanchor="left",
                    x=1.05
                ),
                margin=dict(r=250, b=50)
            )
            
            # Save plot
            try:
                fig.write_html(os.path.join(output_dir, f'evolution_island_{global_id}.html'))
            except Exception as e:
                print(f"Error saving plot for island {global_id}: {str(e)}")
                
        except Exception as e:
            print(f"Error in plot_island_evolution for island {global_id}: {str(e)}")

In [0]:
def analyze_temporal_evolution(combined_summary_file, output_dir):
    """Perform temporal evolution analysis"""
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Read data
    print("Reading data...")
    data = pd.read_csv(combined_summary_file)
    
    # Initialize analyzer
    analyzer = TemporalEvolutionAnalyzer()
    
    # Analyze each island
    print("Analyzing island evolution patterns...")
    evolution_results = []
    
    for global_id in data['global_id'].unique():
        island_data = data[data['global_id'] == global_id]
        if len(island_data) >= 3:  # Minimum points for trend analysis
            metrics = analyzer.analyze_island_evolution(island_data)
            metrics['global_id'] = global_id
            
            # Add location information
            metrics['start_location'] = island_data['start_location'].mean()
            metrics['end_location'] = island_data['end_location'].mean()
            evolution_results.append(metrics)
            
            # Create individual evolution plot
            analyzer.plot_island_evolution(island_data, global_id, output_dir)
    
    evolution_df = pd.DataFrame(evolution_results)
    
    # Create spectral evolution plots
    print("Creating spectral evolution plots...")
    for feature in analyzer.features_to_track:
        analyzer.plot_spectral_heatmap(data, feature, output_dir)
    
    # Save evolution metrics
    evolution_df.to_csv(os.path.join(output_dir, 'evolution_metrics.csv'), index=False)
    
    # Create summary visualizations
    create_summary_visualizations(evolution_df, output_dir)
    
    return evolution_df


def create_summary_visualizations(evolution_df, output_dir):
    """Create summary visualizations with fixed colorbar"""
    features = [
        'peak_amplitude', 'pattern_width', 'area_under_curve',
        'mean_slope', 'max_slope', 'risk_score'
    ]
    
    # Create one main colorbar for R² values
    colorbar_settings = dict(
        title=dict(
            text='R²',
            side='right'
        ),
        tickformat='.2f',
        len=0.5,  # Length of colorbar
        yanchor='middle',
        y=0.5,    # Center vertically
        xanchor='left',
        x=1.05    # Position to the right of plots
    )
    
    fig = make_subplots(
        rows=2, cols=3,
        subplot_titles=[
            f"{feature.replace('_', ' ').title()} Growth Rate vs Initial Value"
            for feature in features
        ],
        vertical_spacing=0.2,
        horizontal_spacing=0.15
    )
    
    row = 1
    col = 1
    for feature in features:
        # Create scatter plot
        scatter = go.Scatter(
            x=evolution_df[f'{feature}_initial_value'],
            y=evolution_df[f'{feature}_growth_rate'],
            mode='markers',
            marker=dict(
                size=8,
                color=evolution_df[f'{feature}_r_squared'],
                colorscale='Viridis',
                showscale=(col == 3 and row == 1),  # Show colorbar only once
                colorbar=colorbar_settings if (col == 3 and row == 1) else None
            ),
            text=[
                f"Location: {row['start_location']:.1f}m - {row['end_location']:.1f}m<br>" +
                f"Global ID: {row['global_id']}<br>" +
                f"R²: {row[f'{feature}_r_squared']:.3f}"
                for _, row in evolution_df.iterrows()
            ],
            hovertemplate="Initial Value: %{x:.2f}<br>" +
                         "Growth Rate: %{y:.2e}/day<br>" +
                         "%{text}<br>" +
                         "<extra></extra>"
        )
        
        fig.add_trace(scatter, row=row, col=col)
        
        # Update axes labels
        fig.update_xaxes(
            title_text=f"Initial {feature.replace('_', ' ').title()}",
            row=row, col=col
        )
        fig.update_yaxes(
            title_text="Growth Rate (per day)",
            row=row, col=col
        )
        
        col += 1
        if col > 3:
            col = 1
            row += 1
    
    fig.update_layout(
        height=900,
        width=1500,
        title='Growth Rate Analysis of Track Features',
        showlegend=False,
        template='plotly_white',
        margin=dict(r=150)  # Make room for colorbar
    )
    
    fig.write_html(os.path.join(output_dir, 'growth_rate_analysis.html'))

In [0]:
def predict_threshold_crossing(risk_scores, dates, threshold, max_forecast_days=365):
    """
    Predict when risk score will cross a threshold based on historical trend.
    
    Parameters:
        risk_scores: array of historical risk scores
        dates: array of corresponding dates
        threshold: risk score threshold to predict crossing
        max_forecast_days: maximum number of days to forecast into future
    
    Returns:
        predicted_date: datetime or None if threshold won't be crossed within max_forecast_days
    """
    if len(risk_scores) < 2:
        return None
        
    try:
        # Convert dates to numerical days since first observation
        dates_numeric = np.array([(pd.to_datetime(d) - pd.to_datetime(dates[0])).days 
                                for d in dates])
        
        # Fit polynomial regression (degree 2 for better trend capture)
        coeffs = np.polyfit(dates_numeric, risk_scores, deg=2)
        poly = np.poly1d(coeffs)
        
        # Find root of polynomial - threshold equation
        # We're looking for future dates only
        future_days = np.arange(dates_numeric[-1] + 1, 
                              dates_numeric[-1] + max_forecast_days)
        
        future_risks = poly(future_days)
        
        # Find first crossing of threshold
        crossing_idx = np.where(future_risks >= threshold)[0]
        
        if len(crossing_idx) > 0:
            days_until_threshold = future_days[crossing_idx[0]] - dates_numeric[-1]
            predicted_date = pd.to_datetime(dates[-1]) + pd.Timedelta(days=days_until_threshold)
            return predicted_date
        
        return None
        
    except Exception as e:
        print(f"Error in threshold prediction: {str(e)}")
        return None

In [0]:
def analyze_threshold_crossings(combined_summary_df, evolution_df, output_dir, thresholds=[0.9, 1.0], prediction_days=180):
    """
    Analyze when each island is predicted to cross specific risk thresholds
    """
    # Initialize LSTM predictor if not already done
    risk_predictor = RiskTrajectoryPredictor()
    risk_predictor.prepare_data(combined_summary_df)
    risk_predictor.train()
    
    threshold_results = []
    
    for _, row in evolution_df.iterrows():
        global_id = row['global_id']
        
        # Get historical data for this island
        island_data = combined_summary_df[
            combined_summary_df['global_id'] == global_id
        ].sort_values('date')
        
        if len(island_data) >= risk_predictor.sequence_length:
            # Get the last known date and risk score
            last_date = pd.to_datetime(island_data['date'].max())
            last_risk = island_data['risk_score'].iloc[-1]
            
            # Get start and end locations from the island_data
            start_location = island_data['start_location'].mean()
            end_location = island_data['end_location'].mean()
            
            # Get predictions
            predictions = risk_predictor.predict_trajectory(island_data, future_steps=prediction_days)
            
            # Create future dates
            future_dates = pd.date_range(
                start=last_date,
                periods=len(predictions) + 1,
                freq='D'
            )[1:]
            
            # Initialize result dictionary
            result = {
                'global_id': global_id,
                'start_location': start_location,
                'end_location': end_location,
                'last_measurement_date': last_date.strftime('%Y-%m-%d'),
                'current_risk_score': last_risk
            }
            
            # Find crossing dates for each threshold
            for threshold in thresholds:
                # Check if already crossed
                if last_risk >= threshold:
                    result[f'threshold_{threshold}_crossed'] = True
                    result[f'threshold_{threshold}_crossing_date'] = 'Already Crossed'
                    result[f'days_until_{threshold}'] = 0
                else:
                    # Find first crossing of threshold
                    crossing_indices = np.where(predictions >= threshold)[0]
                    
                    if len(crossing_indices) > 0:
                        crossing_idx = crossing_indices[0]
                        crossing_date = future_dates[crossing_idx]
                        
                        result[f'threshold_{threshold}_crossed'] = True
                        result[f'threshold_{threshold}_crossing_date'] = crossing_date.strftime('%Y-%m-%d')
                        result[f'days_until_{threshold}'] = (crossing_date - last_date).days
                    else:
                        result[f'threshold_{threshold}_crossed'] = False
                        result[f'threshold_{threshold}_crossing_date'] = 'Not Within Prediction Window'
                        result[f'days_until_{threshold}'] = None
            
            # Calculate velocity and acceleration of risk increase
            if len(predictions) > 1:
                risk_velocity = (predictions[1] - predictions[0])  # risk/day
                result['risk_velocity'] = risk_velocity
                
                if len(predictions) > 2:
                    risk_acceleration = (predictions[2] - 2*predictions[1] + predictions[0])  # risk/day²
                    result['risk_acceleration'] = risk_acceleration
                else:
                    result['risk_acceleration'] = 0
            else:
                result['risk_velocity'] = 0
                result['risk_acceleration'] = 0
            
            threshold_results.append(result)
    
    # Create DataFrame and sort by urgency
    results_df = pd.DataFrame(threshold_results)
    
    # Sort by days until first threshold crossing (excluding already crossed)
    for threshold in thresholds:
        mask = results_df[f'days_until_{threshold}'].notna()
        results_df.loc[mask, 'min_days_to_threshold'] = results_df.loc[mask, f'days_until_{threshold}']
    
    results_df['min_days_to_threshold'] = results_df['min_days_to_threshold'].fillna(float('inf'))
    results_df = results_df.sort_values('min_days_to_threshold')
    
    # Save results
    output_file = os.path.join(output_dir, 'threshold_crossing_predictions.csv')
    results_df.to_csv(output_file, index=False)
    
    # Create summary statistics
    summary_stats = {
        'total_islands': len(results_df),
        'islands_crossing_thresholds': {
            str(threshold): {
                'already_crossed': len(results_df[results_df[f'threshold_{threshold}_crossing_date'] == 'Already Crossed']),
                'will_cross': len(results_df[results_df[f'threshold_{threshold}_crossed'] & 
                                          (results_df[f'threshold_{threshold}_crossing_date'] != 'Already Crossed')]),
                'wont_cross': len(results_df[~results_df[f'threshold_{threshold}_crossed']])
            }
            for threshold in thresholds
        },
        'average_days_to_threshold': {
            str(threshold): results_df[results_df[f'days_until_{threshold}'].notna()][f'days_until_{threshold}'].mean()
            for threshold in thresholds
        }
    }
    
    # Save summary statistics
    with open(os.path.join(output_dir, 'threshold_crossing_summary.json'), 'w') as f:
        json.dump(summary_stats, f, indent=4)
    
    return results_df, summary_stats

In [0]:
if __name__ == "__main__":
    # Configuration
    CONFIG = {
        'CHANNEL': '18_top_right',
        'INPUT_DIR': '/Workspace/Users/amahmud1@networkrail.co.uk/CT/CT_processed_data/',
        'OUTPUT_BASE_DIR': '/Workspace/Users/amahmud1@networkrail.co.uk/CT',
        'PHASE1_DIR': '/Workspace/Users/amahmud1@networkrail.co.uk/CT/CT_phase1',
        'PHASE3_DIR': '/Workspace/Users/amahmud1@networkrail.co.uk/CT/CT_phase3',
        'ISLAND_DISTANCE_THRESHOLD': 5.0,  # meters
        'ISLAND_OVERLAP_THRESHOLD': 0.3,   # 30% overlap required
        'NUM_CLUSTERS': 5,                 # for risk analysis
        'INTERPOLATION_POINTS': 100,       # for spline fitting
        'FEATURE_WEIGHTS': {
            'peak_amplitude': 0.3,
            'average_y': 0.1,
            'distance_difference': 0.1,
            'pattern_width': 0.15,
            'area_under_curve': 0.15,
            'mean_slope': 0.1,
            'max_slope': 0.1
        },
        'RISK_SMOOTHING_FACTOR': 0.7,      # Temporal smoothing factor (0-1)
        'RISK_ADJUSTMENT_MAX': 0.2         # Maximum risk adjustment (20%)
    }
    
    try:
        print("\n=== Starting C-PRICS Analysis Pipeline ===")
        
        # Create all necessary directories
        for dir_path in [
            os.path.join(CONFIG['PHASE1_DIR'], 'summaries'),
            os.path.join(CONFIG['PHASE1_DIR'], 'plots'),
            CONFIG['PHASE3_DIR']
        ]:
            os.makedirs(dir_path, exist_ok=True)
        
        # Phase 1: Initial Processing and Risk Analysis
        print("\n=== Phase 1: Signal Processing and Risk Analysis ===")
        print(f"Channel: {CONFIG['CHANNEL']}")
        print(f"Input Directory: {CONFIG['INPUT_DIR']}")
        
        # Get sorted files
        sorted_files = get_sorted_files(CONFIG['INPUT_DIR'])
        print(f"\nFound {len(sorted_files)} files to process")
        
        # Initialize identifier
        identifier = IslandIdentifier(
            distance_threshold=CONFIG['ISLAND_DISTANCE_THRESHOLD'],
            overlap_threshold=CONFIG['ISLAND_OVERLAP_THRESHOLD']
        )
        
        # First pass: collect all data for ML training and assign global IDs
        print("\nFirst pass: Processing files for ML training and initial island identification...")
        all_summaries = []
        for date, filename in sorted_files:
            print(f"Processing {filename}...")
            input_file = os.path.join(CONFIG['INPUT_DIR'], filename)
            date_str = date.strftime('%Y-%m-%d')
            
            summary_df = process_file(
                input_file=input_file,
                output_dir=CONFIG['PHASE1_DIR'],
                channel=CONFIG['CHANNEL'],
                date_str=date_str,
                identifier=identifier
            )
            
            if summary_df is not None:
                all_summaries.append(summary_df)
        
        # Create and train risk analyzer
        print("\nTraining risk analyzer...")
        combined_summary = pd.concat(all_summaries, ignore_index=True)
        analyzer = CTAnalyzer(n_clusters=CONFIG['NUM_CLUSTERS'])
        analyzer.feature_weights = CONFIG['FEATURE_WEIGHTS']
        analyzer.fit(combined_summary)
        
        # Second pass: process files with risk scoring
        print("\nSecond pass: Processing files with risk scoring...")
        all_summaries = []
        sorted_files.sort(key=lambda x: x[0])
        
        for date, filename in sorted_files:
            print(f"Processing {filename}...")
            input_file = os.path.join(CONFIG['INPUT_DIR'], filename)
            date_str = date.strftime('%Y-%m-%d')
            
            summary_df = process_file(
                input_file=input_file,
                output_dir=CONFIG['PHASE1_DIR'],
                channel=CONFIG['CHANNEL'],
                date_str=date_str,
                analyzer=analyzer,
                identifier=identifier
            )
            
            if summary_df is not None:
                all_summaries.append(summary_df)
        
        # Save final combined summary
        final_summary = pd.concat(all_summaries, ignore_index=True)
        final_summary.sort_values(['date', 'global_id'], inplace=True)
        combined_summary_path = os.path.join(CONFIG['PHASE1_DIR'], f'combined_summary_{CONFIG["CHANNEL"]}.csv')
        final_summary.to_csv(combined_summary_path, index=False)
        
        # Save evolution history
        history_df = identifier.get_history_summary()
        history_df.to_csv(os.path.join(CONFIG['PHASE1_DIR'], 'island_history.csv'), index=False)
        
        # Create initial evolution summary
        print("\nGenerating evolution summary...")
        evolution_summary = []
        for global_id in final_summary['global_id'].unique():
            island_data = final_summary[final_summary['global_id'] == global_id].sort_values('date')
            risk_volatility = island_data['risk_score'].std() if len(island_data) > 1 else 0
            
            summary = {
                'global_id': global_id,
                'first_appearance': island_data['date'].min(),
                'last_appearance': island_data['date'].max(),
                'num_observations': len(island_data),
                'avg_start_location': island_data['start_location'].mean(),
                'avg_end_location': island_data['end_location'].mean(),
                'location_drift': max(
                    island_data['start_location'].max() - island_data['start_location'].min(),
                    island_data['end_location'].max() - island_data['end_location'].min()
                ),
                'avg_risk_score': island_data['risk_score'].mean(),
                'max_risk_score': island_data['risk_score'].max(),
                'risk_score_trend': np.polyfit(
                    range(len(island_data)), 
                    island_data['risk_score'], 
                    1
                )[0] if len(island_data) > 1 else 0,
                'risk_volatility': risk_volatility,
                'risk_acceleration': np.diff(island_data['risk_score']).mean() if len(island_data) > 2 else 0
            }
            evolution_summary.append(summary)
        
        evolution_df = pd.DataFrame(evolution_summary)
        evolution_df.to_csv(os.path.join(CONFIG['PHASE1_DIR'], 'island_evolution_summary.csv'), index=False)
        
        # Phase 3: Temporal Evolution Analysis
        print("\n=== Phase 3: Temporal Evolution Analysis ===")
        
        # Initialize and train LSTM predictor
        print("\nInitializing LSTM predictor...")
        temporal_analyzer = TemporalEvolutionAnalyzer()
        temporal_analyzer.initialize_risk_predictor(final_summary)
        
        # Perform temporal evolution analysis
        print("\nPerforming temporal evolution analysis...")
        temporal_evolution_df = analyze_temporal_evolution(combined_summary_path, CONFIG['PHASE3_DIR'])
        
        # Perform threshold crossing analysis
        print("\nAnalyzing threshold crossings...")
        threshold_results, threshold_summary = analyze_threshold_crossings(
            combined_summary_df=final_summary,
            evolution_df=temporal_evolution_df,
            output_dir=CONFIG['PHASE3_DIR']
        )
        
        # Save final metadata
        metadata = {
            'channel': CONFIG['CHANNEL'],
            'files_processed': len(sorted_files),
            'total_islands': len(final_summary),
            'unique_islands': len(evolution_df),
            'number_of_clusters': CONFIG['NUM_CLUSTERS'],
            'cluster_risks': analyzer.cluster_risks.tolist(),
            'feature_weights': CONFIG['FEATURE_WEIGHTS'],
            'risk_smoothing_factor': CONFIG['RISK_SMOOTHING_FACTOR'],
            'risk_adjustment_max': CONFIG['RISK_ADJUSTMENT_MAX'],
            'date_range': [str(final_summary['date'].min()), str(final_summary['date'].max())],
            'processing_date': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
            'configuration': CONFIG,
            'risk_score_statistics': {
                'mean': float(final_summary['risk_score'].mean()),
                'std': float(final_summary['risk_score'].std()),
                'min': float(final_summary['risk_score'].min()),
                'max': float(final_summary['risk_score'].max()),
                'volatility': float(final_summary.groupby('global_id')['risk_score'].std().mean())
            },
            'threshold_analysis': threshold_summary,
            'temporal_analysis_completed': True
        }
        
        with open(os.path.join(CONFIG['OUTPUT_BASE_DIR'], 'full_analysis_metadata.json'), 'w') as f:
            json.dump(metadata, f, indent=4)
        
        print("\n=== Analysis Pipeline Complete ===")
        print(f"Processed {len(sorted_files)} files")
        print(f"Found {len(evolution_df)} unique islands")
        print(f"Results saved to {CONFIG['OUTPUT_BASE_DIR']}")
        print(f"Phase 1 results: {CONFIG['PHASE1_DIR']}")
        print(f"Phase 3 results: {CONFIG['PHASE3_DIR']}")
        
    except Exception as e:
        print(f"\nError in processing pipeline: {str(e)}")
        raise