In [None]:
# Installs and imports

!pip install --upgrade gym
!pip install scikit-learn
!pip install pandas
!pip install matplotlib
!pip install numpy
!pip install shimmy
!pip install d3rlpy

In [None]:
import os
import pandas as pd
import tqdm
from sklearn.model_selection import train_test_split
import numpy as np
import matplotlib.pyplot as plt
from gym import spaces, Env
import gc
import zipfile
import torch as th
from concurrent.futures import ThreadPoolExecutor
import threading
from collections import defaultdict, Counter
import torch
import d3rlpy
from d3rlpy.dataset import MDPDataset

Preprocessing EdNet - Code to parse files into dataframe that we can process further

In [None]:

extract_to = "C:\\Columbia\\RL\\KT4"

# Get all CSV files in the directory
all_files = [os.path.join(extract_to, f) for f in os.listdir(extract_to) if f.endswith(".csv")]
len(all_files)

# Load questions and payments data
questions_file = 'C:\\Columbia\\RL\\EdNet-Contents\\contents\\questions.csv'
payments_file = 'C:\\Columbia\\RL\\EdNet-Contents\\contents\\payments.csv'
questions = pd.read_csv(questions_file)
payments = pd.read_csv(payments_file)

# Rename columns for consistency
questions.rename(columns={'question_id': 'item_id'}, inplace=True)

# Load all files into a single DataFrame
print("Loading all files into a single DataFrame...")
all_data = []
for file in tqdm.tqdm(all_files, desc="Reading Files"):
    student_df = pd.read_csv(file)
    student_df['user_id'] = file.split('/')[-1].split('.')[0]  # Extract user_id from file name
    if "action_type" in student_df.columns and "pay" not in student_df["action_type"].values: # Filter students who don't have full access to the platform
        continue
    all_data.append(student_df)
all_data = pd.concat(all_data, ignore_index=True)

all_data["user_id"] = all_data["user_id"].apply(lambda x: x.split("\\")[-1])
all_data.columns

# Join with questions.csv to determine correctness for questions
print("Joining with questions.csv...")
all_data= pd.merge(all_data, questions, on='item_id', how='left')
all_data.drop("Unnamed: 0", axis = 1).to_csv("C:\\Columbia\\RL\\EdNet_KT4_PaidOnly.csv", index = False)

# Define organic sources
organic_sources = ['diagnosis', 'sprint', 'review', 'in_review']
maxTimeOnOneInteraction = 300000 # 5 minutes
all_data['session_id'] = all_data.groupby('user_id')['timestamp'].diff().fillna(0).gt(maxTimeOnOneInteraction).cumsum()
grouped = all_data.groupby(['user_id', 'session_id'])
valid_sessions = grouped.filter(lambda x: x['source'].isin(organic_sources).any())

# Calculate correctness for question rows
print("Calculating correctness...")
valid_sessions['correct'] = valid_sessions.apply(
    lambda x: 1 if (str(x['item_id']).startswith('q') and x['user_answer'] == x['correct_answer']) else 0,
    axis=1
)

# Fill missing values
valid_sessions.fillna({
    'cursor_time': 0,
    'correct': 0,
    'action_type': 'unknown'
}, inplace=True)
print(len(valid_sessions[["user_id", "session_id"]].drop_duplicates()))

# Split into train and test sets
# Get unique sessions
unique_sessions = valid_sessions[['user_id', 'session_id']].drop_duplicates()
# Perform train-test split on unique sessions
# 70-15-15
train_sessions, temp_sessions = train_test_split(
    unique_sessions,
    test_size=0.3,
    random_state=101
)

val_sessions, test_sessions = train_test_split(
        temp_sessions,
        test_size=0.5,  # Adjust proportion for the second split
        random_state=101
    )

# Merge back to get full train and test data
train_data = valid_sessions.merge(train_sessions, on=['user_id', 'session_id'])
val_data = valid_sessions.merge(val_sessions, on=['user_id', 'session_id'])
test_data = valid_sessions.merge(test_sessions, on=['user_id', 'session_id'])

dfWithRequiredCols = train_data[['user_id', 'session_id', 'timestamp', 'item_id', 'action_type', 'source', 'correct']]
dfWithRequiredCols.to_csv("C:\\Columbia\\RL\\EdNet_KT4_Train_Main2.csv", index = False)

dfWithRequiredCols = val_data[['user_id', 'session_id', 'timestamp', 'item_id', 'action_type', 'source', 'correct']]
dfWithRequiredCols.to_csv("C:\\Columbia\\RL\\EdNet_KT4_Val_Main2.csv", index = False)

