In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!cp "/content/drive/MyDrive/converted_datasets.zip" /content/

In [None]:
drive.flush_and_unmount()

In [None]:
!rm -rf "/content/converted_datasets.zip"

In [None]:
import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
from pathlib import Path
import random
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import logging
import matplotlib.pyplot as plt

# Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def compute_cvar(ranges, range_max, alpha=0.05):
    valid = ranges[np.isfinite(ranges)]
    if valid.size == 0 or range_max == 0:
        return 1.0
    n = max(1, int(alpha * valid.size))
    cvar = np.sort(valid)[:n].mean()
    return float(np.clip(1.0 - (cvar / range_max), 0.0, 1.0))


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=3500):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(max_len).unsqueeze(1).float()
        div = torch.exp(torch.arange(0, d_model, 2).float() *
                        -(torch.log(torch.tensor(10000.0)) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe.unsqueeze(0))
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]


class ModalityEncoder(nn.Module):
    def __init__(self, inp_dim, d_model, nhead, layers, dropout=0.3):
        super().__init__()
        self.proj = nn.Linear(inp_dim, d_model)
        self.pos = PositionalEncoding(d_model)
        layer = nn.TransformerEncoderLayer(d_model, nhead, d_model * 4, dropout, batch_first=True)
        self.enc = nn.TransformerEncoder(layer, layers)
        self.norm = nn.LayerNorm(d_model)
        self.drop = nn.Dropout(dropout)
    def forward(self, x, mask=None):
        x = self.proj(x) * (self.proj.out_features ** 0.5)
        x = self.pos(x)
        x = self.drop(x)
        x = self.enc(x, src_key_padding_mask=mask)
        return self.norm(x)


class CrossModalAttention(nn.Module):
    def __init__(self, d_model, nhead, dropout=0.3):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, nhead, dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 4, d_model)
        )
        self.drop = nn.Dropout(dropout)
    def forward(self, q, k, v, kmask=None):
        out, _ = self.attn(q, k, v, key_padding_mask=kmask)
        q = self.norm1(q + self.drop(out))
        ff = self.ff(q)
        return self.norm2(q + self.drop(ff))


class GMMHead(nn.Module):
    def __init__(self, d_model, act_dim, comps=5):
        super().__init__()
        self.K = comps
        self.ad = act_dim
        self.logits = nn.Linear(d_model, comps)
        self.means = nn.Linear(d_model, comps * act_dim)
        self.log_stds = nn.Linear(d_model, comps * act_dim)
    def forward(self, x):
        B, T, _ = x.shape
        logits = self.logits(x)
        means = self.means(x).view(B, T, self.K, self.ad)
        logstd = self.log_stds(x).view(B, T, self.K, self.ad)
        stds = torch.exp(logstd.clamp(-5, 2))
        return logits, means, stds


class FormationGoalEncoder(nn.Module):
    def __init__(self, goal_dim=13, d_model=256, dropout=0.3):
        super().__init__()
        self.enc = nn.Sequential(
            nn.Linear(goal_dim, d_model),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_model, d_model)
        )
    def forward(self, x):
        return self.enc(x)


class SwarmTransformerWithGoal(nn.Module):
    def __init__(self):
        super().__init__()
        self.odom_enc = ModalityEncoder(13, 256, 8, 4)
        self.scan_enc = ModalityEncoder(361, 256, 8, 4)
        self.gossip_enc = ModalityEncoder(70, 256, 8, 4)
        self.goal_enc = FormationGoalEncoder(goal_dim=13, d_model=256)
        self.cross1 = CrossModalAttention(256, 8)  # State + Goal
        self.cross2 = CrossModalAttention(256, 8)  # Goal-aware + Perception
        self.cross3 = CrossModalAttention(256, 8)  # Final + Communication
        self.gmm_head = GMMHead(256, 2, comps=5)
    def forward(self, odom, scan, gossip, formation_goal, smask, gmask):
        o = self.odom_enc(odom)              # [B, T, 256]
        s = self.scan_enc(scan, smask)       # [B, T, 256]
        g = self.gossip_enc(gossip, gmask)   # [B, T, 256]
        fg = self.goal_enc(formation_goal)   # [B, T, 256]

        # Stage 1: Fuse state with goal
        goal_aware_state = self.cross1(o, fg, fg, None)
        # Stage 2: Fuse with perception
        perception_fusion = self.cross2(goal_aware_state, s, s, smask)
        # Stage 3: Fuse with communication/gossip
        full = self.cross3(perception_fusion, g, g, gmask)

        return self.gmm_head(full)
    def loss(self, logits, means, stds, tgt):
        B, T, K, D = means.shape
        lf = logits.view(-1, K)
        mf = means.view(-1, K, D)
        sf = stds.view(-1, K, D)
        tf = tgt.view(-1, D)
        mix = torch.distributions.Categorical(logits=lf)
        comp = torch.distributions.Independent(torch.distributions.Normal(mf, sf), 1)
        gmm = torch.distributions.MixtureSameFamily(mix, comp)
        return -gmm.log_prob(tf).mean()


