In [None]:
# from google.colab import drive
# import os
# drive.mount('/content/drive')

# os.chdir('/content/drive/MyDrive/gdl')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import numpy as np
import matplotlib.pyplot as plt
import time
import os

# --- CONFIGURATION ---
MODEL_MODE = 'grf'  # 'grf', 'baseline', 'mp'
DATA_PATH = "./data/"
BATCH_SIZE = 16
LR = 1e-3 # Learning Rate
EPOCHS = 500
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")
K_NEIGHBORS = 6
VAL_SPLIT = 0.2

# --- DATASET ---
class RobotArmDataset(Dataset):
    def __init__(self, points_path, knn_path):
        if not os.path.exists(points_path):
            raise FileNotFoundError(f"File {points_path} not found.")
        self.points = np.load(points_path).astype(np.float32)
        self.knn = np.load(knn_path).astype(np.int64)

        # Normalization stats
        self.mean = np.mean(self.points, axis=(0, 1))
        self.std = np.std(self.points, axis=(0, 1))
        self.points = (self.points - self.mean) / (self.std + 1e-6)

    def __len__(self):
        return len(self.points) - 1

    def __getitem__(self, idx):
        return torch.from_numpy(self.points[idx]), \
               torch.from_numpy(self.knn[idx]), \
               torch.from_numpy(self.points[idx + 1])

# --- MODELS ---
class LinearAttention(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
        self.to_out = nn.Linear(dim, dim)

    def forward(self, x):
        B, N, D = x.shape
        q, k, v = self.to_qkv(x).chunk(3, dim=-1)
        q, k = torch.relu(q) + 1e-6, torch.relu(k) + 1e-6
        kv = torch.einsum("bnd,bne->bde", k, v)
        z = 1 / (torch.einsum("bnd,bd->bn", q, k.sum(dim=1)) + 1e-6)
        attn = torch.einsum("bnd,bde,bn->bne", q, kv, z)
        return self.to_out(attn)

class TopologicalGRFLayer(nn.Module):
    def __init__(self, dim, k_neighbors, hops=3):
        super().__init__()
        self.k = k_neighbors
        self.hops = hops
        self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
        self.to_out = nn.Linear(dim, dim)

    def forward(self, x, knn_idx):
        B, N, D = x.shape
        # Sparse Matrix Construction
        src = torch.arange(N, device=x.device).view(1, N, 1).expand(B, N, self.k)
        batch_off = torch.arange(B, device=x.device).view(B, 1, 1) * N
        indices = torch.stack([(knn_idx + batch_off).view(-1), (src + batch_off).view(-1)])
        values = torch.ones(indices.shape[1], device=x.device)
        adj = torch.sparse_coo_tensor(indices, values, (B*N, B*N))

        q, k, v = self.to_qkv(x).chunk(3, dim=-1)
        v_f, k_f = v.view(B*N, D), k.view(B*N, D)

        # Random Walk Diffusion
        for _ in range(self.hops):
            v_f = torch.sparse.mm(adj, v_f) / (self.k + 1e-6)
            k_f = torch.sparse.mm(adj, k_f) / (self.k + 1e-6)

        attn = (q * k_f.view(B, N, D)).sum(dim=-1, keepdim=True)
        return self.to_out(attn * v_f.view(B, N, D))

class SimpleMessagePassing(nn.Module):
    def __init__(self, dim, k_neighbors):
        super().__init__()
        self.k = k_neighbors
        self.proj = nn.Linear(dim, dim)

    def forward(self, x, knn_idx):
        B, N, D = x.shape
        flat_idx = knn_idx.reshape(B, N * self.k).unsqueeze(-1).expand(-1, -1, D)
        neighbors = torch.gather(x, 1, flat_idx.reshape(B, N * self.k, D).long()).reshape(B, N, self.k, D)
        return self.proj(neighbors.mean(dim=2))

class UnifiedInterlacer(nn.Module):
    def __init__(self, mode='grf', input_dim=3, embed_dim=64):
        super().__init__()
        self.mode = mode
        self.embedding = nn.Linear(input_dim, embed_dim)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.norm3 = nn.LayerNorm(embed_dim)

        if mode == 'grf':
            self.l1 = TopologicalGRFLayer(embed_dim, K_NEIGHBORS)
            self.l3 = TopologicalGRFLayer(embed_dim, K_NEIGHBORS)
        elif mode == 'mp':
            self.l1 = SimpleMessagePassing(embed_dim, K_NEIGHBORS)
            self.l3 = SimpleMessagePassing(embed_dim, K_NEIGHBORS)
        else:
            self.l1 = nn.Identity()
            self.l3 = nn.Identity()

        self.l2 = LinearAttention(embed_dim)
        self.head = nn.Linear(embed_dim, 3)

    def forward(self, x, knn):
        h = self.embedding(x)
        h = h + (self.l1(self.norm1(h), knn) if self.mode != 'baseline' else self.norm1(h))
        h = h + self.l2(self.norm2(h))
        h = h + (self.l3(self.norm3(h), knn) if self.mode != 'baseline' else self.norm3(h))
        return self.head(h)

# --- MAIN ---
def main():
    dataset = RobotArmDataset("points.npy", "knn_indices.npy")
    train_size = int(len(dataset) * (1 - VAL_SPLIT))
    val_size = len(dataset) - train_size
    train_ds, val_ds = random_split(dataset, [train_size, val_size])
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)

    model = UnifiedInterlacer(mode=MODEL_MODE).to(DEVICE)
    optimizer = optim.AdamW(model.parameters(), lr=LR)
    criterion = nn.MSELoss()

    # Track losses
    train_losses = []
    val_losses = []

    print(f"Start Training: {MODEL_MODE.upper()}")
    for ep in range(EPOCHS):
        # Training
        model.train()
        epoch_train_losses = []
        for x, knn, y in train_loader:
            x, knn, y = x.to(DEVICE), knn.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            pred = model(x, knn)
            loss = criterion(pred, y)
            loss.backward()
            optimizer.step()
            epoch_train_losses.append(loss.item())

        # Validation
        model.eval()
        epoch_val_losses = []
        with torch.no_grad():
            for x, knn, y in val_loader:
                x, knn, y = x.to(DEVICE), knn.to(DEVICE), y.to(DEVICE)
                pred = model(x, knn)
                loss = criterion(pred, y)
                epoch_val_losses.append(loss.item())

        train_loss = np.mean(epoch_train_losses)
        val_loss = np.mean(epoch_val_losses)
        train_losses.append(train_loss)
        val_losses.append(val_loss)

        print(f"Ep {ep+1}: Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}")

    # Plot and save loss curves
    plt.figure(figsize=(10, 6))
    plt.plot(range(1, EPOCHS + 1), train_losses, label='Training Loss', marker='o')
    plt.plot(range(1, EPOCHS + 1), val_losses, label='Validation Loss', marker='s')
    plt.title(f'Training and Validation Loss - {MODEL_MODE.upper()}')
    plt.xlabel('Epoch')
    plt.ylabel('Loss (MSE)')
    plt.yscale('log')
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.savefig(f'loss_plot_{MODEL_MODE}.png', dpi=150, bbox_inches='tight')
    plt.close()
    print(f"Loss plot saved to loss_plot_{MODEL_MODE}.png")

    # Save losses to numpy file
    np.savez(f'losses_{MODEL_MODE}.npz',
             train_losses=np.array(train_losses),
             val_losses=np.array(val_losses))
    print(f"Losses saved to losses_{MODEL_MODE}.npz")

    # SAVE MODEL WEIGHTS (include losses in checkpoint)
    torch.save({
        'model_state_dict': model.state_dict(),
        'mean': dataset.mean,
        'std': dataset.std,
        'mode': MODEL_MODE,
        'train_losses': train_losses,
        'val_losses': val_losses
    }, f"model_{MODEL_MODE}.pth")
    print(f"Model saved to model_{MODEL_MODE}.pth")

