In [9]:
import wandb
import ast
import argparse
import os
import networkx as nx
from lineflow.helpers import get_device
from graph_utils import get_line_graph, build_graph_from_info

In [10]:
import numpy as np
from lineflow.simulation import (
    Source,
    Sink,
    Line,
    Assembly,
)

def make_optimal_agent(line):

    waiting_times = line['S_component'].state['waiting_time'].categories
    processing_time_source = line['S_component'].processing_time

    def agent(state, env):
        """
        A policy that can effectively set float waiting times by
        alternating between ints
        """
        time_assembly = state['Assembly']['processing_time'].value + 1 + 1 + 1.1
        time_source = processing_time_source*1.1 + 1.1
        waiting_time = time_assembly - time_source

        index = np.argmin(np.abs(waiting_times - waiting_time))
        actions = {}
        actions['S_component'] = {'waiting_time': index}
        return actions
    return agent


class WTAssembly(Assembly):

    def __init__(
        self,
        name,
        R=0.75,
        t_jump_max=2000,
        **kwargs,
    ):

        self.R = R
        self.t_jump_max = t_jump_max
        self.trigger_time = None
        self.factor = None

        super().__init__(name=name,  **kwargs)

    def init(self, random):
        """
        Function that is called after line is built, so all available information is present
        """
        super().init(random)

        self._sample_trigger_time()

    def _compute_scaling_factor(self, T_jump, E=3.1):

        T = self.processing_time
        S = self.processing_std
        T_sim = self.t_jump_max*2

        return 1/T*((T_jump*(T+S+E)) / ((self.R-1)*T_sim+T_jump) - S -E)


    def _sample_trigger_time(self):

        self.t_jump = np.random.uniform(
            0.8*self.t_jump_max,
            self.t_jump_max,
        )

        self.factor = self._compute_scaling_factor(self.t_jump)
        self.trigger_time = self.random.uniform(0.25, 0.75)*self.t_jump_max

    def _sample_exp_time(self, time=None, scale=None, rework_probability=0):
        """
        Samples a time from an exponential distribution
        """
        coeff = self.get_performance_coefficient()
        if self.trigger_time < self.env.now < self.trigger_time + self.t_jump:
            factor = self.factor
        else: 
            factor = 1

        return time*factor*coeff + self.random.exponential(scale=scale)



