In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import networkx as nx
import json
import numpy as np
import os
import random


class GCNLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super(GCNLayer, self).__init__()
        self.linear = nn.Linear(in_features, out_features)

    def forward(self, node_features, adj_matrix):
        # Normalize adjacency matrix
        deg = torch.sum(adj_matrix, dim=1)
        deg_inv_sqrt = torch.pow(deg + 1e-6, -0.5)
        norm_adj = torch.mul(
            torch.mul(adj_matrix, deg_inv_sqrt.unsqueeze(1)),
            deg_inv_sqrt.unsqueeze(0)
        )
        
        # Perform graph convolution
        aggregated_features = torch.matmul(norm_adj, node_features)
        transformed_features = self.linear(aggregated_features)
        return transformed_features



class MyGCN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(MyGCN, self).__init__()
        self.gcn1 = GCNLayer(input_dim, hidden_dim)
        self.gcn2 = GCNLayer(hidden_dim, output_dim)
        self.dropout = nn.Dropout(0.2)
        
    def forward(self, adj_matrix, node_features):
        # First GCN layer with ReLU and dropout
        h1 = F.relu(self.gcn1(node_features, adj_matrix))
        h1 = self.dropout(h1)
        # Second GCN layer
        h2 = self.gcn2(h1, adj_matrix)
        return h2


class MyHeuristic():
    def __init__(self, params):
        self.params = params

    def __call__(self, gnn_embeddings, raw_features, available_jobs):
      #Implement your heuristic function
       priorities = {}
       for job_id in available_jobs:
           priority = raw_features["jobs"][f"job_{job_id}"]["priority"]
           priorities[f"job_{job_id}"]=priority
       if priorities:
        action = max(priorities, key = priorities.get)
       else:
          action = None
       return action


class FMSDatasetGenerator:
    """Generates datasets for an FMS scheduling problem."""

    def __init__(self, num_jobs, num_machines, num_ops_per_job_range, processing_time_range,
                 failure_rate, setup_time_range, resource_capacity, load_factor,
                 dynamic_arrival=True, fixed_routing=False):
        """
        Initializes the generator.
        Args:
            num_jobs (int or tuple): The number of jobs, or a range for the number of jobs
            num_machines (int): The number of machines.
            num_ops_per_job_range (tuple): Range of operations a job can have.
            processing_time_range (tuple): Processing time range (min, max).
            failure_rate (float): Probability of machine failure (per simulation step).
            setup_time_range (tuple): range for setup times (min, max).
            resource_capacity(int): number of resources shared by machines
            load_factor (float): Load factor for resource contention(percentage of machines that need shared resources).
            dynamic_arrival (bool): Whether to add dynamic arrivals
            fixed_routing (bool): whether the routing for each job is fixed
        """
        self.num_jobs = num_jobs if isinstance(num_jobs, int) else random.randint(*num_jobs)
        self.num_machines = num_machines
        self.num_ops_per_job_range = num_ops_per_job_range
        self.processing_time_range = processing_time_range
        self.failure_rate = failure_rate
        self.setup_time_range = setup_time_range
        self.resource_capacity = resource_capacity
        self.load_factor = load_factor
        self.dynamic_arrival = dynamic_arrival
        self.fixed_routing = fixed_routing
        self.machines = self._generate_machines()
        self.jobs = self._generate_jobs()
        self.resources = self._generate_resources()

    def _generate_resources(self):
      resources = {
          "resource_capacity": self.resource_capacity,
          "machines":  random.sample(self.machines, int(len(self.machines) * self.load_factor))
      }
      return resources

    def _generate_machines(self):
        """Generates machine data."""
        machines = []
        for machine_id in range(self.num_machines):
            machines.append({"machine_id": machine_id,
                               "status": "available",  # initially available
                               "current_job": None,
                               "next_available_time": 0,
                               "type": random.choice(["general", "special"])  # Example of different machine types
                              })
        return machines

    def _generate_jobs(self):
      """Generates job data."""
      jobs = []
      for job_id in range(self.num_jobs):
          num_ops = random.randint(*self.num_ops_per_job_range)
          ops = self._generate_job_ops(job_id, num_ops)
          jobs.append({"job_id": job_id, "operations": ops,
                       "arrival_time": 0 if not self.dynamic_arrival else random.randint(0,50),
                       "completed_operations":0,
                       "priority": random.randint(1,5),
                       "machine_seq": [op["machine"] for op in ops],
                       "current_operation":0
                       })
      return jobs

    def _generate_job_ops(self, job_id, num_ops):
        """Generates operations for a single job."""
        ops = []
        available_machines = self._get_available_machines()
        if self.fixed_routing:
          machine_seq = random.sample(available_machines, num_ops)
        else:
            machine_seq = []

        for op_id in range(num_ops):
          if self.fixed_routing:
            machine = machine_seq[op_id]
          else:
              machine = random.choice(available_machines)
          proc_time = self._triangular_distribution(*self.processing_time_range)
          setup_time = random.uniform(*self.setup_time_range) if op_id > 0 else 0.0
          ops.append({"job_id": job_id, "op_id": op_id, "machine": machine,
                          "processing_time": proc_time, "setup_time": setup_time,
                          "status": "pending"  # initially pending
                         })
        return ops

    def _triangular_distribution(self, min, max):
        mode = (min + max)/2
        u = random.uniform(0,1)
        if u <= (mode-min)/(max-min):
            return min + np.sqrt(u*(max-min)*(mode-min))
        else:
             return max-np.sqrt((1-u)*(max-min)*(max-mode))

    def _get_available_machines(self):
      available_machines = []
      for m in self.machines:
          available_machines.append(m["machine_id"])
      return available_machines


    def save_dataset(self, filename, save_dir):
        """Saves dataset to a JSON file."""
        os.makedirs(save_dir, exist_ok = True)
        dataset = {
            "jobs": self.jobs,
            "machines": self.machines,
            "resources": self.resources,
        }

        with open(os.path.join(save_dir,filename), 'w') as f:
            json.dump(dataset, f, indent=4)

    @classmethod
    def load_dataset(cls, filename):
        """Loads dataset from a JSON file."""
        with open(filename, 'r') as f:
            dataset = json.load(f)
        return dataset

