## Coding Graph

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import defaultdict
import random

class Node:
    def __init__(self, node_id, features, timestamp):
        self.id = node_id
        self.features = np.array(features, dtype=np.float32)
        self.timestamp = timestamp

class Edge:
    def __init__(self, source_id, target_id, features, timestamp):
        self.source_id = source_id
        self.target_id = target_id
        self.features = np.array(features, dtype=np.float32)
        self.timestamp = timestamp

class Event:
    def __init__(self, timestamp, event_type, obj_id, old_features, new_features):
        self.timestamp = timestamp
        self.event_type = event_type
        self.obj_id = obj_id
        self.old_features = old_features
        self.new_features = new_features

class TemporalGraph:
    def __init__(self):
        self.nodes = {}
        self.edges = {}
        self.event_stream = []
        self.current_time = 0

    def add_node(self, node_id, features, timestamp):
        if node_id in self.nodes:
            return False
        node = Node(node_id, features, timestamp)
        self.nodes[node_id] = node
        event = Event(timestamp, 'NODE_JOIN', node_id, None, features)
        self.event_stream.append(event)
        return True

    def add_edge(self, source_id, target_id, features, timestamp):
        if source_id not in self.nodes or target_id not in self.nodes:
            return False
        edge_key = (source_id, target_id)
        edge = Edge(source_id, target_id, features, timestamp)
        self.edges[edge_key] = edge
        self.edges[(target_id, source_id)] = edge
        event = Event(timestamp, 'EDGE_CREATED', edge_key, None, features)
        self.event_stream.append(event)
        return True

    def update_node_features(self, node_id, new_features, timestamp):
        if node_id not in self.nodes:
            return False
        old_features = self.nodes[node_id].features.copy()
        self.nodes[node_id].features = np.array(new_features, dtype=np.float32)
        self.nodes[node_id].timestamp = timestamp
        event = Event(timestamp, 'NODE_UPDATE', node_id, old_features, new_features)
        self.event_stream.append(event)
        return True

    def update_edge_features(self, source_id, target_id, new_features, timestamp):
        edge_key = (source_id, target_id)
        if edge_key not in self.edges:
            return False
        old_features = self.edges[edge_key].features.copy()
        self.edges[edge_key].features = np.array(new_features, dtype=np.float32)
        self.edges[edge_key].timestamp = timestamp
        self.edges[(target_id, source_id)].features = np.array(new_features, dtype=np.float32)
        event = Event(timestamp, 'EDGE_UPDATE', edge_key, old_features, new_features)
        self.event_stream.append(event)
        return True

    def remove_edge(self, source_id, target_id, timestamp):
        edge_key = (source_id, target_id)
        if edge_key not in self.edges:
            return False
        old_features = self.edges[edge_key].features.copy()
        del self.edges[edge_key]
        del self.edges[(target_id, source_id)]
        event = Event(timestamp, 'EDGE_DELETED', edge_key, old_features, None)
        self.event_stream.append(event)
        return True

    def get_neighbors(self, node_id):
        neighbors = []
        for (src, tgt) in self.edges.keys():
            if src == node_id:
                neighbors.append(tgt)
        return list(set(neighbors))

class TimeEncoder(nn.Module):
    def __init__(self, dimension):
        super(TimeEncoder, self).__init__()
        self.dimension = dimension
        self.w = nn.Linear(1, dimension)

    def forward(self, t):
        t = t.unsqueeze(dim=1)
        output = torch.cos(self.w(t))
        return output

class MessageFunction(nn.Module):
    def __init__(self, node_feat_dim, edge_feat_dim, memory_dim, time_dim, message_dim):
        super(MessageFunction, self).__init__()
        input_dim = 2 * node_feat_dim + edge_feat_dim + 2 * memory_dim + time_dim
        self.fc1 = nn.Linear(input_dim, message_dim)
        self.fc2 = nn.Linear(message_dim, message_dim)
        self.relu = nn.ReLU()

    def forward(self, source_feat, target_feat, edge_feat, source_mem, target_mem, time_enc):
        x = torch.cat([source_feat, target_feat, edge_feat, source_mem, target_mem, time_enc], dim=-1)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class MemoryUpdater(nn.Module):
    def __init__(self, memory_dim, message_dim):
        super(MemoryUpdater, self).__init__()
        self.memory_dim = memory_dim
        self.gru = nn.GRUCell(message_dim, memory_dim)

    def forward(self, memory, message):
        new_memory = self.gru(message, memory)
        return new_memory

