# Preprocess, train and use Trustee

## Imports for pre processing

In [16]:
import pandas as pd
import numpy as np
import pickle
import argparse

## Constants (Directly imported from the puffer paper)

In [17]:
VIDEO_DURATION = 180180
PKT_BYTES = 1500
MILLION = 1000000
PAST_CHUNKS = 8
FUTURE_CHUNKS = 5

## Steps:
    1. 2 CSVs video sent and video acked
    2. Parse them to relevant data types
    3. calculate the transmission time from sent and acked
    4. Consider the past 8 chunks and pad the values if necessary
    5. Create the examples for the future chunks

In [24]:
sent_df = pd.read_csv('/mnt/md0/cs190n/video_sent.csv')
acked_df = pd.read_csv('/mnt/md0/cs190n/video_acked.csv')

In [21]:
def prepare_raw_data(video_sent_path, video_acked_path, time_start=None, time_end=None):
    """
    Load data from files and calculate chunk transmission times.
    """
    video_sent_df = pd.read_csv(video_sent_path)
    video_acked_df = pd.read_csv(video_acked_path)

    # Rename "time (ns GMT)" to "time" for convenience
    video_sent_df.rename(columns={'time (ns GMT)': 'time'}, inplace=True)
    video_acked_df.rename(columns={'time (ns GMT)': 'time'}, inplace=True)

    # Convert nanosecond timestamps to datetime
    video_sent_df['time'] = pd.to_datetime(video_sent_df['time'], unit='ns')
    video_acked_df['time'] = pd.to_datetime(video_acked_df['time'], unit='ns')

    # Filter by time range
    if time_start:
        time_start = pd.to_datetime(time_start)
        video_sent_df = video_sent_df[video_sent_df['time'] >= time_start]
        video_acked_df = video_acked_df[video_acked_df['time'] >= time_start]
    if time_end:
        time_end = pd.to_datetime(time_end)
        video_sent_df = video_sent_df[video_sent_df['time'] <= time_end]
        video_acked_df = video_acked_df[video_acked_df['time'] <= time_end]

    # Process the data
    return calculate_trans_times(video_sent_df, video_acked_df)

In [22]:
def calculate_trans_times(video_sent_df, video_acked_df):
    """
    Calculate transmission times from video_sent and video_acked datasets using session_id.
    """
    d = {}
    last_video_ts = {}

    for _, row in video_sent_df.iterrows():
        session = row['session_id']  # Use only session_id to track sessions
        if session not in d:
            d[session] = {}
            last_video_ts[session] = None

        video_ts = int(row['video_ts'])
        if last_video_ts[session] is not None:
            if video_ts != last_video_ts[session] + VIDEO_DURATION:
                continue

        last_video_ts[session] = video_ts
        d[session][video_ts] = {
            'sent_ts': pd.Timestamp(row['time']),
            'size': float(row['size']) / PKT_BYTES,
            'delivery_rate': float(row['delivery_rate']) / PKT_BYTES,
            'cwnd': float(row['cwnd']),
            'in_flight': float(row['in_flight']),
            'min_rtt': float(row['min_rtt']) / MILLION,
            'rtt': float(row['rtt']) / MILLION,
        }

    for _, row in video_acked_df.iterrows():
        session = row['session_id']  # Use only session_id
        if session not in d:
            continue

        video_ts = int(row['video_ts'])
        if video_ts not in d[session]:
            continue

        dsv = d[session][video_ts]
        sent_ts = dsv['sent_ts']
        acked_ts = pd.Timestamp(row['time'])
        dsv['acked_ts'] = acked_ts
        dsv['trans_time'] = (acked_ts - sent_ts).total_seconds()

    return d

In [27]:
prepare_raw_data("/mnt/md0/cs190n/video_sent.csv", "/mnt/md0/cs190n/video_acked.csv")

