In [1]:
import os
import re
import glob
import random
import numpy as np
import pandas as pd
import ast
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sentence_transformers import SentenceTransformer
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans

In [2]:
# Optional: faiss for fast retrieval
try:
    import faiss
    _HAS_FAISS = True
except Exception:
    _HAS_FAISS = False
    print("faiss not installed - retrieval will use sklearn NearestNeighbors fallback.")

In [3]:
# ----------------------------
# Config / Hyperparams
# ----------------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

WINDOW = 32        # frames per sample
STEP = 16          # sliding step
INPUT_DIM = 33*3   # per frame flattened dim (if your CSV format differs, see loader)
LATENT_DIM = 256   #64
HIDDEN = 256
BATCH_SIZE = 32
VAE_EPOCHS = 40
PROJ_EPOCHS = 30
LR = 1e-3

POEM_MODEL_NAME = "all-MiniLM-L6-v2"  # sentence-transformers for poem embeddings

In [4]:
# ----------------------------
# Utilities: CSV loader (supports 2 formats)
# Format A: each frame as 33 rows (joint per row) repeated T times (rows = T*33)
# Format B: each frame as a row with 99 columns (x0,y0,z0,...)
# ----------------------------
def load_csv_as_frames(path):
    """
    Load CSV where each row = one frame,
    and each pN cell contains a tuple like ('x','y','z').
    Output shape: (T, 33, 3)
    """
    # 自動偵測分隔符
    with open(path, 'r', encoding='utf-8') as f:
        sample = f.read(1024)
        f.seek(0)
        dialect = csv.Sniffer().sniff(sample, delimiters=[',','\t',';'])
    df = pd.read_csv(path, sep=dialect.delimiter, engine='python')

    # 找出以 p 開頭的欄位
    joint_cols = [c for c in df.columns if c.startswith("p")]
    if len(joint_cols) == 0:
        raise ValueError(f"No joint columns found in {path}. Found columns: {df.columns.tolist()}")

    frames = []
    for _, row in df.iterrows():
        joints = []
        for c in joint_cols:
            cell = str(row[c]).strip()
            if cell.startswith("(") and cell.endswith(")"):
                try:
                    tup = ast.literal_eval(cell)
                    joints.append([float(tup[0]), float(tup[1]), float(tup[2])])
                except Exception:
                    joints.append([0.0, 0.0, 0.0])
            else:
                joints.append([0.0, 0.0, 0.0])
        frames.append(joints)

    frames = np.array(frames, dtype=np.float32)

    # reshape 如果必要
    if frames.ndim == 2 and frames.shape[1] % 3 == 0:
        frames = frames.reshape(frames.shape[0], -1, 3)

    if frames.ndim != 3 or frames.shape[2] != 3:
        raise ValueError(f"Unexpected frame shape {frames.shape} in {path}")

    print(f"Loaded {path} → shape {frames.shape}")
    return frames

In [5]:
# ----------------------------
# Windowing & feature preprocessing
# ----------------------------
def create_windows_from_files(csv_paths, window=WINDOW, step=STEP, do_velocity=True, center_on_hip=True):
    """
    Returns windows np.array shape (N_windows, window, 99)
    """
    windows = []
    for p in csv_paths:
        frames = load_csv_as_frames(p)  # (T,33,3)
        T = frames.shape[0]
        # center on hip (joint 0 assumed hip). if not, adjust mapping.
        if center_on_hip:
            hip = frames[:, 0:1, :]  # (T,1,3)
            frames = frames - hip  # broadcast
        # flatten per frame
        flat = frames.reshape(T, -1)  # (T,99)
        # optional velocity feature
        if do_velocity:
            vel = np.zeros_like(flat)
            vel[1:] = flat[1:] - flat[:-1]
            feat = np.concatenate([flat, vel], axis=1)  # (T,198)
        else:
            feat = flat
        # normalize per-file
        mu = feat.mean(axis=0, keepdims=True)
        sd = feat.std(axis=0, keepdims=True) + 1e-8
        feat = (feat - mu) / sd
        # sliding windows
        for i in range(0, T - window + 1, step):
            w = feat[i:i+window]  # (window, dim)
            windows.append(w.astype(np.float32))
    if len(windows) == 0:
        return np.zeros((0, window, feat.shape[1]), dtype=np.float32)
    return np.stack(windows)  # (N, window, dim)

class SkeletonWindowDataset(Dataset):
    def __init__(self, windows):
        # windows: numpy array (N, window, dim)
        self.windows = windows
    def __len__(self): return len(self.windows)
    def __getitem__(self, idx):
        return self.windows[idx]