def create_graph_and_features(dataset, current_time = 0):
  """
        Creates graph data and feature vectors from current dataset status.
        Args:
            dataset (dict): The loaded dataset from JSON
            current_time (int): current time of the simulation
        Returns:
            tuple: (graph_data (nx.Graph), heuristic_data (dict))
        """
  jobs = dataset.jobs
  machines = dataset.machines
  graph = nx.DiGraph()
  heuristic_data = {"jobs": {}, "machines": {}, "operations": {}}
   # 1. Node generation:
  # Add job nodes
  for job in jobs:
      graph.add_node(f"job_{job['job_id']}", type="job", **job)
      heuristic_data["jobs"][f"job_{job['job_id']}"] = _extract_job_features(job,current_time)

  # Add machine nodes
  for machine in machines:
    # Create a copy without 'type' or any other non-node property
      machine_node_data = {k: v for k, v in machine.items() if k not in ["type", "current_job"]}
      graph.add_node(f"machine_{machine['machine_id']}", type="machine", **machine_node_data)
      heuristic_data["machines"][f"machine_{machine['machine_id']}"] = _extract_machine_features(machine, current_time)

  # Add operation nodes
  for job in jobs:
      for op in job["operations"]:
          graph.add_node(f"op_{job['job_id']}_{op['op_id']}", type="operation", **op)
          heuristic_data["operations"][f"op_{job['job_id']}_{op['op_id']}"] = _extract_op_features(op, job, current_time)

  # 2. Edge Generation
  #Edges for operations in a job sequence
  for job in jobs:
      ops = job["operations"]
      for i in range(len(ops)-1):
          graph.add_edge(f"op_{job['job_id']}_{ops[i]['op_id']}",
                      f"op_{job['job_id']}_{ops[i+1]['op_id']}", type="op_seq")

  #Edges between operations and machines
  for job in jobs:
      for op in job["operations"]:
         machine_node = f"machine_{op['machine']}"
         graph.add_edge(f"op_{job['job_id']}_{op['op_id']}", machine_node, type="machine_link")

  #Disjunctive edges for operations in a machine
  for m in machines:
      ops_in_machine = []
      for job in jobs:
          for op in job["operations"]:
              if op["machine"] == m["machine_id"]:
                  ops_in_machine.append(f"op_{job['job_id']}_{op['op_id']}")
      for i in range(len(ops_in_machine)):
          for j in range(i+1, len(ops_in_machine)):
              graph.add_edge(ops_in_machine[i], ops_in_machine[j], type = "disjunctive_link")

  return graph, heuristic_data