dfWithRequiredCols = test_data[['user_id', 'session_id', 'timestamp', 'item_id', 'action_type', 'source', 'correct']]
dfWithRequiredCols.to_csv("C:\\Columbia\\RL\\EdNet_KT4_Test_Main2.csv", index = False)

Few Hyperparams

In [None]:
alpha = 1.0
beta = 0.0
action_labels = ["recommend_question", "recommend_video", "recommend_explanation"]
action_dim = len(action_labels)

max_time = 300000  # normalization factor

Now, get data in shape to process via Offline RL Algos

In [None]:
train_file = "/content/drive/MyDrive/RL/EdNet_KT4_Train_Main2.csv"
val_file = "/content/drive/MyDrive/RL/EdNet_KT4_Val_Main2.csv"
test_file = "/content/drive/MyDrive/RL/EdNet_KT4_Test_Main2.csv"

train_df = pd.read_csv(train_file)
val_df = pd.read_csv(val_file)
test_df = pd.read_csv(test_file)

In [None]:
# Define actions as high-level content recommendations:
# 0: recommend a question
# 1: recommend a video
# 2: recommend explanation/content
# You can decide the mapping based on action_type/source or other signals in your data.

# Assume columns: user_id, session_id, timestamp, cursor_time, action_type, source, correct

############################################
# Step 2: Map Interactions to Actions
############################################
# For simplicity, let's define a function that maps the row to one of the three high-level actions.
# This is a heuristic. In practice, you’d have logic that determines which content type was recommended.

def is_question(item_id):
    return isinstance(item_id, str) and item_id.startswith('q')

def is_explanation(item_id):
    return isinstance(item_id, str) and item_id.startswith('e')

def process_session(session_df):
    session_df = session_df.to_dict('records')
    events_list = []

    # Track the current events:
    # Textual explanation: enter (start), quit (end) or next question (if no quit)
    textual_start = None   # (timestamp, item_id)
    # Video explanation: play (start), pause (end) or next question (if no pause)
    video_start = None     # (timestamp, item_id)
    # Respond: keep track of the last respond before submit
    last_respond = None    # (timestamp, action_info)

    for i, ev in enumerate(session_df):
        at = ev['action_type']
        ts = ev['timestamp']
        item = ev.get('item_id', None)
        src = ev.get('source', None)
        corr = ev.get('correct', None)

        # Handle textual explanation events
        if is_explanation(item):
            if at == 'enter':
                # Start textual explanation (if none ongoing)
                textual_start = (ts, item)
            elif at == 'quit' and textual_start is not None:
                # End textual explanation
                start_ts, start_item = textual_start
                engagement_time = ts - start_ts
                events_list.append({
                    'user_id': ev['user_id'],
                    'session_id': ev['session_id'],
                    'event_type': 'textual_explanation',
                    'engagement_time': engagement_time,
                    'correct': None,
                    'source': None
                })
                textual_start = None

        # Handle video explanation events
        if at in ['play_audio','play_video'] and video_start is None:
            video_start = (ts, item)
        elif at in ['pause_audio','pause_video'] and video_start is not None:
            # End video explanation
            start_ts, start_item = video_start
            engagement_time = ts - start_ts
            events_list.append({
                'user_id': ev['user_id'],
                'session_id': ev['session_id'],
                'event_type': 'video_explanation',
                'engagement_time': engagement_time,
                'correct': None,
                'source': None
            })
            video_start = None

        # Handle respond event (track only last respond)
        if at == 'respond' and is_question(item):
            last_respond = (ts, ev)  # store the whole event if needed

        # Handle question event
        if at == 'submit':
            # Close any ongoing textual or video explanation here if not ended
            question_ts = ts

            # If textual explanation was ongoing with no quit, end now
            if textual_start is not None:
                start_ts, start_item = textual_start
                engagement_time = question_ts - start_ts
                events_list.append({
                    'user_id': ev['user_id'],
                    'session_id': ev['session_id'],
                    'event_type': 'textual_explanation',
                    'engagement_time': engagement_time,
                    'correct': None,
                    'source': None
                })
                textual_start = None

            # If video explanation was ongoing with no pause, end now
            if video_start is not None:
                start_ts, start_item = video_start
                engagement_time = question_ts - start_ts
                events_list.append({
                    'user_id': ev['user_id'],
                    'session_id': ev['session_id'],
                    'event_type': 'video_explanation',
                    'engagement_time': engagement_time,
                    'correct': None,
                    'source': None
                })
                video_start = None

            # For the question event:
            # engagement_time = 0
            # correct and source from current event
            # Only last respond before submit is considered as final chosen answer,
            # but correctness presumably comes from the submit line itself.
            # If you need to use last_respond info (e.g. user_answer), you can do so here.
            if last_respond is not None:
                corr = last_respond[1].get('correct', None)
                src = last_respond[1].get('source', None)
                events_list.append({
                    'user_id': ev['user_id'],
                    'session_id': ev['session_id'],
                    'event_type': 'question',
                    'engagement_time': 0,
                    'correct': corr,
                    'source': src
                })

            # Reset last_respond after the question
            last_respond = None

    # If a session ends without submit, any ongoing textual/video block is discarded
    # since we only consider explanation events tied to transitions ending in a question.

    return events_list

