In [4]:
import os
import glob
import typing as tt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.distributions import Normal, kl
from tqdm import tqdm
import wandb

# --- 1. Configuration (Upgraded) ---

config = {
    # --- Paths ---
    "DATA_PATH": "world_model_data",
    "CHECKPOINT_DIR": "saved_world_model_rssm", # Changed to directory
    # Set this path to resume training, e.g., "saved_world_model_rssm/latest_checkpoint.pth"
    "LOAD_CHECKPOINT_PATH": None, 
    "RUN_NAME": "wm_train_rssm_v1",

    # --- Training Parameters ---
    "BATCH_SIZE": 64,
    "CHUNK_LENGTH": 50,
    "LEARNING_RATE": 4e-4,
    "NUM_EPOCHS": 50,
    "DEVICE": "cuda" if torch.cuda.is_available() else "cpu",
    "GRAD_CLIP_NORM": 100.0,
    "NUM_WORKERS": 0,

    # --- Model Architecture ---
    "INPUT_SHAPE": (1, 54, 132),          # Observation shape (C, H, W)
    "ACTION_DIM_PER_BRANCH": 3,
    "NUM_ACTION_BRANCHES": 3,
    "EMBED_DIM": 256,
    "RNN_HIDDEN_DIM": 512,
    "LATENT_DIM": 512,                  # Stochastic latent dim (z)

    # --- Loss Weights (Upgraded) ---
    "RECON_LOSS_WEIGHT": 1.0,
    "REWARD_LOSS_WEIGHT": 1.0,
    "DONE_LOSS_WEIGHT": 1.0,
    "KL_LOSS_WEIGHT": 0.1,              # Weight for the new KL loss
    "KL_FREE_NATS": 2.0                 # KL "free bits" to prevent posterior collapse
}

# --- 2. Dataset (Unchanged) ---
# ... [你的 WorldModelDataset 代码保持不变] ...
class WorldModelDataset(Dataset):
    """
    Dataset for loading .npz trajectory data and sampling sequences (chunks).
    """
    def __init__(self, data_dir: str, chunk_length: int):
        self.chunk_length = chunk_length
        self.data_files = glob.glob(os.path.join(data_dir, "*.npz"))
        print(f"Found {len(self.data_files)} episode files.")

        self.trajectories = []
        self.chunk_indices = []

        self.load_data()
        self.generate_indices()

    def load_data(self):
        """Loads all .npz files into memory."""
        print("Loading data into memory...")
        for file_path in tqdm(self.data_files):
            try:
                with np.load(file_path) as data:
                    obs = data['observations']
                    act_p0 = data['actions_p0']
                    rew_p0 = data['rewards_p0']
                    dones = data['dones']

                    s_t = obs[:-1]
                    s_next = obs[1:]

                    T = len(act_p0)
                    if not (len(s_t) == T and len(rew_p0) == T and len(dones) == T and len(s_next) == T):
                        print(f"Warning: Skipping {file_path} due to mismatched data lengths.")
                        continue
                    
                    if T == 0:
                        continue

                    self.trajectories.append({
                        's_t': s_t,
                        'a_t': act_p0,
                        'r_next': rew_p0,
                        'd_next': dones,
                        's_next': s_next
                    })
            except Exception as e:
                print(f"Warning: Failed to load {file_path}: {e}")

    def generate_indices(self):
        """Creates indices for all possible sequences of `chunk_length`."""
        print("Generating sequence (chunk) indices...")
        for ep_idx, traj in enumerate(self.trajectories):
            ep_len = len(traj['a_t'])
            if ep_len < self.chunk_length:
                continue

            for start_idx in range(ep_len - self.chunk_length + 1):
                self.chunk_indices.append((ep_idx, start_idx))
        print(f"Total of {len(self.chunk_indices)} training chunks available.")

    def __len__(self) -> int:
        return len(self.chunk_indices)

    def __getitem__(self, idx: int) -> tt.Tuple[torch.Tensor, ...]:
        ep_idx, start_idx = self.chunk_indices[idx]
        end_idx = start_idx + self.chunk_length

        traj = self.trajectories[ep_idx]

        s_t_chunk = traj['s_t'][start_idx:end_idx]
        a_t_chunk = traj['a_t'][start_idx:end_idx]
        r_next_chunk = traj['r_next'][start_idx:end_idx]
        d_next_chunk = traj['d_next'][start_idx:end_idx]
        s_next_chunk = traj['s_next'][start_idx:end_idx]

        num_branches = config["NUM_ACTION_BRANCHES"]
        branch_dim = config["ACTION_DIM_PER_BRANCH"]
        total_action_dim = num_branches * branch_dim

        a_t_one_hot = np.zeros((self.chunk_length, total_action_dim), dtype=np.float32)
        for i in range(num_branches):
            branch_actions = a_t_chunk[:, i]
            # Clamp actions to be valid indices
            branch_actions_clamped = np.clip(branch_actions, 0, branch_dim - 1).astype(int)
            branch_one_hot = np.eye(branch_dim)[branch_actions_clamped]
            a_t_one_hot[:, i*branch_dim : (i+1)*branch_dim] = branch_one_hot

        s_t_tensor = torch.from_numpy(s_t_chunk.astype(np.float32))
        a_t_tensor = torch.from_numpy(a_t_one_hot.astype(np.float32))
        r_next_tensor = torch.from_numpy(r_next_chunk.astype(np.float32))
        d_next_tensor = torch.from_numpy(d_next_chunk.astype(np.float32))
        s_next_tensor = torch.from_numpy(s_next_chunk.astype(np.float32))

        return s_t_tensor, a_t_tensor, r_next_tensor, d_next_tensor, s_next_tensor

