In [None]:
import itertools
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, random_split
from transformers import AutoModel, AutoConfig
from tqdm import tqdm
import h5py

# ==== Configuration ====
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 16
LEARNING_RATE = 1e-4
EPOCHS = 5
PATIENCE = 2

In [12]:
# ==== Paths ====
DATA_PATH = "concatenated_data.hdf5"
BEST_MODEL_PATH = "best_bert_udrl.pth"


# ==== Data Loading ====
def load_data(path=DATA_PATH):
    with h5py.File(path, "r") as f:
        data = f["concatenated_data"]
        states = data["observations"][:]
        actions = data["actions"][:]
        rewards = data["rewards_to_go"][:].reshape(-1, 1)
        times = data["time_to_go"][:].reshape(-1, 1)
    return states, rewards, times, actions


In [13]:
X_s, X_r, X_t, y = load_data()
X_s, X_r, X_t, y = map(torch.tensor, (X_s, X_r, X_t, y))

In [14]:
dataset = TensorDataset(X_s.float(), X_r.float(), X_t.float(), y.float())
lengths = [int(len(dataset) * 0.8), int(len(dataset) * 0.1)]
lengths.append(len(dataset) - sum(lengths))

In [15]:
train_ds, val_ds, test_ds = random_split(dataset, lengths, generator=torch.Generator().manual_seed(42))

In [16]:
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE)

In [22]:
def train(BATCH_SIZE, LEARNING_RATE, EPOCHS):
    # Load untrained BERT-small
    config = AutoConfig.from_pretrained("prajjwal1/bert-small")
    config.vocab_size = 1  # dummy since we're using inputs_embeds
    config.max_position_embeddings = 3
    model_bert = AutoModel.from_config(config).to(DEVICE)

    # Create input projection layers and head
    d_r_encoder = nn.Linear(1, config.hidden_size).to(DEVICE)
    d_t_encoder = nn.Linear(1, config.hidden_size).to(DEVICE)
    state_encoder = nn.Linear(105, config.hidden_size).to(DEVICE)
    head = nn.Linear(config.hidden_size, 8).to(DEVICE)

    optimizer = optim.Adam(list(model_bert.parameters()) + 
                           list(d_r_encoder.parameters()) + 
                           list(d_t_encoder.parameters()) + 
                           list(state_encoder.parameters()) + 
                           list(head.parameters()), lr=LEARNING_RATE)
    loss_fn = nn.MSELoss()

    best_loss = float("inf")
    patience = PATIENCE

    for epoch in range(EPOCHS):
        model_bert.train()
        total_train_loss = 0.0
        for s, r, t, a in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
            s, r, t, a = s.to(DEVICE), r.to(DEVICE), t.to(DEVICE), a.to(DEVICE)
            optimizer.zero_grad()
            encoded_r = d_r_encoder(r).unsqueeze(1)
            encoded_t = d_t_encoder(t).unsqueeze(1)
            encoded_s = state_encoder(s).unsqueeze(1)
            sequence = torch.cat([encoded_r, encoded_t, encoded_s], dim=1)
            bert_out = model_bert(inputs_embeds=sequence).last_hidden_state
            pred = head(bert_out[:, -1])
            loss = loss_fn(pred, a)
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()
        avg_train_loss = total_train_loss / len(train_loader)

        # Validation
        model_bert.eval()
        total_val_loss = 0.0
        with torch.no_grad():
            for s, r, t, a in val_loader:
                s, r, t, a = s.to(DEVICE), r.to(DEVICE), t.to(DEVICE), a.to(DEVICE)
                encoded_r = d_r_encoder(r).unsqueeze(1)
                encoded_t = d_t_encoder(t).unsqueeze(1)
                encoded_s = state_encoder(s).unsqueeze(1)
                sequence = torch.cat([encoded_r, encoded_t, encoded_s], dim=1)
                bert_out = model_bert(inputs_embeds=sequence).last_hidden_state
                pred = head(bert_out[:, -1])
                loss = loss_fn(pred, a)
                total_val_loss += loss.item()
        avg_val_loss = total_val_loss / len(val_loader)

        print(f"Epoch {epoch+1}: Train Loss = {avg_train_loss:.4f}, Val Loss = {avg_val_loss:.4f}")

        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            patience = PATIENCE
            torch.save({
                'bert': model_bert.state_dict(),
                'd_r': d_r_encoder.state_dict(),
                'd_t': d_t_encoder.state_dict(),
                'state': state_encoder.state_dict(),
                'head': head.state_dict()
            }, BEST_MODEL_PATH)
            print("best model found !!!!!!!!!!!!!!!!!!")
        else:
            patience -= 1
            if patience == 0:
                print("Early stopping.")
                break

    print("Training complete.")