In [None]:
grouped = train_df.groupby(['user_id','session_id'], as_index=False)
all_events = []
count = 0
for (uid,sid), grp in tqdm(grouped, desc="Processing sessions"):
    grp = grp.sort_values('timestamp')
    session_events = process_session(grp)
    all_events.extend(session_events)
    count += 1
    if count > 100000:
        break

final_train_df = pd.DataFrame(all_events)

grouped = val_df.groupby(['user_id','session_id'], as_index=False)
all_events = []
for (uid,sid), grp in tqdm(grouped, desc="Processing sessions"):
    grp = grp.sort_values('timestamp')
    session_events = process_session(grp)
    all_events.extend(session_events)

final_val_df = pd.DataFrame(all_events)

grouped = test_df.groupby(['user_id','session_id'], as_index=False)
all_events = []
for (uid,sid), grp in tqdm(grouped, desc="Processing sessions"):
    grp = grp.sort_values('timestamp')
    session_events = process_session(grp)
    all_events.extend(session_events)

final_test_df = pd.DataFrame(all_events)

In [None]:
event_types = ["question", "textual_explanation", "video_explanation"]
event_type_to_idx = {e: i for i, e in enumerate(event_types)}

sources = final_train_df['source'].dropna().unique().tolist() if 'source' in final_train_df.columns else []
source_to_idx = {s: i for i, s in enumerate(sources)}

state_length = 5
max_engagement = 30000.0  # Arbitrary normalization factor for engagement time

def one_hot_encode(idx, size):
    v = np.zeros(size)
    v[idx] = 1.0
    return v

def featurize_event(row):
    # One-hot event_type
    et_vec = one_hot_encode(event_type_to_idx[row['event_type']], len(event_types))

    # Correctness
    # If NaN or None for explanation, treat as 0
    corr = row['correct']
    corr_val = 0.0 if pd.isnull(corr) else float(corr)

    # Engagement time (not directly used for current event as a feature,
    # but we can include it. For question, it's always 0.)
    eng_time = row['engagement_time']
    eng_time_val = 0.0 if pd.isnull(eng_time) else float(eng_time) / max_engagement
    eng_time_val = min(eng_time_val, 1.0)  # clip

    # Source (only meaningful if question)
    src = row['source']
    if src in source_to_idx:
        src_vec = one_hot_encode(source_to_idx[src], len(sources))
    else:
        # If not applicable or unknown, zero vector
        src_vec = np.zeros(len(sources))

    # Final feature: [et_vec, corr_val, eng_time_val, src_vec]
    feat = np.concatenate([et_vec, [corr_val, eng_time_val], src_vec])
    return feat

def featurize_state(events, state_length=5):
    # events: list of rows (dict)
    # If fewer than state_length, pad with zero features

    feats = [featurize_event(e) for e in events]

    # Determine the dimension of a single event's feature vector
    # If we have no events, create a dummy event to determine dimension
    if len(feats) == 0:
        dummy_event = {"event_type": "question", "correct": 0, "engagement_time": 0, "source": None}
        dummy_feat = featurize_event(dummy_event)
        dim = len(dummy_feat)
    else:
        dim = len(feats[0])

    # If fewer than state_length, pad with zero vectors at the front
    if len(feats) < state_length:
        padding_needed = state_length - len(feats)
        padding = [np.zeros(dim) for _ in range(padding_needed)]
        feats = padding + feats
    else:
        # If more events than state_length, take the last state_length
        feats = feats[-state_length:]

    # Concatenate into a single state vector
    state_vec = np.concatenate(feats)
    return state_vec