# --- 3. Model Definition (Upgraded to RSSM-like) ---

class Encoder(nn.Module):
    """
    Encodes observations (B*T, C, H, W) into embeddings (B*T, embed_dim).
    (Unchanged, your architecture is correct for the input size)
    """
    def __init__(self, embed_dim: int, input_shape: tt.Tuple[int, int, int]):
        super().__init__()
        c, h, w = input_shape

        self.conv_net = nn.Sequential(
            nn.Conv2d(c, 32, kernel_size=8, stride=4), nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nn.ReLU(),
        )

        with torch.no_grad():
            dummy_input = torch.zeros(1, c, h, w)
            dummy_output = self.conv_net(dummy_input)
            self.conv_out_shape = dummy_output.shape[1:]
            self.cnn_out_dim = dummy_output.numel()
            print(f"Encoder CNN output shape: {self.conv_out_shape}, Flattened: {self.cnn_out_dim}")

        self.flatten = nn.Flatten()
        self.fc = nn.Linear(self.cnn_out_dim, embed_dim)

    def forward(self, obs: torch.Tensor) -> torch.Tensor:
        # Normalize observations
        embed = self.conv_net(obs / 255.0) 
        embed = self.flatten(embed)
        embed = self.fc(embed)
        return embed

class Decoder(nn.Module):
    """
    Decodes latent states (B*T, latent_dim) back into observations (B*T, C, H, W).
    (Unchanged, your architecture is correct)
    """
    def __init__(self, latent_dim: int, cnn_out_dim: int, conv_out_shape: tt.Tuple[int, int, int]):
        super().__init__()
        self.fc = nn.Linear(latent_dim, cnn_out_dim)
        self.unflatten = nn.Unflatten(1, conv_out_shape)

        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(64, 64, kernel_size=3, stride=1, padding=1), nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2), nn.ReLU(),
            # output_padding=(2, 0) is the key to matching the (54, 132) shape
            nn.ConvTranspose2d(32, 1, kernel_size=8, stride=4, output_padding=(2, 0)),
            nn.Sigmoid() # Output pixels between 0 and 1
        )

    def forward(self, latent: torch.Tensor) -> torch.Tensor:
        x = self.fc(latent)
        x = self.unflatten(x)
        x = self.deconv(x)
        return x

