In [None]:
# Imports

import os
from glob import glob
from datetime import datetime
import json
import matplotlib.pyplot as plt
from tqdm import trange
import pandas as pd
import numpy as np
import torch


# Load schedules

def get_data_files(directory="/proj/fair-ai/data"):
    return glob(os.path.join(directory, "*.csv"))
    
def load_data(select=None, directory="/proj/fair-ai/data"):    
    csv_files = glob(os.path.join(directory, "*.csv"))

    selected_files = []

    if select is not None:
        if isinstance(select, tuple) and len(select) == 2:
            start_date = datetime.strptime(select[0], "%Y-%m-%d")
            end_date = datetime.strptime(select[1], "%Y-%m-%d")

            for file in csv_files:
                file_date_str = file.split("_")[-2]
                file_date = datetime.strptime(file_date_str, "%Y-%m-%d")

                if start_date <= file_date <= end_date:
                    selected_files.append(file)

            if not selected_files:
                raise ValueError(f"No files found within the specified date range: {select}")
        else:
            raise ValueError("Invalid input. 'select' must be either None or a tuple of two date strings.")
    else:
        selected_files = [np.random.choice(csv_files)]

    dataframes = []
    for file in selected_files:
        df = pd.read_csv(file, usecols=["rdiTimestamp", "cell", "pmRrcConnLevSum", "pmRrcConnLevSamp"])
        df["load"] = df.pmRrcConnLevSum/df.pmRrcConnLevSamp
        df.rename({"rdiTimestamp": "time"}, axis=1, inplace=True)
        df["time"] = pd.to_datetime(df.time, format="%Y-%m-%d %H.%M")
        df.set_index("time", inplace=True)
        dataframes.append(df[["cell", "load"]])

    combined_df = pd.concat(dataframes).sort_index()
    combined_df_capped = handle_load_outliers(combined_df)
    
    bs_df = make_bs_df(combined_df_capped)

    return bs_df

def handle_load_outliers(df, q=0.999):
    df.load = df.load.clip(upper=df.load.quantile(q))
    return df

def make_bs_df(df):
    return df.pivot(columns="cell", values="load").dropna(axis=1)

def get_busy_time(df_bs):
    return df_bs.mean(axis=1).idxmax()

def get_max_load(df_bs):
    return df_bs.max().max()

def get_cells(df_bs):
    return df_bs.columns.tolist()

def choose_clients(cells, n_clients):
    return np.random.choice(cells, n_clients, replace=False)

def normalize_load(df_bs, max_load):
    return df_bs/max_load

def get_client_loads(df_bs, clients):
    df_clients = df_bs[clients]
    df_clients.columns = [f"{i}" for i in range(df_clients.shape[1])]
    return df_clients

def get_client_checkpoints(df_clients, start, stop):
    return df_clients.loc[start:stop]

def reduce_checkpoints(df_checkpoints):
    n_cp = 160
    ix = np.linspace(0, len(df_checkpoints) - 1, n_cp).round().astype(int)
    return df_checkpoints.iloc[ix]

def client_schedules_to_dict(df_client_schedules):
    return df_client_schedules.to_dict(orient="list")

def generate_schedules(df_bs, n_clients=8, start=None, stop=None, visualize=False):
    cells = get_cells(df_bs)
    clients = choose_clients(cells, n_clients)
    df_client_loads = get_client_loads(df_bs, clients)
    df_client_checkpoints = get_client_checkpoints(df_client_loads, start=start, stop=stop)
    df_client_reduced = reduce_checkpoints(df_client_checkpoints)
    
    if visualize:
        visualize_average_load(df_bs, start=start, stop=stop)
        visualize_client_load(df_client_reduced)
    
    return df_client_reduced.values.transpose()

def visualize_average_load(df_bs, start=None, stop=None):
    df_bs_avg = df_bs.mean(axis=1)

    plt.figure(figsize=(12, 4))
    plt.plot(df_bs_avg, label='Average Load')

    if start:
        plt.axvline(start, color='r', linestyle='-', label='Start')
    else:
        plt.axvline(df_bs_avg.index[0], color='r', linestyle='-', label='Start')

    if stop:
        plt.axvline(stop, color='g', linestyle='-.', label='Stop')
    else:
        plt.axvline(df_bs_avg.index[-1], color='g', linestyle='-.', label='Stop')

    plt.xlabel('Time')
    plt.ylabel('Average Load')
    plt.title('Average Load of Base Stations')
    plt.legend()
    plt.show()
    
def visualize_client_load(df_client):
    df_client = df_client[df_client.max().sort_values().index]

    client_ids = df_client.columns
    min_loads = df_client.min()
    avg_loads = df_client.mean()
    max_loads = df_client.max()

    ind = np.arange(len(client_ids))
    width = 0.5

    fig, ax = plt.subplots(figsize=(12, 4))
    
    p1 = ax.bar(ind, min_loads, width, label='Min Load')
    p2 = ax.bar(ind, avg_loads - min_loads, width, bottom=min_loads, label='Average Load')
    p3 = ax.bar(ind, max_loads - avg_loads, width, bottom=avg_loads, label='Max Load')

    ax.axhline(0, color='grey', linewidth=0.8)
    ax.set_ylabel('Load')
    ax.set_title('Min, Avg, and Max Load for each client')
    ax.set_xticks(ind)
    ax.set_xticklabels(client_ids)
    ax.legend()

    plt.show()
    

# Noise 

def checkpoint_noise_var(load_schedules, mean_var=0.5):
    cp_noise_var = 1/load_schedules
    scale = (mean_var / cp_noise_var.mean())
    return scale * cp_noise_var, scale


def sample_noise(cp_noise_var):
    cp_noise_std = np.sqrt(cp_noise_var)
    n_cp = 160
    cp_train_samps = 10
    cp_val_samps = 2
    cp_noise_std_train = np.expand_dims(cp_noise_std, axis=-1)
    cp_noise_std_val = np.expand_dims(cp_noise_std[:, :n_cp/2], axis=-1)
    train_noise = np.random.normal(0, cp_noise_std_train, size=cp_noise_std_train.shape[:-1] + (cp_train_samps,))
    val_noise = np.random.normal(0, cp_noise_std_val, size=cp_noise_std_val.shape[:-1] + (cp_val_samps,)).reshape(8, -1)
    return train_noise, val_noise


# Dynamics

def get_load_dynamics(mean_var=0.5, seed=None, visualize=False):
    np.random.seed(seed)
    df_bs = load_data(("2023-02-27", "2023-02-28"))
    load_schedules = generate_schedules(df_bs, 8, visualize=visualize)
    load_schedules = np.maximum(load_schedules, 1) # simplification
    noise_var, sample_var = checkpoint_noise_var(load_schedules, mean_var)
    train_noise, val_noise = sample_noise(noise_var)
    return torch.tensor(load_schedules), torch.tensor(noise_var), (torch.tensor(train_noise), torch.tensor(val_noise))

In [None]:
load_schedules, noise_var, (train_noise, val_noise) = get_load_dynamics(seed=100, visualize=True)