train satellite nwteork

In [None]:
# === Imports ===
import os
import random
import torch
import torch.optim as optim
import matplotlib.pyplot as plt
import pandas as pd
from collections import defaultdict

from src.dataloader import load_all_data
from src.graph_builder import build_hetero_graph
from src.model import SatGatewayCellGNN
from src.model2 import SatelliteRLHead, DQNAgent  

# === Paths ===
folder_path = r"C:\Users\aruna\Desktop\MS Thesis\Real Data\Final folder real data"
cell_file = r"C:\Users\aruna\Desktop\MS Thesis\Real Data\cells.csv"
gateway_file = r"C:\Users\aruna\Desktop\MS Thesis\Real Data\gateways.csv"

# === Normalize function ===
def normalize(x):
    return (x - x.mean(dim=0)) / (x.std(dim=0) + 1e-6)

# === Reward Function ===
def calculate_rewards(actions, data, num_gateways, num_cells):
    rewards = []
    gateway_assignments = actions[actions < num_gateways]
    cell_assignments = actions[(actions >= num_gateways) & (actions < num_gateways + num_cells)]

    gateway_coverage = gateway_assignments.unique().numel() / num_gateways
    cell_coverage = cell_assignments.unique().numel() / num_cells

    for action in actions:
        if action < num_gateways:
            rewards.append(1.0 * gateway_coverage)  # reward proportional to gateway coverage
        elif action < num_gateways + num_cells:
            rewards.append(1.5 * cell_coverage)  # higher reward for cell coverage
        else:
            rewards.append(-1.0)  # invalid action
    return torch.tensor(rewards, dtype=torch.float32)

# === Setup ===
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# === Snapshot files ===
all_snapshot_files = sorted([f for f in os.listdir(folder_path) if f.endswith('.csv')])
snapshot_files = all_snapshot_files[:10]  # Use the first 10 snapshots

# === Model State ===
hidden_dim = 64
memory_dict = {}
bin_size = 1.0

# === Initialize model (after first snapshot loaded) ===
first_snapshot = snapshot_files[0]
satellites, gateways, cells = load_all_data(folder_path, cell_file, gateway_file, snapshot_filename=first_snapshot)
data, visibility_matrices = build_hetero_graph(
    satellites, gateways, cells, timestep=0, structured_neighbors=True
)

input_dims = {
    'sat': data['sat'].x.shape[1],
    'gateway': data['gateway'].x.shape[1],
    'cell': data['cell'].x.shape[1]
}

num_gateways = data['gateway'].num_nodes
num_cells = data['cell'].num_nodes
action_dim = num_gateways + num_cells  # Total possible actions

# === Build GNN + RL Head + DQN Agent ===
gnn_model = SatGatewayCellGNN(hidden_dim, num_gateways, num_cells, input_dims).to(device)
rl_head = SatelliteRLHead(hidden_dim, action_dim).to(device)
agent = DQNAgent(gnn_model, rl_head, action_dim, device)

# === Store for plotting ===
reward_per_snapshot = []
gateway_coverage_list = []
cell_coverage_list = []


# === Training across snapshots ===
for idx, file in enumerate(snapshot_files):
    print(f"\nTraining on snapshot {idx+1}/{len(snapshot_files)}: {file}")

    satellites, gateways, cells = load_all_data(folder_path, cell_file, gateway_file, snapshot_filename=file)
    data, visibility_matrices = build_hetero_graph(
        satellites, gateways, cells, timestep=idx, structured_neighbors=True
    )

    for node_type in ['sat', 'gateway', 'cell']:
        data[node_type].x = normalize(data[node_type].x)

    data = data.to(device)

    # === Forward pass ===
    outputs = agent.gnn_model(data, visibility_matrices)
    satellite_embeddings = outputs['sat_memory_out']

    # === Action Selection ===
    actions = agent.select_action(satellite_embeddings)
    # === Calculate Coverage ===
    num_satellites = data['sat'].num_nodes
    num_gateways = data['gateway'].num_nodes
    num_cells = data['cell'].num_nodes

    gateway_assignments = actions[actions < num_gateways]
    cell_assignments = actions[(actions >= num_gateways) & (actions < num_gateways + num_cells)]

    gateway_coverage = gateway_assignments.unique().numel() / num_gateways  # fraction of gateways connected
    cell_coverage = cell_assignments.unique().numel() / num_cells            # fraction of cells served

    gateway_coverage_list.append(gateway_coverage)
    cell_coverage_list.append(cell_coverage)


    # === Calculate Rewards ===
    rewards = calculate_rewards(actions, data, num_gateways, num_cells)


    # === Save to Replay Buffer ===
    agent.replay_buffer.push(
        satellite_embeddings.detach(), actions, rewards, satellite_embeddings.detach()
    )

    # === Optimization Step ===
    agent.optimize_model()

    # === Logging ===
    avg_reward = rewards.mean().item()
    reward_per_snapshot.append(avg_reward)
    print(f"Snapshot {idx+1} - Avg Reward: {avg_reward:.2f}")
    print(f"Gateway Coverage: {gateway_coverage:.2%}, Cell Coverage: {cell_coverage:.2%}")


# === Save final model ===
torch.save(agent.gnn_model.state_dict(), "trained_gnn_model_dqn.pt")
torch.save(agent.rl_head.state_dict(), "trained_rl_head_dqn.pt")
print("\n Models saved: 'trained_gnn_model_dqn.pt' and 'trained_rl_head_dqn.pt'")

# === Plot Rewards Across Snapshots ===
plt.figure(figsize=(8, 5))
plt.plot(reward_per_snapshot, marker='o')
plt.xlabel('Snapshot Index')
plt.ylabel('Average Reward')
plt.title('Average Reward per Snapshot')
plt.grid(True)
plt.tight_layout()
plt.show()




Training on snapshot 1/10: file_data_00_00_00.csv
Snapshot 1 - Avg Reward: 0.01
Gateway Coverage: 0.00%, Cell Coverage: 0.35%

Training on snapshot 2/10: file_data_00_00_20.csv
