In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, random_split
from transformers import BertModel, BertConfig, AutoModel, AutoConfig
from tqdm import tqdm
import numpy as np
import h5py

# ==== Configuration ====
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 16
LEARNING_RATE = 1e-4
EPOCHS = 1
PATIENCE = 2

In [5]:
# ==== Paths ====
DATA_PATH = "../data/processed/concatenated_data.hdf5"
BEST_MODEL_PATH = "../models/best_bert_udrl.pth"


# ==== Data Loading ====
def load_data(path=DATA_PATH):
    with h5py.File(path, "r") as f:
        data = f["concatenated_data"]
        states = data["observations"][:]
        actions = data["actions"][:]
        rewards = data["rewards_to_go"][:].reshape(-1, 1)
        times = data["time_to_go"][:].reshape(-1, 1)
    return states, rewards, times, actions


In [6]:
X_s, X_r, X_t, y = load_data()
X_s, X_r, X_t, y = map(torch.tensor, (X_s, X_r, X_t, y))

In [7]:
dataset = TensorDataset(X_s.float(), X_r.float(), X_t.float(), y.float())
lengths = [int(len(dataset) * 0.8), int(len(dataset) * 0.1)]
lengths.append(len(dataset) - sum(lengths))

In [None]:
train_ds, val_ds, test_ds = random_split(dataset, lengths, generator=torch.Generator().manual_seed(42))

In [None]:
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE)

In [10]:
# Load untrained BERT-small
config = AutoConfig.from_pretrained("prajjwal1/bert-small")
config.vocab_size = 1  # dummy since we're using inputs_embeds
config.max_position_embeddings = 3
model_bert = AutoModel.from_config(config).to(DEVICE)

config.json:   0%|          | 0.00/286 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


In [13]:
# Create input projection layers and head
d_r_encoder = nn.Linear(1, config.hidden_size).to(DEVICE)
d_h_encoder = nn.Linear(1, config.hidden_size).to(DEVICE)
state_encoder = nn.Linear(105, config.hidden_size).to(DEVICE)
head = nn.Linear(config.hidden_size, 8).to(DEVICE)

optimizer = optim.Adam(
    list(model_bert.parameters())
    + list(d_r_encoder.parameters())
    + list(d_h_encoder.parameters())
    + list(state_encoder.parameters())
    + list(head.parameters()),
    lr=LEARNING_RATE,
)
loss_fn = nn.MSELoss()

In [14]:
def train():
    best_loss = float("inf")
    patience = PATIENCE

    for epoch in range(EPOCHS):
        model_bert.train()
        total_train_loss = 0.0
        for s, r, h, a in tqdm(train_loader, desc=f"Epoch {epoch + 1}"):
            s, r, h, a = s.to(DEVICE), r.to(DEVICE), h.to(DEVICE), a.to(DEVICE)
            optimizer.zero_grad()
            encoded_r = d_r_encoder(r).unsqueeze(1)  # reward to go
            encoded_h = d_h_encoder(h).unsqueeze(1)  # horizon to go
            encoded_s = state_encoder(s).unsqueeze(1)  # state
            sequence = torch.cat([encoded_r, encoded_h, encoded_s], dim=1)
            bert_out = model_bert(inputs_embeds=sequence).last_hidden_state
            pred = head(bert_out[:, -1])
            loss = loss_fn(pred, a)
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()
        avg_train_loss = total_train_loss / len(train_loader)

        # Validation
        model_bert.eval()
        total_val_loss = 0.0
        with torch.no_grad():
            for s, r, h, a in val_loader:
                s, r, h, a = s.to(DEVICE), r.to(DEVICE), h.to(DEVICE), a.to(DEVICE)
                encoded_r = d_r_encoder(r).unsqueeze(1)
                encoded_h = d_h_encoder(h).unsqueeze(1)
                encoded_s = state_encoder(s).unsqueeze(1)
                sequence = torch.cat([encoded_r, encoded_h, encoded_s], dim=1)
                bert_out = model_bert(inputs_embeds=sequence).last_hidden_state
                pred = head(bert_out[:, -1])
                loss = loss_fn(pred, a)
                total_val_loss += loss.item()
        avg_val_loss = total_val_loss / len(val_loader)

        print(
            f"Epoch {epoch + 1}: Train Loss = {avg_train_loss:.4f}, Val Loss = {avg_val_loss:.4f}"
        )

        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            patience = PATIENCE
            torch.save(
                {
                    "bert": model_bert.state_dict(),
                    "d_r": d_r_encoder.state_dict(),
                    "d_t": d_h_encoder.state_dict(),
                    "state": state_encoder.state_dict(),
                    "head": head.state_dict(),
                },
                BEST_MODEL_PATH,
            )
        else:
            patience -= 1
            if patience == 0:
                print("Early stopping.")
                break

    print("Training complete.")


In [15]:
train()

Epoch 1:   3%|â–Ž         | 1374/49970 [01:07<39:35, 20.45it/s]


KeyboardInterrupt: 