In [None]:
event_types = ["question", "textual_explanation", "video_explanation"]
event_type_to_idx = {e: i for i, e in enumerate(event_types)}

# Compute mean and std for textual and video engagement times from the training set
def compute_engagement_stats(df):
    textual_times = df.loc[df['event_type'] == 'textual_explanation', 'engagement_time'].dropna()
    video_times = df.loc[df['event_type'] == 'video_explanation', 'engagement_time'].dropna()

    textual_mean = textual_times.mean() if len(textual_times) > 0 else 0.0
    textual_std = textual_times.std() if len(textual_times) > 0 else 1.0
    if textual_std == 0:  # Avoid division by zero
        textual_std = 1.0

    video_mean = video_times.mean() if len(video_times) > 0 else 0.0
    video_std = video_times.std() if len(video_times) > 0 else 1.0
    if video_std == 0:
        video_std = 1.0

    return textual_mean, textual_std, video_mean, video_std

# Compute reward given a single event row
def compute_reward(row, textual_mean, textual_std, video_mean, video_std, alpha = 0.5, beta = 0.5):
    et = row['event_type']
    if et == 'question':
        # reward = correctness (0 or 1)
        corr = row.get('correct', 0)
        return alpha * (1.0 if corr == 1 else 0.0)
    elif et == 'textual_explanation':
        # Normalize engagement
        eng = row.get('engagement_time', 0.0)
        normalized = (eng - textual_mean) / textual_std
        # Clip to [0, 1]
        return beta * (min(max(normalized, 0.0), 1.0))
    elif et == 'video_explanation':
        eng = row.get('engagement_time', 0.0)
        normalized = (eng - video_mean) / video_std
        return beta * (min(max(normalized, 0.0), 1.0))
    else:
        # Should not happen
        return 0.0

def build_dataset(df, textual_mean, textual_std, video_mean, video_std, state_length=5, alpha = 0.5, beta = 0.5):
    # We assume df has columns:
    # user_id, session_id, event_type, engagement_time, correct, source
    # Sort by user, session, and some time indicator if not already sorted
    # (assuming already sorted)

    transitions = []

    grouped = df.groupby(['user_id','session_id'], as_index=False)
    for (uid,sid), grp in tqdm(grouped, desc="Building dataset"):
        grp = grp.to_dict('records')

        # We create transitions from consecutive events
        # For each event at index i, state = last state_length events before i
        # action = event_type of current event
        # reward = computed from current event
        # next_state = last state_length events before i+1
        # done = True if i+1 is out of range (end of session)
        for i in range(len(grp)-1):
            current_event = grp[i]
            next_event = grp[i+1]

            # state
            prev_events = grp[max(0, i - state_length):i]
            s = featurize_state(prev_events, state_length=state_length)

            # action
            a = event_type_to_idx[current_event['event_type']]

            # reward
            r = compute_reward(current_event, textual_mean, textual_std, video_mean, video_std, alpha = alpha, beta = beta)

            # next_state
            next_prev_events = grp[max(0, (i+1) - state_length):(i+1)]
            s_next = featurize_state(next_prev_events, state_length=state_length)

            done = (i == len(grp)-2)  # last transition in this session

            transitions.append((s, a, r, s_next, done))

    return transitions

textual_mean, textual_std, video_mean, video_std = compute_engagement_stats(final_train_df)

Build the dataset

In [None]:
train_data = build_dataset(final_train_df, textual_mean, textual_std, video_mean, video_std, state_length=5, alpha = alpha, beta = beta)
val_data = build_dataset(final_val_df, textual_mean, textual_std, video_mean, video_std, state_length=5, alpha = alpha, beta = beta)
test_data = build_dataset(final_test_df, textual_mean, textual_std, video_mean, video_std, state_length=5, alpha = alpha, beta = beta)

In [None]:
def to_mdpdataset(transitions):
    s_arr = np.array([t[0] for t in transitions], dtype=np.float32)
    a_arr = np.array([t[1] for t in transitions], dtype=np.int64)
    r_arr = np.array([t[2] for t in transitions], dtype=np.float32)
    d_arr = np.array([t[4] for t in transitions], dtype=np.bool_)
    return MDPDataset(observations=s_arr, actions=a_arr, rewards=r_arr, terminals=d_arr)

In [None]:
train_dataset = to_mdpdataset(train_data)
val_dataset = to_mdpdataset(val_data)
test_dataset = to_mdpdataset(test_data)

Functions to evaluate and visualize sessions