class SwarmRobotDataset(Dataset):
    def __init__(self, files, max_nb=10, scan_dim=360, seq_len=None, normalize=True):
        self.files = files
        self.max_nb = max_nb
        self.scan_dim = scan_dim
        self.seq_len = seq_len
        self.normalize = normalize
        self.samples = []
        for f in files:
            data = np.load(f, allow_pickle=True)
            for robot in data.files:
                self.samples.append((f, robot))
        assert self.samples, "No samples loaded"
        if self.normalize:
            self._compute_stats()
            logger.info(f"Computed stats on {len(self.samples)} samples")
            logger.info(f"odom_mean: {self.om.tolist()}, odom_std: {self.os.tolist()}")
            logger.info(f"scan_mean (first5): {self.sm.tolist()[:5]}, scan_std (first5): {self.ss.tolist()[:5]}")
            logger.info(f"gossip_mean: {self.gm.tolist()}, gossip_std: {self.gs.tolist()}")
            logger.info(f"cmd_vel_mean: {self.cm.tolist()}, cmd_vel_std: {self.cs.tolist()}")

    def _compute_stats(self):
        od, sc, go, cmd = [], [], [], []
        for f, robot in random.sample(self.samples, min(len(self.samples), 1000)):
            odom, scan, gossip, ms, mg, cmdv = self._load_one(f, robot)
            od.append(odom)
            sc.append(scan)
            go.append(gossip)
            cmd.append(cmdv)
        od = torch.cat(od, 0)
        sc = torch.cat(sc, 0)
        go = torch.cat(go, 0)
        cmd = torch.cat(cmd, 0)
        self.om, self.os = od.mean(0), od.std(0) + 1e-8
        self.sm, self.ss = sc.mean(0), sc.std(0) + 1e-8
        self.gm, self.gs = go.mean(0), go.std(0) + 1e-8
        self.cm, self.cs = cmd.mean(0), cmd.std(0) + 1e-8

    def _load_one(self, f, robot):
        data = np.load(f, allow_pickle=True)
        rd = data[robot].item()
        odom = torch.tensor(rd['odom'], dtype=torch.float32)
        cmdv = torch.tensor(rd['cmd_vel'], dtype=torch.float32)
        scan_list, mask_s = [], []
        for r in rd['scan']:
            if r is None:
                scan_list.append(np.zeros(self.scan_dim + 1))
                mask_s.append(True)
            else:
                array = r.astype(np.float32)
                valid = array[np.isfinite(array)]
                range_max = valid.max() if valid.size > 0 else 1.0
                cvar = compute_cvar(array, float(range_max))
                array = np.where(np.isinf(array), range_max, array) / range_max
                if array.size != self.scan_dim:
                    array = np.interp(np.linspace(0, array.size - 1, self.scan_dim),
                                      np.arange(array.size), array)
                scan_list.append(np.concatenate(([cvar], array)))
                mask_s.append(False)
        scan = torch.tensor(np.vstack(scan_list), dtype=torch.float32)
        mask_s = torch.tensor(mask_s)
        go_list, mask_g = [], []
        for g in rd['gossip']:
            if g.size == 0:
                go_list.append(np.zeros(self.max_nb * 7))
                mask_g.append(True)
            else:
                arr = np.zeros(self.max_nb * 7)
                for i, row in enumerate(g[:self.max_nb]):
                    arr[i * 7:(i + 1) * 7] = row
                go_list.append(arr)
                mask_g.append(False)
        gossip = torch.tensor(np.vstack(go_list), dtype=torch.float32)
        mask_g = torch.tensor(mask_g)
        return odom, scan, gossip, mask_s, mask_g, cmdv

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        f, robot = self.samples[idx]
        odom, scan, gossip, ms, mg, cmdv = self._load_one(f, robot)
        L = odom.size(0)
        if self.seq_len and L > self.seq_len:
            i = random.randint(0, L - self.seq_len)
            odom, scan, gossip, ms, mg, cmdv = [
                x[i:i + self.seq_len] for x in (odom, scan, gossip, ms, mg, cmdv)
            ]
        if self.normalize:
            odom = (odom - self.om) / self.os
            scan = (scan - self.sm) / self.ss
            gossip = (gossip - self.gm) / self.gs
            cmdv = (cmdv - self.cm) / self.cs

        formation_goal_vec = odom[-1].unsqueeze(0).repeat(odom.size(0), 1)  # [L, 13]
        return {
            'odom': odom, 'scan': scan, 'gossip': gossip,
            'scan_mask': ms, 'gossip_mask': mg, 'cmd_vel': cmdv,
            'formation_goal': formation_goal_vec
        }


def collate_fn(batch):
    B = len(batch)
    L = max(x['odom'].size(0) for x in batch)
    odom = torch.zeros(B, L, 13)
    scan = torch.zeros(B, L, 361)
    gossip = torch.zeros(B, L, 70)
    sm = torch.ones(B, L, dtype=torch.bool)
    gm = torch.ones(B, L, dtype=torch.bool)
    cmdv = torch.zeros(B, L, 2)
    formation_goal = torch.zeros(B, L, 13)
    for i, x in enumerate(batch):
        l = x['odom'].size(0)
        odom[i, :l] = x['odom']
        scan[i, :l] = x['scan']
        gossip[i, :l] = x['gossip']
        sm[i, :l] = x['scan_mask']
        gm[i, :l] = x['gossip_mask']
        cmdv[i, :l] = x['cmd_vel']
        formation_goal[i, :l] = x['formation_goal']
    return {
        'odom': odom, 'scan': scan, 'gossip': gossip,
        'scan_mask': sm, 'gossip_mask': gm, 'cmd_vel': cmdv,
        'formation_goal': formation_goal
    }