def _extract_job_features(job, current_time):
    """
    Extracts a feature vector for job.
    Args:
         job (dict): current job
        current_time (int): current simulation time
    Returns:
         dict: features
    """
    time_since_arrival = current_time-job["arrival_time"]
    remaining_operations = len(job["operations"])- job["completed_operations"]
    features = {"job_id": float(job["job_id"]),
                "priority": float(job["priority"]),
                "time_since_arrival": float(time_since_arrival),
                "remaining_operations":float(remaining_operations),
                "completed_operations": float(job["completed_operations"])
                }
    return features

def _extract_machine_features(machine, current_time):
  """Extracts a feature vector for machine"""
  features = { "machine_id": float(machine["machine_id"]),
               "next_available_time": float(machine["next_available_time"] - current_time )
             }
  return features

def _extract_op_features(op, job, current_time):
    """Extracts a feature vector for operation."""
    features = { "job_id": float(job["job_id"]),
                "op_id": float(op["op_id"]),
                 "machine": float(op["machine"]),
                 "processing_time": float(op["processing_time"]),
                 "setup_time": float(op["setup_time"]),

    }
    return features
def prepare_gnn_input(graph, feature_data):
    """Prepares data for GNN with consistent feature dimensions."""
    # Define feature sizes for each node type
    job_feature_size = 5  # job_id, priority, time_since_arrival, remaining_ops, completed_ops
    machine_feature_size = 2  # machine_id, next_available_time
    operation_feature_size = 5  # job_id, op_id, machine, processing_time, setup_time
    
    # Use the maximum feature size for all nodes to ensure consistent dimensions
    max_feature_size = max(job_feature_size, machine_feature_size, operation_feature_size)
    
    node_features = []
    node_ids = []
    
    # Process all nodes and pad features to max_feature_size
    for node_id, node_data in graph.nodes(data=True):
        node_ids.append(node_id)
        
        if node_data["type"] == "job":
            features = list(feature_data["jobs"][f"job_{node_data['job_id']}"].values())
        elif node_data["type"] == "machine":
            features = list(feature_data["machines"][f"machine_{node_data['machine_id']}"].values())
        elif node_data["type"] == "operation":
            features = list(feature_data["operations"][f"op_{node_data['job_id']}_{node_data['op_id']}"].values())
        
        # Pad with zeros to reach max_feature_size
        padded_features = features + [0.0] * (max_feature_size - len(features))
        node_features.append(padded_features)
    
    # Convert to tensors
    node_features_tensor = torch.tensor(node_features, dtype=torch.float32)
    adj_matrix = nx.to_numpy_array(graph, nodelist=node_ids)
    adj_matrix_tensor = torch.tensor(adj_matrix, dtype=torch.float32)
    
    return node_features_tensor, adj_matrix_tensor

def run_scheduling_step(dataset, current_time, gnn_model, heuristic_algo):
    # Step 1: Create graph and features
    graph, feature_data = create_graph_and_features(dataset, current_time)

    # get available jobs
    available_jobs = []
    for job in dataset.jobs:
      if job["current_operation"]< len(job["operations"]):
           available_jobs.append(job["job_id"])
    # Step 2: Convert to tensors
    node_features, adj_matrix= prepare_gnn_input(graph, feature_data)
    # Step 3: GNN forward pass
    gnn_embeddings = gnn_model(adj_matrix, node_features)
    # Step 4: Heuristic Calculation
    action = heuristic_algo(gnn_embeddings, feature_data, available_jobs)

    return action, dataset

