In [12]:
import pandas as pd
import seaborn as sns
import traintools
from trainconstants import *
from collections import defaultdict
import networkx as nx

sns.set()
%matplotlib inline

df = pd.read_csv("datasets/nsdata1.txt",  delimiter=r"\s+")
df['departure_time'] = df['departure_time'].map(traintools.to_minutes_past_midnight)
df['arrival_time'] = df['arrival_time'].map(traintools.to_minutes_past_midnight)
df.head()

Unnamed: 0,train_number,from,to,departure_time,arrival_time,first_class,second_class
0,2123,2,3,420,460,4,58
1,2123,3,4,463,518,14,328
2,2127,1,2,408,475,47,340
3,2127,2,3,481,521,35,272
4,2127,3,4,523,578,19,181


In [13]:
minimum_number_of_type_3_trains = traintools.minimum_number_of_trains_at_time_t(df, TYPE_3_TRAIN)
rush_hour_peak = int(minimum_number_of_type_3_trains.idxmax())
minimum_number_of_type_3_trains.loc[rush_hour_peak]

number of trains    22.0
Name: 1055, dtype: float64

In [14]:
# creating the new "loop" in the schedule
start = minimum_number_of_type_3_trains.ne(0).idxmax()

minimum_number_of_type_3_trains['time'] = minimum_number_of_type_3_trains.index

tail = minimum_number_of_type_3_trains.loc[rush_hour_peak:].copy()
head = minimum_number_of_type_3_trains.loc[start[0]:rush_hour_peak].copy()
head['time'] = head['time'].map(lambda x: x + 1440)

new_loop = pd.concat([tail, head])
new_loop['time'] = new_loop['time'].map(lambda x:  x - rush_hour_peak)
new_loop.head()

# Do this logic but let new_loop also have the actual number of passengers stored somewhere

Unnamed: 0,number of trains,time
1055,22.0,0
1056,22.0,1
1057,22.0,2
1058,22.0,3
1059,18.0,4


In [25]:
def nodes(timetable):
    """Creates the nodes with all their connections.
    """

    nodes = [] 

    for index, row in timetable.iterrows():
        start, start_time, end, end_time = row['from'], row['departure_time'], row['to'], row['arrival_time']
        nodes.extend([(start, start_time), (end, end_time)])
   
    return nodes


nodes = nodes(df)
def node_dict(nodes):
    """Creates a dict with every station as a key 
    and for each station has all the connections sorted in time.
    """

    node_dict = defaultdict(list)
    for node in nodes:
        node_dict[node[0]].append(node)

    for key in node_dict.keys():
        node_dict[key].sort()

    return node_dict




def connnect_nodes(nodes, df):
    """Connect all the nodes according to the schedule.
    
    """

    g = nx.Graph()
    node_set = node_dict(nodes)
    for index, row in df.iterrows():
        start, start_time, end, end_time = row['from'], row['departure_time'], row['to'], row['arrival_time']
        passengers = (row['first_class'], row['second_class'])
        start = (start, start_time)
        end = (end, end_time)
        g.add_edge(start, end, weight=passengers)
        
    for key in node_set.keys():
        i=1
        for level_node in node_set[key]:
            for curr_node in node_set[key][i:i+1]:
                g.add_edge(level_node, curr_node, weight=(0, 0))
            i += 1

    return g

g = connnect_nodes(nodes, df)
print(type(g))


<class 'networkx.classes.graph.Graph'>