class WaitingTime(Line):
    def __init__(
        self, 
        processing_time_source=5, 
        transition_time=5, 
        with_jump=False,
        t_jump_max=None,
        assembly_condition=35,
        scrap_factor=1,
        R=0.75,
        **kwargs,
    ):
        self.processing_time_source = processing_time_source
        self.transition_time = transition_time
        self.with_jump = with_jump
        self.t_jump_max = t_jump_max
        self.assembly_condition = assembly_condition
        self.R = R
        self.components = None
        self.static_graph_info = None
        if self.with_jump:
            assert self.t_jump_max is not None
        super().__init__(scrap_factor=scrap_factor, **kwargs)

    def build(self):
        # builds the line graph and records the graph information
        # ran in _make_objects of the line class
        self.graph_info = {
            'nodes': {},
            'edges': []
        }
        source_main = Source(
            'S_main',
            position=(300, 300),
            processing_time=0,
            carrier_capacity=2,
            actionable_waiting_time=False,
            unlimited_carriers=True,
        )
            
        # Record node information
        self.graph_info['nodes']['S_main'] = {
            'type': 'Source'
        }
        source_component = Source(
            'S_component',
            position=(500, 450),
            processing_time=self.processing_time_source,
            waiting_time=0,
            waiting_time_step=1,
            carrier_capacity=1,
            part_specs=[{
                "assembly_condition": self.assembly_condition
            }],
            unlimited_carriers=True,
            actionable_waiting_time=True,
        )
        self.graph_info['nodes']['S_component'] = {
            'type': 'Source',
        }
        if self.with_jump:
            assembly = WTAssembly(
                'Assembly',
                t_jump_max=self.t_jump_max,
                position=(500, 300),
                R=self.R,
                processing_time=20,
                NOK_part_error_time=5,
            )
            self.graph_info['nodes']['Assembly'] = {
                'type': 'WTAssembly'
            }
        else:
            assembly = Assembly(
                'Assembly',
                position=(500, 300),
                processing_time=20,
                NOK_part_error_time=5,
            )
            self.graph_info['nodes']['Assembly'] = {
                'type': 'Assembly'
            }


        sink = Sink('Sink', processing_time=0, position=(700, 300))
        self.graph_info['nodes']['Sink'] = {
            'type': 'Sink'
        }
        assembly.connect_to_component_input(
            station=source_component,
            capacity=3,
            transition_time=self.transition_time, # time needed to transfer parts from source to assembly
        )
        self.graph_info['edges'].append({
            'source': 'S_component',
            'target': 'Assembly'
        })
        assembly.connect_to_input(source_main, capacity=2, transition_time=2)
        self.graph_info['edges'].append({
            'source': 'S_main',
            'target': 'Assembly'
        })
        
        sink.connect_to_input(assembly, capacity=2, transition_time=2)

        self.graph_info['edges'].append({
            'source': 'Assembly',
            'target': 'Sink'
        })
        if self.static_graph_info is None:
            self.static_graph_info = self.graph_info

    def get_dynamic_graph(self):
        """
        Returns the dynamic graph of the line
        """
        # Enhance graph_info with current state
        dynamic_graph_info = self._create_dynamic_graph_info()
        
        return build_graph_from_info(dynamic_graph_info, output_format='heterodata')
    

    def _create_dynamic_graph_info(self):
        """Create graph info enhanced with current state data"""
        dynamic_graph_info = {
            'nodes': {},
            'edges': self.graph_info['edges'].copy(),
            'metadata': {
                'type': 'dynamic',
                'simulation_time': self.env.now if hasattr(self, 'env') else 0,
                'has_state_info': True
            }
        }
        
        # Enhance each node with current state
        for node_name, static_info in self.graph_info['nodes'].items():
            enhanced_node = static_info.copy()
            
            # Add current state if component exists
            if hasattr(self, 'components') and node_name in self.components:
                component = self.components[node_name]
                
                if hasattr(component, 'state') and component.state is not None:
                    # Add current state values as node features
                    enhanced_node['current_state'] = {}
                    enhanced_node['state_features'] = []
                    
                    for state_name, state in component.state.states.items():
                        enhanced_node['current_state'][state_name] = state.value
                        
                        # Create numerical features for hetero graph
                        if hasattr(state, 'value_index'):
                            enhanced_node['state_features'].append(float(state.value_index))
                        elif isinstance(state.value, (int, float)):
                            enhanced_node['state_features'].append(float(state.value))
                        elif isinstance(state.value, bool):
                            enhanced_node['state_features'].append(float(state.value))
                        else:
                            enhanced_node['state_features'].append(0.0)
            
            dynamic_graph_info['nodes'][node_name] = enhanced_node
        
        return dynamic_graph_info

In [13]:
line = WaitingTime()

In [5]:
from lineflow.learning.helpers import (
    make_stacked_vec_env,
)

In [6]:
line.graph_info