# --- 4. Main Loop ---
def update_system_state(dataset, selected_job, current_time):
    """Updates the system state with real job processing."""
    if selected_job is None:
        # Just update machine states if no job selected
        for machine in dataset.machines:
            if machine['next_available_time'] <= current_time:
                machine['status'] = 'available'
                machine['current_job'] = None
        return dataset

    job_id = int(selected_job.split('_')[1])
    
    # Find the selected job
    for job in dataset.jobs:
        if job['job_id'] == job_id:
            # Check if job still has operations
            if job['current_operation'] >= len(job['operations']):
                return dataset
                
            current_op = job['operations'][job['current_operation']]
            machine_id = current_op['machine']
            
            # Find the required machine
            for machine in dataset.machines:
                if machine['machine_id'] == machine_id and machine['status'] == 'available':
                    # Start processing the job
                    processing_time = current_op['processing_time']
                    setup_time = current_op['setup_time']
                    total_time = processing_time + setup_time
                    
                    # Update machine status
                    machine['status'] = 'busy'
                    machine['current_job'] = job_id
                    machine['next_available_time'] = current_time + total_time
                    
                    # Update operation status
                    current_op['status'] = 'completed'
                    
                    # Move to next operation
                    job['current_operation'] += 1
                    job['completed_operations'] += 1
                    
                    print(f"\nTime {current_time}:")
                    print(f"- Started job {job_id} operation {current_op['op_id']} on machine {machine_id}")
                    print(f"- Processing time: {processing_time}, Setup time: {setup_time}")
                    print(f"- Machine will be available at time {machine['next_available_time']}")
                    break
    
    # Update all machine states
    for machine in dataset.machines:
        if machine['next_available_time'] <= current_time:
            machine['status'] = 'available'
            machine['current_job'] = None
            
    return dataset

def main_simulation(train_data_dir, val_data_dir, test_data_dir, gnn_model, heuristic_algo, train_iters):
    """Main simulation loop with detailed progress tracking."""
    for i in range(train_iters):
        for filename in os.listdir(train_data_dir):
            print(f"\nProcessing dataset: {filename}")
            
            # Load dataset
            loaded_generator = FMSDatasetGenerator.load_dataset(os.path.join(train_data_dir, filename))
            
            # Initialize simulation
            current_time = 0
            num_steps = 20
            current_dataset = FMSDatasetGenerator(
                num_jobs=len(loaded_generator["jobs"]),
                num_machines=len(loaded_generator["machines"]),
                num_ops_per_job_range=(3, 8),
                processing_time_range=(2, 20),
                failure_rate=0.05,
                setup_time_range=(0, 3),
                resource_capacity=6,
                load_factor=0.6
            )
            current_dataset.jobs = loaded_generator["jobs"]
            current_dataset.machines = loaded_generator["machines"]
            current_dataset.resources = loaded_generator["resources"]
            
            # Print initial state
            print("\nInitial state:")
            print(f"Total jobs: {len(current_dataset.jobs)}")
            print(f"Total machines: {len(current_dataset.machines)}")
            
            # Simulation loop
            while current_time < num_steps:
                # Print periodic status updates
                if current_time % 10 == 0:
                    print(f"\n=== Time {current_time} Status ===")
                    completed_jobs = sum(1 for job in current_dataset.jobs 
                                       if job['completed_operations'] == len(job['operations']))
                    in_progress = sum(1 for machine in current_dataset.machines 
                                    if machine['status'] == 'busy')
                    print(f"Completed jobs: {completed_jobs}")
                    print(f"Machines busy: {in_progress}")
                
                # Run scheduling step
                action, current_dataset = run_scheduling_step(
                    current_dataset, 
                    current_time, 
                    gnn_model, 
                    heuristic_algo
                )
                
                # Check if all jobs are complete
                all_complete = all(job['completed_operations'] == len(job['operations']) 
                                 for job in current_dataset.jobs)
                if all_complete:
                    print(f"\nAll jobs completed at time {current_time}!")
                    break
                
                current_time += 1
            
            # Print final statistics
            completed_jobs = sum(1 for job in current_dataset.jobs 
                               if job['completed_operations'] == len(job['operations']))
            total_operations = sum(len(job['operations']) for job in current_dataset.jobs)
            completed_operations = sum(job['completed_operations'] for job in current_dataset.jobs)
            
            print("\nFinal Statistics:")
            print(f"Total time steps: {current_time}")
            print(f"Completed jobs: {completed_jobs}/{len(current_dataset.jobs)}")
            print(f"Completed operations: {completed_operations}/{total_operations}")
            print("Machine utilization:")
            for machine in current_dataset.machines:
                busy_time = sum(1 for job in current_dataset.jobs 
                              for op in job['operations'] 
                              if op['machine'] == machine['machine_id'] and op['status'] == 'completed')
                utilization = (busy_time / current_time) * 100 if current_time > 0 else 0
                print(f"- Machine {machine['machine_id']}: {utilization:.1f}% utilized")
    # --- Test Loop ---
    for filename in test_files:
     loaded_generator = FMSDatasetGenerator.load_dataset(os.path.join(test_data_dir,filename))
      # Dummy data for FMSDatasetGenerator
     dummy_instance = FMSDatasetGenerator(num_jobs=len(loaded_generator["jobs"]),
                                          num_machines=len(loaded_generator["machines"]),
                                            num_ops_per_job_range=(3, 8), processing_time_range=(2, 20),
                                            failure_rate=0.05, setup_time_range=(0, 3), resource_capacity=6,
                                           load_factor=0.6)
     dummy_instance.jobs = loaded_generator["jobs"]
     dummy_instance.machines = loaded_generator["machines"]
     dummy_instance.resources = loaded_generator["resources"]

     current_time = 0
     num_steps = 100
     current_dataset = dummy_instance
     while current_time < num_steps:
         action, current_dataset = run_scheduling_step(current_dataset, current_time, gnn_model, heuristic_algo)
            # update the system state based on the action, compute evaluation metrics
         current_time += 1

    # --- Test Loop ---
    for filename in test_files:
         dataset = FMSDatasetGenerator.load_dataset(os.path.join(test_data_dir,filename))
        # Dummy data for FMSDatasetGenerator
         dummy_instance = FMSDatasetGenerator(num_jobs=len(dataset["jobs"]), num_machines=len(dataset["machines"]),
                                            num_ops_per_job_range = (3,8), processing_time_range = (2,20),
                                            failure_rate=0.05, setup_time_range=(0,3), resource_capacity = 6,
                                            load_factor=0.6)
         dummy_instance.jobs = dataset["jobs"]
         dummy_instance.machines = dataset["machines"]
         dummy_instance.resources = dataset["resources"]

         current_time = 0
         num_steps = 100
         current_dataset = dummy_instance
         while current_time < num_steps:
             action, current_dataset = run_scheduling_step(current_dataset, current_time, gnn_model, heuristic_algo)
             # update the system state based on the action, compute evaluation metrics
             current_time += 1