In [None]:
event_types = ["question", "textual_explanation", "video_explanation"]
action_labels = event_types

def evaluate_bc_policy(bc_policy, dataset, name="BC Policy", trajectory_length=50):
    # Evaluates BC by measuring action match accuracy.
    # BC does not produce Q-values.
    traj = dataset.sample_trajectory(trajectory_length)
    observations = traj.observations
    actions = traj.actions

    correct_predictions = 0
    for o, a in zip(observations, actions):
        o_np = o[np.newaxis, :]
        predicted_action = bc_policy.predict(o_np)[0]  # returns predicted action
        if predicted_action == a:
            correct_predictions += 1
    accuracy = correct_predictions / len(observations) if len(observations) > 0 else 0
    print(f"{name} Action Match Accuracy: {accuracy:.4f}")


def visualize_bc_policy(bc_policy, dataset, action_labels, num_states=5, trajectory_length=5):
    # Visualize BC by showing what action it chooses vs the true action.
    # BC does not have Q-values, so we just show chosen vs actual action.
    for _ in range(num_states):
        traj = dataset.sample_trajectory(trajectory_length)
        observations = traj.observations
        actions = traj.actions

        if len(observations) == 0:
            print("Empty trajectory, skipping.")
            continue

        idx = np.random.randint(len(observations))
        obs = observations[idx][np.newaxis, :]
        predicted_action = bc_policy.predict(obs)[0]
        true_action = actions[idx][0]
        # print(true_action, predicted_action)

        print(f"True Next Action: {action_labels[true_action]}, Predicted Next Action: {action_labels[predicted_action]}")

        # Create a bar plot with 0s for non-chosen actions, 1 for chosen action
        action_values = np.zeros(len(action_labels))
        action_values[predicted_action] = 1.0

        plt.figure(figsize=(6,4))
        plt.bar(range(len(action_labels)), action_values, tick_label=action_labels)
        plt.title("Chosen Action for a Sampled State")
        plt.xlabel("Action")
        plt.ylabel("Selection Indicator")
        plt.show()


def evaluate_q_policy(policy, dataset, name="Policy", trajectory_length=50, numSamplings = 100):
    # Evaluates Q-based policies (CQL, BCQ, IQL) by computing average Q-value on a sampled trajectory
    q_vals = []
    for _ in range(numSamplings):
        traj = dataset.sample_trajectory(trajectory_length)
        observations = traj.observations
        actions = traj.actions
        if len(observations) == 0:
            print(f"{name} Average Q-value: N/A (no data)")
            return

        for o, a in zip(observations, actions):
            o_np = o[np.newaxis, :]
            q = policy.predict_value(o_np, np.array([a]))
            q_vals.append(q[0])
    print(f"{name} Average Q-value: {np.mean(q_vals):.4f}")


def decode_event_vector(event_vec):
    # This is pseudo-code and must be adapted to your feature indexing
    # Assuming featurize_event creates a feature vector:
    # feat = np.concatenate([et_vec, [corr_val, eng_time_val], src_vec])
    # [one-hot event_type (len(event_types)), correctness(1), eng_time_norm(1), source_one_hot(len(sources))]
    et_size = len(event_types)
    src_size = len(sources)
    # Layout: [et_vec, corr, eng_time_norm, src_vec]
    et_vec = event_vec[:et_size]
    corr_val = event_vec[et_size]
    eng_time_norm = event_vec[et_size+1]
    src_vec = event_vec[et_size+2:et_size+2+src_size]

    et_idx = np.argmax(et_vec) if np.max(et_vec) > 0 else None
    et_str = event_types[et_idx] if et_idx is not None else "PAD"

    correct = int(round(corr_val))

    src_str = "None"
    if src_size > 0 and np.max(src_vec) > 0:
        src_idx = np.argmax(src_vec)
        src_str = sources[src_idx]

    # Return a readable dictionary
    return {
        "event_type": et_str,
        "correct": correct,
        "engagement_norm": eng_time_norm,
        "source": src_str
    }

def decode_state(state_vec, state_length=5):
    # Based on featurize_state logic
    et_size = len(event_types)
    src_size = len(sources)
    per_event_dim = et_size + 1 + 1 + src_size  # event_type + correct + eng + sources
    assert len(state_vec) == state_length * per_event_dim, "State vector length mismatch."

    events = []
    for i in range(state_length):
        start = i * per_event_dim
        end = start + per_event_dim
        event_vec = state_vec[start:end]
        ev = decode_event_vector(event_vec)
        events.append(ev)
    return events

