In [15]:
import torch
import os
import numpy as np
from tqdm import tqdm
ds_tag   = "sensors"
backbone = "bptransformer"
N_folds  = 5
embed_files = [
    f"/data1/bubble3jh/bp_L2P/code/train/embeds_mean/{backbone}/{ds_tag}_fold{i}_ALL.pt" for i in range(N_folds)
]

all_embeddings = []
all_groups     = []

for f in embed_files:
    if not os.path.exists(f):
        print(f"[!] missing: {f}")
        continue

    embs = torch.load(f)
    all_embeddings.append(embs["embeddings"])       # (N_i, D)
    all_groups.append(embs["labels"][:, 0].long())  # SP만 사용해 group 구별했다고 가정한 경우 → 수정 가능

embeddings_tensor = torch.cat(all_embeddings, dim=0)  # (N_total, D)
groups_tensor     = torch.cat(all_groups, dim=0)      # (N_total,)

# -----------------------------------------------------
# 3. group별 평균 embedding 계산
# -----------------------------------------------------
groups = groups_tensor.cpu().numpy()                 # int64
embeds = embeddings_tensor.cpu().numpy()             # float32

mean_embeddings = []
for group in tqdm(range(4), desc="Computing group means"):
    indices         = np.where(groups == group)[0]
    group_embeds    = embeds[indices]
    mean_embedding  = np.mean(group_embeds, axis=0)
    mean_embeddings.append(mean_embedding)

# -----------------------------------------------------
# 4. 저장
# -----------------------------------------------------
mean_embeddings_tensor = torch.tensor(mean_embeddings)  # (4, D)
torch.save(mean_embeddings_tensor, f"{backbone}/{ds_tag}_mean_embeddings.pt")
print(f"Saved mean embedding: {ds_tag}_mean_embeddings.pt {tuple(mean_embeddings_tensor.shape)}")


Computing group means: 100%|██████████| 4/4 [00:00<00:00, 8738.13it/s]

Saved mean embedding: sensors_mean_embeddings.pt (4, 128)