class WorldModel(nn.Module):
    """
    Stochastic World Model (RSSM-like).
    
    This model predicts the *next* state (t+1) given the current state (t)
    and action (t). It uses the *actual* next observation (s_{t+1}) to
    form a posterior distribution, which is trained to match a prior
    distribution predicted from history (s_{0...t}, a_t).
    """
    def __init__(self,
                 input_shape: tt.Tuple[int, int, int],
                 embed_dim: int,
                 action_dim: int,
                 rnn_hidden_dim: int,
                 latent_dim: int):
        super().__init__()
        self.encoder = Encoder(embed_dim, input_shape)
        
        # RNN to process history (s_0...s_t)
        self.rnn = nn.GRU(embed_dim, rnn_hidden_dim, batch_first=True)

        # Input to transition models: (h_t, a_t)
        transition_input_dim = rnn_hidden_dim + action_dim

        # --- Prior Network (The "Imagination") ---
        # p(z_{t+1} | h_t, a_t)
        self.prior_net = nn.Sequential(
            nn.Linear(transition_input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 2 * latent_dim) # mean and std
        )

        # --- Posterior Network (The "Reality-Check") ---
        # q(z_{t+1} | h_t, a_t, s_{t+1})
        # We add the embedding of the *next* state
        posterior_input_dim = transition_input_dim + embed_dim 
        self.posterior_net = nn.Sequential(
            nn.Linear(posterior_input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 2 * latent_dim) # mean and std
        )
        
        # --- Prediction Heads ---
        # These predict outcomes from the stochastic latent state z
        self.decoder = Decoder(
            latent_dim,
            self.encoder.cnn_out_dim,
            self.encoder.conv_out_shape
        )
        self.reward_head = nn.Sequential(
            nn.Linear(latent_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
        self.done_head = nn.Sequential(
            nn.Linear(latent_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
        
        self.latent_dim = latent_dim

    def forward(self, 
                    s_t: torch.Tensor, 
                    a_t: torch.Tensor, 
                    s_next: torch.Tensor
                ) -> tt.Tuple[torch.Tensor, ...]:
            
            B, T, C, H, W = s_t.shape

            # 1. Encode current observation s_t
            # 使用 .reshape() 替代 .view()
            s_t_flat = s_t.reshape(B * T, C, H, W) 
            embed_t_flat = self.encoder(s_t_flat)
            # 使用 .reshape() 替代 .view()
            embed_t = embed_t_flat.reshape(B, T, -1) 

            # 2. Get history context h_t from RNN
            h_t_seq, _ = self.rnn(embed_t)
            # 使用 .reshape() 替代 .view() (这是导致你错误的行)
            h_t_flat = h_t_seq.reshape(B * T, -1) 
            
            # 3. Flatten action a_t
            # 使用 .reshape() 替代 .view()
            a_t_flat = a_t.reshape(B * T, -1) 

            # --- 4. Form Prior Distribution ---
            # p(z_{t+1} | h_t, a_t)
            transition_input_flat = torch.cat([h_t_flat, a_t_flat], dim=-1)
            prior_mean_std = self.prior_net(transition_input_flat)
            prior_mean, prior_std_pre_act = torch.chunk(prior_mean_std, 2, dim=-1)
            prior_std = F.softplus(prior_std_pre_act) + 1e-4
            prior_dist = Normal(prior_mean, prior_std)

            # --- 5. Form Posterior Distribution ---
            # q(z_{t+1} | h_t, a_t, s_{t+1})
            
            # Encode *next* observation s_next
            # 使用 .reshape() 替代 .view()
            s_next_flat = s_next.reshape(B * T, C, H, W) 
            embed_next_flat = self.encoder(s_next_flat) # Already normalized in encoder

            posterior_input_flat = torch.cat([transition_input_flat, embed_next_flat], dim=-1)
            post_mean_std = self.posterior_net(posterior_input_flat)
            post_mean, post_std_pre_act = torch.chunk(post_mean_std, 2, dim=-1)
            post_std = F.softplus(post_std_pre_act) + 1e-4
            post_dist = Normal(post_mean, post_std)

            # --- 6. Sample from Posterior for Predictions ---
            z_next_flat = post_dist.rsample() # Reparameterization trick

            # --- 7. Predict next state, reward, done ---
            s_next_pred_flat = self.decoder(z_next_flat)
            r_next_pred_flat = self.reward_head(z_next_flat)
            d_next_pred_flat = self.done_head(z_next_flat)

            # Reshape back to (B, T, ...)
            # 使用 .reshape() 替代 .view()
            s_next_pred = s_next_pred_flat.reshape(B, T, C, H, W) 
            r_next_pred = r_next_pred_flat.reshape(B, T, 1)
            d_next_pred = d_next_pred_flat.reshape(B, T, 1)

            return s_next_pred, r_next_pred, d_next_pred, post_dist, prior_dist
    
    def predict_from_prior(self, 
                            s_t_seq: torch.Tensor, 
                            a_t_seq: torch.Tensor
                            ) -> tt.Tuple[torch.Tensor, ...]:
            """
            Used for evaluation/visualization. 
            Predicts s_{t+1} using only the *prior* (imagination).
            Input tensors should be (B, T, ...).
            """
            B, T, C, H, W = s_t_seq.shape

            # Use .reshape()
            s_t_flat = s_t_seq.reshape(B * T, C, H, W)
            embed_t_flat = self.encoder(s_t_flat)
            # Use .reshape()
            embed_t = embed_t_flat.reshape(B, T, -1)

            h_t_seq, _ = self.rnn(embed_t)
            # Use .reshape() (This was the likely source of the error)
            h_t_flat = h_t_seq.reshape(B * T, -1)
            
            # Use .reshape()
            a_t_flat = a_t_seq.reshape(B * T, -1)

            transition_input_flat = torch.cat([h_t_flat, a_t_flat], dim=-1)
            prior_mean_std = self.prior_net(transition_input_flat)
            prior_mean, prior_std_pre_act = torch.chunk(prior_mean_std, 2, dim=-1)
            prior_std = F.softplus(prior_std_pre_act) + 1e-4
            
            # Sample from the prior mean for stability during viz
            z_next_flat = prior_mean 

            s_next_pred_flat = self.decoder(z_next_flat)
            r_next_pred_flat = self.reward_head(z_next_flat)
            d_next_pred_flat = self.done_head(z_next_flat)

            # Use .reshape()
            s_next_pred = s_next_pred_flat.reshape(B, T, C, H, W)
            # Use .reshape()
            r_next_pred = r_next_pred_flat.reshape(B, T, 1)
            # Use .reshape()
            d_next_pred = d_next_pred_flat.reshape(B, T, 1)

            return s_next_pred, r_next_pred, d_next_pred


# --- 4. Training Loop (Upgraded) ---

if __name__ == "__main__":
    run = wandb.init(
        project="WorldModel-Training-RSSM",
        config=config,
        name=config["RUN_NAME"]
    )
    # Update config with wandb sweeps, if any
    config = wandb.config 

    device = torch.device(config.DEVICE)
    print(f"Using device: {device}")

    if not os.path.exists(config.CHECKPOINT_DIR):
        os.makedirs(config.CHECKPOINT_DIR)

    dataset = WorldModelDataset(config.DATA_PATH, config.CHUNK_LENGTH)
    dataloader = DataLoader(
        dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=True,
        num_workers=config.NUM_WORKERS,
        pin_memory=True
    )

    vis_batch_size = max(1, config.BATCH_SIZE // 4)
    vis_dataloader_iter = iter(DataLoader(
        dataset,
        batch_size=vis_batch_size,
        shuffle=True,
        num_workers=config.NUM_WORKERS,
        pin_memory=True
    ))

    total_action_dim = config.NUM_ACTION_BRANCHES * config.ACTION_DIM_PER_BRANCH
    model = WorldModel(
        input_shape=config.INPUT_SHAPE,
        embed_dim=config.EMBED_DIM,
        action_dim=total_action_dim,
        rnn_hidden_dim=config.RNN_HIDDEN_DIM,
        latent_dim=config.LATENT_DIM
    ).to(device)

    optimizer = optim.Adam(model.parameters(), lr=config.LEARNING_RATE)

    recon_loss_fn = nn.MSELoss()
    reward_loss_fn = nn.MSELoss()
    done_loss_fn = nn.BCEWithLogitsLoss()

    print(f"Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
    wandb.watch(model, log="all", log_freq=100)

    # --- Checkpoint Loading Logic ---
    start_epoch = 0
    global_step = 0
    
    if config.LOAD_CHECKPOINT_PATH and os.path.exists(config.LOAD_CHECKPOINT_PATH):
        print(f"Loading checkpoint from: {config.LOAD_CHECKPOINT_PATH}")
        try:
            checkpoint = torch.load(config.LOAD_CHECKPOINT_PATH, map_location=device)
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            start_epoch = checkpoint.get('epoch', 0) + 1 # Resume from the *next* epoch
            global_step = checkpoint.get('global_step', 0)
            print(f"Successfully loaded checkpoint. Resuming from Epoch {start_epoch}, Global Step {global_step}")
        except Exception as e:
            print(f"Warning: Failed to load checkpoint. {e}. Starting from scratch.")
            start_epoch = 0
            global_step = 0
    else:
        print("Starting training from scratch.")


    for epoch in range(start_epoch, config.NUM_EPOCHS):
        print(f"\n--- Epoch {epoch+1} / {config.NUM_EPOCHS} ---")
        model.train()

        for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}"):
            s_t, a_t, r_next, d_next, s_next = batch
            
            s_t = s_t.to(device)
            a_t = a_t.to(device)
            r_next = r_next.to(device).unsqueeze(-1)
            d_next = d_next.to(device).unsqueeze(-1)
            s_next = s_next.to(device)

            # --- Model Forward Pass ---
            # Pass s_next to the model for posterior calculation
            s_next_pred, r_next_pred, d_next_pred, post_dist, prior_dist = model(s_t, a_t, s_next)

            # --- Target Preparation ---
            # The decoder outputs sigmoid (0,1), so target is normalized
            s_next_target = s_next / 255.0 
            r_next_target = r_next
            d_next_target = d_next.float()

            # --- Loss Calculation ---
            loss_recon = recon_loss_fn(s_next_pred, s_next_target)
            loss_reward = reward_loss_fn(r_next_pred, r_next_target)
            loss_done = done_loss_fn(d_next_pred, d_next_target)

            # --- New KL Divergence Loss ---
            kl_div = kl.kl_divergence(post_dist, prior_dist)
            # Apply "free nats" to prevent KL from collapsing to zero
            loss_kl_raw = kl_div.mean()
            loss_kl = kl_div.clamp(min=config.KL_FREE_NATS).mean()

            total_loss = (
                config.RECON_LOSS_WEIGHT * loss_recon +
                config.REWARD_LOSS_WEIGHT * loss_reward +
                config.DONE_LOSS_WEIGHT * loss_done +
                config.KL_LOSS_WEIGHT * loss_kl
            )

            optimizer.zero_grad()
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.GRAD_CLIP_NORM)
            optimizer.step()

            if global_step % 100 == 0:
                wandb.log({
                    "train/total_loss": total_loss.item(),
                    "train/recon_loss": loss_recon.item(),
                    "train/reward_loss": loss_reward.item(),
                    "train/done_loss": loss_done.item(),
                    "train/kl_loss": loss_kl.item(),
                    "train/kl_loss_raw": loss_kl_raw.item(),
                    "epoch": epoch,
                }, step=global_step)

            global_step += 1

        print(f"Epoch {epoch+1} finished. Total Loss: {total_loss.item():.4f}, KL Loss: {loss_kl.item():.4f}")

        # --- Visualization ---
        try:
            model.eval()

            try:
                vis_s_t, vis_a_t, _, _, vis_s_next = next(vis_dataloader_iter)
            except StopIteration:
                vis_dataloader_iter = iter(DataLoader(
                    dataset, batch_size=vis_batch_size, shuffle=True,
                    num_workers=config.NUM_WORKERS, pin_memory=True
                ))
                vis_s_t, vis_a_t, _, _, vis_s_next = next(vis_dataloader_iter)

            vis_s_t = vis_s_t.to(device)
            vis_a_t = vis_a_t.to(device)
            vis_s_next = vis_s_next.to(device)

            with torch.no_grad():
                # Use the "imagination" (prior) for visualization
                vis_s_next_pred, _, _ = model.predict_from_prior(vis_s_t, vis_a_t)

            s_next_target = (vis_s_next[0] / 255.0).cpu().detach()
            s_next_pred_vis = vis_s_next_pred[0].cpu().detach()

            comparison_video_tensor = torch.cat(
                [s_next_target, s_next_pred_vis], dim=3
            ).clamp(0, 1)

            if comparison_video_tensor.shape[1] == 1:
                comparison_video_tensor = comparison_video_tensor.expand(-1, 3, -1, -1)
            
            video_data = (comparison_video_tensor * 255).to(torch.uint8)

            wandb.log({
                "eval/prediction_video": wandb.Video(
                    video_data,
                    fps=8,
                    caption="Left: Ground Truth (s_t+1), Right: Prediction (from prior)",
                    format="gif"
                )
            }, step=global_step)

        except Exception as e:
            print(f"Failed to log video: {e}")
        finally:
            model.train()

        # --- Checkpoint Saving ---
        checkpoint_data = {
            'epoch': epoch,
            'global_step': global_step,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }
        
        # Save epoch-specific checkpoint
        save_path_epoch = os.path.join(config.CHECKPOINT_DIR, f"wm_epoch_{epoch+1}.pth")
        torch.save(checkpoint_data, save_path_epoch)
        
        # Save latest checkpoint
        save_path_latest = os.path.join(config.CHECKPOINT_DIR, "latest_checkpoint.pth")
        torch.save(checkpoint_data, save_path_latest)
        
        print(f"Model saved to: {save_path_latest}")

    print("Training complete.")
    wandb.finish()

[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▃█
train/done_loss,█▁▁
train/kl_loss,▁▁▁
train/kl_loss_raw,▁▇█
train/recon_loss,█▁▁
train/reward_loss,█▁▂
train/total_loss,█▁▁

0,1
epoch,3.0
train/done_loss,0.0
train/kl_loss,2.0
train/kl_loss_raw,1.08824
train/recon_loss,0.00646
train/reward_loss,0.00501
train/total_loss,0.21147


Using device: cuda
Found 464 episode files.
Loading data into memory...


100%|██████████| 464/464 [00:09<00:00, 48.98it/s]


Generating sequence (chunk) indices...
Total of 905264 training chunks available.
Encoder CNN output shape: torch.Size([64, 5, 15]), Flattened: 4800
Model parameters: 5942723
Starting training from scratch.

--- Epoch 1 / 50 ---


Epoch 1: 100%|██████████| 14145/14145 [1:26:01<00:00,  2.74it/s]


Epoch 1 finished. Total Loss: 0.2058, KL Loss: 2.0000


Model saved to: saved_world_model_rssm\latest_checkpoint.pth

--- Epoch 2 / 50 ---


Epoch 2: 100%|██████████| 14145/14145 [1:25:59<00:00,  2.74it/s]


Epoch 2 finished. Total Loss: 0.2053, KL Loss: 2.0000


Model saved to: saved_world_model_rssm\latest_checkpoint.pth

--- Epoch 3 / 50 ---


Epoch 3: 100%|██████████| 14145/14145 [1:25:56<00:00,  2.74it/s]


Epoch 3 finished. Total Loss: 0.2045, KL Loss: 2.0000


Model saved to: saved_world_model_rssm\latest_checkpoint.pth

--- Epoch 4 / 50 ---


Epoch 4: 100%|██████████| 14145/14145 [1:25:44<00:00,  2.75it/s]


Epoch 4 finished. Total Loss: 0.2048, KL Loss: 2.0000


Model saved to: saved_world_model_rssm\latest_checkpoint.pth

--- Epoch 5 / 50 ---


Epoch 5: 100%|██████████| 14145/14145 [1:25:32<00:00,  2.76it/s]


Epoch 5 finished. Total Loss: 0.2057, KL Loss: 2.0000


Model saved to: saved_world_model_rssm\latest_checkpoint.pth

--- Epoch 6 / 50 ---


Epoch 6: 100%|██████████| 14145/14145 [1:25:35<00:00,  2.75it/s]


Epoch 6 finished. Total Loss: 0.2044, KL Loss: 2.0000


Model saved to: saved_world_model_rssm\latest_checkpoint.pth

--- Epoch 7 / 50 ---


Epoch 7: 100%|██████████| 14145/14145 [1:25:36<00:00,  2.75it/s]


Epoch 7 finished. Total Loss: 0.2049, KL Loss: 2.0000


Model saved to: saved_world_model_rssm\latest_checkpoint.pth

--- Epoch 8 / 50 ---


Epoch 8: 100%|██████████| 14145/14145 [1:25:36<00:00,  2.75it/s]


Epoch 8 finished. Total Loss: 0.2044, KL Loss: 2.0000


Model saved to: saved_world_model_rssm\latest_checkpoint.pth

--- Epoch 9 / 50 ---


Epoch 9: 100%|██████████| 14145/14145 [1:25:35<00:00,  2.75it/s]


Epoch 9 finished. Total Loss: 0.2043, KL Loss: 2.0000


Model saved to: saved_world_model_rssm\latest_checkpoint.pth

--- Epoch 10 / 50 ---


Epoch 10: 100%|██████████| 14145/14145 [1:25:34<00:00,  2.75it/s]


Epoch 10 finished. Total Loss: 0.2047, KL Loss: 2.0000


Model saved to: saved_world_model_rssm\latest_checkpoint.pth

--- Epoch 11 / 50 ---


Epoch 11: 100%|██████████| 14145/14145 [1:25:34<00:00,  2.75it/s]


Epoch 11 finished. Total Loss: 0.2046, KL Loss: 2.0000


Model saved to: saved_world_model_rssm\latest_checkpoint.pth

--- Epoch 12 / 50 ---


Epoch 12: 100%|██████████| 14145/14145 [1:25:34<00:00,  2.75it/s]


Epoch 12 finished. Total Loss: 0.2043, KL Loss: 2.0000


Model saved to: saved_world_model_rssm\latest_checkpoint.pth

--- Epoch 13 / 50 ---


Epoch 13: 100%|██████████| 14145/14145 [1:25:42<00:00,  2.75it/s]


Epoch 13 finished. Total Loss: 0.2046, KL Loss: 2.0000


Model saved to: saved_world_model_rssm\latest_checkpoint.pth

--- Epoch 14 / 50 ---


Epoch 14: 100%|██████████| 14145/14145 [1:25:43<00:00,  2.75it/s]


Epoch 14 finished. Total Loss: 0.2047, KL Loss: 2.0000


Model saved to: saved_world_model_rssm\latest_checkpoint.pth

--- Epoch 15 / 50 ---


Epoch 15: 100%|██████████| 14145/14145 [1:25:43<00:00,  2.75it/s]


Epoch 15 finished. Total Loss: 0.2051, KL Loss: 2.0000


Model saved to: saved_world_model_rssm\latest_checkpoint.pth

--- Epoch 16 / 50 ---


Epoch 16: 100%|██████████| 14145/14145 [1:25:45<00:00,  2.75it/s]


Epoch 16 finished. Total Loss: 0.2043, KL Loss: 2.0000


Model saved to: saved_world_model_rssm\latest_checkpoint.pth

--- Epoch 17 / 50 ---


Epoch 17: 100%|██████████| 14145/14145 [1:25:45<00:00,  2.75it/s]


Epoch 17 finished. Total Loss: 0.2046, KL Loss: 2.0000


Model saved to: saved_world_model_rssm\latest_checkpoint.pth

--- Epoch 18 / 50 ---


Epoch 18: 100%|██████████| 14145/14145 [1:25:45<00:00,  2.75it/s]


Epoch 18 finished. Total Loss: 0.2047, KL Loss: 2.0000


Model saved to: saved_world_model_rssm\latest_checkpoint.pth

--- Epoch 19 / 50 ---


Epoch 19: 100%|██████████| 14145/14145 [1:25:43<00:00,  2.75it/s]


Epoch 19 finished. Total Loss: 0.2051, KL Loss: 2.0000


Model saved to: saved_world_model_rssm\latest_checkpoint.pth

--- Epoch 20 / 50 ---


Epoch 20: 100%|██████████| 14145/14145 [1:25:33<00:00,  2.76it/s]

Epoch 20 finished. Total Loss: 0.2043, KL Loss: 2.0000





Model saved to: saved_world_model_rssm\latest_checkpoint.pth

--- Epoch 21 / 50 ---


Epoch 21: 100%|██████████| 14145/14145 [1:25:32<00:00,  2.76it/s]


Epoch 21 finished. Total Loss: 0.2047, KL Loss: 2.0000


Model saved to: saved_world_model_rssm\latest_checkpoint.pth

--- Epoch 22 / 50 ---


Epoch 22: 100%|██████████| 14145/14145 [1:25:32<00:00,  2.76it/s]


Epoch 22 finished. Total Loss: 0.2042, KL Loss: 2.0000


Model saved to: saved_world_model_rssm\latest_checkpoint.pth

--- Epoch 23 / 50 ---


Epoch 23: 100%|██████████| 14145/14145 [1:25:32<00:00,  2.76it/s]


Epoch 23 finished. Total Loss: 0.2051, KL Loss: 2.0000


Model saved to: saved_world_model_rssm\latest_checkpoint.pth

--- Epoch 24 / 50 ---


Epoch 24: 100%|██████████| 14145/14145 [1:25:31<00:00,  2.76it/s]


Epoch 24 finished. Total Loss: 0.2043, KL Loss: 2.0000


Model saved to: saved_world_model_rssm\latest_checkpoint.pth

--- Epoch 25 / 50 ---


Epoch 25: 100%|██████████| 14145/14145 [1:25:29<00:00,  2.76it/s]


Epoch 25 finished. Total Loss: 0.2047, KL Loss: 2.0000


Model saved to: saved_world_model_rssm\latest_checkpoint.pth

--- Epoch 26 / 50 ---


Epoch 26: 100%|██████████| 14145/14145 [1:25:30<00:00,  2.76it/s]


Epoch 26 finished. Total Loss: 0.2047, KL Loss: 2.0000


Model saved to: saved_world_model_rssm\latest_checkpoint.pth

--- Epoch 27 / 50 ---


Epoch 27: 100%|██████████| 14145/14145 [1:25:28<00:00,  2.76it/s]


Epoch 27 finished. Total Loss: 0.2046, KL Loss: 2.0000


Model saved to: saved_world_model_rssm\latest_checkpoint.pth

--- Epoch 28 / 50 ---


Epoch 28: 100%|██████████| 14145/14145 [1:25:28<00:00,  2.76it/s]


Epoch 28 finished. Total Loss: 0.2046, KL Loss: 2.0000


Model saved to: saved_world_model_rssm\latest_checkpoint.pth

--- Epoch 29 / 50 ---


Epoch 29: 100%|██████████| 14145/14145 [1:25:30<00:00,  2.76it/s]


Epoch 29 finished. Total Loss: 0.2047, KL Loss: 2.0000


Model saved to: saved_world_model_rssm\latest_checkpoint.pth

--- Epoch 30 / 50 ---


Epoch 30: 100%|██████████| 14145/14145 [1:25:31<00:00,  2.76it/s]


Epoch 30 finished. Total Loss: 0.2046, KL Loss: 2.0000


Model saved to: saved_world_model_rssm\latest_checkpoint.pth

--- Epoch 31 / 50 ---


Epoch 31:  96%|█████████▌| 13587/14145 [1:22:11<03:22,  2.76it/s]


KeyboardInterrupt: 