In [17]:
# -----------------------------
# 1. Imports and Seed Setup
# -----------------------------
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os

# Reproducibility
torch.manual_seed(42)
np.random.seed(42)


In [18]:
# -----------------------------
# 2. Enhanced PointNet2D
# -----------------------------
class PointNet2D(nn.Module):
    def __init__(self, input_dim=2, emb_dim=256):
        super(PointNet2D, self).__init__()
        self.mlp = nn.Sequential(
            nn.Conv1d(input_dim, 64, kernel_size=1),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Conv1d(64, 128, kernel_size=1),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Conv1d(128, emb_dim, kernel_size=1),
            nn.BatchNorm1d(emb_dim),
            nn.ReLU()
        )

    def forward(self, x, mask=None):
        # x: (B, N, 2) -> (B, 2, N)
        x = x.transpose(1, 2)
        features = self.mlp(x)  # (B, emb_dim, N)

        if mask is not None:
            mask = mask.unsqueeze(1)  # (B, 1, N)
            features = features.masked_fill(mask == 0, -1e2)

        embedding = torch.max(features, dim=2)[0]  # (B, emb_dim)
        return embedding


In [19]:
# -----------------------------
# 3. LSTM Slice Encoder
# -----------------------------
class LSTMSliceEncoder(nn.Module):
    def __init__(self, input_dim=256, hidden_dim=256, num_layers=2, bidirectional=True):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=bidirectional
        )

    def forward(self, x):
        # x: (B, S, D)
        _, (h_n, _) = self.lstm(x)
        if self.lstm.bidirectional:
            return torch.cat((h_n[-2], h_n[-1]), dim=-1)  # (B, 2H)
        else:
            return h_n[-1]  # (B, H)

