In [84]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import matplotlib.pyplot as plt
import time
from data_generation import generate_data
from visualizations import plot_tour_graph

class PolicyNetwork(nn.Module):
    def __init__(self, input_dim, action_dim, hidden_dim=128):
        super(PolicyNetwork, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, action_dim)
        )

    def forward(self, x):
        return torch.softmax(self.network(x), dim=-1)

class ConstrainedRLAgent:
    def __init__(self, state_dim, action_dim, lr=0.001, gamma=0.99, beta=0.1, alpha=0.01):
        self.policy_net = PolicyNetwork(state_dim, action_dim)
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)
        self.gamma = gamma
        self.beta = beta
        self.alpha = alpha  # Learning rate for Lagrange multiplier
        self.lambda_penalty = 0.0  # Initialize Lagrange multiplier

    def select_action(self, state, mask, epsilon):
        state_t = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
        raw_probs = self.policy_net(state_t).squeeze(0)
        mask_t = torch.tensor(mask, dtype=torch.float32)
        masked_probs = raw_probs * mask_t

        if masked_probs.sum().item() <= 1e-12:  # Handle zero-probability case
            available_actions = torch.nonzero(mask_t, as_tuple=False).flatten()
            uniform_probs = torch.zeros_like(mask_t)
            uniform_probs[available_actions] = 1.0 / len(available_actions)
            masked_probs = uniform_probs
        else:
            masked_probs /= masked_probs.sum()

        if np.random.rand() < epsilon:
            available_actions = np.nonzero(mask)[0]
            action = np.random.choice(available_actions)
            log_prob = torch.log(masked_probs[action] + 1e-8)
        else:
            m = Categorical(masked_probs)
            action = m.sample()
            log_prob = m.log_prob(action)

        return action.item(), log_prob

    def compute_loss(self, rewards, log_probs, constraints):
        discounted_rewards = []
        cumulative = 0.0
        for r in reversed(rewards):
            cumulative = r + self.gamma * cumulative
            discounted_rewards.insert(0, cumulative)
        discounted_rewards = torch.tensor(discounted_rewards, dtype=torch.float32)

        log_probs = torch.stack(log_probs)
        constraints = torch.stack(constraints)

        policy_loss = -(log_probs * discounted_rewards).mean()
        constraint_loss = constraints.mean()
        loss = policy_loss + self.lambda_penalty * constraint_loss
        return loss, constraint_loss.item()

    def update_policy(self, rewards, log_probs, constraints, actual_violation):
        loss, constraint_loss = self.compute_loss(rewards, log_probs, constraints)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Update Lagrange multiplier dynamically
        self.lambda_penalty = max(0.0, self.lambda_penalty + self.alpha * (actual_violation))

def environment_step(state, action, travel_time, time_windows, visited, penalty_scale=2.0):
    current_city = state['current_city']
    current_time = state['current_time']

    travel_cost = travel_time[current_city, action]
    arrival_time = current_time + travel_cost
    earliest, latest = time_windows[action]

    violation = 0
    early_violation = 0
    late_violation = 0

    if arrival_time < earliest:
        violation = penalty_scale * (earliest - arrival_time)
        early_violation = 1
        arrival_time = earliest
    elif arrival_time > latest:
        violation = penalty_scale * (arrival_time - latest)
        late_violation = 1

    next_state = {
        'current_city': action,
        'current_time': arrival_time
    }
    visited[action] = True
    done = all(visited)

    reward = -(travel_cost + violation)
    return next_state, reward, violation, early_violation, late_violation, done

def compute_route_cost(route, travel_time, time_windows):
    current_time = 0
    total_cost = 0
    for i in range(len(route) - 1):
        current_city = route[i]
        next_city = route[i + 1]
        t_time = travel_time[current_city, next_city]
        current_time += t_time

        earliest, latest = time_windows[next_city]
        penalty = 0
        if current_time < earliest:
            penalty = earliest - current_time
            current_time = earliest
        elif current_time > latest:
            penalty = current_time - latest

        total_cost += t_time + penalty
    return total_cost

def main():
    num_cases = 5
    num_cities = 10

    total_costs = []
    total_penalties = []
    total_early_violations = []
    total_late_violations = []
    solve_times = []
    training_times = []  # New list to track training times

    for test_id in range(num_cases):
        coords, time_windows, travel_time = generate_data(num_cities, seed=test_id)

        state_dim = 2  # current_city, current_time
        action_dim = num_cities

        agent = ConstrainedRLAgent(state_dim, action_dim)
        best_route = None
        best_total_reward = float('-inf')

        # Track training time
        training_start_time = time.time()
        for episode in range(5000):
            visited = [False] * num_cities
            visited[0] = True
            state = {'current_city': 0, 'current_time': 0}
            route = [0]
            rewards, log_probs, violations = [], [], []
            early_violations, late_violations = 0, 0
            done = False

            while not done:
                mask = [0 if v else 1 for v in visited]
                s = np.array([state['current_city'], state['current_time']], dtype=np.float32)

                epsilon = max(0.1, 1.0 - episode / 5000)
                action, log_prob = agent.select_action(s, mask, epsilon)

                next_state, reward, penalty, early, late, done = environment_step(
                    state, action, travel_time, time_windows, visited
                )

                rewards.append(reward)
                log_probs.append(log_prob)
                violations.append(torch.tensor(penalty, dtype=torch.float32))
                early_violations += early
                late_violations += late

                route.append(action)
                state = next_state

            actual_violation = sum(v.item() for v in violations)  # Total violations
            agent.update_policy(rewards, log_probs, violations, actual_violation)
            total_reward = sum(rewards)

            if total_reward > best_total_reward:
                best_route = route
                best_total_reward = total_reward

        training_time = time.time() - training_start_time  # End training time
        training_times.append(training_time)

        # Track solve time
        solve_start_time = time.time()
        if best_route[-1] != 0:
            best_route.append(0)

        total_cost = compute_route_cost(best_route, travel_time, time_windows)
        solve_time = time.time() - solve_start_time  # End solve time

        total_penalty = sum(v.item() for v in violations)

        total_costs.append(total_cost)
        total_penalties.append(total_penalty)
        total_early_violations.append(early_violations)
        total_late_violations.append(late_violations)
        solve_times.append(solve_time)

    print(f"Average Total Cost: {np.mean(total_costs):.2f}")
    print(f"Average Training Time: {np.mean(training_times):.2f}s")  # Output training time
    print(f"Average Solve Time: {np.mean(solve_times):.2f}s")
    print(f"Average Total Penalty: {np.mean(total_penalties):.2f}")
    print(f"Average Early Violations: {np.mean(total_early_violations):.2f}")
    print(f"Average Late Violations: {np.mean(total_late_violations):.2f}")

if __name__ == "__main__":
    main()


Average Total Cost: 1378.63
Average Training Time: 37.29s
Average Solve Time: 0.00s
Average Total Penalty: 2978.61
Average Early Violations: 0.80
Average Late Violations: 8.00