class Predictor(nn.Module):
    def __init__(self, memory_dim, edge_feat_dim, time_dim):
        super(Predictor, self).__init__()
        input_dim = 2 * memory_dim + edge_feat_dim + time_dim
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 1)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, source_mem, target_mem, edge_feat, time_enc):
        x = torch.cat([source_mem, target_mem, edge_feat, time_enc], dim=-1)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.sigmoid(self.fc3(x))
        return x

class TGN(nn.Module):
    def __init__(self, num_nodes, node_feat_dim, edge_feat_dim, memory_dim, time_dim, message_dim):
        super(TGN, self).__init__()
        self.num_nodes = num_nodes
        self.memory_dim = memory_dim
        self.memory = torch.zeros(num_nodes, memory_dim)

        self.time_encoder = TimeEncoder(time_dim)
        self.message_function = MessageFunction(node_feat_dim, edge_feat_dim, memory_dim, time_dim, message_dim)
        self.memory_updater = MemoryUpdater(memory_dim, message_dim)
        self.predictor = Predictor(memory_dim, edge_feat_dim, time_dim)

    def reset_memory(self):
        self.memory = torch.zeros(self.num_nodes, self.memory_dim)

    def forward(self, source_ids, target_ids, source_feats, target_feats, edge_feats, timestamps):
        time_enc = self.time_encoder(timestamps)

        source_mems = self.memory[source_ids]
        target_mems = self.memory[target_ids]

        messages_source = self.message_function(source_feats, target_feats, edge_feats,
                                                source_mems, target_mems, time_enc)
        messages_target = self.message_function(target_feats, source_feats, edge_feats,
                                                target_mems, source_mems, time_enc)

        new_source_mems = self.memory_updater(source_mems, messages_source)
        new_target_mems = self.memory_updater(target_mems, messages_target)

        self.memory[source_ids] = new_source_mems.detach()
        self.memory[target_ids] = new_target_mems.detach()

        predictions = self.predictor(new_source_mems, new_target_mems, edge_feats, time_enc)

        return predictions

def generate_disaster_dataset(duration, num_nodes):
    graph = TemporalGraph()

    for i in range(num_nodes):
        features = [
            random.uniform(50, 100),
            random.choice([0, 1, 2, 3]),
            random.uniform(0, 500),
            random.uniform(0, 500),
            0
        ]
        graph.add_node(i, features, timestamp=0)

    for i in range(num_nodes):
        for j in range(i+1, num_nodes):
            if random.random() < 0.3:
                features = [
                    random.uniform(-80, -50),
                    random.uniform(0, 0.1),
                    random.uniform(10, 50)
                ]
                graph.add_edge(i, j, features, timestamp=0)

    for t in range(1, duration):
        for node_id in list(graph.nodes.keys()):
            node = graph.nodes[node_id]
            new_battery = max(0, node.features[0] - random.uniform(0, 2))
            new_features = node.features.copy()
            new_features[0] = new_battery
            if new_battery < 20:
                new_features[4] = 1
            graph.update_node_features(node_id, new_features, timestamp=t)

        for edge_key in list(graph.edges.keys()):
            if edge_key[0] < edge_key[1]:
                edge = graph.edges[edge_key]
                new_signal = edge.features[0] - random.uniform(0, 3)
                new_features = edge.features.copy()
                new_features[0] = new_signal
                graph.update_edge_features(edge_key[0], edge_key[1], new_features, timestamp=t)

                if new_signal < -90:
                    graph.remove_edge(edge_key[0], edge_key[1], timestamp=t)

        if random.random() < 0.05:
            active_nodes = [nid for nid, n in graph.nodes.items() if n.features[4] == 0]
            if active_nodes:
                failed_node = random.choice(active_nodes)
                new_features = graph.nodes[failed_node].features.copy()
                new_features[4] = 2
                graph.update_node_features(failed_node, new_features, timestamp=t)

    return graph

