In [12]:
import minari
import torch
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
from torch import nn, Tensor
import torch.nn.functional as F

In [65]:
dataset = minari.load_dataset('D4RL/pointmaze/umaze-v2')

for episode_data in dataset.iterate_episodes():
    observations = episode_data.observations
    actions = episode_data.actions
    rewards = episode_data.rewards
    terminations = episode_data.terminations
    truncations = episode_data.truncations
    infos = episode_data.infos
obs_tensor = observations["observation"]
print(obs_tensor.shape)
print(actions.shape)
print(terminations.shape)

(34, 4)
(33, 2)
(33,)


In [67]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# obs_tensor = {key: torch.tensor(value) for key, value in observations.items()} # Keys: ["achieved goal", "desired_goal", "observation"]
# obs_tensor = {key: tensor.to(device) for key, tensor in obs_tensor.items()}

obs_tensor = torch.tensor(obs_tensor, device=device)
action_tensor = torch.tensor(actions, device=device)
reward_tensor = torch.tensor(rewards, device=device)
terminations_tensor = torch.tensor(terminations, device=device)
truncations_tensor = torch.tensor(truncations, device=device)

N_min = min(len(obs_tensor), len(action_tensor), len(reward_tensor), len(terminations_tensor), len(truncations_tensor))

obs_tensor = obs_tensor[:N_min]  # Trim to 33
action_tensor = action_tensor[:N_min]
reward_tensor = reward_tensor[:N_min]
terminations_tensor = terminations_tensor[:N_min]
truncations_tensor = truncations_tensor[:N_min]

print(f"obs tensor shape: {obs_tensor.shape}")
print(f"action shape: {actions.shape}")
print(f"terminations shape:{terminations.shape}")
dataset = TrajectoryDataset(obs_tensor, action_tensor, reward_tensor, terminations_tensor, truncations_tensor)


obs tensor shape: torch.Size([33, 4])
action shape: (33, 2)
terminations shape:(33,)


  obs_tensor = torch.tensor(obs_tensor, device=device)


In [69]:
class TrajectoryDataset(Dataset):
    def __init__(self, observations, actions, rewards, terminations, truncations):
        self.observations = observations
        self.actions = actions
        self.rewards = rewards
        self.terminations = terminations
        self.truncations = truncations

    def __len__(self):
        return len(self.observations)

    def __getitem__(self, idx):
        return self.observations[idx], self.actions[idx], self.rewards[idx], self.terminations[idx], self.truncations[idx]

print(f"Observations shape: {obs_tensor.shape}")
print(f"Actions shape: {action_tensor.shape}")
print(f"Rewards shape: {reward_tensor.shape}")
print(f"Terminations shape: {terminations_tensor.shape}")
print(f"Truncations shape: {truncations_tensor.shape}")

dataset = TrajectoryDataset(obs_tensor, action_tensor, reward_tensor, terminations_tensor, truncations_tensor)
print("Length of dataset: ", len(dataset))

Observations shape: torch.Size([33, 4])
Actions shape: torch.Size([33, 2])
Rewards shape: torch.Size([33])
Terminations shape: torch.Size([33])
Truncations shape: torch.Size([33])
Length of dataset:  33


In [74]:
batch_size = 32
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [81]:
def flow_matching_loss(model, x_start, x_end, timesteps):
    """
    Computes the flow matching loss.

    Args:
        model (nn.Module): The trajectory flow model
        x_start (Tensor): Initial observations (N, T, D_obs)
        x_end (Tensor): Target observations (N, T, D_obs)
        timesteps (Tensor): Random time values (N, T, 1)
    
    Returns:
        Tensor: Scalar loss value
    """
    # Generate noisy interpolation between start and end
    alpha = timesteps  # Time-dependent interpolation coefficient
    print("x start, x end, alpha shapes")
    print(x_start.shape, x_end.shape, alpha.shape) 
    # ----------------------------------------------------------------------------------------------------
    # Problemet ligger i shapes til x_start, x_end og alpha, sjekk hvordan de defineres
    # ----------------------------------------------------------------------------------------------------
    x_t = alpha * x_start + (1 - alpha) * x_end  # Interpolated trajectory
    print(f"x_t shape: {x_t.shape}")
    print(f"Timesteps shape: {timesteps.shape}")
    # Predict flow field
    
    velocity_pred = model(x_t, timesteps)
    print(f"Velopcity pred shape: {velocity_pred.shape}")

    # Compute target flow field
    velocity_target = x_end - x_start  # Ideal velocity to reach x_end

    # Compute loss (MSE)
    loss = F.mse_loss(velocity_pred, velocity_target)
    return loss


