In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/dojo-dataset/earthquake_dataset.pkl
/kaggle/input/dojo-dataset/bridge_nodes.csv
/kaggle/input/dojo-dataset/Bridge_Sensors.csv
/kaggle/input/dojo-dataset/complete_simulation_dataset/damage_case_3_bearing_severe_widespread/2019-05-01.csv
/kaggle/input/dojo-dataset/complete_simulation_dataset/damage_case_3_bearing_severe_widespread/2018-07-17_1.csv
/kaggle/input/dojo-dataset/complete_simulation_dataset/damage_case_3_bearing_severe_widespread/2020-01-21.csv
/kaggle/input/dojo-dataset/complete_simulation_dataset/damage_case_3_bearing_severe_widespread/2018-09-05.csv
/kaggle/input/dojo-dataset/complete_simulation_dataset/damage_case_3_bearing_severe_widespread/2019-12-05_3.csv
/kaggle/input/dojo-dataset/complete_simulation_dataset/damage_case_3_bearing_severe_widespread/2019-07-29.csv
/kaggle/input/dojo-dataset/complete_simulation_dataset/damage_case_3_bearing_severe_widespread/2020-02-01_1.csv
/kaggle/input/dojo-dataset/complete_simulation_dataset/damage_case_3_bearing_severe_