class SwarmTrainer:
    def __init__(self, model, tl, vl, te, device):
        self.m, self.tl, self.vl, self.te = model.to(device), tl, vl, te
        self.dev = device
        self.opt = torch.optim.AdamW(model.parameters(), lr=1e-5, weight_decay=1e-4)
        self.sch = torch.optim.lr_scheduler.CosineAnnealingLR(self.opt, T_max=100)
        self.best = float('inf')
        self.train_losses, self.val_losses = [], []
        self.train_maes, self.val_maes = [], []
        self.train_rmses, self.val_rmses = [], []

    def _run(self, loader, train):
        self.m.train() if train else self.m.eval()
        total_loss, total_mae, total_rmse, n_samples = 0.0, 0.0, 0.0, 0
        for b in tqdm(loader, desc="Train" if train else "Val"):
            od, sc, go, fg, sm, gm, cmd = [b[k].to(self.dev) for k in
                                           ('odom', 'scan', 'gossip', 'formation_goal', 'scan_mask', 'gossip_mask', 'cmd_vel')]
            with torch.set_grad_enabled(train):
                logits, means, stds = self.m(od, sc, go, fg, sm, gm)
                loss = self.m.loss(logits, means, stds, cmd)
                if train:
                    self.opt.zero_grad()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.m.parameters(), 1.0)
                    self.opt.step()
            total_loss += loss.item()
            n_samples += 1
            maxidx = logits.argmax(dim=-1)
            pred = torch.gather(means, 2, maxidx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 1, means.size(-1))).squeeze(2)
            total_mae += (pred - cmd).abs().mean().item()
            total_rmse += ((pred - cmd) ** 2).mean().sqrt().item()
        avg_loss = total_loss / n_samples
        avg_mae = total_mae / n_samples
        avg_rmse = total_rmse / n_samples
        return avg_loss, avg_mae, avg_rmse

    def train(self, epochs=100):
        for e in range(1, epochs + 1):
            tr_loss, tr_mae, tr_rmse = self._run(self.tl, True)
            val_loss, val_mae, val_rmse = self._run(self.vl, False)
            self.train_losses.append(tr_loss)
            self.val_losses.append(val_loss)
            self.train_maes.append(tr_mae)
            self.val_maes.append(val_mae)
            self.train_rmses.append(tr_rmse)
            self.val_rmses.append(val_rmse)

            if val_loss < self.best:
                self.best = val_loss
                torch.save(self.m.state_dict(), f"best_form_model_{e}.pth")

            self.sch.step()
            print()
            print(f"Epoch: {e}/{epochs} | Train Loss: {tr_loss:.4f} | Val Loss: {val_loss:.4f} "
                  f"Train MAE: {tr_mae:.4f} | Val MAE:{val_mae:.4f} "
                  f"Train RMSE {tr_rmse:.4f} | Val RMSE {val_rmse:.4f}", flush=True)
            logger.info(f"Epoch {e}/{epochs} Train Loss {tr_loss:.4f} Val Loss {val_loss:.4f} "
                        f"Train MAE {tr_mae:.4f} Val MAE {val_mae:.4f} "
                        f"Train RMSE {tr_rmse:.4f} Val RMSE {val_rmse:.4f}")

        plt.figure(); plt.plot(self.train_losses, label='Train'); plt.plot(self.val_losses, label='Val'); plt.legend()
        plt.savefig("training_curve.png")
        te_loss, te_mae, te_rmse = self._run(self.te, False)
        logger.info(f"Test Loss: {te_loss:.4f} | Test MAE: {te_mae:.4f} | Test RMSE: {te_rmse:.4f}")


def main():
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    print("------------------ Starting Code -----------------------")
    print("Device : ", DEVICE)
    DATA_DIR = "/content/converted_datasets"
    files = list(Path(DATA_DIR).glob("*.npz"))
    tr, tmp = train_test_split(files, test_size=0.3, random_state=42)
    vl, te = train_test_split(tmp, test_size=0.5, random_state=42)
    print("Splits made:-")
    print(f"Train: {len(tr)} | Val: {len(vl)} | Test: {len(te)}")
    print("--------------------------------------------------------")
    print("Creating Dataset Classes.")
    train_ds = SwarmRobotDataset(tr, seq_len=3500, normalize=True)
    print("Train Dataset Class created.")
    val_ds = SwarmRobotDataset(vl, seq_len=3500, normalize=False)
    print("Val Dataset Class created.")
    test_ds = SwarmRobotDataset(te, seq_len=3500, normalize=False)
    print("Test Dataset Class created.")
    print((f"Train: {len(train_ds)} | Val: {len(val_ds)} | Test: {len(test_ds)}"))
    for ds in (val_ds, test_ds):
        ds.om, ds.os = train_ds.om, train_ds.os
        ds.sm, ds.ss = train_ds.sm, train_ds.ss
        ds.gm, ds.gs = train_ds.gm, train_ds.gs
        ds.cm, ds.cs = train_ds.cm, train_ds.cs
        ds.normalize = True

    train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, collate_fn=collate_fn, num_workers=0)
    val_loader = DataLoader(val_ds, batch_size=4, shuffle=False, collate_fn=collate_fn, num_workers=0)
    test_loader = DataLoader(test_ds, batch_size=4, shuffle=False, collate_fn=collate_fn, num_workers=0)

    model = SwarmTransformerWithGoal()
    logger.info(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    print("--------------------------------------------------------")
    print("Starting Training.")
    print("--------------------------------------------------------")
    trainer = SwarmTrainer(model, train_loader, val_loader, test_loader, DEVICE)
    trainer.train(epochs=100)


if __name__ == "__main__":
    main()

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
from pathlib import Path
import random
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import logging
import matplotlib.pyplot as plt

# Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def compute_cvar(ranges, range_max, alpha=0.05):
    valid = ranges[np.isfinite(ranges)]
    if valid.size == 0 or range_max == 0:
        return 1.0
    n = max(1, int(alpha * valid.size))
    cvar = np.sort(valid)[:n].mean()
    return float(np.clip(1.0 - (cvar / range_max), 0.0, 1.0))


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=3500):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(max_len).unsqueeze(1).float()
        div = torch.exp(torch.arange(0, d_model, 2).float() *
                        -(torch.log(torch.tensor(10000.0)) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe.unsqueeze(0))
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]


