To visualize accuracy vs. time and loss vs. time in your notebook, you can log the metrics during training and then plot them using matplotlib. Below is the code to modify your RLTrainer class to store metrics and plot them after training.

Step 1: Modify RLTrainer to Log Metrics
Add lists to store loss and accuracy values for each epoch.

# Add these attributes to RLTrainer's __init__ method
self.epoch_train_loss = []
self.epoch_train_acc = []
self.epoch_test_acc = []

Update the train method to log metrics after each epoch:
# Update the train method to log metrics
def train(self, train_ds, test_ds, epochs=5):
    for ep in range(1, epochs + 1):
        self.train_loss.reset_state()
        self.train_acc.reset_state()
        self.test_acc.reset_state()
        for step, (imgs, lbls) in enumerate(train_ds):
            l, cl, pl, eb = self.train_step(imgs, lbls)
            if step % 100 == 0:
                print(f"[E{ep} S{step}] loss={l:.4f}, ce={cl:.4f}, pg={pl:.4f}, ent={eb:.4f}, acc={self.train_acc.result():.4f}")
        for imgs, lbls in test_ds:
            self.test_step(imgs, lbls)
        # Log metrics
        self.epoch_train_loss.append(self.train_loss.result().numpy())
        self.epoch_train_acc.append(self.train_acc.result().numpy())
        self.epoch_test_acc.append(self.test_acc.result().numpy())
        print(f"→ E{ep} TrainAcc: {self.train_acc.result():.4f}, TestAcc: {self.test_acc.result():.4f}")


Step 2: Add a Visualization Function
Add a function to plot accuracy and loss vs. epochs.
# Add this function to RLTrainer
def plot_metrics(self):
    epochs = range(1, len(self.epoch_train_loss) + 1)
    plt.figure(figsize=(12, 6))
    
    # Plot Loss
    plt.subplot(1, 2, 1)
    plt.plot(epochs, self.epoch_train_loss, label='Train Loss')
    plt.title('Loss vs. Epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    
    # Plot Accuracy
    plt.subplot(1, 2, 2)
    plt.plot(epochs, self.epoch_train_acc, label='Train Accuracy')
    plt.plot(epochs, self.epoch_test_acc, label='Test Accuracy')
    plt.title('Accuracy vs. Epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    
    plt.tight_layout()
    plt.show()

Step 3: Call the Plot Function
After training, call the plot_metrics function to visualize the results.

# Add this after training
trainer.train(train_ds, test_ds, epochs=5)
trainer.plot_metrics()

This will generate two plots: one for loss vs. epochs and another for accuracy vs. epochs.