In [19]:
DEVICE

device(type='cuda')

In [23]:
def grid_search():
    # Define hyperparameters grid
    batch_sizes = [16, 8]
    learning_rates = [1e-4, 5e-5]
    epochs_list = [10, 20]

    # Create combinations of all hyperparameters
    param_grid = itertools.product(batch_sizes, learning_rates, epochs_list)
    
    # Grid Search
    for BATCH_SIZE, LEARNING_RATE, EPOCHS in param_grid:
        print(f"Running grid search with BATCH_SIZE={BATCH_SIZE}, LEARNING_RATE={LEARNING_RATE}, EPOCHS={EPOCHS}")
        train(BATCH_SIZE, LEARNING_RATE, EPOCHS)

In [24]:
grid_search()

Running grid search with BATCH_SIZE=16, LEARNING_RATE=0.0001, EPOCHS=10


Epoch 1: 100%|██████████| 49970/49970 [06:58<00:00, 119.27it/s]


Epoch 1: Train Loss = 0.0391, Val Loss = 0.0248
best model found !!!!!!!!!!!!!!!!!!


Epoch 2: 100%|██████████| 49970/49970 [07:00<00:00, 118.94it/s]


Epoch 2: Train Loss = 0.0243, Val Loss = 0.0213
best model found !!!!!!!!!!!!!!!!!!


Epoch 3: 100%|██████████| 49970/49970 [06:57<00:00, 119.66it/s]


Epoch 3: Train Loss = 0.0222, Val Loss = 0.0202
best model found !!!!!!!!!!!!!!!!!!


Epoch 4: 100%|██████████| 49970/49970 [06:48<00:00, 122.43it/s]


Epoch 4: Train Loss = 0.0212, Val Loss = 0.0198
best model found !!!!!!!!!!!!!!!!!!


Epoch 5: 100%|██████████| 49970/49970 [06:46<00:00, 122.97it/s]


Epoch 5: Train Loss = 0.0205, Val Loss = 0.0192
best model found !!!!!!!!!!!!!!!!!!


Epoch 6: 100%|██████████| 49970/49970 [06:48<00:00, 122.33it/s]


Epoch 6: Train Loss = 0.0200, Val Loss = 0.0195


Epoch 7: 100%|██████████| 49970/49970 [06:46<00:00, 122.93it/s]


Epoch 7: Train Loss = 0.0197, Val Loss = 0.0190
best model found !!!!!!!!!!!!!!!!!!


Epoch 8: 100%|██████████| 49970/49970 [06:41<00:00, 124.57it/s]


Epoch 8: Train Loss = 0.0193, Val Loss = 0.0188
best model found !!!!!!!!!!!!!!!!!!


Epoch 9: 100%|██████████| 49970/49970 [06:39<00:00, 125.02it/s]


Epoch 9: Train Loss = 0.0190, Val Loss = 0.0184
best model found !!!!!!!!!!!!!!!!!!


Epoch 10: 100%|██████████| 49970/49970 [06:43<00:00, 123.83it/s]


Epoch 10: Train Loss = 0.0188, Val Loss = 0.0188
Training complete.
Running grid search with BATCH_SIZE=16, LEARNING_RATE=0.0001, EPOCHS=20


Epoch 1:  38%|███▊      | 18751/49970 [02:31<04:12, 123.47it/s]


KeyboardInterrupt: 