class ModalityEncoder(nn.Module):
    def __init__(self, inp_dim, d_model, nhead, layers, dropout=0.3):
        super().__init__()
        self.proj = nn.Linear(inp_dim, d_model)
        self.pos = PositionalEncoding(d_model)
        layer = nn.TransformerEncoderLayer(d_model, nhead, d_model * 4, dropout, batch_first=True)
        self.enc = nn.TransformerEncoder(layer, layers)
        self.norm = nn.LayerNorm(d_model)
        self.drop = nn.Dropout(dropout)
    def forward(self, x, mask=None):
        x = self.proj(x) * (self.proj.out_features ** 0.5)
        x = self.pos(x)
        x = self.drop(x)
        x = self.enc(x, src_key_padding_mask=mask)
        return self.norm(x)


class CrossModalAttention(nn.Module):
    def __init__(self, d_model, nhead, dropout=0.3):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, nhead, dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 4, d_model)
        )
        self.drop = nn.Dropout(dropout)
    def forward(self, q, k, v, kmask=None):
        out, _ = self.attn(q, k, v, key_padding_mask=kmask)
        q = self.norm1(q + self.drop(out))
        ff = self.ff(q)
        return self.norm2(q + self.drop(ff))


class GMMHead(nn.Module):
    def __init__(self, d_model, act_dim, comps=5):
        super().__init__()
        self.K = comps
        self.ad = act_dim
        self.logits = nn.Linear(d_model, comps)
        self.means = nn.Linear(d_model, comps * act_dim)
        self.log_stds = nn.Linear(d_model, comps * act_dim)
    def forward(self, x):
        B, T, _ = x.shape
        logits = self.logits(x)
        means = self.means(x).view(B, T, self.K, self.ad)
        logstd = self.log_stds(x).view(B, T, self.K, self.ad)
        stds = torch.exp(logstd.clamp(-5, 2))
        return logits, means, stds


class FormationGoalEncoder(nn.Module):
    def __init__(self, goal_dim=13, d_model=256, dropout=0.3):
        super().__init__()
        self.enc = nn.Sequential(
            nn.Linear(goal_dim, d_model),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_model, d_model)
        )
    def forward(self, x):
        return self.enc(x)


class SwarmTransformerWithGoal(nn.Module):
    def __init__(self):
        super().__init__()
        self.odom_enc = ModalityEncoder(13, 256, 8, 4)
        self.scan_enc = ModalityEncoder(361, 256, 8, 4)
        self.gossip_enc = ModalityEncoder(70, 256, 8, 4)
        self.goal_enc = FormationGoalEncoder(goal_dim=13, d_model=256)
        self.cross1 = CrossModalAttention(256, 8)  # State + Goal
        self.cross2 = CrossModalAttention(256, 8)  # Goal-aware + Perception
        self.cross3 = CrossModalAttention(256, 8)  # Final + Communication
        self.gmm_head = GMMHead(256, 2, comps=5)
    def forward(self, odom, scan, gossip, formation_goal, smask, gmask):
        o = self.odom_enc(odom)              # [B, T, 256]
        s = self.scan_enc(scan, smask)       # [B, T, 256]
        g = self.gossip_enc(gossip, gmask)   # [B, T, 256]
        fg = self.goal_enc(formation_goal)   # [B, T, 256]

        # Stage 1: Fuse state with goal
        goal_aware_state = self.cross1(o, fg, fg, None)
        # Stage 2: Fuse with perception
        perception_fusion = self.cross2(goal_aware_state, s, s, smask)
        # Stage 3: Fuse with communication/gossip
        full = self.cross3(perception_fusion, g, g, gmask)

        return self.gmm_head(full)
    def loss(self, logits, means, stds, tgt):
        B, T, K, D = means.shape
        lf = logits.view(-1, K)
        mf = means.view(-1, K, D)
        sf = stds.view(-1, K, D)
        tf = tgt.view(-1, D)
        mix = torch.distributions.Categorical(logits=lf)
        comp = torch.distributions.Independent(torch.distributions.Normal(mf, sf), 1)
        gmm = torch.distributions.MixtureSameFamily(mix, comp)
        return -gmm.log_prob(tf).mean()