In [2]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.7.0-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.7.0-py3-none-any.whl (1.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m20.1 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.7.0


In [3]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.data import Data, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from scipy import signal
from scipy.fft import fft, fftfreq
import os
from glob import glob
import warnings
warnings.filterwarnings('ignore')

# ============================================================================
# FEATURE EXTRACTION FUNCTIONS
# ============================================================================

def calculate_rms_error(healthy_signal, damaged_signal):
    """Calculate RMS error between healthy and damaged signals"""
    return np.sqrt(np.mean((healthy_signal - damaged_signal) ** 2))

def extract_dominant_frequency(signal_data, sampling_rate=100):
    """Extract dominant frequency from signal using FFT"""
    n = len(signal_data)
    yf = fft(signal_data)
    xf = fftfreq(n, 1/sampling_rate)
    
    # Get positive frequencies only
    positive_freq_idx = xf > 0
    xf_pos = xf[positive_freq_idx]
    yf_pos = np.abs(yf[positive_freq_idx])
    
    # Find dominant frequency
    dominant_idx = np.argmax(yf_pos)
    return xf_pos[dominant_idx]

def calculate_frequency_shift(healthy_signal, damaged_signal, sampling_rate=100):
    """Calculate frequency shift between healthy and damaged signals"""
    healthy_freq = extract_dominant_frequency(healthy_signal, sampling_rate)
    damaged_freq = extract_dominant_frequency(damaged_signal, sampling_rate)
    return abs(damaged_freq - healthy_freq)

def calculate_displacement_magnitude(acceleration_data):
    """Calculate displacement magnitude from acceleration (double integration)"""
    # Simple cumulative integration approximation
    velocity = np.cumsum(acceleration_data)
    displacement = np.cumsum(velocity)
    return np.max(np.abs(displacement))

def extract_node_features(healthy_df, damaged_df, sensor_names):
    """Extract features for each sensor node"""
    features = []
    
    for sensor in sensor_names:
        sensor_cols = [col for col in healthy_df.columns if sensor in col]
        
        if len(sensor_cols) == 0:
            continue
            
        node_features = []
        
        for col in sensor_cols:
            healthy_signal = healthy_df[col].values
            damaged_signal = damaged_df[col].values
            
            # Feature 1: RMS Error
            rms_error = calculate_rms_error(healthy_signal, damaged_signal)
            
            # Feature 2: Frequency Shift
            freq_shift = calculate_frequency_shift(healthy_signal, damaged_signal)
            
            # Feature 3: Displacement Magnitude (from damaged signal)
            disp_magnitude = calculate_displacement_magnitude(damaged_signal)
            
            # Feature 4: Max Acceleration
            max_accel = np.max(np.abs(damaged_signal))
            
            # Feature 5: Standard Deviation
            std_dev = np.std(damaged_signal)
            
            node_features.extend([rms_error, freq_shift, disp_magnitude, max_accel, std_dev])
        
        features.append(node_features)
    
    return np.array(features)

# ============================================================================
# GRAPH CONSTRUCTION
# ============================================================================

def create_bridge_graph(num_nodes=3):
    """
    Create edge index for bridge structure
    Assuming sensors 24, 31, 32 are connected in a simple bridge topology
    """
    # Define edges (bidirectional)
    edges = [
        [0, 1], [1, 0],  # Node 24 <-> Node 31
        [1, 2], [2, 1],  # Node 31 <-> Node 32
        [0, 2], [2, 0],  # Node 24 <-> Node 32 (for structural connectivity)
    ]
    
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    return edge_index

# ============================================================================
# DATASET PREPARATION
# ============================================================================

def load_and_prepare_data(healthy_path, damage_paths, max_samples=None):
    """Load healthy and damaged datasets and prepare graph data"""
    
    # Load healthy data
    healthy_df = pd.read_csv(healthy_path)
    
    # Get sensor names
    sensor_names = ['Sensor_24', 'Sensor_31', 'Sensor_32']
    
    graph_data_list = []
    labels = []
    
    # Process healthy case (label = 0)
    edge_index = create_bridge_graph(num_nodes=3)
    features = extract_node_features(healthy_df, healthy_df, sensor_names)
    
    # Pad features to ensure consistent size
    if features.shape[1] < 15:
        features = np.pad(features, ((0, 0), (0, 15 - features.shape[1])), mode='constant')
    
    x = torch.tensor(features, dtype=torch.float)
    data = Data(x=x, edge_index=edge_index, y=torch.tensor([0], dtype=torch.float))
    graph_data_list.append(data)
    labels.append(0)
    
    # Process damage cases (label = 1)
    count = 1
    for damage_path in damage_paths:
        if max_samples and count >= max_samples:
            break
            
        try:
            damaged_df = pd.read_csv(damage_path)
            features = extract_node_features(healthy_df, damaged_df, sensor_names)
            
            # Pad features
            if features.shape[1] < 15:
                features = np.pad(features, ((0, 0), (0, 15 - features.shape[1])), mode='constant')
            
            x = torch.tensor(features, dtype=torch.float)
            data = Data(x=x, edge_index=edge_index, y=torch.tensor([1], dtype=torch.float))
            graph_data_list.append(data)
            labels.append(1)
            count += 1
        except Exception as e:
            print(f"Error processing {damage_path}: {e}")
            continue
    
    return graph_data_list, labels

# ============================================================================
# GNN MODEL
# ============================================================================

class BridgeDamageGNN(nn.Module):
    def __init__(self, num_node_features, hidden_channels=64):
        super(BridgeDamageGNN, self).__init__()
        
        # Graph Convolutional Layers
        self.conv1 = GCNConv(num_node_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels // 2)
        
        # Fully connected layers for graph-level prediction
        self.fc1 = nn.Linear(hidden_channels // 2, 32)
        self.fc2 = nn.Linear(32, 1)
        
        self.dropout = nn.Dropout(0.3)
        
    def forward(self, x, edge_index, batch=None):
        # Node-level embeddings
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        
        x = self.conv3(x, edge_index)
        x = F.relu(x)
        
        # Global pooling (graph-level representation)
        if batch is None:
            batch = torch.zeros(x.size(0), dtype=torch.long)
        x = global_mean_pool(x, batch)
        
        # Classification
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        
        return torch.sigmoid(x)

# ============================================================================
# TRAINING FUNCTIONS
# ============================================================================

def train_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        
        out = model(data.x, data.edge_index, data.batch)
        loss = criterion(out.squeeze(), data.y)
        
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    return total_loss / len(loader)

def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            out = model(data.x, data.edge_index, data.batch)
            pred = (out.squeeze() > 0.5).float()
            correct += (pred == data.y).sum().item()
            total += data.y.size(0)
    
    return correct / total

# ============================================================================
# MAIN TRAINING SCRIPT
# ============================================================================

def main():
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Paths (adjust these for your Kaggle environment)
    healthy_path = '/kaggle/input/dojo-dataset/complete_simulation_dataset/healthy/2017-11-03.csv'
    
    # Collect all damage case paths
    damage_paths = []
    for i in range(1, 9):
        pattern = f'/kaggle/input/dojo-dataset/complete_simulation_dataset/damage_case_{i}_*/2017-11-03.csv'
        damage_paths.extend(glob(pattern))
    
    print(f"Found {len(damage_paths)} damage case files")
    
    # Load and prepare data
    print("Loading and preparing data...")
    graph_data_list, labels = load_and_prepare_data(healthy_path, damage_paths, max_samples=50)
    
    print(f"Total samples: {len(graph_data_list)}")
    print(f"Healthy samples: {labels.count(0)}")
    print(f"Damaged samples: {labels.count(1)}")
    
    # Split data
    train_data, test_data = train_test_split(graph_data_list, test_size=0.2, random_state=42)
    
    # Create data loaders
    train_loader = DataLoader(train_data, batch_size=8, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=8, shuffle=False)
    
    # Initialize model
    num_features = train_data[0].x.shape[1]
    model = BridgeDamageGNN(num_node_features=num_features, hidden_channels=64).to(device)
    
    # Loss and optimizer
    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
    
    # Training loop
    print("\nStarting training...")
    num_epochs = 100
    best_acc = 0
    
    for epoch in range(num_epochs):
        loss = train_epoch(model, train_loader, optimizer, criterion, device)
        train_acc = evaluate(model, train_loader, device)
        test_acc = evaluate(model, test_loader, device)
        
        if test_acc > best_acc:
            best_acc = test_acc
            torch.save(model.state_dict(), 'best_bridge_gnn_model.pth')
        
        if (epoch + 1) % 10 == 0:
            print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss:.4f}, '
                  f'Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')
    
    print(f'\nBest Test Accuracy: {best_acc:.4f}')
    print("Model saved as 'best_bridge_gnn_model.pth'")

if __name__ == "__main__":
    main()

Using device: cuda
Found 8 damage case files
Loading and preparing data...
Total samples: 9
Healthy samples: 1
Damaged samples: 8

Starting training...
Epoch 10/100, Loss: 28.5714, Train Acc: 0.8571, Test Acc: 1.0000
Epoch 20/100, Loss: 14.2857, Train Acc: 0.8571, Test Acc: 1.0000
Epoch 30/100, Loss: 14.2857, Train Acc: 0.8571, Test Acc: 1.0000
Epoch 40/100, Loss: 14.2857, Train Acc: 0.8571, Test Acc: 1.0000
Epoch 50/100, Loss: 14.2857, Train Acc: 0.8571, Test Acc: 1.0000
Epoch 60/100, Loss: 14.2857, Train Acc: 0.8571, Test Acc: 1.0000
Epoch 70/100, Loss: 14.2857, Train Acc: 0.8571, Test Acc: 1.0000
Epoch 80/100, Loss: 14.2857, Train Acc: 0.8571, Test Acc: 1.0000
Epoch 90/100, Loss: 14.2857, Train Acc: 0.8571, Test Acc: 1.0000
Epoch 100/100, Loss: 14.2857, Train Acc: 0.8571, Test Acc: 1.0000

Best Test Accuracy: 1.0000
Model saved as 'best_bridge_gnn_model.pth'


In [4]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, global_mean_pool, global_max_pool
from torch_geometric.data import Data, DataLoader
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, confusion_matrix
from scipy import signal
from scipy.fft import fft, fftfreq
import os
from glob import glob
import warnings
warnings.filterwarnings('ignore')

# ============================================================================
# FEATURE EXTRACTION FUNCTIONS
# ============================================================================

def align_signals(signal1, signal2):
    """Align two signals to the same length by truncating to minimum length"""
    min_len = min(len(signal1), len(signal2))
    return signal1[:min_len], signal2[:min_len]

def calculate_rms_error(healthy_signal, damaged_signal):
    """Calculate RMS error between healthy and damaged signals"""
    healthy_signal, damaged_signal = align_signals(healthy_signal, damaged_signal)
    return np.sqrt(np.mean((healthy_signal - damaged_signal) ** 2))

def extract_dominant_frequency(signal_data, sampling_rate=100):
    """Extract dominant frequency from signal using FFT"""
    n = len(signal_data)
    yf = fft(signal_data)
    xf = fftfreq(n, 1/sampling_rate)
    
    positive_freq_idx = xf > 0
    xf_pos = xf[positive_freq_idx]
    yf_pos = np.abs(yf[positive_freq_idx])
    
    dominant_idx = np.argmax(yf_pos)
    return xf_pos[dominant_idx]

def calculate_frequency_shift(healthy_signal, damaged_signal, sampling_rate=100):
    """Calculate frequency shift between healthy and damaged signals"""
    healthy_signal, damaged_signal = align_signals(healthy_signal, damaged_signal)
    healthy_freq = extract_dominant_frequency(healthy_signal, sampling_rate)
    damaged_freq = extract_dominant_frequency(damaged_signal, sampling_rate)
    return abs(damaged_freq - healthy_freq)

def calculate_statistical_features(signal_data):
    """Calculate additional statistical features"""
    return {
        'mean': np.mean(signal_data),
        'std': np.std(signal_data),
        'max': np.max(np.abs(signal_data)),
        'min': np.min(signal_data),
        'kurtosis': np.mean((signal_data - np.mean(signal_data))**4) / (np.std(signal_data)**4),
        'skewness': np.mean((signal_data - np.mean(signal_data))**3) / (np.std(signal_data)**3),
        'rms': np.sqrt(np.mean(signal_data**2)),
        'peak_to_peak': np.ptp(signal_data)
    }

def extract_node_features(healthy_df, damaged_df, sensor_names):
    """Extract comprehensive features for each sensor node"""
    features = []
    
    for sensor in sensor_names:
        sensor_cols = [col for col in healthy_df.columns if sensor in col and col != 'Time']
        
        if len(sensor_cols) == 0:
            continue
            
        node_features = []
        
        for col in sensor_cols:
            if col not in damaged_df.columns:
                continue
                
            healthy_signal = healthy_df[col].values
            damaged_signal = damaged_df[col].values
            
            # Align signals to same length
            healthy_signal, damaged_signal = align_signals(healthy_signal, damaged_signal)
            
            # Feature 1: RMS Error
            rms_error = np.sqrt(np.mean((healthy_signal - damaged_signal) ** 2))
            
            # Feature 2: Frequency Shift
            try:
                healthy_freq = extract_dominant_frequency(healthy_signal)
                damaged_freq = extract_dominant_frequency(damaged_signal)
                freq_shift = abs(damaged_freq - healthy_freq)
            except:
                freq_shift = 0.0
            
            # Feature 3-10: Statistical features from damaged signal
            stats = calculate_statistical_features(damaged_signal)
            
            # Feature 11: Correlation between healthy and damaged
            try:
                correlation = np.corrcoef(healthy_signal, damaged_signal)[0, 1]
                if np.isnan(correlation):
                    correlation = 0.0
            except:
                correlation = 0.0
            
            node_features.extend([
                rms_error, 
                freq_shift, 
                stats['mean'],
                stats['std'],
                stats['max'],
                stats['kurtosis'],
                stats['skewness'],
                stats['rms'],
                stats['peak_to_peak'],
                correlation
            ])
        
        features.append(node_features)
    
    return np.array(features)

# ============================================================================
# DATA AUGMENTATION
# ============================================================================

def augment_data(data, num_augmentations=5, noise_level=0.01):
    """Create augmented versions of graph data"""
    augmented_data = []
    
    for _ in range(num_augmentations):
        # Add small random noise to node features
        noise = torch.randn_like(data.x) * noise_level
        x_aug = data.x + noise
        
        # Create new data object
        data_aug = Data(x=x_aug, edge_index=data.edge_index, y=data.y)
        augmented_data.append(data_aug)
    
    return augmented_data

# ============================================================================
# GRAPH CONSTRUCTION
# ============================================================================

def create_bridge_graph(num_nodes=3):
    """Create edge index for bridge structure"""
    edges = [
        [0, 1], [1, 0],  # Node 24 <-> Node 31
        [1, 2], [2, 1],  # Node 31 <-> Node 32
        [0, 2], [2, 0],  # Node 24 <-> Node 32
    ]
    
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    return edge_index

# ============================================================================
# DATASET PREPARATION
# ============================================================================

def load_all_csv_files(healthy_dir, damage_case_dirs):
    """Load all CSV files from healthy and damage directories"""
    healthy_files = sorted(glob(os.path.join(healthy_dir, '*.csv')))
    
    damage_files = []
    for damage_dir in damage_case_dirs:
        damage_files.extend(sorted(glob(os.path.join(damage_dir, '*.csv'))))
    
    print(f"Found {len(healthy_files)} healthy files")
    print(f"Found {len(damage_files)} damage files")
    
    return healthy_files, damage_files

def load_and_prepare_data(healthy_files, damage_files, augment=True, num_augmentations=5):
    """Load all datasets and prepare graph data with augmentation"""
    
    sensor_names = ['Sensor_24', 'Sensor_31', 'Sensor_32']
    edge_index = create_bridge_graph(num_nodes=3)
    
    graph_data_list = []
    labels = []
    
    # Process healthy files
    print("\nProcessing healthy files...")
    for i, healthy_path in enumerate(healthy_files):
        try:
            healthy_df = pd.read_csv(healthy_path)
            
            # For healthy case, compare with itself (minimal features)
            features = extract_node_features(healthy_df, healthy_df, sensor_names)
            
            if features.shape[1] < 30:
                features = np.pad(features, ((0, 0), (0, 30 - features.shape[1])), mode='constant')
            
            x = torch.tensor(features, dtype=torch.float)
            data = Data(x=x, edge_index=edge_index, y=torch.tensor([0], dtype=torch.float))
            graph_data_list.append(data)
            labels.append(0)
            
            # Augment healthy data
            if augment:
                augmented = augment_data(data, num_augmentations, noise_level=0.005)
                graph_data_list.extend(augmented)
                labels.extend([0] * len(augmented))
            
            if (i + 1) % 10 == 0:
                print(f"Processed {i+1}/{len(healthy_files)} healthy files")
        except Exception as e:
            print(f"Error processing {healthy_path}: {e}")
            continue
    
    # Process damage files
    print("\nProcessing damage files...")
    # Use first healthy file as reference
    reference_healthy_df = pd.read_csv(healthy_files[0])
    
    for i, damage_path in enumerate(damage_files):
        try:
            damaged_df = pd.read_csv(damage_path)
            features = extract_node_features(reference_healthy_df, damaged_df, sensor_names)
            
            if features.shape[1] < 30:
                features = np.pad(features, ((0, 0), (0, 30 - features.shape[1])), mode='constant')
            
            x = torch.tensor(features, dtype=torch.float)
            data = Data(x=x, edge_index=edge_index, y=torch.tensor([1], dtype=torch.float))
            graph_data_list.append(data)
            labels.append(1)
            
            # Augment damage data
            if augment:
                augmented = augment_data(data, num_augmentations, noise_level=0.01)
                graph_data_list.extend(augmented)
                labels.extend([1] * len(augmented))
            
            if (i + 1) % 10 == 0:
                print(f"Processed {i+1}/{len(damage_files)} damage files")
        except Exception as e:
            print(f"Error processing {damage_path}: {e}")
            continue
    
    return graph_data_list, labels

# ============================================================================
# IMPROVED GNN MODEL
# ============================================================================

class ImprovedBridgeDamageGNN(nn.Module):
    def __init__(self, num_node_features, hidden_channels=64):
        super(ImprovedBridgeDamageGNN, self).__init__()
        
        # Graph Attention layers for better feature learning
        self.conv1 = GATConv(num_node_features, hidden_channels, heads=4, concat=True)
        self.conv2 = GATConv(hidden_channels * 4, hidden_channels, heads=4, concat=True)
        self.conv3 = GCNConv(hidden_channels * 4, hidden_channels)
        
        # Batch normalization
        self.bn1 = nn.BatchNorm1d(hidden_channels * 4)
        self.bn2 = nn.BatchNorm1d(hidden_channels * 4)
        self.bn3 = nn.BatchNorm1d(hidden_channels)
        
        # Fully connected layers
        self.fc1 = nn.Linear(hidden_channels * 2, 64)  # *2 for mean + max pooling
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, 1)
        
        self.dropout = nn.Dropout(0.4)
        
    def forward(self, x, edge_index, batch=None):
        # Graph convolutions with attention
        x = self.conv1(x, edge_index)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.dropout(x)
        
        x = self.conv2(x, edge_index)
        x = self.bn2(x)
        x = F.relu(x)
        x = self.dropout(x)
        
        x = self.conv3(x, edge_index)
        x = self.bn3(x)
        x = F.relu(x)
        
        # Dual pooling (mean + max)
        if batch is None:
            batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
        
        x_mean = global_mean_pool(x, batch)
        x_max = global_max_pool(x, batch)
        x = torch.cat([x_mean, x_max], dim=1)
        
        # Classification layers
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout(x)
        
        x = self.fc2(x)
        x = F.relu(x)
        x = self.dropout(x)
        
        x = self.fc3(x)
        
        return torch.sigmoid(x)