def prepare_training_data(graph, prediction_horizon=10):
    data = []

    for i, event in enumerate(graph.event_stream):
        if event.event_type in ['EDGE_UPDATE', 'EDGE_CREATED']:
            source_id, target_id = event.obj_id

            if source_id not in graph.nodes or target_id not in graph.nodes:
                continue

            source_node = graph.nodes[source_id]
            target_node = graph.nodes[target_id]

            future_time = event.timestamp + prediction_horizon
            label = 0

            for future_event in graph.event_stream[i:]:
                if future_event.timestamp > future_time:
                    break
                if future_event.event_type == 'EDGE_DELETED' and future_event.obj_id == event.obj_id:
                    label = 1
                    break
                if future_event.event_type == 'EDGE_UPDATE' and future_event.obj_id == event.obj_id:
                    if future_event.new_features[0] < -90:
                        label = 1
                        break

            data.append({
                'source_id': source_id,
                'target_id': target_id,
                'source_features': source_node.features,
                'target_features': target_node.features,
                'edge_features': event.new_features,
                'timestamp': event.timestamp,
                'label': label
            })

    return data

def train_tgn(graph, num_epochs=50, learning_rate=0.001, batch_size=32):
    num_nodes = len(graph.nodes)
    node_feat_dim = 5
    edge_feat_dim = 3
    memory_dim = 64
    time_dim = 32
    message_dim = 32

    model = TGN(num_nodes, node_feat_dim, edge_feat_dim, memory_dim, time_dim, message_dim)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.BCELoss()

    training_data = prepare_training_data(graph, prediction_horizon=10)

    for epoch in range(num_epochs):
        model.reset_memory()
        random.shuffle(training_data)

        total_loss = 0
        num_batches = 0

        for i in range(0, len(training_data), batch_size):
            batch = training_data[i:i+batch_size]

            source_ids = torch.tensor([d['source_id'] for d in batch], dtype=torch.long)
            target_ids = torch.tensor([d['target_id'] for d in batch], dtype=torch.long)
            source_feats = torch.tensor(np.array([d['source_features'] for d in batch]), dtype=torch.float32)
            target_feats = torch.tensor(np.array([d['target_features'] for d in batch]), dtype=torch.float32)
            edge_feats = torch.tensor(np.array([d['edge_features'] for d in batch]), dtype=torch.float32)
            timestamps = torch.tensor([d['timestamp'] for d in batch], dtype=torch.float32)
            labels = torch.tensor([[d['label']] for d in batch], dtype=torch.float32)

            optimizer.zero_grad()
            predictions = model(source_ids, target_ids, source_feats, target_feats, edge_feats, timestamps)
            loss = criterion(predictions, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            num_batches += 1

        avg_loss = total_loss / num_batches
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

    return model

def predict_link_failures(model, graph, current_time, threshold=0.7):
    at_risk_links = []

    model.eval()
    with torch.no_grad():
        for edge_key in graph.edges.keys():
            if edge_key[0] < edge_key[1]:
                source_id, target_id = edge_key
                edge = graph.edges[edge_key]
                source_node = graph.nodes[source_id]
                target_node = graph.nodes[target_id]

                source_ids = torch.tensor([source_id], dtype=torch.long)
                target_ids = torch.tensor([target_id], dtype=torch.long)
                source_feats = torch.tensor([source_node.features], dtype=torch.float32)
                target_feats = torch.tensor([target_node.features], dtype=torch.float32)
                edge_feats = torch.tensor([edge.features], dtype=torch.float32)
                timestamps = torch.tensor([current_time], dtype=torch.float32)

                prediction = model(source_ids, target_ids, source_feats, target_feats, edge_feats, timestamps)

                if prediction.item() > threshold:
                    at_risk_links.append({
                        'edge': edge_key,
                        'probability': prediction.item(),
                        'source_battery': source_node.features[0],
                        'target_battery': target_node.features[0],
                        'signal_strength': edge.features[0]
                    })

    at_risk_links.sort(key=lambda x: x['probability'], reverse=True)
    return at_risk_links

def find_alternative_path(graph, source, target, avoid_edge):
    distances = {node_id: float('inf') for node_id in graph.nodes.keys()}
    distances[source] = 0
    previous = {node_id: None for node_id in graph.nodes.keys()}
    unvisited = set(graph.nodes.keys())

    while unvisited:
        current = min(unvisited, key=lambda node: distances[node])

        if distances[current] == float('inf'):
            break

        if current == target:
            break

        unvisited.remove(current)

        for neighbor in graph.get_neighbors(current):
            if neighbor not in graph.nodes:
                continue

            edge_key = (current, neighbor)
            if edge_key == avoid_edge or (neighbor, current) == avoid_edge:
                continue

            if edge_key not in graph.edges:
                continue

            edge = graph.edges[edge_key]
            weight = 1.0 / (abs(edge.features[0]) + 1)

            alt_distance = distances[current] + weight

            if alt_distance < distances[neighbor]:
                distances[neighbor] = alt_distance
                previous[neighbor] = current

    if distances[target] == float('inf'):
        return None

    path = []
    current = target
    while current is not None:
        path.append(current)
        current = previous[current]
    path.reverse()

    return path

def execute_self_healing(graph, at_risk_links, current_time):
    actions = []

    for risky_link in at_risk_links:
        source, target = risky_link['edge']

        alt_path = find_alternative_path(graph, source, target, risky_link['edge'])

        if alt_path and len(alt_path) > 2:
            action = {
                'type': 'REROUTE',
                'from_edge': (source, target),
                'to_path': alt_path,
                'timestamp': current_time,
                'risk': risky_link['probability']
            }
            actions.append(action)
        else:
            action = {
                'type': 'INCREASE_POWER',
                'edge': (source, target),
                'timestamp': current_time,
                'risk': risky_link['probability']
            }
            actions.append(action)

    return actions

if __name__ == "__main__":
    print("Generating disaster network dataset...")
    graph = generate_disaster_dataset(duration=100, num_nodes=10)
    print(f"Generated {len(graph.event_stream)} events")
    print(f"Current nodes: {len(graph.nodes)}")
    print(f"Current edges: {len(graph.edges)//2}")

    print("\nTraining TGN model...")
    model = train_tgn(graph, num_epochs=20, learning_rate=0.001, batch_size=16)

    print("\nPredicting link failures...")
    current_time = 95
    at_risk = predict_link_failures(model, graph, current_time, threshold=0.6)

    print(f"\nFound {len(at_risk)} at-risk links:")
    for link in at_risk[:5]:
        print(f"Edge {link['edge']}: {link['probability']:.3f} failure probability")
        print(f"  Signal: {link['signal_strength']:.1f} dBm")
        print(f"  Battery: {link['source_battery']:.1f}% / {link['target_battery']:.1f}%")

    print("\nExecuting self-healing actions...")
    actions = execute_self_healing(graph, at_risk[:3], current_time)

    for action in actions:
        print(f"\nAction: {action['type']}")
        if action['type'] == 'REROUTE':
            print(f"  From: {action['from_edge']}")
            print(f"  Path: {' -> '.join(map(str, action['to_path']))}")
        else:
            print(f"  Edge: {action['edge']}")
        print(f"  Risk: {action['risk']:.3f}")

Generating disaster network dataset...
Generated 1283 events
Current nodes: 10
Current edges: 0

Training TGN model...
Epoch 1/20, Loss: 0.6955
Epoch 2/20, Loss: 0.5652
Epoch 3/20, Loss: 0.4541
Epoch 4/20, Loss: 0.4172
Epoch 5/20, Loss: 0.3263
Epoch 6/20, Loss: 0.2790
Epoch 7/20, Loss: 0.2059
Epoch 8/20, Loss: 0.1854
Epoch 9/20, Loss: 0.1356
Epoch 10/20, Loss: 0.1335
Epoch 11/20, Loss: 0.1183
Epoch 12/20, Loss: 0.1174
Epoch 13/20, Loss: 0.1333
Epoch 14/20, Loss: 0.1105
Epoch 15/20, Loss: 0.0994
Epoch 16/20, Loss: 0.0967
Epoch 17/20, Loss: 0.0779
Epoch 18/20, Loss: 0.1211
Epoch 19/20, Loss: 0.1219
Epoch 20/20, Loss: 0.0958

Predicting link failures...

Found 0 at-risk links:

Executing self-healing actions...