In [6]:
# ----------------------------
# VAE seq model (simple LSTM-based)
# ----------------------------
class SeqEncoder(nn.Module):
    def __init__(self, input_dim, hidden=HIDDEN, latent=LATENT_DIM):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden, batch_first=True)
        self.fc_mu = nn.Linear(hidden, latent)
        self.fc_logvar = nn.Linear(hidden, latent)
    def forward(self, x):
        # x: (B, T, input_dim)
        _, (h, _) = self.lstm(x)
        h = h.squeeze(0)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

class SeqDecoder(nn.Module):
    def __init__(self, latent=LATENT_DIM, hidden=HIDDEN, out_dim=None, seq_len=WINDOW):
        super().__init__()
        self.fc = nn.Linear(latent, hidden)
        self.lstm = nn.LSTM(hidden, hidden, batch_first=True)
        self.out = nn.Linear(hidden, out_dim)
        self.seq_len = seq_len
    def forward(self, z):
        # z: (B, latent)
        h0 = torch.tanh(self.fc(z)).unsqueeze(0)  # (1,B,hidden)
        c0 = torch.zeros_like(h0)
        # feed zeros
        B = z.size(0)
        inp = torch.zeros(B, self.seq_len, h0.size(-1), device=z.device)
        out, _ = self.lstm(inp, (h0, c0))
        return self.out(out)  # (B, T, out_dim)

class VAESeq(nn.Module):
    def __init__(self, input_dim, hidden=HIDDEN, latent=LATENT_DIM, seq_len=WINDOW):
        super().__init__()
        self.enc = SeqEncoder(input_dim, hidden, latent)
        self.dec = SeqDecoder(latent, hidden, input_dim, seq_len)
    def reparam(self, mu, logvar):
        std = (0.5 * logvar).exp()
        eps = torch.randn_like(std)
        return mu + eps * std
    def forward(self, x):
        mu, logvar = self.enc(x)
        z = self.reparam(mu, logvar)
        recon = self.dec(z)
        return recon, mu, logvar, z

In [7]:
# ----------------------------
# Training utilities
# ----------------------------
def train_vae(model, dataloader, epochs=VAE_EPOCHS, lr=LR, beta=1.0):
    model.to(DEVICE)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    for epoch in range(epochs):
        model.train()
        total_loss = 0.0
        total_recon = 0.0
        total_kl = 0.0
        for batch in tqdm(dataloader, desc=f"VAE epoch {epoch+1}/{epochs}"):
            x = batch.to(DEVICE)  # (B, T, dim)
            recon, mu, logvar, z = model(x)
            recon_loss = F.mse_loss(recon, x)
            kl = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
            loss = recon_loss + beta * kl
            opt.zero_grad(); loss.backward(); opt.step()
            total_loss += loss.item()
            total_recon += recon_loss.item()
            total_kl += kl.item()
        print(f"Epoch {epoch+1}: loss={total_loss/len(dataloader):.6f}, recon={total_recon/len(dataloader):.6f}, kl={total_kl/len(dataloader):.6f}")
    return model

def encode_all_z(model, dataloader):
    model.to(DEVICE)
    model.eval()
    zs = []
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Encoding all windows to z"):
            x = batch.to(DEVICE)
            mu, logvar = model.enc(x)
            zs.append(mu.cpu().numpy())
    return np.concatenate(zs, axis=0)  # (N, latent)

In [8]:
# ----------------------------
# Poem embeddings (sentence-transformers)
# ----------------------------
def build_poem_corpus_embeddings(poem_list, model_name=POEM_MODEL_NAME):
    sbert = SentenceTransformer(model_name)
    emb = sbert.encode(poem_list, convert_to_numpy=True, show_progress_bar=True)
    # normalize
    norms = np.linalg.norm(emb, axis=1, keepdims=True) + 1e-10
    emb = emb / norms
    return emb, sbert

In [9]:
# ----------------------------
# Projection heads and contrastive training
# ----------------------------
class ProjectionHead(nn.Module):
    def __init__(self, in_dim, out_dim=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, 128),
            nn.ReLU(),
            nn.Linear(128, out_dim)
        )
    def forward(self, x):
        return F.normalize(self.net(x), dim=-1)