if __name__ == "__main__":
    os.chdir(DATA_PATH)

    MODEL_MODE = 'baseline'
    main()
    MODEL_MODE = 'mp'
    main()
    MODEL_MODE = 'grf'
    main()


Using device: cuda
Start Training: BASELINE
Ep 1: Train Loss: 0.435942, Val Loss: 0.117731
Ep 2: Train Loss: 0.081333, Val Loss: 0.064164
Ep 3: Train Loss: 0.058039, Val Loss: 0.056372
Ep 4: Train Loss: 0.051227, Val Loss: 0.050842
Ep 5: Train Loss: 0.046237, Val Loss: 0.045099
Ep 6: Train Loss: 0.042346, Val Loss: 0.041847
Ep 7: Train Loss: 0.038286, Val Loss: 0.036727
Ep 8: Train Loss: 0.034563, Val Loss: 0.033371
Ep 9: Train Loss: 0.030055, Val Loss: 0.029248
Ep 10: Train Loss: 0.028208, Val Loss: 0.031879
Ep 11: Train Loss: 0.025277, Val Loss: 0.026574
Ep 12: Train Loss: 0.024223, Val Loss: 0.026323
Ep 13: Train Loss: 0.024496, Val Loss: 0.024218
Ep 14: Train Loss: 0.023349, Val Loss: 0.027794
Ep 15: Train Loss: 0.023980, Val Loss: 0.023757
Ep 16: Train Loss: 0.022008, Val Loss: 0.023401
Ep 17: Train Loss: 0.021762, Val Loss: 0.023962
Ep 18: Train Loss: 0.020996, Val Loss: 0.021346
Ep 19: Train Loss: 0.021132, Val Loss: 0.023190
Ep 20: Train Loss: 0.021778, Val Loss: 0.021515
Ep 21