In [20]:
# -----------------------------
# 4. Cd Regressor MLP
# -----------------------------
class CdRegressor(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        return self.net(x).squeeze(1)

In [21]:
# -----------------------------
# 5. Full Model Assembly
# -----------------------------
class CdPredictorNet(nn.Module):
    def __init__(self, pointnet, lstm_encoder, regressor):
        super().__init__()
        self.pointnet = pointnet
        self.lstm_encoder = lstm_encoder
        self.regressor = regressor

    def forward(self, slices, point_mask, slice_mask):
        B, S, N, D = slices.shape
        flat_slices = slices.view(B * S, N, D)
        flat_mask = point_mask.view(B * S, N)

        slice_embs = self.pointnet(flat_slices, flat_mask)  # (B*S, 256)
        slice_embs = slice_embs.view(B, S, -1)  # (B, 80, 256)

        car_emb = self.lstm_encoder(slice_embs)  # (B, 512 if bidirectional)
        return self.regressor(car_emb)

In [22]:
# -----------------------------
# 6. Dataset Loader
# -----------------------------
class CarSlicesDataset(torch.utils.data.Dataset):
    def __init__(self, ids_txt, npz_dir, csv_path, max_cars=None):
        self.car_ids = [line.strip() for line in open(ids_txt)]
        if max_cars:
            self.car_ids = self.car_ids[:max_cars]
        self.npz_dir = npz_dir
        self.df = pd.read_csv(csv_path)

    def __len__(self):
        return len(self.car_ids)

    def __getitem__(self, idx):
        car_id = self.car_ids[idx]
        data = np.load(os.path.join(self.npz_dir, f"{car_id}_axis-x.npz"))
        slices = torch.tensor(data["slices"], dtype=torch.float32)  # (80, 6500, 2)
        point_mask = torch.tensor(data["point_mask"], dtype=torch.float32)  # (80, 6500)
        slice_mask = torch.tensor(data["slice_mask"], dtype=torch.float32)  # (80,)
        cd_value = self.df[self.df["Design"] == car_id]["Average Cd"].values[0]
        return slices, point_mask, slice_mask, torch.tensor(cd_value, dtype=torch.float32)

In [25]:
import glob
from tqdm import tqdm

def train_model(resume=True, num_epochs=50, max_cars=1000, checkpoint_dir="../outputs/checkpoints", early_stopping_patience=5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Initialize model components
    pointnet = PointNet2D()
    lstm_encoder = LSTMSliceEncoder()
    regressor = CdRegressor(input_dim=512)
    model = CdPredictorNet(pointnet, lstm_encoder, regressor).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.MSELoss()

    os.makedirs(checkpoint_dir, exist_ok=True)
    start_epoch = 1
    best_loss = float("inf")
    patience_counter = 0
    epoch_losses = []

    # Resume from checkpoint if exists
    if resume:
        checkpoints = sorted(glob.glob(f"{checkpoint_dir}/epoch_*.pt"))
        if checkpoints:
            latest_ckpt = checkpoints[-1]
            print(f"🔄 Resuming from checkpoint: {latest_ckpt}")
            state = torch.load(latest_ckpt, map_location=device)
            model.load_state_dict(state['model'])
            optimizer.load_state_dict(state['optimizer'])
            start_epoch = state['epoch'] + 1
            epoch_losses = state.get('epoch_losses', [])
            best_loss = min(epoch_losses) if epoch_losses else float("inf")
        else:
            print("⏩ No previous checkpoint found, starting fresh.")

    # Load dataset
    dataset = CarSlicesDataset(
        ids_txt="../data/subset_dir/train_design_ids.txt",
        npz_dir="../outputs/pad_masked_slices",
        csv_path="../data/DrivAerNetPlusPlus_Drag_8k_cleaned.csv",
        max_cars=max_cars
    )
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True)

    try:
        for epoch in range(start_epoch, num_epochs + 1):
            model.train()
            total_loss = 0.0
            pbar = tqdm(dataloader, desc=f"Epoch {epoch}", unit="batch")

            for slices, point_mask, slice_mask, cd_gt in pbar:
                slices, point_mask, cd_gt = slices.to(device), point_mask.to(device), cd_gt.to(device)

                pred = model(slices, point_mask, slice_mask)
                loss = loss_fn(pred, cd_gt)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                batch_loss = loss.item() * slices.size(0)
                total_loss += batch_loss
                pbar.set_postfix(loss=batch_loss / slices.size(0))

            avg_loss = total_loss / len(dataloader.dataset)
            epoch_losses.append(avg_loss)
            print(f"✅ Epoch {epoch}: Avg MSE = {avg_loss:.5f}")

            # Save checkpoint
            ckpt_path = os.path.join(checkpoint_dir, f"epoch_{epoch:02d}_mse_{avg_loss:.4f}.pt")
            torch.save({
                'epoch': epoch,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch_losses': epoch_losses
            }, ckpt_path)
            print(f"💾 Checkpoint saved: {ckpt_path}")

            # Early stopping
            if avg_loss < best_loss:
                best_loss = avg_loss
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter >= early_stopping_patience:
                    print(f"🛑 Early stopping triggered at epoch {epoch}.")
                    break

    except KeyboardInterrupt:
        print("\n⛔ Interrupted by user. Saving last checkpoint...")
        torch.save({
            'epoch': epoch,
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch_losses': epoch_losses
        }, os.path.join(checkpoint_dir, f"interrupted_epoch_{epoch}.pt"))
        print("🧷 Last checkpoint saved. Safe to resume later.")

    return model, epoch_losses

model = train_model()

⏩ No previous checkpoint found, starting fresh.


Epoch 1:   0%|          | 0/500 [00:13<?, ?batch/s]


⛔ Interrupted by user. Saving last checkpoint...
🧷 Last checkpoint saved. Safe to resume later.





In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import json
import os

def plot_training_analysis(epoch_losses, output_dir="../outputs/analysis"):
    os.makedirs(output_dir, exist_ok=True)

    # Save epoch loss log to disk
    loss_log_path = os.path.join(output_dir, "epoch_losses.json")
    with open(loss_log_path, "w") as f:
        json.dump(epoch_losses, f)
    print(f"📄 Saved epoch loss log to: {loss_log_path}")

    # Plot raw loss curve
    plt.figure(figsize=(8, 5))
    sns.lineplot(x=list(range(1, len(epoch_losses)+1)), y=epoch_losses, marker="o")
    plt.title("📉 Epoch-wise Training MSE Loss")
    plt.xlabel("Epoch")
    plt.ylabel("MSE Loss")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "loss_curve.png"))
    plt.show()

    # Plot loss changes
    loss_deltas = [epoch_losses[i] - epoch_losses[i-1] for i in range(1, len(epoch_losses))]
    plt.figure(figsize=(8, 5))
    sns.barplot(x=list(range(2, len(epoch_losses)+1)), y=loss_deltas)
    plt.title("📊 Change in MSE Loss Across Epochs")
    plt.xlabel("Epoch")
    plt.ylabel("Δ Loss")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "loss_deltas.png"))
    plt.show()

# Example usage (paste this after training):
# model, epoch_losses = train_model()
# plot_training_analysis(epoch_losses)


In [None]:
import torch
import numpy as np
import pandas as pd
from sklearn.metrics import r2_score

# Load full car IDs from training split
with open("data/subset_dir/train_design_ids.txt") as f:
    all_ids = [line.strip() for line in f]

# Select 100 unseen IDs after first 1000 used in training
test_ids = all_ids[1000:1100]

# Load CSV for ground truth Cd values
df = pd.read_csv("data/DrivAerNetPlusPlus_Drag_8k_cleaned.csv")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.eval()

preds = []
trues = []

for car_id in test_ids:
    path = f"outputs/pad_masked_slices/{car_id}_axis-x.npz"
    data = np.load(path)

    slices = torch.tensor(data["slices"], dtype=torch.float32).unsqueeze(0).to(device)
    point_mask = torch.tensor(data["point_mask"], dtype=torch.float32).unsqueeze(0).to(device)
    slice_mask = torch.tensor(data["slice_mask"], dtype=torch.float32).unsqueeze(0).to(device)

    with torch.no_grad():
        cd_pred = model(slices, point_mask, slice_mask).item()

    cd_true = df[df["Design"] == car_id]["Average Cd"].values[0]

    preds.append(cd_pred)
    trues.append(cd_true)

    print(f"🚗 {car_id} → Predicted Cd: {cd_pred:.4f} | True Cd: {cd_true:.4f}")

# Compute and print R²
r2 = r2_score(trues, preds)
print(f"\n📊 R² Score on unseen 100-car subset: {r2:.4f}")