def contrastive_train(proj_skel, proj_poem, z_loader, poem_emb, poem_list, epochs=PROJ_EPOCHS, lr=1e-3, pseudo_pairs=None):
    # z_loader yields (B, latent)
    proj_skel.to(DEVICE); proj_poem.to(DEVICE)
    opt = torch.optim.Adam(list(proj_skel.parameters()) + list(proj_poem.parameters()), lr=lr)
    # poem_emb: (M, d)
    poem_emb_t = torch.tensor(poem_emb, dtype=torch.float32).to(DEVICE)
    M = poem_emb.shape[0]
    temperature = 0.07
    for epoch in range(epochs):
        proj_skel.train(); proj_poem.train()
        total_loss = 0.0
        for batch_idx, z in enumerate(tqdm(z_loader, desc=f"Contrastive epoch {epoch+1}/{epochs}")):
            z = z.to(DEVICE)
            # sample same-size poem batch
            B = z.size(0)
            # Strategy: if pseudo_pairs provided (list of poem indices per z), use them; else random sample
            if pseudo_pairs is not None:
                # pseudo_pairs is dict mapping sample idx of the dataset to poem idx
                # Here, simple approach: we sample random poems as "positives" for each z based on pseudo_pairs; if not found choose random.
                pos_idxs = []
                for i in range(batch_idx*B, batch_idx*B + B):
                    pos = pseudo_pairs.get(i, random.randrange(M))
                    pos_idxs.append(pos)
                pos_tensor = poem_emb_t[pos_idxs]  # (B, d)
            else:
                # random positive selection (weak) - we treat any poem as candidate positive to force distributional alignment
                rand_idx = np.random.choice(M, size=B, replace=False if M>=B else True)
                pos_tensor = poem_emb_t[rand_idx]
            # project
            z_proj = proj_skel(z)  # (B, out)
            e_proj = proj_poem(pos_tensor)  # (B, out)
            # compute similarity matrix with negatives (z_proj vs full poem_emb projected)
            # project full poem_emb too
            poem_proj_all = proj_poem(poem_emb_t)  # (M, out)
            # logits: z_proj @ poem_proj_all.T
            logits = z_proj @ poem_proj_all.T  # (B, M)
            # create labels: each row's positive index is pos_idxs
            if pseudo_pairs is not None:
                labels = torch.tensor(pos_idxs, dtype=torch.long, device=DEVICE)
            else:
                # we used rand_idx - covert to tensor
                labels = torch.tensor(rand_idx, dtype=torch.long, device=DEVICE)
            loss = F.cross_entropy(logits / temperature, labels)
            opt.zero_grad(); loss.backward(); opt.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}: contrastive_loss={total_loss/(batch_idx+1):.6f}")
    return proj_skel, proj_poem

In [10]:
# ----------------------------
# Retrieval (FAISS or fallback)
# ----------------------------
def build_faiss_index(poem_emb):
    # poem_emb normalized
    d = poem_emb.shape[1]
    if _HAS_FAISS:
        index = faiss.IndexFlatIP(d)
        index.add(poem_emb.astype('float32'))
        return index
    else:
        # fallback to simple numpy search
        return None

def retrieve_poems_by_vector(vectors, poem_emb, topk=5, index=None):
    # vectors: (B, d) numpy normalized
    if index is not None and _HAS_FAISS:
        faiss.normalize_L2(vectors)
        D, I = index.search(vectors.astype('float32'), topk)
        return I, D
    else:
        # brute-force cosine similarity
        poem = poem_emb  # (M,d)
        sims = vectors @ poem.T
        idxs = np.argsort(-sims, axis=1)[:, :topk]
        vals = np.take_along_axis(sims, idxs, axis=1)
        return idxs, vals

In [11]:
# ----------------------------
# Utility: simple heuristic pseudo-pair (optional warm start)
# A few rules: hand high -> candidate poems indices for "hand" poems, rotation -> rotation poems, etc.
# This creates a mapping from window index -> poem index (pseudo label)
# ----------------------------
def simple_heuristic_pseudo_pairs(windows, poem_list):
    """
    windows: numpy (N, window, dim) - here dim likely 198 if velocity used
    poem_list: list of poem strings - must contain keywords we match
    returns: dict mapping sample_idx -> poem_idx
    """
    mapping = {}
    N = windows.shape[0]
    lower = [p for p in poem_list if '手' in p or '手舞' in p]
    foot = [p for p in poem_list if '足' in p or '踏' in p]
    spin = [p for p in poem_list if '風' in p or '旋' in p or '行' in p]
    calm = [p for p in poem_list if '雲' in p or '細' in p or '悠' in p]
    idx_map = { 'hand': [poem_list.index(x) for x in lower] if lower else [],
                'foot': [poem_list.index(x) for x in foot] if foot else [],
                'spin': [poem_list.index(x) for x in spin] if spin else [],
                'calm': [poem_list.index(x) for x in calm] if calm else [] }
    for i in range(N):
        w = windows[i]  # (window, dim)
        # approximate: check mean height of hand joints -> hand likely high
        # we assumed original flat was [x0,y0,z0,...], velocity appended -> dim maybe 198, so find y indices
        dim = w.shape[1]
        # find y positions indices heuristically: y occur every 3rd starting at 1: idx 1,4,7...
        y_idxs = list(range(1, min(99, dim), 3))
        # compute mean y of typical hand joint idx in flattened frame (hand indices might vary, we use approximate)
        # safe heuristic: use top 33 joints' y mean
        mean_y = w[:, y_idxs].mean()
        # variance -> large movement
        var = w.var()
        # pick mapping
        if mean_y > 0.2:
            pool = idx_map['hand']
        elif var > 1.0:
            pool = idx_map['spin'] if idx_map['spin'] else idx_map['foot']
        else:
            pool = idx_map['calm']
        if pool:
            mapping[i] = random.choice(pool)
    return mapping