class SwarmRobotDataset(Dataset):
    def __init__(self, files, max_nb=10, scan_dim=360, seq_len=None, normalize=True):
        self.files = files
        self.max_nb = max_nb
        self.scan_dim = scan_dim
        self.seq_len = seq_len
        self.normalize = normalize
        self.samples = []
        for f in files:
            data = np.load(f, allow_pickle=True)
            for robot in data.files:
                self.samples.append((f, robot))
        assert self.samples, "No samples loaded"
        if self.normalize:
            self._compute_stats()
            logger.info(f"Computed stats on {len(self.samples)} samples")
            logger.info(f"odom_mean: {self.om.tolist()}, odom_std: {self.os.tolist()}")
            logger.info(f"scan_mean (first5): {self.sm.tolist()[:5]}, scan_std (first5): {self.ss.tolist()[:5]}")
            logger.info(f"gossip_mean: {self.gm.tolist()}, gossip_std: {self.gs.tolist()}")
            logger.info(f"cmd_vel_mean: {self.cm.tolist()}, cmd_vel_std: {self.cs.tolist()}")

    def _compute_stats(self):
        od, sc, go, cmd = [], [], [], []
        for f, robot in random.sample(self.samples, min(len(self.samples), 1000)):
            odom, scan, gossip, ms, mg, cmdv = self._load_one(f, robot)
            od.append(odom)
            sc.append(scan)
            go.append(gossip)
            cmd.append(cmdv)
        od = torch.cat(od, 0)
        sc = torch.cat(sc, 0)
        go = torch.cat(go, 0)
        cmd = torch.cat(cmd, 0)
        self.om, self.os = od.mean(0), od.std(0) + 1e-8
        self.sm, self.ss = sc.mean(0), sc.std(0) + 1e-8
        self.gm, self.gs = go.mean(0), go.std(0) + 1e-8
        self.cm, self.cs = cmd.mean(0), cmd.std(0) + 1e-8

    def _load_one(self, f, robot):
        data = np.load(f, allow_pickle=True)
        rd = data[robot].item()
        odom = torch.tensor(rd['odom'], dtype=torch.float32)
        cmdv = torch.tensor(rd['cmd_vel'], dtype=torch.float32)
        scan_list, mask_s = [], []
        for r in rd['scan']:
            if r is None:
                scan_list.append(np.zeros(self.scan_dim + 1))
                mask_s.append(True)
            else:
                array = r.astype(np.float32)
                valid = array[np.isfinite(array)]
                range_max = valid.max() if valid.size > 0 else 1.0
                cvar = compute_cvar(array, float(range_max))
                array = np.where(np.isinf(array), range_max, array) / range_max
                if array.size != self.scan_dim:
                    array = np.interp(np.linspace(0, array.size - 1, self.scan_dim),
                                      np.arange(array.size), array)
                scan_list.append(np.concatenate(([cvar], array)))
                mask_s.append(False)
        scan = torch.tensor(np.vstack(scan_list), dtype=torch.float32)
        mask_s = torch.tensor(mask_s)
        go_list, mask_g = [], []
        for g in rd['gossip']:
            if g.size == 0:
                go_list.append(np.zeros(self.max_nb * 7))
                mask_g.append(True)
            else:
                arr = np.zeros(self.max_nb * 7)
                for i, row in enumerate(g[:self.max_nb]):
                    arr[i * 7:(i + 1) * 7] = row
                go_list.append(arr)
                mask_g.append(False)
        gossip = torch.tensor(np.vstack(go_list), dtype=torch.float32)
        mask_g = torch.tensor(mask_g)
        return odom, scan, gossip, mask_s, mask_g, cmdv

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        f, robot = self.samples[idx]
        odom, scan, gossip, ms, mg, cmdv = self._load_one(f, robot)
        L = odom.size(0)
        if self.seq_len and L > self.seq_len:
            i = random.randint(0, L - self.seq_len)
            odom, scan, gossip, ms, mg, cmdv = [
                x[i:i + self.seq_len] for x in (odom, scan, gossip, ms, mg, cmdv)
            ]
        if self.normalize:
            odom = (odom - self.om) / self.os
            scan = (scan - self.sm) / self.ss
            gossip = (gossip - self.gm) / self.gs
            cmdv = (cmdv - self.cm) / self.cs

        formation_goal_vec = odom[-1].unsqueeze(0).repeat(odom.size(0), 1)  # [L, 13]
        return {
            'odom': odom, 'scan': scan, 'gossip': gossip,
            'scan_mask': ms, 'gossip_mask': mg, 'cmd_vel': cmdv,
            'formation_goal': formation_goal_vec
        }


def collate_fn(batch):
    B = len(batch)
    L = max(x['odom'].size(0) for x in batch)
    odom = torch.zeros(B, L, 13)
    scan = torch.zeros(B, L, 361)
    gossip = torch.zeros(B, L, 70)
    sm = torch.ones(B, L, dtype=torch.bool)
    gm = torch.ones(B, L, dtype=torch.bool)
    cmdv = torch.zeros(B, L, 2)
    formation_goal = torch.zeros(B, L, 13)
    for i, x in enumerate(batch):
        l = x['odom'].size(0)
        odom[i, :l] = x['odom']
        scan[i, :l] = x['scan']
        gossip[i, :l] = x['gossip']
        sm[i, :l] = x['scan_mask']
        gm[i, :l] = x['gossip_mask']
        cmdv[i, :l] = x['cmd_vel']
        formation_goal[i, :l] = x['formation_goal']
    return {
        'odom': odom, 'scan': scan, 'gossip': gossip,
        'scan_mask': sm, 'gossip_mask': gm, 'cmd_vel': cmdv,
        'formation_goal': formation_goal
    }