def visualize_q_policy(policy, dataset, action_labels, num_states=5, trajectory_length=5):
    for _ in range(num_states):
        traj = dataset.sample_trajectory(trajectory_length)
        observations = traj.observations
        actions = traj.actions
        rewards = traj.rewards  # if needed
        terminals = traj.terminals

        if len(observations) == 0:
            print("Empty trajectory, skipping.")
            continue

        idx = min(4, len(observations)-1) # Hardcode for now
        obs = observations[idx][np.newaxis, :]
        q_vals = policy.predict_value(obs, np.arange(len(action_labels)))

        true_action = actions[idx][0]
        predicted_action = np.argmax(q_vals)

        # Decode the current state (last 5 events)
        # The current state = observations[idx]
        # It's already featurized state. We can decode it:
        current_state = observations[idx]  # shape: [state_dim]
        decoded_events = decode_state(current_state, state_length=5)

        print("---- PREDICTION POINT ----")
        print("State (last 5 events):")
        for i, ev in enumerate(decoded_events, start=1):
            print(f"  Step {i}: Event_Type={ev['event_type']}, Correct={ev['correct']}, "
                  f"Engagement_Norm={ev['engagement_norm']:.2f}, Source={ev['source']}")

        true_action_str = action_labels[true_action] if true_action < len(action_labels) else f"Action_{true_action}"
        predicted_action_str = action_labels[predicted_action] if predicted_action < len(action_labels) else f"Action_{predicted_action}"

        print(f"True Next Action: {true_action_str}, Predicted Next Action: {predicted_action_str}")

        # Print the entire session (trajectory) up to this point if desired.
        # We have all observations up to idx. Let's decode each state at each timestep (optional).
        print("\nEntire trajectory up to this point:")
        for t in range(idx+1):
            state_at_t = observations[t]
            evs = decode_state(state_at_t, state_length=5)
            chosen_action = actions[t][0]
            chosen_action_str = action_labels[chosen_action] if chosen_action < len(action_labels) else f"Action_{chosen_action}"
            print(f"Time {t}:")
            for e_id, ev_info in enumerate(evs, start=1):
                print(f"  Step {e_id}: {ev_info}")
            print(f"  Executed Action: {chosen_action_str}\n")

        # Plot Q-values
        plt.figure(figsize=(6,4))
        plt.bar(range(len(action_labels)), q_vals, tick_label=action_labels)
        plt.title("Q-values for the sampled state")
        plt.xlabel("Action")
        plt.ylabel("Predicted Q-value")
        plt.show()
        print("="*50 + "\n")

Training code

In [None]:
############################################
# Step 5: Train Offline RL Algorithms
############################################
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
# 3. Set enable_ddp (usually False for single-process training)
enable_ddp = False
bc = d3rlpy.algos.DiscreteBCConfig(batch_size = 1024, learning_rate = 0.01).create(device=True)
print("Training BC...")
bc.fit(train_dataset, n_steps = 50000, logging_steps = 1000)  # Increase epochs as needed

In [None]:
evaluate_bc_policy(bc, test_dataset, "BC")
print("Visualizing BC policy:")
visualize_bc_policy(bc, test_dataset, action_labels)

In [None]:
cql = d3rlpy.algos.DiscreteCQLConfig(batch_size = 1024).create(device=True)
print("Training CQL...")
cql.fit(train_dataset, n_steps = 50000, logging_steps = 1000)

In [None]:
evaluate_q_policy(cql, test_dataset, "CQL")
print("Visualizing CQL policy:")
visualize_q_policy(cql, test_dataset, action_labels=action_labels)

In [None]:
bcq = d3rlpy.algos.DiscreteBCQConfig(batch_size = 1024).create(device=True)
print("Training BCQ...")
bcq.fit(train_dataset, n_steps = 50000, logging_steps = 1000)


In [None]:
evaluate_q_policy(bcq, test_dataset, "BCQ")
print("Visualizing BCQ policy:")
visualize_q_policy(bcq, test_dataset, action_labels=action_labels)

In [None]:
sac = d3rlpy.algos.DiscreteSACConfig(batch_size = 1024).create(device=True)
print("Training SAC...")
sac.fit(train_dataset, n_steps = 50000, logging_steps = 1000)

In [None]:
evaluate_q_policy(sac, test_dataset, "SAC")
print("Visualizing SAC policy:")
visualize_q_policy(sac, test_dataset, action_labels=action_labels)