In [12]:
# ----------------------------
# Compose chapter from retrieved poem indices
# ----------------------------
def compose_chapters_from_indices(idx_seq, poem_list, per_chapter=4, repeat_rule=True):
    # idx_seq: list or array of poem indices (N)
    chapters = []
    for i in range(0, len(idx_seq), per_chapter):
        block = idx_seq[i:i+per_chapter].tolist()
        # if block shorter than per_chapter, pad by repeating last
        while len(block) < per_chapter:
            block.append(block[-1] if len(block)>0 else 0)
        lines = [poem_list[idx] for idx in block]
        # apply repeat rule: if consecutive identical -> duplicate line to form 重章
        if repeat_rule:
            lines2 = []
            for j in range(len(lines)):
                lines2.append(lines[j])
                if j>0 and lines[j]==lines[j-1]:
                    lines2.append(lines[j])  # repeat
            lines = lines2
        chapters.append(lines)
    return chapters

In [13]:
# ----------------------------
# Main flow as functions
# ----------------------------
def run_pipeline(csv_folder, poem_list, out_dir="outputs", use_pseudo_pairs=True):
    os.makedirs(out_dir, exist_ok=True)
    csv_paths = glob.glob(os.path.join(csv_folder, "*.csv"))
    print(f"Found {len(csv_paths)} csv files.")
    # Step 1: windows
    windows = create_windows_from_files(csv_paths, window=WINDOW, step=STEP)
    print("Total windows:", windows.shape)
    # Save small sample
    np.save(os.path.join(out_dir, "windows.npy"), windows)
    # Step 2: optionally reduce dim with PCA (helps training speed)
    N, W, D = windows.shape
    windows_reshaped = windows.reshape(N*W, D)
    pca = PCA(n_components=min(64, D))
    windows_pca = pca.fit_transform(windows_reshaped)
    windows_pca = windows_pca.reshape(N, W, -1)
    input_dim = windows_pca.shape[-1]
    print("PCA reduced dim:", input_dim)
    # dataloader
    ds = SkeletonWindowDataset(windows_pca)
    dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
    # Step 3: train VAE
    vae = VAESeq(input_dim=input_dim, hidden=HIDDEN, latent=LATENT_DIM, seq_len=WINDOW)
    print("Training VAE...")
    train_vae(vae, dl, epochs=VAE_EPOCHS, lr=LR, beta=1.0)
    # encode all windows to z
    full_dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False)
    z_all = encode_all_z(vae, full_dl)  # (N, latent)
    np.save(os.path.join(out_dir, "z_all.npy"), z_all)
    # Step 4: poem embeddings
    poem_emb, sbert = build_poem_corpus_embeddings(poem_list)
    np.save(os.path.join(out_dir, "poem_emb.npy"), poem_emb)
    # Step 5: projection heads + contrastive
    proj_skel = ProjectionHead(LATENT_DIM, out_dim=64)
    proj_poem = ProjectionHead(poem_emb.shape[1], out_dim=64)
    # create dataloader for z
    z_ds = torch.tensor(z_all, dtype=torch.float32)
    z_dl_train = DataLoader(z_ds, batch_size=BATCH_SIZE, shuffle=True)
    pseudo_pairs = None
    if use_pseudo_pairs:
        print("Creating heuristic pseudo pairs...")
        pseudo_pairs = simple_heuristic_pseudo_pairs(windows, poem_list)
        print("Pseudo pairs count:", len(pseudo_pairs))
    print("Training projection heads (contrastive)...")
    proj_skel, proj_poem = contrastive_train(proj_skel, proj_poem, z_dl_train, poem_emb, poem_list, epochs=PROJ_EPOCHS, lr=1e-3, pseudo_pairs=pseudo_pairs)
    # Step 6: build index for retrieval
    # compute poem projection for indexing
    poem_emb_t = torch.tensor(poem_emb, dtype=torch.float32)
    with torch.no_grad():
        poem_proj = proj_poem(poem_emb_t.to(DEVICE)).cpu().numpy().astype('float32')
    # normalize
    poem_proj = poem_proj / (np.linalg.norm(poem_proj, axis=1, keepdims=True) + 1e-10)
    index = build_faiss_index(poem_proj)
    # Step 7: example retrieval on dataset z_all
    with torch.no_grad():
        z_t = torch.tensor(z_all, dtype=torch.float32).to(DEVICE)
        z_proj_all = proj_skel(z_t).cpu().numpy().astype('float32')
    # normalize
    z_proj_all = z_proj_all / (np.linalg.norm(z_proj_all, axis=1, keepdims=True) + 1e-10)
    idxs, vals = retrieve_poems_by_vector(z_proj_all, poem_proj, topk=3, index=index)
    # pick first poem choice per window
    chosen = idxs[:,0]
    chapters = compose_chapters_from_indices(chosen, poem_list, per_chapter=4, repeat_rule=True)
    # save results
    out_txt = os.path.join(out_dir, "generated_chapters.txt")
    with open(out_txt, "w", encoding="utf-8") as f:
        for i, chap in enumerate(chapters):
            f.write(f"章 {i+1}:\n")
            for line in chap:
                f.write(line + "\n")
            f.write("\n")
    print("Saved generated chapters to", out_txt)
    return {
        "vae": vae,
        "proj_skel": proj_skel,
        "proj_poem": proj_poem,
        "poem_emb": poem_emb,
        "sbert": sbert,
        "index": index,
        "chapters": chapters
    }