class SwarmTrainer:
    def __init__(self, model, tl, vl, te, device):
        self.m, self.tl, self.vl, self.te = model.to(device), tl, vl, te
        self.dev = device
        self.opt = torch.optim.AdamW(model.parameters(), lr=1e-5, weight_decay=1e-4)
        self.sch = torch.optim.lr_scheduler.CosineAnnealingLR(self.opt, T_max=100)
        self.best = float('inf')
        self.train_losses, self.val_losses = [], []
        self.train_maes, self.val_maes = [], []
        self.train_rmses, self.val_rmses = [], []

    def _run(self, loader, train):
        self.m.train() if train else self.m.eval()
        total_loss, total_mae, total_rmse, n_samples = 0.0, 0.0, 0.0, 0
        for b in tqdm(loader, desc="Train" if train else "Val"):
            od, sc, go, fg, sm, gm, cmd = [b[k].to(self.dev) for k in
                                           ('odom', 'scan', 'gossip', 'formation_goal', 'scan_mask', 'gossip_mask', 'cmd_vel')]
            with torch.set_grad_enabled(train):
                logits, means, stds = self.m(od, sc, go, fg, sm, gm)
                loss = self.m.loss(logits, means, stds, cmd)
                if train:
                    self.opt.zero_grad()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.m.parameters(), 1.0)
                    self.opt.step()
            total_loss += loss.item()
            n_samples += 1
            maxidx = logits.argmax(dim=-1)
            pred = torch.gather(means, 2, maxidx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 1, means.size(-1))).squeeze(2)
            total_mae += (pred - cmd).abs().mean().item()
            total_rmse += ((pred - cmd) ** 2).mean().sqrt().item()
        avg_loss = total_loss / n_samples
        avg_mae = total_mae / n_samples
        avg_rmse = total_rmse / n_samples
        return avg_loss, avg_mae, avg_rmse

    def save_checkpoint(self, epoch, path):
        torch.save({
            'epoch': epoch,
            'model_state': self.m.state_dict(),
            'optimizer_state': self.opt.state_dict(),
            'scheduler_state': self.sch.state_dict(),
            'best': self.best
        }, path)

    def train(self, epochs=100):
        for e in range(81, epochs + 1):
            tr_loss, tr_mae, tr_rmse = self._run(self.tl, True)
            val_loss, val_mae, val_rmse = self._run(self.vl, False)
            self.train_losses.append(tr_loss)
            self.val_losses.append(val_loss)
            self.train_maes.append(tr_mae)
            self.val_maes.append(val_mae)
            self.train_rmses.append(tr_rmse)
            self.val_rmses.append(val_rmse)

            if val_loss < self.best:
                self.best = val_loss
                checkpoint_path = f"best_form_model_{e}.pth"
                self.save_checkpoint(e, checkpoint_path)

            self.sch.step()
            print()
            print(f"Epoch: {e}/{epochs} | Train Loss: {tr_loss:.4f} | Val Loss: {val_loss:.4f} "
                  f"Train MAE: {tr_mae:.4f} | Val MAE:{val_mae:.4f} "
                  f"Train RMSE {tr_rmse:.4f} | Val RMSE {val_rmse:.4f}", flush=True)
            logger.info(f"Epoch {e}/{epochs} Train Loss {tr_loss:.4f} Val Loss {val_loss:.4f} "
                        f"Train MAE {tr_mae:.4f} Val MAE {val_mae:.4f} "
                        f"Train RMSE {tr_rmse:.4f} Val RMSE {val_rmse:.4f}")

        plt.figure(); plt.plot(self.train_losses, label='Train'); plt.plot(self.val_losses, label='Val'); plt.legend()
        plt.savefig("training_curve.png")
        te_loss, te_mae, te_rmse = self._run(self.te, False)
        logger.info(f"Test Loss: {te_loss:.4f} | Test MAE: {te_mae:.4f} | Test RMSE: {te_rmse:.4f}")

    def load_model_only(self, path):
        checkpoint = torch.load(path, map_location=self.dev)
        self.m.load_state_dict(checkpoint['model_state'])
        self.opt.load_state_dict(checkpoint['optimizer_state'])
        self.sch.load_state_dict(checkpoint['scheduler_state'])
        print(f"Loaded model weights from {path}")

In [None]:
import torch
from tqdm import tqdm

def evaluate_test_set(model, test_loader, device):
    model.eval()
    total_loss = 0.0
    total_mae = 0.0
    total_rmse = 0.0
    n_samples = 0

    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating Test Set"):
            odom = batch['odom'].to(device)
            scan = batch['scan'].to(device)
            gossip = batch['gossip'].to(device)
            formation_goal = batch['formation_goal'].to(device)
            scan_mask = batch['scan_mask'].to(device)
            gossip_mask = batch['gossip_mask'].to(device)
            cmd_vel = batch['cmd_vel'].to(device)

            logits, means, stds = model(odom, scan, gossip, formation_goal, scan_mask, gossip_mask)
            loss = model.loss(logits, means, stds, cmd_vel)
            total_loss += loss.item()

            maxidx = logits.argmax(dim=-1)
            pred = torch.gather(means, 2, maxidx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 1, means.size(-1))).squeeze(2)

            total_mae += (pred - cmd_vel).abs().mean().item()
            total_rmse += ((pred - cmd_vel) ** 2).mean().sqrt().item()
            n_samples += 1

    avg_loss = total_loss / n_samples
    avg_mae = total_mae / n_samples
    avg_rmse = total_rmse / n_samples

    print(f"Test Loss (NLL): {avg_loss:.4f}")
    print(f"Test MAE: {avg_mae:.4f}")
    print(f"Test RMSE: {avg_rmse:.4f}")

    return avg_loss, avg_mae, avg_rmse

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("------------------ Starting Code -----------------------")
print("Device : ", DEVICE)
DATA_DIR = "/content/converted_datasets"
files = list(Path(DATA_DIR).glob("*.npz"))
tr, tmp = train_test_split(files, test_size=0.3, random_state=42)
vl, te = train_test_split(tmp, test_size=0.5, random_state=42)
print("Splits made:-")
print(f"Train: {len(tr)} | Val: {len(vl)} | Test: {len(te)}")
print("--------------------------------------------------------")
print("Creating Dataset Classes.")
train_ds = SwarmRobotDataset(tr, seq_len=3500, normalize=True)
print("Train Dataset Class created.")
val_ds = SwarmRobotDataset(vl, seq_len=3500, normalize=False)
print("Val Dataset Class created.")
test_ds = SwarmRobotDataset(te, seq_len=3500, normalize=False)
print("Test Dataset Class created.")
print((f"Train: {len(train_ds)} | Val: {len(val_ds)} | Test: {len(test_ds)}"))
for ds in (val_ds, test_ds):
    ds.om, ds.os = train_ds.om, train_ds.os
    ds.sm, ds.ss = train_ds.sm, train_ds.ss
    ds.gm, ds.gs = train_ds.gm, train_ds.gs
    ds.cm, ds.cs = train_ds.cm, train_ds.cs
    ds.normalize = True

test_loader = DataLoader(test_ds, batch_size=4, shuffle=False, collate_fn=collate_fn, num_workers=0)

