TCN+Transformer Fusion

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split, GroupKFold
from sklearn.metrics import precision_score, recall_score, f1_score
from torch.utils.tensorboard import SummaryWriter

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Data Augmentation Function
def augment_data(X):
    augmented_data = []
    for sample in X:
        noise = np.random.normal(0, 0.01, sample.shape)
        augmented_data.append(sample + noise)
    return np.array(augmented_data)

# TCN Block Definition
class Chomp1d(nn.Module):
    def __init__(self, chomp_size):
        super(Chomp1d, self).__init__()
        self.chomp_size = chomp_size

    def forward(self, x):
        return x[:, :, :-self.chomp_size].contiguous()

class TemporalBlock(nn.Module):
    def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):
        super(TemporalBlock, self).__init__()
        self.conv1 = nn.Conv1d(n_inputs, n_outputs, kernel_size,
                               stride=stride, padding=padding, dilation=dilation)
        self.chomp1 = Chomp1d(padding)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)

        self.conv2 = nn.Conv1d(n_outputs, n_outputs, kernel_size,
                               stride=stride, padding=padding, dilation=dilation)
        self.chomp2 = Chomp1d(padding)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)

        self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1,
                                 self.conv2, self.chomp2, self.relu2, self.dropout2)
        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
        self.relu = nn.ReLU()
        self.init_weights()

    def init_weights(self):
        self.conv1.weight.data.normal_(0, 0.01)
        self.conv2.weight.data.normal_(0, 0.01)
        if self.downsample is not None:
            self.downsample.weight.data.normal_(0, 0.01)

    def forward(self, x):
        out = self.net(x)
        res = x if self.downsample is None else self.downsample(x)
        return self.relu(out + res)

class TemporalConvNet(nn.Module):
    def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
        super(TemporalConvNet, self).__init__()
        layers = []
        num_levels = len(num_channels)
        for i in range(num_levels):
            dilation_size = 2 ** i
            in_channels = num_inputs if i == 0 else num_channels[i-1]
            out_channels = num_channels[i]
            layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size,
                                     padding=(kernel_size-1) * dilation_size, dropout=dropout)]

        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)

# Define the TCN block for each sensor
class TCNBlock(nn.Module):
    def __init__(self, input_size, num_channels):
        super(TCNBlock, self).__init__()
        self.tcn = TemporalConvNet(input_size, num_channels)

    def forward(self, x):
        return self.tcn(x)

# Define the fusion and Transformer model
class FusionTransformerModel(nn.Module):
    def __init__(self, input_dim, tcn_channels, transformer_hidden_dim, num_layers, num_heads, num_classes=2):
        super(FusionTransformerModel, self).__init__()
        self.frontends = nn.ModuleList([TCNBlock(input_size=1, num_channels=tcn_channels) for _ in range(input_dim)])

        tcn_output_dim = tcn_channels[-1]  # The output dimension of the last TCN layer
        self.fc1 = nn.Linear(tcn_output_dim * input_dim, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, (tcn_output_dim * input_dim) // num_heads * num_heads)  # Ensure divisibility

        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=(tcn_output_dim * input_dim) // num_heads * num_heads, nhead=num_heads, dim_feedforward=transformer_hidden_dim),
            num_layers=num_layers
        )
        self.dropout = nn.Dropout(0.3)
        self.fc3 = nn.Linear((tcn_output_dim * input_dim) // num_heads * num_heads, num_classes)

    def forward(self, x):
        x = x.permute(0, 2, 1).to(device)
        
        front_end_outputs = [front_end(x[:, i:i+1, :]) for i, front_end in enumerate(self.frontends)]
        combined = torch.cat([output[:, :, -1] for output in front_end_outputs], dim=1)
        combined = self.fc1(combined)
        combined = self.relu(combined)
        combined = self.fc2(combined)

        combined = combined.unsqueeze(1)  # Add sequence dimension
        combined = self.transformer_encoder(combined)
        combined = self.dropout(combined)
        combined = combined[:, -1, :]  # Take the last time step output from the Transformer

        output = self.fc3(combined)
        return output

def train_model(X_train, y_train, X_val, y_val, input_dim, tcn_channels, transformer_hidden_dim, num_layers, num_heads, num_classes=2, epochs=100, batch_size=32, learning_rate=0.001, patience=10, fold_index=0):
    X_train_tensor = torch.tensor(X_train, dtype=torch.float32).to(device)
    y_train_tensor = torch.tensor(y_train, dtype=torch.long).to(device)
    X_val_tensor = torch.tensor(X_val, dtype=torch.float32).to(device)
    y_val_tensor = torch.tensor(y_val, dtype=torch.long).to(device)

    train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
    val_dataset = TensorDataset(X_val_tensor, y_val_tensor)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    model = FusionTransformerModel(input_dim, tcn_channels, transformer_hidden_dim, num_layers, num_heads, num_classes).to(device)
    model.train()

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)  # Learning rate scheduler

    writer = SummaryWriter(f'runs/fold_{fold_index}')  # TensorBoard writer

    best_val_loss = float('inf')
    patience_counter = 0

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_accuracy = correct / total

        # Validation
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for val_inputs, val_labels in val_loader:
                val_outputs = model(val_inputs)
                loss = criterion(val_outputs, val_labels)
                val_loss += loss.item() * val_inputs.size(0)
                _, val_predicted = torch.max(val_outputs, 1)
                val_total += val_labels.size(0)
                val_correct += (val_predicted == val_labels).sum().item()

        val_loss /= len(val_loader.dataset)
        val_accuracy = val_correct / val_total

        writer.add_scalar('Loss/train', epoch_loss, epoch)
        writer.add_scalar('Loss/val', val_loss, epoch)
        writer.add_scalar('Accuracy/train', epoch_accuracy, epoch)
        writer.add_scalar('Accuracy/val', val_accuracy, epoch)

        print(f'Epoch {epoch+1}/{epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}')
        scheduler.step(val_loss)  # Step the learning rate scheduler based on the validation loss

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print(f'Early stopping at epoch {epoch+1}')
            break

    writer.close()

    # Calculate number of parameters
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params}")

    # Calculate FLOPs
    input_size = (1, input_dim, X_train.shape[2])  # Using batch size 1 for FLOPs calculation
    flops = torchprofile.profile_macs(model, torch.randn(input_size).to(device))
    print(f"FLOPs: {flops}")

    return model, epoch_loss