In [14]:
# ----------------------------
# Example poem list (請用你自己的詩庫擴充)
# ----------------------------
DEFAULT_POEMS = [
    "手舞風翔", "手擎明月", "風行林間", "雲翻細波", "足踏流水",
    "踏雲而行", "身躍山河", "躍浪飛雲", "手揮秋水", "揚波微瀾",
    "心念君子", "月出東山", "波瀾不驚", "雲行雨落", "霓虹閃爍"
]

In [15]:
import csv

In [16]:
# ----------------------------
# If run as script
# ----------------------------
if __name__ == "__main__":
    import types
    args = types.SimpleNamespace(
        csv_folder=r"C:\Users\user\anaconda_projects\codelab\dance_csv",
        out_dir=r"C:\Users\user\anaconda_projects\codelab\outputs",
        use_pseudo=True,
        poem_file=None
    )

    if args.poem_file and os.path.exists(args.poem_file):
        with open(args.poem_file, "r", encoding="utf-8") as f:
            poem_list = [line.strip() for line in f if line.strip()]
    else:
        poem_list = DEFAULT_POEMS

    res = run_pipeline(args.csv_folder, poem_list, out_dir=args.out_dir, use_pseudo_pairs=args.use_pseudo)
    print("Pipeline finished. Example chapters:")
    for i, c in enumerate(res['chapters'][:5]):
        print(f"章{i+1}:")
        for line in c:
            print("  ", line)
        print()


Found 251 csv files.
Loaded C:\Users\user\anaconda_projects\codelab\dance_csv\Ballet_1.csv → shape (241, 33, 3)
Loaded C:\Users\user\anaconda_projects\codelab\dance_csv\Ballet_10.csv → shape (70, 33, 3)
Loaded C:\Users\user\anaconda_projects\codelab\dance_csv\Ballet_11.csv → shape (832, 33, 3)
Loaded C:\Users\user\anaconda_projects\codelab\dance_csv\Ballet_12.csv → shape (9, 33, 3)
Loaded C:\Users\user\anaconda_projects\codelab\dance_csv\Ballet_13.csv → shape (18, 33, 3)
Loaded C:\Users\user\anaconda_projects\codelab\dance_csv\Ballet_14.csv → shape (1167, 33, 3)
Loaded C:\Users\user\anaconda_projects\codelab\dance_csv\Ballet_15.csv → shape (3, 33, 3)
Loaded C:\Users\user\anaconda_projects\codelab\dance_csv\Ballet_16.csv → shape (48, 33, 3)
Loaded C:\Users\user\anaconda_projects\codelab\dance_csv\Ballet_17.csv → shape (481, 33, 3)
Loaded C:\Users\user\anaconda_projects\codelab\dance_csv\Ballet_18.csv → shape (61, 33, 3)
Loaded C:\Users\user\anaconda_projects\codelab\dance_csv\Ballet_19.

VAE epoch 1/40: 100%|████████████████████████████████████████████████████████████████| 518/518 [00:14<00:00, 35.53it/s]


Epoch 1: loss=2.917431, recon=2.900261, kl=0.017170


VAE epoch 2/40: 100%|████████████████████████████████████████████████████████████████| 518/518 [00:14<00:00, 36.25it/s]


Epoch 2: loss=2.915561, recon=2.902065, kl=0.013496


VAE epoch 3/40: 100%|████████████████████████████████████████████████████████████████| 518/518 [00:14<00:00, 35.17it/s]


Epoch 3: loss=2.881481, recon=2.866378, kl=0.015103


VAE epoch 4/40: 100%|████████████████████████████████████████████████████████████████| 518/518 [00:14<00:00, 36.28it/s]


Epoch 4: loss=2.859492, recon=2.838108, kl=0.021384


VAE epoch 5/40: 100%|████████████████████████████████████████████████████████████████| 518/518 [00:13<00:00, 37.15it/s]