model = SwarmTransformerWithGoal()
logger.info(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print("--------------------------------------------------------")
print("Starting Training.")
print("--------------------------------------------------------")
checkpoint_path = "best_form_model_99.pth" # Specify the correct checkpoint path here

# Load the checkpoint dictionary
checkpoint = torch.load(checkpoint_path, map_location=DEVICE)

# Load the model state from the checkpoint
model.load_state_dict(checkpoint['model_state'] if 'model_state' in checkpoint else checkpoint)
model.to(DEVICE)

evaluate_test_set(model, test_loader, DEVICE)

MLP BASELINE

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch.optim as optim
import os

class MLPPolicy(nn.Module):
    def __init__(self, input_dim, goal_dim=13, hidden_sizes=[128, 128], output_dim=2):
        super().__init__()
        layers = []
        in_dim = input_dim + goal_dim
        for h in hidden_sizes:
            layers.append(nn.Linear(in_dim, h))
            layers.append(nn.ReLU())
            in_dim = h
        layers.append(nn.Linear(in_dim, output_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, odom, scan, gossip, formation_goal):
        B, T, _ = odom.shape
        x = torch.cat([odom, scan, gossip, formation_goal], dim=-1)
        x = x.view(B * T, -1)
        out = self.net(x)
        return out.view(B, T, -1)


def validate_mlp(model, val_loader, device):
    model.eval()
    total_loss, total_mae, total_rmse, n = 0, 0, 0, 0
    criterion = nn.MSELoss()
    with torch.no_grad():
        for batch in val_loader:
            odom = batch['odom'].to(device)
            scan = batch['scan'].to(device)
            gossip = batch['gossip'].to(device)
            fg = batch['formation_goal'].to(device)
            cmd = batch['cmd_vel'].to(device)

            pred = model(odom, scan, gossip, fg)
            loss = criterion(pred, cmd)
            total_loss += loss.item()
            total_mae += (pred - cmd).abs().mean().item()
            total_rmse += ((pred - cmd) ** 2).mean().sqrt().item()
            n += 1
    avg_loss = total_loss / n
    print(f"Val MSE: {avg_loss:.6f}, MAE: {total_mae/n:.6f}, RMSE: {total_rmse/n:.6f}")
    return avg_loss


def train_mlp(model, train_loader, val_loader, device, epochs=30, lr=1e-4, save_path=""):
    model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    criterion = nn.MSELoss()
    best_val_loss = float('inf')

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
            odom = batch['odom'].to(device)
            scan = batch['scan'].to(device)
            gossip = batch['gossip'].to(device)
            fg = batch['formation_goal'].to(device)
            cmd = batch['cmd_vel'].to(device)

            pred = model(odom, scan, gossip, fg)
            loss = criterion(pred, cmd)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_train_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1} Train MSE: {avg_train_loss:.6f}")

        val_loss = validate_mlp(model, val_loader, device)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), f"best_mlp_{epoch + 1}.pth")
            print(f"Saved best model with val loss {best_val_loss:.6f} at epoch {epoch+1}")

    return model

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, collate_fn=collate_fn, num_workers=0)
val_loader = DataLoader(val_ds, batch_size=4, shuffle=False, collate_fn=collate_fn, num_workers=0)
# MLP baseline
mlp_model = MLPPolicy(input_dim=444, goal_dim=13)
trained_mlp = train_mlp(mlp_model, train_loader, val_loader, device)

In [None]:
def evaluate_mlp_on_test(model, test_loader, device, checkpoint_path=None):
    # Optionally load the best model weights saved during training
    if checkpoint_path is not None:
        model.load_state_dict(torch.load(checkpoint_path, map_location=device))
        print(f"Loaded model weights from {checkpoint_path}")

    model.to(device)
    model.eval()
    total_loss, total_mae, total_rmse, n = 0, 0, 0, 0
    criterion = nn.MSELoss()
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Test Evaluation"):
            odom = batch['odom'].to(device)
            scan = batch['scan'].to(device)
            gossip = batch['gossip'].to(device)
            fg = batch['formation_goal'].to(device)
            cmd = batch['cmd_vel'].to(device)

            pred = model(odom, scan, gossip, fg)
            loss = criterion(pred, cmd)
            total_loss += loss.item()
            total_mae += (pred - cmd).abs().mean().item()
            total_rmse += ((pred - cmd) ** 2).mean().sqrt().item()
            n += 1

    print(f"\nTest MSE: {total_loss/n:.6f}\nTest MAE: {total_mae/n:.6f}\nTest RMSE: {total_rmse/n:.6f}")

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("------------------ Starting Code -----------------------")
print("Device : ", DEVICE)
DATA_DIR = "/content/converted_datasets"
files = list(Path(DATA_DIR).glob("*.npz"))
tr, tmp = train_test_split(files, test_size=0.3, random_state=42)
vl, te = train_test_split(tmp, test_size=0.5, random_state=42)
print("Splits made:-")
print(f"Train: {len(tr)} | Val: {len(vl)} | Test: {len(te)}")
print("--------------------------------------------------------")
print("Creating Dataset Classes.")
train_ds = SwarmRobotDataset(tr, seq_len=3500, normalize=True)
print("Train Dataset Class created.")
val_ds = SwarmRobotDataset(vl, seq_len=3500, normalize=False)
print("Val Dataset Class created.")
test_ds = SwarmRobotDataset(te, seq_len=3500, normalize=False)
print("Test Dataset Class created.")
print((f"Train: {len(train_ds)} | Val: {len(val_ds)} | Test: {len(test_ds)}"))
for ds in (val_ds, test_ds):
    ds.om, ds.os = train_ds.om, train_ds.os
    ds.sm, ds.ss = train_ds.sm, train_ds.ss
    ds.gm, ds.gs = train_ds.gm, train_ds.gs
    ds.cm, ds.cs = train_ds.cm, train_ds.cs
    ds.normalize = True

test_loader = DataLoader(test_ds, batch_size=4, shuffle=False, collate_fn=collate_fn, num_workers=0)