In [76]:
class TrajectoryFlowModel(nn.Module):
    def __init__(self, obs_dim, hidden_dim=128, num_layers=3):
        """
        A neural network that estimates the velocity field for flow matching.
        
        Args:
            obs_dim (int): Dimensionality of observations (D_obs)
            hidden_dim (int): Number of hidden units in the MLP
            num_layers (int): Number of layers in the MLP
        """
        super().__init__()
        
        layers = []
        input_dim = obs_dim + 1  # We include time `t` as an input
        
        for _ in range(num_layers - 1):
            layers.append(nn.Linear(input_dim, hidden_dim))
            layers.append(nn.ReLU())
            input_dim = hidden_dim  # Keep hidden layer size consistent
        
        layers.append(nn.Linear(hidden_dim, obs_dim))  # Output has the same shape as observations
        self.network = nn.Sequential(*layers)

    def forward(self, x, t):
        """
        Forward pass for the trajectory flow model.
        
        Args:
            x (Tensor): Input observations of shape (N, T, D_obs)
            t (Tensor): Time conditioning of shape (N, T, 1)
        
        Returns:
            Tensor: Predicted velocity field of shape (N, T, D_obs)
        """
        # Concatenate time `t` to observations
        xt = torch.cat([x, t], dim=-1)  # Shape: (N, T, D_obs + 1)
        velocity = self.network(xt)  # Predict flow field
        return velocity  # Shape: (N, T, D_obs)


In [82]:
epochs = 1000
lr = 1e-3
D_Obs = 4 # Magic number here,equal to the batch_trajs.shape found in the for loop
model = TrajectoryFlowModel(D_Obs)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

for epoch in range(epochs):
    for batch in train_loader:
        batch_trajs = batch[0]
        # print(batch_trajs)
        T, D_Obs = batch_trajs.shape
        print(f"T: {T}, D_obs: {D_Obs}")

        split_point = T // 2
        x_start = batch_trajs[:, :split_point]
        x_end = batch_trajs[:, split_point:]

        # Generate random timesteps for this batch
        timesteps = torch.rand(batch_size, split_point, 1).to(device)

        optimizer.zero_grad()

        loss = flow_matching_loss(model, x_start, x_end, timesteps)
        
        loss.backward()
        optimizer.step()
    if epoch % 100 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item()}")
'''
observations = observations.reshape(-1, D_obs)
actions = actions.reshape(-1, D_action)

Might have to pad if trajectories are different length

Possible training loop example:
for obs, act in zip(observations, actions):
    action_pred = model(obs)
    loss = loss_fn(action_pred, act)
    loss.backward()
'''

T: 32, D_obs: 4
x start, x end, alpha shapes
torch.Size([32, 4]) torch.Size([32, 0]) torch.Size([32, 16, 1])


RuntimeError: The size of tensor a (16) must match the size of tensor b (32) at non-singleton dimension 1

In [None]:
# class Flow(nn.Module):
#     def __init__(self, dim: int = 2, h: int = 64):
#         super().__init__()
#         self.net = nn.Sequential(
#             nn.Linear(dim + 1, h), nn.ELU(),
#             nn.Linear(h, h), nn.ELU(),
#             nn.Linear(h, h), nn.ELU(),
#             nn.Linear(h, dim))

#     def forward(self, t: Tensor, x_t: Tensor) -> Tensor:
#         return self.net(torch.cat((t, x_t), -1))

#     def step(self, x_t: Tensor, t_start: Tensor, t_end: Tensor) -> Tensor:
#         t_start = t_start.view(1, 1).expand(x_t.shape[0], 1)
#         return x_t + (t_end - t_start) * self(t=t_start + (t_end - t_start) / 2, x_t= x_t + self(x_t=x_t, t=t_start) * (t_end - t_start) / 2)