Epoch 5: loss=2.839634, recon=2.810887, kl=0.028747


VAE epoch 6/40: 100%|████████████████████████████████████████████████████████████████| 518/518 [00:14<00:00, 36.28it/s]


Epoch 6: loss=2.800184, recon=2.776034, kl=0.024151


VAE epoch 7/40: 100%|████████████████████████████████████████████████████████████████| 518/518 [00:14<00:00, 35.32it/s]


Epoch 7: loss=2.777922, recon=2.751874, kl=0.026048


VAE epoch 8/40: 100%|████████████████████████████████████████████████████████████████| 518/518 [00:14<00:00, 36.26it/s]


Epoch 8: loss=2.750128, recon=2.720310, kl=0.029818


VAE epoch 9/40: 100%|████████████████████████████████████████████████████████████████| 518/518 [00:14<00:00, 35.91it/s]


Epoch 9: loss=2.700275, recon=2.664807, kl=0.035468


VAE epoch 10/40: 100%|███████████████████████████████████████████████████████████████| 518/518 [00:14<00:00, 35.91it/s]


Epoch 10: loss=2.671767, recon=2.633208, kl=0.038559


VAE epoch 11/40: 100%|███████████████████████████████████████████████████████████████| 518/518 [00:14<00:00, 35.99it/s]


Epoch 11: loss=2.646609, recon=2.604835, kl=0.041773


VAE epoch 12/40: 100%|███████████████████████████████████████████████████████████████| 518/518 [00:14<00:00, 35.22it/s]


Epoch 12: loss=2.629034, recon=2.584681, kl=0.044353


VAE epoch 13/40: 100%|███████████████████████████████████████████████████████████████| 518/518 [00:14<00:00, 35.96it/s]


Epoch 13: loss=2.607290, recon=2.557989, kl=0.049300


VAE epoch 14/40: 100%|███████████████████████████████████████████████████████████████| 518/518 [00:14<00:00, 36.17it/s]


Epoch 14: loss=2.580631, recon=2.526583, kl=0.054048


VAE epoch 15/40: 100%|███████████████████████████████████████████████████████████████| 518/518 [00:14<00:00, 35.25it/s]


Epoch 15: loss=2.546998, recon=2.488635, kl=0.058363


VAE epoch 16/40: 100%|███████████████████████████████████████████████████████████████| 518/518 [00:14<00:00, 35.50it/s]


Epoch 16: loss=2.523890, recon=2.463969, kl=0.059920


VAE epoch 17/40: 100%|███████████████████████████████████████████████████████████████| 518/518 [00:14<00:00, 35.28it/s]


Epoch 17: loss=2.501090, recon=2.439919, kl=0.061171


VAE epoch 18/40: 100%|███████████████████████████████████████████████████████████████| 518/518 [00:14<00:00, 35.34it/s]


Epoch 18: loss=2.477388, recon=2.415656, kl=0.061732


VAE epoch 19/40: 100%|███████████████████████████████████████████████████████████████| 518/518 [00:14<00:00, 35.36it/s]


Epoch 19: loss=2.458771, recon=2.396975, kl=0.061796


VAE epoch 20/40: 100%|███████████████████████████████████████████████████████████████| 518/518 [00:14<00:00, 35.19it/s]


Epoch 20: loss=2.442786, recon=2.380891, kl=0.061895


VAE epoch 21/40: 100%|███████████████████████████████████████████████████████████████| 518/518 [00:14<00:00, 35.36it/s]


Epoch 21: loss=2.427774, recon=2.365789, kl=0.061985


VAE epoch 22/40: 100%|███████████████████████████████████████████████████████████████| 518/518 [00:14<00:00, 34.69it/s]


Epoch 22: loss=2.412202, recon=2.351271, kl=0.060931


VAE epoch 23/40: 100%|███████████████████████████████████████████████████████████████| 518/518 [00:13<00:00, 37.02it/s]


Epoch 23: loss=2.397276, recon=2.336437, kl=0.060839


VAE epoch 24/40: 100%|███████████████████████████████████████████████████████████████| 518/518 [00:14<00:00, 35.58it/s]


Epoch 24: loss=2.382177, recon=2.320778, kl=0.061399


VAE epoch 25/40: 100%|███████████████████████████████████████████████████████████████| 518/518 [00:14<00:00, 35.63it/s]


Epoch 25: loss=2.367975, recon=2.305387, kl=0.062588


VAE epoch 26/40: 100%|███████████████████████████████████████████████████████████████| 518/518 [00:14<00:00, 35.96it/s]


Epoch 26: loss=2.354235, recon=2.290677, kl=0.063558


VAE epoch 27/40: 100%|███████████████████████████████████████████████████████████████| 518/518 [00:14<00:00, 35.12it/s]


Epoch 27: loss=2.343374, recon=2.277996, kl=0.065378


