<a href="https://colab.research.google.com/github/DanAkarca/AdaptiveStochasticity/blob/main/incremental_build_v9.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Mutli-network modular system w/ attention-based dropout and communciation delays:
- v9 - has a switch of tasks half way through so must change rapidly the dropout.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

class AttentionBasedAdaptiveDropout(nn.Module):
    def __init__(self, initial_rate, min_rate, max_rate, adaptation_rate):
        super(AttentionBasedAdaptiveDropout, self).__init__()
        self.dropout_rate = nn.Parameter(torch.tensor(initial_rate))
        self.min_rate = min_rate
        self.max_rate = max_rate
        self.adaptation_rate = adaptation_rate

    def forward(self, x, attention_weights=None):
        if self.training and attention_weights is not None:
            mean_attention = attention_weights.mean().item()
            self.update_rate(mean_attention)

        return F.dropout(x, p=self.dropout_rate.item(), training=self.training)

    def update_rate(self, attention_score):
        new_rate = self.dropout_rate - self.adaptation_rate * (attention_score - 0.5)
        new_rate = torch.clamp(new_rate, self.min_rate, self.max_rate)
        self.dropout_rate.data.copy_(new_rate)

    def get_current_rate(self):
        return self.dropout_rate.item()

class AttentionLayer(nn.Module):
    def __init__(self, hidden_size, attention_size):
        super(AttentionLayer, self).__init__()
        self.attention = nn.Sequential(
            nn.Linear(hidden_size, attention_size),
            nn.Tanh(),
            nn.Linear(attention_size, 1)
        )

    def forward(self, hidden_states):
        attention_weights = self.attention(hidden_states)
        attention_weights = torch.softmax(attention_weights, dim=1)
        return attention_weights

class SwappableModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, model_type, dropout_params):
        super(SwappableModel, self).__init__()
        self.model_type = model_type
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size

        if model_type == 'lstm':
            self.model = nn.LSTM(input_size, hidden_size, batch_first=True)
        elif model_type == 'gru':
            self.model = nn.GRU(input_size, hidden_size, batch_first=True)
        elif model_type == 'transformer':
            self.d_model = ((input_size - 1) // 8 + 1) * 8
            self.embedding = nn.Linear(input_size, self.d_model)
            self.model = nn.TransformerEncoder(
                nn.TransformerEncoderLayer(d_model=self.d_model, nhead=8, batch_first=True), num_layers=2
            )
        self.fc = nn.Linear(hidden_size if model_type in ['lstm', 'gru'] else self.d_model, output_size)

        self.adaptive_dropout = AttentionBasedAdaptiveDropout(**dropout_params)

    def forward(self, x, attention_weights=None):
        if self.model_type in ['lstm', 'gru']:
            out, _ = self.model(x)
        else:  # transformer
            x = self.embedding(x)
            out = self.model(x)

        # Apply dropout to each time step
        out = self.adaptive_dropout(out, attention_weights)

        # Apply final linear layer to each time step
        out = self.fc(out)

        return out

class EnhancedSwappableModelEnsemble(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_secondary, delays, secondary_model_types, primary_model_type, dropout_params):
        super(EnhancedSwappableModelEnsemble, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_secondary = num_secondary
        self.delays = delays

        # Secondary models
        self.secondary_models = nn.ModuleList([
            SwappableModel(input_size, hidden_size, output_size, model_type, dropout_params)
            for model_type in secondary_model_types
        ])

        # Attention mechanism
        self.attention = AttentionLayer(output_size, hidden_size)

        # Primary model (now swappable)
        self.primary_model = SwappableModel(input_size + output_size, hidden_size, output_size, primary_model_type, dropout_params)

    def forward(self, x):
        batch_size, seq_len, _ = x.size()

        # Process through secondary models
        secondary_outputs = []
        for model, delay in zip(self.secondary_models, self.delays):
            out = model(x)
            # Apply delay
            padding = torch.zeros(batch_size, delay, self.output_size, device=x.device)
            delayed_out = torch.cat([padding, out[:, :-delay, :]], dim=1)
            secondary_outputs.append(delayed_out.unsqueeze(1))

        # Stack secondary outputs
        secondary_outputs = torch.cat(secondary_outputs, dim=1)

        # Apply attention to secondary outputs
        attention_weights = self.attention(secondary_outputs)

        # Apply adaptive dropout to each secondary output individually
        dropped_secondary_outputs = []
        for i, model in enumerate(self.secondary_models):
            dropped_out = model.adaptive_dropout(
                secondary_outputs[:, i, :, :],
                attention_weights[:, i, :].unsqueeze(-1)
            )
            dropped_secondary_outputs.append(dropped_out.unsqueeze(1))

        dropped_secondary_outputs = torch.cat(dropped_secondary_outputs, dim=1)

        attended_output = torch.sum(attention_weights * dropped_secondary_outputs, dim=1)

        # Combine original input with attended secondary output
        combined_input = torch.cat([x, attended_output], dim=2)

        # Primary model
        out_primary = self.primary_model(combined_input, attention_weights.mean(dim=1))

        return out_primary[:, -1, :], attention_weights

    def get_dropout_rates(self):
        return [model.adaptive_dropout.get_current_rate() for model in self.secondary_models] + [self.primary_model.adaptive_dropout.get_current_rate()]

def generate_segmented_time_series(n_samples, seq_length, n_features, change_points):
    total_length = n_samples + seq_length
    time = np.arange(total_length) / 100.0
    series = []

    for i in range(n_features):
        feature = np.zeros(total_length)
        for j in range(len(change_points) + 1):
            start = 0 if j == 0 else change_points[j-1]
            end = total_length if j == len(change_points) else change_points[j]

            if j % 4 == 0:  # Linear trend
                feature[start:end] = 0.05 * time[start:end] + 0.1 * i
            elif j % 4 == 1:  # Sine wave
                feature[start:end] = np.sin(2 * np.pi * 0.05 * time[start:end])
            elif j % 4 == 2:  # Increased noise
                feature[start:end] = np.random.normal(0, 0.5, size=end-start)
            else:  # Complex sine
                feature[start:end] = (
                    0.5 * np.sin(2 * np.pi * 0.03 * time[start:end]) +
                    0.3 * np.sin(2 * np.pi * 0.07 * time[start:end])
                )

            # Add some noise to all regimes
            feature[start:end] += np.random.normal(0, 0.05, size=end-start)

        series.append(feature)

    return np.array(series).T

# Set hyperparameters
input_size = 5
hidden_size = 64
output_size = input_size
num_secondary = 3
delays = [1, 5, 3]
secondary_model_types = ['lstm', 'gru', 'lstm']
primary_model_type = 'lstm'
seq_length = 50
n_samples = 1000
batch_size = 32
num_epochs = 60
learning_rate = 0.001

# Dropout parameters
dropout_params = {
    'initial_rate': 0.25,
    'min_rate': 0.1,
    'max_rate': 0.9,
    'adaptation_rate': 0.001
}

# Define change points for the features
change_points = [250, 500, 750]  # Example change points

# Generate data
data = generate_segmented_time_series(n_samples + seq_length, seq_length, input_size, change_points)

# Prepare input sequences and target values
X = np.array([data[i:i+seq_length] for i in range(n_samples)])
y = data[seq_length:seq_length+n_samples]

X = torch.FloatTensor(X)
y = torch.FloatTensor(y)

# Split data into train and test sets
train_size = int(0.8 * len(X))
X_train, X_test = X[:train_size], X[train_size:]
y_train, y_test = y[:train_size], y[train_size:]

# Create data loaders
train_dataset = torch.utils.data.TensorDataset(X_train, y_train)
test_dataset = torch.utils.data.TensorDataset(X_test, y_test)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Initialize model, loss function, and optimizer
model = EnhancedSwappableModelEnsemble(input_size, hidden_size, output_size, num_secondary, delays, secondary_model_types, primary_model_type, dropout_params)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Train the model
model.train()
losses = []
dropout_rates_history = []
for epoch in range(num_epochs):
    total_loss = 0
    for batch_X, batch_y in train_loader:
        optimizer.zero_grad()
        outputs, attention_weights = model(batch_X)
        loss = criterion(outputs, batch_y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(train_loader)
    losses.append(avg_loss)
    dropout_rates_history.append(model.get_dropout_rates())
    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')

# Evaluate the model on the entire dataset
model.eval()
all_predictions = []
all_actuals = []
attention_weights_list = []

with torch.no_grad():
    for i in range(len(X)):
        input_seq = X[i].unsqueeze(0)  # Add batch dimension
        target = y[i].unsqueeze(0)
        output, attention_weights = model(input_seq)
        all_predictions.append(output.squeeze().numpy())
        all_actuals.append(target.squeeze().numpy())
        attention_weights_list.append(attention_weights.squeeze().numpy())

all_predictions = np.array(all_predictions)
all_actuals = np.array(all_actuals)
attention_weights = np.array(attention_weights_list)

# Visualize the full time series with predictions
plt.figure(figsize=(20, 15))
for i in range(input_size):
    plt.subplot(input_size, 1, i+1)

    # Plot full data
    plt.plot(data[seq_length:, i], label='Full Data', alpha=0.5)

    # Plot actual and predicted values for the entire series
    plt.plot(all_actuals[:, i], label='Actual', linewidth=2)
    plt.plot(all_predictions[:, i], label='Predicted', linewidth=2)

    plt.title(f'Feature {i+1}')
    plt.legend()

    # Add vertical lines for change points
    for change_point in change_points:
        if change_point >= seq_length:
            plt.axvline(x=change_point - seq_length, color='r', linestyle='--')

    # Add labels for each segment
    for j, change_point in enumerate(change_points + [len(data)]):
        if change_point >= seq_length:
            prev_change_point = seq_length if j == 0 else max(change_points[j-1], seq_length)
            mid_point = (prev_change_point + change_point - seq_length) // 2
            segment_type = ['Linear', 'Sine', 'Noise', 'Complex Sine'][j % 4]
            plt.text(mid_point, plt.ylim()[1], segment_type, horizontalalignment='center', verticalalignment='bottom')

plt.tight_layout()
plt.show()

# Calculate and print the Mean Squared Error for each feature
for i in range(output_size):
    mse = np.mean((actuals[:, i] - predictions[:, i])**2)

# Visualize attention weights
plt.figure(figsize=(10, 5))
mean_attention = attention_weights.mean(axis=2)
mean_attention_list = [mean_attention[:, i, 0] for i in range(mean_attention.shape[1])]
plt.boxplot(mean_attention_list)
plt.title('Distribution of Mean Attention Weights for Each Secondary Model')
plt.xlabel('Secondary Model (by delay)')
plt.ylabel('Mean Attention Weight')
plt.xticks(range(1, len(delays) + 1), delays)
plt.show()

# Print average attention weights
avg_attention = mean_attention.mean(axis=0)
for i, delay in enumerate(delays):
    print(f"Average attention weight for model with delay {delay}: {avg_attention[i][0]:.4f}")

# Visualize attention weights over time
plt.figure(figsize=(20, 5))
for i in range(len(delays)):
    plt.subplot(1, len(delays), i+1)
    plt.imshow(attention_weights[:, i, :].squeeze().T, aspect='auto', cmap='viridis')
    plt.title(f'Attention Weights for {secondary_model_types[i].upper()} (Delay {delays[i]})')
    plt.xlabel('Sample')
    plt.ylabel('Time Step')
    plt.colorbar()
plt.tight_layout()
plt.show()

# Visualize dropout rates during training
plt.figure(figsize=(12, 6))
dropout_rates_history = np.array(dropout_rates_history)
for i in range(dropout_rates_history.shape[1] - 1):
    plt.plot(dropout_rates_history[:, i], label=f'{secondary_model_types[i].upper()} (Delay {delays[i]})')
plt.plot(dropout_rates_history[:, -1], label=f'Primary {primary_model_type.upper()}', linestyle='--')
plt.title('Dropout Rates During Training for Each Model')
plt.xlabel('Epoch')
plt.ylabel('Dropout Rate')
plt.legend()
plt.grid(True)
plt.show()

# Visualize dropout rate vs attention weight
plt.figure(figsize=(20, 5))
for i in range(len(delays)):
    plt.subplot(1, len(delays), i+1)
    # Use mean dropout rate per batch for plotting
    mean_dropout_rates = dropout_rates_test[:, i].mean()
    plt.scatter(mean_attention[:, i, 0], np.full_like(mean_attention[:, i, 0], mean_dropout_rates), alpha=0.5)
    plt.title(f'Dropout Rate vs Attention Weight\n{secondary_model_types[i].upper()} (Delay {delays[i]})')
    plt.xlabel('Mean Attention Weight')
    plt.ylabel('Dropout Rate')
plt.tight_layout()
plt.show()

# Print final dropout rates
final_dropout_rates = model.get_dropout_rates()
for i, rate in enumerate(final_dropout_rates):
    print(f"Final dropout rate for {'Primary' if i == len(final_dropout_rates) - 1 else 'Secondary'} Model {i+1}: {rate:.4f}")

Epoch [10/60], Loss: 0.0890
Epoch [20/60], Loss: 0.0874