def evaluate_model(models, X_test, y_test, batch_size=32):
    # Aggregate predictions from all models
    test_preds = []
    test_losses = []
    X_test_tensor = torch.tensor(X_test, dtype=torch.float32).to(device)
    y_test_tensor = torch.tensor(y_test, dtype=torch.long).to(device)

    criterion = nn.CrossEntropyLoss()

    for model in models:
        model.eval()  # Set the model to evaluation mode

        # Create DataLoader
        dataset = TensorDataset(X_test_tensor, y_test_tensor)
        test_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

        preds = []
        running_loss = 0.0
        total = 0

        with torch.no_grad():  # Turn off gradients for validation, saves memory and computations
            for data, labels in test_loader:
                outputs = model(data)
                preds.append(outputs)
                loss = criterion(outputs, labels)
                running_loss += loss.item() * data.size(0)
                total += labels.size(0)

        test_losses.append(running_loss / total)
        test_preds.append(torch.cat(preds, dim=0))

    # Average predictions
    mean_preds = torch.mean(torch.stack(test_preds), dim=0)
    _, predicted = torch.max(mean_preds, 1)

    # Calculate metrics
    accuracy = (predicted == y_test_tensor).sum().item() / y_test_tensor.size(0)
    precision = precision_score(y_test_tensor.cpu(), predicted.cpu(), average='weighted')
    recall = recall_score(y_test_tensor.cpu(), predicted.cpu(), average='weighted')
    f1 = f1_score(y_test_tensor.cpu(), predicted.cpu(), average='weighted')
    mean_loss = np.mean(test_losses)
    
    return accuracy, mean_loss, precision, recall, f1

# Perform Group K-Fold Evaluation

NUM_MODELS = 5  # Number of models to train in the ensemble
DNNS = []

input_dim = X_res.shape[-1]  # Number of input channels (updated to include accelerometer data)
tcn_channels = [25, 25, 25]  # TCN channels
transformer_hidden_dim = 64
num_layers = 2
num_heads = 4  # Set the number of heads to a value that divides the embedding dimension
num_classes = 2

# Assuming I_TRAINS and I_TESTS are provided for Group K-Fold

for fold_index, I_train in enumerate(I_TRAINS):
    models = []
    for _ in range(NUM_MODELS):
        # Split training data into training and validation sets
        X_train, X_val, y_train, y_val = train_test_split(X_res[I_train], y_simple[I_train], test_size=0.2, random_state=42)
        
        X_train_augmented = augment_data(X_train)
        X_train_combined = np.concatenate((X_train, X_train_augmented))
        y_train_combined = np.concatenate((y_train, y_train))  # Duplicate labels for augmented data

        signal_length = X_train.shape[2]  # Use the actual signal length
        model, epoch_loss = train_model(X_train_combined, y_train_combined, X_val, y_val, input_dim, tcn_channels, transformer_hidden_dim, num_layers, num_heads, num_classes, epochs=150, batch_size=32, learning_rate=0.0005, patience=10, fold_index=fold_index)  # Adjusted learning rate and added patience
        models.append(model)
        print(f"Fold {fold_index + 1}, Model Loss: {epoch_loss:.4f}")
    DNNS.append(models)

# Evaluate the ensemble of models for each fold
accuracies = []
mean_losses = []
precisions = []
recalls = []
f1_scores = []
for models, I_test in zip(DNNS, I_TESTS):
    X_test, y_test = X_res[I_test], y_simple[I_test]

    accuracy, mean_loss, precision, recall, f1 = evaluate_model(models, X_test, y_test)
    accuracies.append(accuracy)
    mean_losses.append(mean_loss)
    precisions.append(precision)
    recalls.append(recall)
    f1_scores.append(f1)

# Calculate and print the mean and standard deviation of accuracies and other metrics
mean_accuracy = np.mean(accuracies)
std_dev_accuracy = np.std(accuracies)
mean_loss = np.mean(mean_losses)
std_dev_loss = np.std(mean_losses)
mean_precision = np.mean(precisions)
std_dev_precision = np.std(precisions)
mean_recall = np.mean(recalls)
std_dev_recall = np.std(recalls)
mean_f1 = np.mean(f1_scores)
std_dev_f1 = np.std(f1_scores)

print(f"Mean Accuracy: {mean_accuracy * 100:.2f}%, (SD={std_dev_accuracy})")
print(f"Mean Loss: {mean_loss:.4f}, (SD={std_dev_loss})")
print(f"Mean Precision: {mean_precision:.4f}, (SD={std_dev_precision})")
print(f"Mean Recall: {mean_recall:.4f}, (SD={std_dev_recall})")
print(f"Mean F1 Score: {mean_f1:.4f}, (SD={std_dev_f1})")