VAE epoch 28/40: 100%|███████████████████████████████████████████████████████████████| 518/518 [00:15<00:00, 34.53it/s]


Epoch 28: loss=2.330316, recon=2.263130, kl=0.067186


VAE epoch 29/40: 100%|███████████████████████████████████████████████████████████████| 518/518 [00:14<00:00, 35.62it/s]


Epoch 29: loss=2.321166, recon=2.252280, kl=0.068886


VAE epoch 30/40: 100%|███████████████████████████████████████████████████████████████| 518/518 [00:14<00:00, 34.86it/s]


Epoch 30: loss=2.311038, recon=2.240899, kl=0.070139


VAE epoch 31/40: 100%|███████████████████████████████████████████████████████████████| 518/518 [00:14<00:00, 35.92it/s]


Epoch 31: loss=2.305841, recon=2.234353, kl=0.071487


VAE epoch 32/40: 100%|███████████████████████████████████████████████████████████████| 518/518 [00:14<00:00, 35.03it/s]


Epoch 32: loss=2.296987, recon=2.224751, kl=0.072236


VAE epoch 33/40: 100%|███████████████████████████████████████████████████████████████| 518/518 [00:14<00:00, 35.42it/s]


Epoch 33: loss=2.292430, recon=2.218912, kl=0.073518


VAE epoch 34/40: 100%|███████████████████████████████████████████████████████████████| 518/518 [00:14<00:00, 34.77it/s]


Epoch 34: loss=2.284864, recon=2.210609, kl=0.074255


VAE epoch 35/40: 100%|███████████████████████████████████████████████████████████████| 518/518 [00:14<00:00, 34.92it/s]


Epoch 35: loss=2.278342, recon=2.202966, kl=0.075375


VAE epoch 36/40: 100%|███████████████████████████████████████████████████████████████| 518/518 [00:14<00:00, 34.79it/s]


Epoch 36: loss=2.272650, recon=2.196778, kl=0.075872


VAE epoch 37/40: 100%|███████████████████████████████████████████████████████████████| 518/518 [00:14<00:00, 35.05it/s]


Epoch 37: loss=2.268915, recon=2.192232, kl=0.076683


VAE epoch 38/40: 100%|███████████████████████████████████████████████████████████████| 518/518 [00:14<00:00, 35.29it/s]


Epoch 38: loss=2.263576, recon=2.186176, kl=0.077400


VAE epoch 39/40: 100%|███████████████████████████████████████████████████████████████| 518/518 [00:14<00:00, 35.56it/s]


Epoch 39: loss=2.258508, recon=2.180900, kl=0.077608


VAE epoch 40/40: 100%|███████████████████████████████████████████████████████████████| 518/518 [00:14<00:00, 34.91it/s]


Epoch 40: loss=2.253534, recon=2.175319, kl=0.078215


Encoding all windows to z: 100%|████████████████████████████████████████████████████| 519/519 [00:01<00:00, 436.08it/s]


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

  return forward_call(*args, **kwargs)


Creating heuristic pseudo pairs...
Pseudo pairs count: 16591
Training projection heads (contrastive)...


Contrastive epoch 1/30: 100%|███████████████████████████████████████████████████████| 519/519 [00:03<00:00, 166.27it/s]


Epoch 1: contrastive_loss=2.210767


Contrastive epoch 2/30: 100%|███████████████████████████████████████████████████████| 519/519 [00:02<00:00, 188.91it/s]


Epoch 2: contrastive_loss=2.191796


Contrastive epoch 3/30: 100%|███████████████████████████████████████████████████████| 519/519 [00:03<00:00, 134.67it/s]


Epoch 3: contrastive_loss=2.189641


Contrastive epoch 4/30: 100%|███████████████████████████████████████████████████████| 519/519 [00:02<00:00, 191.98it/s]


Epoch 4: contrastive_loss=2.188840


Contrastive epoch 5/30: 100%|███████████████████████████████████████████████████████| 519/519 [00:02<00:00, 201.89it/s]


Epoch 5: contrastive_loss=2.188264


Contrastive epoch 6/30: 100%|███████████████████████████████████████████████████████| 519/519 [00:02<00:00, 192.26it/s]


Epoch 6: contrastive_loss=2.187526


Contrastive epoch 7/30: 100%|███████████████████████████████████████████████████████| 519/519 [00:02<00:00, 189.30it/s]


Epoch 7: contrastive_loss=2.187554


Contrastive epoch 8/30: 100%|███████████████████████████████████████████████████████| 519/519 [00:04<00:00, 109.78it/s]


Epoch 8: contrastive_loss=2.187181


Contrastive epoch 9/30: 100%|███████████████████████████████████████████████████████| 519/519 [00:03<00:00, 168.79it/s]


Epoch 9: contrastive_loss=2.186625


