In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt
import io
from PIL import Image

In [None]:
episodes_dataset = tf.data.TFRecordDataset(
    "gs://rl_unplugged/atari_episodes_ordered/Pong/run_1-00000-of-00050",
    compression_type="GZIP"
)

In [None]:
raw_example = next(iter(episodes_dataset))

example = tf.train.Example()
example.ParseFromString(raw_example.numpy())

print("Features in the TFRecord example:")
for key in example.features.feature.keys():
    feature = example.features.feature[key]
    # Print the type of data stored in this feature
    if feature.bytes_list.value:
        dtype = "bytes"
    elif feature.int64_list.value:
        dtype = "int64"
    elif feature.float_list.value:
        dtype = "float"
    else:
        dtype = "unknown"
    print(f"{key}: {dtype}")

In [None]:

feature_description = {
    "observations": tf.io.VarLenFeature(tf.string),
    "actions": tf.io.VarLenFeature(tf.int64),
    "clipped_rewards": tf.io.VarLenFeature(tf.float32),
}

def _parse_function(example_proto):
    parsed = tf.io.parse_single_example(example_proto, feature_description)

    obs = tf.sparse.to_dense(parsed["observations"])
    actions = tf.sparse.to_dense(parsed["actions"])
    rewards = tf.sparse.to_dense(parsed["clipped_rewards"])

    img = tf.io.decode_image(obs[0], channels=3)
    return img, actions[0], rewards[0]

dataset_parsed = episodes_dataset.map(_parse_function)

# Plot first 10 frames with actions and rewards
plt.figure(figsize=(20, 4))
for i, (frame, action, reward) in enumerate(dataset_parsed.take(10)):
    plt.subplot(2, 10, i+1)
    plt.imshow(frame.numpy())       # now .numpy() is fine outside map
    plt.axis('off')
    plt.title(f"A:{action.numpy()} R:{reward.numpy():.1f}")
plt.show()

In [None]:
!git clone https://github.com/Coluding/world-models.git

In [None]:
import sys
sys.path.append('/content/world-models')
from models.rssm import RSSM

import torch
from torch.utils.data import Dataset, DataLoader

In [None]:
feature_description = {
  "observations": tf.io.VarLenFeature(tf.string),
  "actions": tf.io.VarLenFeature(tf.int64),
  "clipped_rewards": tf.io.VarLenFeature(tf.float32),
}

def parse_sequence(example_proto):
  parsed = tf.io.parse_single_example(example_proto, feature_description)
  obs = tf.sparse.to_dense(parsed["observations"])
  actions = tf.sparse.to_dense(parsed["actions"])
  rewards = tf.sparse.to_dense(parsed["clipped_rewards"])
  imgs = tf.map_fn(lambda x: tf.io.decode_image(x, channels=3), obs, dtype=tf.uint8)
  #imgs.set_shape([None, 84, 84, 3])
  #imgs = tf.image.resize(imgs, [64, 64])
  imgs = tf.cast(imgs, tf.float32) / 255.0
  actions = tf.cast(actions, tf.int64)
  rewards = tf.cast(rewards, tf.float32)
  return imgs, actions, rewards

sequence_length = 50
batch_size = 16

episodes_dataset = tf.data.TFRecordDataset(
    "gs://rl_unplugged/atari_episodes_ordered/Pong/run_1-00000-of-00050",
    compression_type="GZIP"
)
dataset = episodes_dataset.map(parse_sequence)
dataset = dataset.filter(lambda imgs, actions, rewards: tf.shape(imgs)[0] >= sequence_length)
dataset = dataset.map(lambda imgs, actions, rewards: (
    imgs[:sequence_length], actions[:sequence_length], rewards[:sequence_length]
))
dataset = dataset.batch(batch_size, drop_remainder=True)

In [None]:

example = next(iter(episodes_dataset.take(1)))
imgs, actions, rewards = parse_sequence(example)

print(imgs.shape)
print(actions.shape)
print(rewards.shape)

last_ten_imgs = imgs[-10:]

fig, axes = plt.subplots(2, 5, figsize=(15, 6))
for i, ax in enumerate(axes.flat):
  ax.imshow(last_ten_imgs[i], vmin=0, vmax=1)
  ax.set_title(f"Image {i+1}")
  ax.axis("off")
plt.tight_layout()
plt.show()