{'nodes': {'S_main': {'type': 'Source',
   'processing_time': 0,
   'carrier_capacity': 2,
   'actionable_waiting_time': False,
   'properties': {'is_main_source': True, 'controllable': False}},
  'S_component': {'type': 'Source',
   'processing_time': 5,
   'waiting_time': 0,
   'waiting_time_step': 1,
   'carrier_capacity': 1,
   'part_specs': [{'assembly_condition': 35}],
   'actionable_waiting_time': True,
   'properties': {'is_component_source': True, 'controllable': True}},
  'Assembly': {'type': 'Assembly',
   'processing_time': 20,
   'NOK_part_error_time': 5,
   'properties': {'has_jump_behavior': False, 'is_assembly': True}},
  'Sink': {'type': 'Sink',
   'processing_time': 0,
   'properties': {'is_sink': True}}},
 'edges': [{'source': 'S_component',
   'target': 'Assembly',
   'connection_type': 'component_input',
   'capacity': 3,
   'transition_time': 5,
   'properties': {'is_component_feed': True}},
  {'source': 'S_main',
   'target': 'Assembly',
   'connection_type': 'st

In [14]:
env_train = make_stacked_vec_env(
        line=line,
        simulation_end=100+1,
        reward="parts",
        n_envs=1,
        n_stack=1,
    )

In [15]:
print(env_train.observation_space)
print(env_train.action_space)

Box(0.0, [[  2. 199.   2.  99.   2.  inf  inf  inf   2.   1.   1.   1.]], (1, 12), float32)
MultiDiscrete([100])


In [16]:
env_train.reset()

(array([[ 1.,  0.,  0.,  0.,  1.,  0.,  1., 20.,  1.,  0.,  0.,  0.]],
       dtype=float32),
 {'name': 'WaitingTime', 'T': 1, 'n_parts': 0, 'n_scrap_parts': 0})

In [None]:
from stable_baselines3 import (
    PPO,
    A2C,
)

In [None]:
run = wandb.init(
        project='Lineflow',
        sync_tensorboard=True,
    )
log_path = os.path.join("./logs", run.id)

In [None]:
model_args = {
        "policy": 'MlpPolicy',
        "env": env_train,
        "n_steps": 500,
        "gamma": 0.99,  # discount factor
        "learning_rate": 3e-4,
        "use_sde": False,
        "normalize_advantage": True,
        "device": get_device(),
        "tensorboard_log": log_path,
        "stats_window_size": 10,
        "verbose": 0,
        "seed": None,
    }

In [None]:
model_cls = PPO
model_args["batch_size"] = 500  # mini-batch size
model_args["n_epochs"] = 5  # number of times to go over experiences with mini-batches
model_args["clip_range"] = 0.2
model_args["max_grad_norm"] = 0.5
model_args["ent_coef"] = 0.01

In [None]:
model = model_cls(**model_args)
model.learn(
        total_timesteps=10000
    )

In [None]:
line.graph_info['edges'][0]

In [None]:
def hetero_to_networkx(data):
    G = nx.MultiDiGraph()

    # Add nodes
    for node_type in data.node_types:
        for i in range(data[node_type].num_nodes):
            G.add_node((node_type, i), node_type=node_type)

    # Add edges with types
    for edge_type in data.edge_types:
        src_type, rel_type, dst_type = edge_type
        edge_index = data[edge_type].edge_index
        for src, dst in zip(edge_index[0], edge_index[1]):
            G.add_edge((src_type, int(src)), (dst_type, int(dst)), key=rel_type, rel_type=rel_type)

    return G


In [None]:
import matplotlib.pyplot as plt
G = hetero_to_networkx(data)

# Layout
pos = nx.spring_layout(G, seed=42)

# Draw nodes with type labels
node_colors = {
    'Assembly': 'skyblue',
    'Sink': 'lightgreen',
    'Source': 'salmon',
    'Switch': 'lightcoral'
}
for node_type in data.node_types:
    nx.draw_networkx_nodes(
        G, pos,
        nodelist=[n for n in G.nodes if n[0] == node_type],
        node_color=node_colors[node_type],
        label=node_type,
        node_size=500
    )

# Draw edges
nx.draw_networkx_edges(G, pos, arrows=True)

# Draw labels
nx.draw_networkx_labels(G, pos, labels={n: f"{n[0][0]}{n[1]}" for n in G.nodes}, font_size=10)

# Legend
plt.legend(handles=[
    plt.Line2D([0], [0], marker='o', color='w', label='Assembly', markerfacecolor='skyblue', markersize=10),
    plt.Line2D([0], [0], marker='o', color='w', label='Sink', markerfacecolor='lightgreen', markersize=10),
    plt.Line2D([0], [0], marker='o', color='w', label='Source', markerfacecolor='salmon', markersize=10),
    plt.Line2D([0], [0], marker='o', color='w', label='Switch', markerfacecolor='lightcoral', markersize=10)
])
plt.axis('off')
plt.show()


In [None]:
agent = make_optimal_agent(line)

In [None]:
data

In [None]:
data['Source']