# --- Setup and Run ---
input_dim = 5  # This should match max_feature_size from prepare_gnn_input
hidden_dim = 64
output_dim = 32

my_gnn_model = MyGCN(input_dim, hidden_dim, output_dim) # Dummy Implementation
my_heuristic = MyHeuristic(params = None) # Dummy Implementation
train_data_dir = 'fms_train_data'
val_data_dir = 'fms_val_data'
test_data_dir = 'fms_test_data'
main_simulation(train_data_dir,val_data_dir, test_data_dir, my_gnn_model, my_heuristic, train_iters = 1)



Processing dataset: fms_dataset_2447.json

Initial state:
Total jobs: 26
Total machines: 10

=== Time 0 Status ===
Completed jobs: 0
Machines busy: 0

=== Time 10 Status ===
Completed jobs: 0
Machines busy: 0

Final Statistics:
Total time steps: 20
Completed jobs: 0/26
Completed operations: 0/154
Machine utilization:
- Machine 0: 0.0% utilized
- Machine 1: 0.0% utilized
- Machine 2: 0.0% utilized
- Machine 3: 0.0% utilized
- Machine 4: 0.0% utilized
- Machine 5: 0.0% utilized
- Machine 6: 0.0% utilized
- Machine 7: 0.0% utilized
- Machine 8: 0.0% utilized
- Machine 9: 0.0% utilized

Processing dataset: fms_dataset_4295.json

Initial state:
Total jobs: 91
Total machines: 10

=== Time 0 Status ===
Completed jobs: 0
Machines busy: 0

=== Time 10 Status ===
Completed jobs: 0
Machines busy: 0

Final Statistics:
Total time steps: 20
Completed jobs: 0/91
Completed operations: 0/501
Machine utilization:
- Machine 0: 0.0% utilized
- Machine 1: 0.0% utilized
- Machine 2: 0.0% utilized
- Machine