In [None]:
for imgs, actions, rewards in dataset.take(1):
  print("imgs:", imgs.shape)
  print("actions:", actions.shape)
  print("rewards:", rewards.shape)

In [None]:
for imgs, actions, rewards in dataset.take(1):
  seq_imgs = imgs[0]
  seq_actions = actions[0]
  seq_rewards = rewards[0]

  plt.figure(figsize=(20, 4))
  for i in range(10):
    plt.subplot(2, 10, i + 1)
    plt.imshow(seq_imgs[i].numpy())
    plt.axis("off")
    plt.title(f"A:{seq_actions[i].numpy()} R:{seq_rewards[i].numpy():.1f}")
  plt.show()

In [None]:
def parse_sequence_for_inspection(example_proto):
    parsed = tf.io.parse_single_example(example_proto, feature_description)
    obs = tf.sparse.to_dense(parsed["observations"])
    imgs = tf.map_fn(lambda x: tf.io.decode_image(x, channels=3), obs, dtype=tf.uint8)
    print("Original image shape (before resizing):", imgs.shape)
    return imgs

example = next(iter(episodes_dataset.take(1)))
imgs = parse_sequence_for_inspection(example)

In [None]:
class AtariSequenceDataset(Dataset):
  def __init__(self, tf_dataset, num_batches):
    self.data = []
    for i, (imgs, actions, rewards) in enumerate(tf_dataset):
      if i >= num_batches:
        break
      self.data.append((
        imgs.numpy(), actions.numpy(), rewards.numpy()
      ))
  def __len__(self):
    return len(self.data)
  def __getitem__(self, idx):
    imgs, actions, rewards = self.data[idx]
    imgs = torch.tensor(imgs, dtype=torch.float32)  # [seq, 64, 64, 3]
    actions = torch.tensor(actions, dtype=torch.long)  # [seq]
    rewards = torch.tensor(rewards, dtype=torch.float32)  # [seq]
    return imgs, actions, rewards

In [None]:
num_batches = 100
atari_dataset = AtariSequenceDataset(dataset, num_batches)
dataloader = DataLoader(atari_dataset, batch_size=1, shuffle=True)

In [None]:
from models.models import EncoderCNN, DecoderCNN, RewardModel
from models.rssm import RSSM
from models.dynamics import DynamicsModel

device = "cuda" if torch.cuda.is_available() else "cpu"
action_dim = 6
embedding_dim = 1024
hidden_size = 1024
state_size = 30
embedding_size = 16384

encoder = EncoderCNN(in_channels=3, embedding_dim=embedding_size, input_shape=(3, 64, 64)).to(device)
decoder = DecoderCNN(hidden_size, state_size, embedding_size, use_bn=True, output_shape=(3, 64, 64)).to(device)
reward_model = RewardModel(hidden_size, state_size).to(device)
dynamics_model = DynamicsModel(hidden_size, state_size, action_dim, embedding_size).to(device)

rssm = RSSM(encoder, decoder, reward_model, dynamics_model, hidden_size, state_size, action_dim, embedding_size, device=device)
optimizer = torch.optim.Adam(rssm.parameters(), lr=1e-3)


In [None]:
import torch.distributions as dist

def compute_loss(prior_states, posterior_states, imgs, decoder, device):
  # KL loss + reconstruction loss
  kl_loss = torch.mean(torch.sum(
    torch.distributions.kl_divergence(
      dist.Normal(posterior_states[0], posterior_states[1]),
      dist.Normal(prior_states[0], prior_states[1])
    ), dim=-1
  ))
  recon_imgs = decoder(posterior_states[0])
  recon_loss = torch.nn.functional.mse_loss(recon_imgs, imgs.to(device))
  return kl_loss + recon_loss

rssm.train()
for epoch in range(10):
  for imgs, actions, rewards in dataloader:
    imgs = imgs.squeeze(0).to(device)  # [seq, 64, 64, 3]
    actions = actions.squeeze(0).to(device)  # [seq]
    embedded_obs = rssm.encode(imgs)
    hiddens, prior_states, posterior_states, prior_means, prior_logvars, posterior_means, posterior_logvars = \
        rssm.generate_rollout(actions, obs=embedded_obs)
    loss = compute_loss((prior_means, prior_logvars), (posterior_means, posterior_logvars), imgs, rssm.decoder, device)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch}, Loss: {loss.item()}")

print("Training complete.")