# ============================================================================
# TRAINING FUNCTIONS
# ============================================================================

def train_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        
        out = model(data.x, data.edge_index, data.batch)
        loss = criterion(out.squeeze(), data.y)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        total_loss += loss.item()
    
    return total_loss / len(loader)

def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            out = model(data.x, data.edge_index, data.batch)
            pred = (out.squeeze() > 0.5).float()
            correct += (pred == data.y).sum().item()
            total += data.y.size(0)
            
            all_preds.extend(pred.cpu().numpy())
            all_labels.extend(data.y.cpu().numpy())
    
    return correct / total, all_preds, all_labels

# ============================================================================
# MAIN TRAINING SCRIPT
# ============================================================================

def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}\n")
    
    # Define paths
    base_path = '/kaggle/input/dojo-dataset/complete_simulation_dataset'
    healthy_dir = os.path.join(base_path, 'healthy')
    
    # Collect all damage case directories
    damage_dirs = []
    for i in range(1, 9):
        damage_pattern = os.path.join(base_path, f'damage_case_{i}_*')
        damage_dirs.extend(glob(damage_pattern))
    
    print(f"Found {len(damage_dirs)} damage case directories")
    
    # Load all files
    healthy_files, damage_files = load_all_csv_files(healthy_dir, damage_dirs)
    
    # Load and prepare data with augmentation
    print("\nLoading and preparing data with augmentation...")
    graph_data_list, labels = load_and_prepare_data(
        healthy_files, 
        damage_files, 
        augment=True, 
        num_augmentations=5
    )
    
    print(f"\nTotal samples after augmentation: {len(graph_data_list)}")
    print(f"Healthy samples: {labels.count(0)}")
    print(f"Damaged samples: {labels.count(1)}")
    
    # Split data with stratification
    train_data, test_data = train_test_split(
        graph_data_list, 
        test_size=0.2, 
        random_state=42,
        stratify=labels
    )
    
    print(f"\nTrain samples: {len(train_data)}")
    print(f"Test samples: {len(test_data)}")
    
    # Create data loaders
    train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=16, shuffle=False)
    
    # Initialize model
    num_features = train_data[0].x.shape[1]
    print(f"\nNode feature dimension: {num_features}")
    
    model = ImprovedBridgeDamageGNN(
        num_node_features=num_features, 
        hidden_channels=64
    ).to(device)
    
    # Calculate class weights for balanced loss
    n_healthy = labels.count(0)
    n_damaged = labels.count(1)
    weight_for_0 = n_damaged / (n_healthy + n_damaged)
    weight_for_1 = n_healthy / (n_healthy + n_damaged)
    
    class_weights = torch.tensor([weight_for_0, weight_for_1], device=device)
    
    # Loss and optimizer
    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=10, verbose=True
    )
    
    # Training loop
    print("\nStarting training...")
    num_epochs = 150
    best_acc = 0
    patience = 20
    patience_counter = 0
    
    for epoch in range(num_epochs):
        loss = train_epoch(model, train_loader, optimizer, criterion, device)
        train_acc, _, _ = evaluate(model, train_loader, device)
        test_acc, test_preds, test_labels = evaluate(model, test_loader, device)
        
        scheduler.step(loss)
        
        if test_acc > best_acc:
            best_acc = test_acc
            torch.save(model.state_dict(), 'best_bridge_gnn_model.pth')
            patience_counter = 0
        else:
            patience_counter += 1
        
        if (epoch + 1) % 10 == 0:
            print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss:.4f}, '
                  f'Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')
        
        # Early stopping
        if patience_counter >= patience:
            print(f"\nEarly stopping at epoch {epoch+1}")
            break
    
    # Final evaluation
    print(f'\n{"="*60}')
    print(f'Best Test Accuracy: {best_acc:.4f}')
    print(f'{"="*60}')
    
    # Load best model and show detailed metrics
    model.load_state_dict(torch.load('best_bridge_gnn_model.pth'))
    _, test_preds, test_labels = evaluate(model, test_loader, device)
    
    print("\nClassification Report:")
    print(classification_report(test_labels, test_preds, 
                                target_names=['Healthy', 'Damaged']))
    
    print("\nConfusion Matrix:")
    print(confusion_matrix(test_labels, test_preds))
    print("\nModel saved as 'best_bridge_gnn_model.pth'")

if __name__ == "__main__":
    main()

Using device: cuda

Found 8 damage case directories
Found 38 healthy files
Found 304 damage files

Loading and preparing data with augmentation...

Processing healthy files...
Processed 10/38 healthy files
Processed 20/38 healthy files
Processed 30/38 healthy files

Processing damage files...
Processed 10/304 damage files
Processed 20/304 damage files
Processed 30/304 damage files
Processed 40/304 damage files
Processed 50/304 damage files
Processed 60/304 damage files
Processed 70/304 damage files
Processed 80/304 damage files
Processed 90/304 damage files
Processed 100/304 damage files
Processed 110/304 damage files
Processed 120/304 damage files
Processed 130/304 damage files
Processed 140/304 damage files
Processed 150/304 damage files
Processed 160/304 damage files
Processed 170/304 damage files
Processed 180/304 damage files
Processed 190/304 damage files
Processed 200/304 damage files
Processed 210/304 damage files
Processed 220/304 damage files
Processed 230/304 damage files
Pr