Contrastive epoch 10/30: 100%|██████████████████████████████████████████████████████| 519/519 [00:02<00:00, 175.11it/s]


Epoch 10: contrastive_loss=2.186677


Contrastive epoch 11/30: 100%|██████████████████████████████████████████████████████| 519/519 [00:02<00:00, 205.07it/s]


Epoch 11: contrastive_loss=2.186673


Contrastive epoch 12/30: 100%|██████████████████████████████████████████████████████| 519/519 [00:04<00:00, 124.07it/s]


Epoch 12: contrastive_loss=2.186725


Contrastive epoch 13/30: 100%|██████████████████████████████████████████████████████| 519/519 [00:04<00:00, 120.12it/s]


Epoch 13: contrastive_loss=2.187453


Contrastive epoch 14/30: 100%|██████████████████████████████████████████████████████| 519/519 [00:04<00:00, 110.01it/s]


Epoch 14: contrastive_loss=2.186139


Contrastive epoch 15/30: 100%|██████████████████████████████████████████████████████| 519/519 [00:03<00:00, 170.05it/s]


Epoch 15: contrastive_loss=2.186160


Contrastive epoch 16/30: 100%|██████████████████████████████████████████████████████| 519/519 [00:02<00:00, 214.82it/s]


Epoch 16: contrastive_loss=2.185924


Contrastive epoch 17/30: 100%|██████████████████████████████████████████████████████| 519/519 [00:03<00:00, 140.98it/s]


Epoch 17: contrastive_loss=2.185917


Contrastive epoch 18/30: 100%|██████████████████████████████████████████████████████| 519/519 [00:03<00:00, 169.38it/s]


Epoch 18: contrastive_loss=2.186197


Contrastive epoch 19/30: 100%|██████████████████████████████████████████████████████| 519/519 [00:02<00:00, 220.69it/s]


Epoch 19: contrastive_loss=2.186043


Contrastive epoch 20/30: 100%|██████████████████████████████████████████████████████| 519/519 [00:02<00:00, 212.04it/s]


Epoch 20: contrastive_loss=2.185769


Contrastive epoch 21/30: 100%|██████████████████████████████████████████████████████| 519/519 [00:02<00:00, 204.21it/s]


Epoch 21: contrastive_loss=2.185504


Contrastive epoch 22/30: 100%|██████████████████████████████████████████████████████| 519/519 [00:03<00:00, 152.92it/s]


Epoch 22: contrastive_loss=2.186121


Contrastive epoch 23/30: 100%|██████████████████████████████████████████████████████| 519/519 [00:04<00:00, 119.67it/s]


Epoch 23: contrastive_loss=2.185616


Contrastive epoch 24/30: 100%|██████████████████████████████████████████████████████| 519/519 [00:04<00:00, 109.60it/s]


Epoch 24: contrastive_loss=2.185782


Contrastive epoch 25/30: 100%|██████████████████████████████████████████████████████| 519/519 [00:04<00:00, 120.05it/s]


Epoch 25: contrastive_loss=2.186202


Contrastive epoch 26/30: 100%|██████████████████████████████████████████████████████| 519/519 [00:04<00:00, 108.15it/s]


Epoch 26: contrastive_loss=2.185682


Contrastive epoch 27/30: 100%|██████████████████████████████████████████████████████| 519/519 [00:03<00:00, 161.52it/s]


Epoch 27: contrastive_loss=2.185677


Contrastive epoch 28/30: 100%|██████████████████████████████████████████████████████| 519/519 [00:04<00:00, 108.89it/s]


Epoch 28: contrastive_loss=2.186355


Contrastive epoch 29/30: 100%|██████████████████████████████████████████████████████| 519/519 [00:04<00:00, 121.22it/s]


Epoch 29: contrastive_loss=2.186301


Contrastive epoch 30/30: 100%|██████████████████████████████████████████████████████| 519/519 [00:04<00:00, 107.84it/s]


Epoch 30: contrastive_loss=2.185552
Saved generated chapters to C:\Users\user\anaconda_projects\codelab\outputs\generated_chapters.txt
Pipeline finished. Example chapters:
章1:
   手舞風翔
   手舞風翔
   手舞風翔
   手舞風翔
   手舞風翔
   手舞風翔
   手舞風翔

章2:
   手舞風翔
   手舞風翔
   手舞風翔
   手舞風翔
   手舞風翔
   手舞風翔
   手舞風翔

章3:
   手舞風翔
   手舞風翔
   手舞風翔
   手舞風翔
   手舞風翔
   手舞風翔
   手舞風翔

章4:
   手舞風翔
   手舞風翔
   手舞風翔
   雲行雨落
   手舞風翔

章5:
   手舞風翔
   手舞風翔
   手舞風翔
   手舞風翔
   手舞風翔
   手舞風翔
   手舞風翔