model = MLPPolicy(input_dim=444, goal_dim=13)
logger.info(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print("--------------------------------------------------------")
print("Starting Training.")
print("--------------------------------------------------------")
model.load_state_dict(torch.load("best_mlp_29.pth"))
model.to(DEVICE)

evaluate_mlp_on_test(model, test_loader, DEVICE, checkpoint_path="best_mlp_29.pth") # Load best model weights

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch.optim as optim
from pathlib import Path
from sklearn.model_selection import train_test_split
import logging

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

class RNNPolicy(nn.Module):
    def __init__(self, input_dim, goal_dim=13, hidden_dim=128, num_layers=1, output_dim=2):
        super().__init__()
        self.rnn = nn.GRU(input_dim + goal_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, odom, scan, gossip, formation_goal):
        x = torch.cat([odom, scan, gossip, formation_goal], dim=-1)  # [B, T, input_dim+goal_dim]
        out, _ = self.rnn(x)
        out = self.fc(out)
        return out

def validate_rnn(model, val_loader, device):
    model.eval()
    total_loss, total_mae, total_rmse, n = 0, 0, 0, 0
    criterion = nn.MSELoss()
    with torch.no_grad():
        for batch in val_loader:
            odom = batch['odom'].to(device)
            scan = batch['scan'].to(device)
            gossip = batch['gossip'].to(device)
            fg = batch['formation_goal'].to(device)
            cmd = batch['cmd_vel'].to(device)

            pred = model(odom, scan, gossip, fg)
            loss = criterion(pred, cmd)
            total_loss += loss.item()
            total_mae += (pred - cmd).abs().mean().item()
            total_rmse += ((pred - cmd) ** 2).mean().sqrt().item()
            n += 1
    avg_loss = total_loss / n
    avg_mae = total_mae / n
    avg_rmse = total_rmse / n
    print(f"Val MSE: {avg_loss:.6f}, MAE: {avg_mae:.6f}, RMSE: {avg_rmse:.6f}")
    return avg_loss

def train_rnn(model, train_loader, val_loader, device, epochs=30, lr=1e-4, save_path=""):
    model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    criterion = nn.MSELoss()
    best_val_loss = float('inf')

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1}"):
            odom = batch['odom'].to(device)
            scan = batch['scan'].to(device)
            gossip = batch['gossip'].to(device)
            fg = batch['formation_goal'].to(device)
            cmd = batch['cmd_vel'].to(device)

            pred = model(odom, scan, gossip, fg)
            loss = criterion(pred, cmd)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_train_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch + 1} Train MSE: {avg_train_loss:.6f}")

        val_loss = validate_rnn(model, val_loader, device)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), f"best_rnn_{epoch + 1}.pth")
            print(f"Saved best model with val loss {best_val_loss:.6f} at epoch {epoch + 1}")

    return model

def main():
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    print("------------------ Starting RNN Training -----------------------")
    print("Device:", DEVICE)
    DATA_DIR = "/content/converted_datasets"  # update to your dataset path
    files = list(Path(DATA_DIR).glob("*.npz"))
    tr, tmp = train_test_split(files, test_size=0.3, random_state=42)
    vl, te = train_test_split(tmp, test_size=0.5, random_state=42)
    print("Splits made:-")
    print(f"Train: {len(tr)} | Val: {len(vl)} | Test: {len(te)}")
    print("--------------------------------------------------------")

    train_ds = SwarmRobotDataset(tr, seq_len=256, normalize=True)
    val_ds = SwarmRobotDataset(vl, seq_len=256, normalize=True)
    test_ds = SwarmRobotDataset(te, seq_len=256, normalize=True)

    # Share normalization stats
    for ds in (val_ds, test_ds):
        ds.om, ds.os = train_ds.om, train_ds.os
        ds.sm, ds.ss = train_ds.sm, train_ds.ss
        ds.gm, ds.gs = train_ds.gm, train_ds.gs
        ds.cm, ds.cs = train_ds.cm, train_ds.cs
        ds.normalize = True

    train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, collate_fn=collate_fn, num_workers=0)
    val_loader = DataLoader(val_ds, batch_size=4, shuffle=False, collate_fn=collate_fn, num_workers=0)
    test_loader = DataLoader(test_ds, batch_size=4, shuffle=False, collate_fn=collate_fn, num_workers=0)

    input_dim = 13 + 361 + 70  # odom + scan + gossip features
    model = RNNPolicy(input_dim=input_dim, goal_dim=13, hidden_dim=128, num_layers=1)
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    print("--------------------------------------------------------")
    print("Starting Training.")
    print("--------------------------------------------------------")

    # Train and validate
    trained_model = train_rnn(model, train_loader, val_loader, DEVICE, epochs=30, lr=1e-4, save_path="rnn_best.pth")

if __name__ == "__main__":
    main()

In [None]:
import torch
from tqdm import tqdm

def evaluate_rnn(model, test_loader, device):
    model.eval()
    total_loss, total_mae, total_rmse, n = 0, 0, 0, 0
    criterion = torch.nn.MSELoss()

    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating Test Set"):
            odom = batch['odom'].to(device)
            scan = batch['scan'].to(device)
            gossip = batch['gossip'].to(device)
            fg = batch['formation_goal'].to(device)
            cmd = batch['cmd_vel'].to(device)

            pred = model(odom, scan, gossip, fg)
            loss = criterion(pred, cmd)
            total_loss += loss.item()
            total_mae += (pred - cmd).abs().mean().item()
            total_rmse += ((pred - cmd)**2).mean().sqrt().item()
            n += 1

    avg_loss = total_loss / n
    avg_mae = total_mae / n
    avg_rmse = total_rmse / n

    print(f"Test MSE: {avg_loss:.6f}")
    print(f"Test MAE: {avg_mae:.6f}")
    print(f"Test RMSE: {avg_rmse:.6f}")

    return avg_loss, avg_mae, avg_rmse

# Usage example:
# Load your RNN model and best checkpoint first:
input_dim = 13 + 361 + 70
model = RNNPolicy(input_dim=input_dim, goal_dim=13, hidden_dim=128, num_layers=1)
model.load_state_dict(torch.load("best_rnn_29.pth")) # Specify the correct checkpoint path
model.to(device)
evaluate_rnn(model, test_loader, device)