{'RhMF72kUi5Yin0hEqud4YQJiQ7UmZgNxk9YGk6A8UL0=': {42209327160: {'sent_ts': Timestamp('2024-11-20 11:59:47.439000'),
   'size': 604.8413333333333,
   'delivery_rate': 866.1366666666667,
   'cwnd': 1090.0,
   'in_flight': 0.0,
   'min_rtt': 0.043153,
   'rtt': 0.054763,
   'acked_ts': Timestamp('2024-11-20 11:59:47.564000'),
   'trans_time': 0.125},
  42209507340: {'sent_ts': Timestamp('2024-11-20 11:59:47.568000'),
   'size': 631.3986666666667,
   'delivery_rate': 7770.301333333334,
   'cwnd': 1096.0,
   'in_flight': 0.0,
   'min_rtt': 0.043153,
   'rtt': 0.05542,
   'acked_ts': Timestamp('2024-11-20 11:59:47.693000'),
   'trans_time': 0.125},
  42209687520: {'sent_ts': Timestamp('2024-11-20 11:59:47.697000'),
   'size': 358.016,
   'delivery_rate': 7753.74,
   'cwnd': 1010.0,
   'in_flight': 0.0,
   'min_rtt': 0.043153,
   'rtt': 0.054114,
   'acked_ts': Timestamp('2024-11-20 11:59:47.788000'),
   'trans_time': 0.091},
  42209867700: {'sent_ts': Timestamp('2024-11-20 11:59:47.792000'),

In [28]:
def append_past_chunks(ds, next_ts, row):
    i = 1
    past_chunks = []
    while i <= PAST_CHUNKS:
        ts = next_ts - i * VIDEO_DURATION
        if ts in ds and 'trans_time' in ds[ts]:
            past_chunks = [ds[ts]['delivery_rate'],
                           ds[ts]['cwnd'], ds[ts]['in_flight'],
                           ds[ts]['min_rtt'], ds[ts]['rtt'],
                           ds[ts]['size'], ds[ts]['trans_time']] + past_chunks
        else:
            nts = ts + VIDEO_DURATION  # padding with the nearest ts
            padding = [ds[nts]['delivery_rate'],
                       ds[nts]['cwnd'], ds[nts]['in_flight'],
                       ds[nts]['min_rtt'], ds[nts]['rtt']]
            if nts == next_ts:
                padding += [0, 0]  # next_ts is the first chunk to send
            else:
                padding += [ds[nts]['size'], ds[nts]['trans_time']]
            break
        i += 1
    if i != PAST_CHUNKS + 1:  # break in the middle; padding must exist
        while i <= PAST_CHUNKS:
            past_chunks = padding + past_chunks
            i += 1
    row += past_chunks

In [29]:
def prepare_input_output(d):
    ret = [{'in': [], 'out': []} for _ in range(5)]  # FUTURE_CHUNKS = 5

    for session in d:
        ds = d[session]

        for next_ts in ds:
            if 'trans_time' not in ds[next_ts]:
                continue

            row = []

            # Append past chunks
            append_past_chunks(ds, next_ts, row)

            # Append the TCP info of the next chunk
            row += [ds[next_ts]['delivery_rate'],
                    ds[next_ts]['cwnd'], ds[next_ts]['in_flight'],
                    ds[next_ts]['min_rtt'], ds[next_ts]['rtt']]

            # Generate FUTURE_CHUNKS rows
            for i in range(5):  # FUTURE_CHUNKS = 5
                row_i = row.copy()

                ts = next_ts + i * VIDEO_DURATION
                if ts in ds and 'trans_time' in ds[ts]:
                    row_i += [ds[ts]['size']]

                    ret[i]['in'].append(row_i)
                    ret[i]['out'].append(ds[ts]['trans_time'])

    return ret

In [30]:
def save_processed_data(output_file, processed_data):
    """
    Save processed data to a file.
    """
    with open(output_file, 'wb') as f:
        pickle.dump(processed_data, f)
    print(f"Processed data saved to {output_file}")

In [33]:
if __name__ == '__main__':
    DEFAULT_VIDEO_SENT_PATH = '/mnt/md0/cs190n/video_sent.csv'
    DEFAULT_VIDEO_ACKED_PATH = '/mnt/md0/cs190n/video_acked.csv'
    DEFAULT_OUTPUT_FILE = '/home/satyandra/output.pkl'
    
    #Latest datasets can be found at https://puffer.stanford.edu/results/
    
    parser = argparse.ArgumentParser(description="Process video streaming datasets.")
    parser.add_argument('--video_sent_path', type=str, help='Path to the video_sent dataset CSV file.')
    parser.add_argument('--video_acked_path', type=str, help='Path to the video_acked dataset CSV file.')
    parser.add_argument('--output_file', type=str, help='Path to save the processed data.')
    parser.add_argument('--time_start', type=str, default=None, help='Start time for filtering data (RFC3339 format).')
    parser.add_argument('--time_end', type=str, default=None, help='End time for filtering data (RFC3339 format).')
    #args = parser.parse_args()
    #processed_data = prepare_input_output(prepare_raw_data(args.video_sent_path, args.video_acked_path,
    #    time_start=args.time_start, time_end=args.time_end))
    # save_processed_data(args.output_file, processed_data)
    processed_data = prepare_input_output(prepare_raw_data(DEFAULT_VIDEO_SENT_PATH, DEFAULT_VIDEO_ACKED_PATH,
        time_start=None, time_end=None))
    save_processed_data(DEFAULT_OUTPUT_FILE, processed_data)

Processed data saved to /home/satyandra/output.pkl


In [34]:
d = pd.read_pickle("/home/satyandra/output.pkl")

In [40]:
d[0]["out"][:5]

[0.125, 0.125, 0.091, 0.103, 0.12]

In [41]:
d[0]["in"][:5]

[[866.1366666666667,
  1090.0,
  0.0,
  0.043153,
  0.054763,
  0,
  0,
  866.1366666666667,
  1090.0,
  0.0,
  0.043153,
  0.054763,
  0,
  0,
  866.1366666666667,
  1090.0,
  0.0,
  0.043153,
  0.054763,
  0,
  0,
  866.1366666666667,
  1090.0,
  0.0,
  0.043153,
  0.054763,
  0,
  0,
  866.1366666666667,
  1090.0,
  0.0,
  0.043153,
  0.054763,
  0,
  0,
  866.1366666666667,
  1090.0,
  0.0,
  0.043153,
  0.054763,
  0,
  0,
  866.1366666666667,
  1090.0,
  0.0,
  0.043153,
  0.054763,
  0,
  0,
  866.1366666666667,
  1090.0,
  0.0,
  0.043153,
  0.054763,
  0,
  0,
  866.1366666666667,
  1090.0,
  0.0,
  0.043153,
  0.054763,
  604.8413333333333],
 [866.1366666666667,
  1090.0,
  0.0,
  0.043153,
  0.054763,
  604.8413333333333,
  0.125,
  866.1366666666667,
  1090.0,
  0.0,
  0.043153,
  0.054763,
  604.8413333333333,
  0.125,
  866.1366666666667,
  1090.0,
  0.0,
  0.043153,
  0.054763,
  604.8413333333333,
  0.125,
  866.1366666666667,
  1090.0,
  0.0,
  0.043153,
  0.054763,
  

## Model selection
1. Define the model.
2. Normalize the data
3. Discretize the labels
4. Plot the classification report
5. Use trustee to identify the important features
6. Plot the Trustee tree

In [42]:
#model side imports
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import torch
import pickle
import numpy as np
from os import path
from sklearn.metrics import mean_squared_error
import pandas as pd
from trustee import ClassificationTrustee
import graphviz
from sklearn import tree

In [43]:
# Constants
BATCH_SIZE = 32
NUM_EPOCHS = 500
DEVICE = torch.device('cpu')

inference = True

In [44]:
class Model:
    #Model constants
    PAST_CHUNKS = 8
    FUTURE_CHUNKS = 5
    DIM_IN = 62
    COLUMNS = [j + str(i) for i in range(PAST_CHUNKS + 1) for j in ['delivery_rate', 'cwnd', 'in_flight', 'min_rtt', 'rtt', 'size', 'trans_time']][:DIM_IN]
    DIM_OUT = 21  # BIN_MAX + 1
    DIM_H1 = 64
    DIM_H2 = 64
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-4
    BIN_SIZE = 0.5  # seconds
    BIN_MAX = 20

    def __init__(self):
        self.model = torch.nn.Sequential(
            torch.nn.Linear(Model.DIM_IN, Model.DIM_H1),
            torch.nn.ReLU(),
            torch.nn.Linear(Model.DIM_H1, Model.DIM_H2),
            torch.nn.ReLU(),
            torch.nn.Linear(Model.DIM_H2, Model.DIM_OUT),
        ).double().to(device=DEVICE)
        self.loss_fn = torch.nn.CrossEntropyLoss().to(device=DEVICE)
        self.optimizer = torch.optim.Adam(self.model.parameters(),
                                          lr=Model.LEARNING_RATE,
                                          weight_decay=Model.WEIGHT_DECAY)
        self.obs_size = None
        self.obs_mean = None
        self.obs_std = None

    def update_obs_stats(self, raw_in):
        if self.obs_size is None:
            self.obs_size = len(raw_in)
            self.obs_mean = np.mean(raw_in, axis=0)
            self.obs_std = np.std(raw_in, axis=0)
            return
        old_size = self.obs_size
        new_size = len(raw_in)
        self.obs_size = old_size + new_size
        old_mean = self.obs_mean
        new_mean = np.mean(raw_in, axis=0)
        self.obs_mean = (old_mean * old_size + new_mean * new_size) / self.obs_size
        old_std = self.obs_std
        old_sum_square = old_size * (np.square(old_std) + np.square(old_mean))
        new_sum_square = np.sum(np.square(raw_in), axis=0)
        mean_square = (old_sum_square + new_sum_square) / self.obs_size
        self.obs_std = np.sqrt(mean_square - np.square(self.obs_mean))

    def normalize_input(self, raw_in, update_obs=False):
        z = np.array(raw_in)
        if update_obs:
            self.update_obs_stats(z)
        assert self.obs_size is not None
        for col in range(len(self.obs_mean)):
            z[:, col] -= self.obs_mean[col]
            if self.obs_std[col] != 0:
                z[:, col] /= self.obs_std[col]
        return z

    def discretize_output(self, raw_out):
        z = np.array(raw_out)
        z = np.floor((z + 0.5 * Model.BIN_SIZE) / Model.BIN_SIZE).astype(int)
        return np.clip(z, 0, Model.BIN_MAX)

    def train(self, train_input, train_output, test_input, test_output):
        train_input = torch.from_numpy(self.normalize_input(train_input, update_obs=inference)).to(DEVICE)
        train_output = torch.from_numpy(self.discretize_output(train_output)).to(DEVICE)
        test_input = torch.from_numpy(self.normalize_input(test_input, update_obs=False)).to(DEVICE)
        test_output = torch.from_numpy(self.discretize_output(test_output)).to(DEVICE)

        for epoch in range(NUM_EPOCHS):
            self.model.train()
            perm = np.random.permutation(len(train_input))
            train_input = train_input[perm]
            train_output = train_output[perm]

            num_batches = int(np.ceil(len(train_input) / BATCH_SIZE))
            epoch_loss = 0

            for i in range(num_batches):
                start_idx = i * BATCH_SIZE
                end_idx = min((i + 1) * BATCH_SIZE, len(train_input))

                batch_input = train_input[start_idx:end_idx]
                batch_output = train_output[start_idx:end_idx]

                # Forward pass
                predictions = self.model(batch_input)
                loss = self.loss_fn(predictions, batch_output)

                # Backward pass
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                epoch_loss += loss.item()

            print(f"Epoch {epoch + 1}/{NUM_EPOCHS}, Loss: {epoch_loss / num_batches}")

            # Evaluate after each epoch
            self.evaluate(test_input, test_output)

    def load(self, model_path):
        checkpoint = torch.load(model_path)
        self.model.load_state_dict(checkpoint['model_state_dict'])

        self.obs_size = checkpoint['obs_size']
        self.obs_mean = checkpoint['obs_mean']
        self.obs_std = checkpoint['obs_std']

    def save(self, model_path):
        assert (self.obs_size is not None)

        torch.save({
            'model_state_dict': self.model.state_dict(),
            'obs_size': self.obs_size,
            'obs_mean': self.obs_mean,
            'obs_std': self.obs_std,
        }, model_path)

    def predict(self, x):
        with torch.no_grad():
            x = x.to_numpy()
            x = self.normalize_input(x, update_obs=inference)
            x = torch.from_numpy(x).to(DEVICE)
            y_scores = self.model(x)
            y_predicted = torch.max(y_scores, 1)[1].to(device=DEVICE)
            ret = y_predicted.detach().cpu().numpy()
            return ret
    def predict_discrete(self, x):
        with torch.no_grad():
            x = x.to_numpy()
            x = self.normalize_input(x, update_obs=inference)
            x = torch.from_numpy(x).to(DEVICE)
            y_scores = self.model(x)
            y_predicted = torch.max(y_scores, 1)[1].to(device=DEVICE)
            ret = y_predicted.detach().cpu().numpy()
            return y_scores, ret

    def predict_cont(self, x):
        with torch.no_grad():
            x = x.to_numpy()
            x = self.normalize_input(x, update_obs=inference)
            x = torch.from_numpy(x).to(DEVICE)
            y_scores = self.model(x)
            y_predicted = torch.max(y_scores, 1)[1].to(device=DEVICE)
            ret = y_predicted.double().numpy()
            for i in range(len(ret)):
                bin_id = ret[i]
                if bin_id == 0:  # the first bin is defined differently
                    ret[i] = 0.25 * Model.BIN_SIZE
                else:
                    ret[i] = bin_id * Model.BIN_SIZE
            return ret

    def evaluate_with_trustee(self, test_input, test_output):
        self.model.eval()
        pd_input = pd.DataFrame(test_input, columns=self.COLUMNS)
        test_output_discretized = self.discretize_output(test_output)
        with torch.no_grad():
            predictions_prob, class_preds = self.predict_discrete(pd_input)
            # Cross-Entropy Loss
            cross_entropy_loss = self.loss_fn(predictions_prob, torch.from_numpy(test_output_discretized)).item()
            # Print metrics
            print(f"Test Cross-Entropy Loss: {cross_entropy_loss}")
            print("Classification Report:")
            print(classification_report(test_output_discretized, class_preds, zero_division=0))
            trustee = ClassificationTrustee(expert=model)
            trustee.fit(pd_input, test_output_discretized, num_iter=10, num_stability_iter=2, samples_size=0.3, verbose=True)
            dt, pruned_dt, agreement, reward = trustee.explain()
            dt_y_pred = dt.predict(pd_input)
            print("Model explanation global fidelity report:")
            print(classification_report(class_preds, dt_y_pred))
            dot_data = tree.export_graphviz(pruned_dt, class_names=[str(i)for i in range(21)], feature_names=model.COLUMNS,filled=True,rounded=True,special_characters=True)
            graph = graphviz.Source(dot_data)
            fil = graph.render("~/trustee_tree_puffer_pruned", format="png")

    def evaluate(self, test_input, test_output):
        self.model.eval()
        pd_input = pd.DataFrame(test_input, columns=self.COLUMNS)
        with torch.no_grad():
            predictions = self.predict_cont(pd_input)
            # Print metrics

            # Mean Squared Error
            mse_loss = mean_squared_error(test_output, predictions)
            print(f"Test Mean Squared Error: {mse_loss}")


In [45]:
with open('/mnt/md0/satya/output.pkl', 'rb') as f:
    processed_data = pickle.load(f)
model = Model()

for i, chunk_data in enumerate(processed_data):
    input_data = np.array(chunk_data['in'])
    output_data = np.array(chunk_data['out'])
    if not inference:
        print(f"Training model for future chunk {i + 1}")
        # Train-test split (70% train, 30% test)
        train_input, test_input, train_output, test_output = train_test_split(
            input_data, output_data, test_size=0.3, random_state=42
        )
        model.train(train_input, train_output, test_input, test_output)
    else:
        model.load("/mnt/md0/cs190n/py-0-checkpoint-200.pt")
        # the model is available at https://storage.googleapis.com/puffer-models/puffer-ttp/bbr-20221001-1.tar.gz
        model.evaluate(input_data, output_data)

  checkpoint = torch.load(model_path)


Test Mean Squared Error: 0.6606137356852603
Test Mean Squared Error: 0.6137877031558187


  checkpoint = torch.load(model_path)


Test Mean Squared Error: 0.5599636530967992
Test Mean Squared Error: 0.5114938987398462


  checkpoint = torch.load(model_path)
  checkpoint = torch.load(model_path)


Test Mean Squared Error: 0.46656162041617716


  checkpoint = torch.load(model_path)


In [None]:
model.evaluate_with_trustee(input_data, output_data)

![title](trustee_tree_puffer_pruned.png)