In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import math
from transformers import CLIPModel, CLIPTokenizer
import sys
import os
import random
import numpy as np  

seed = 42  
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

sys.path.append(os.path.abspath('/home/user/dxc/motion/CLIP/'))
from motion_loader import get_dataset_loader  
from tqdm import tqdm
import yaml
from argparse import Namespace
from model import *

In [2]:

with open('config.yaml', 'r') as f:
    config = yaml.safe_load(f)

opt = Namespace(**config)


print(opt.batch_size)  
print(opt.lr)          
print(opt.device)      

32
0.0001
cuda


## load dataset

In [3]:
sys.path.append(os.path.abspath('/home/user/dxc/motion/StableMoFusion'))
train_loader = get_dataset_loader(
        opt,
        batch_size=opt.batch_size,
        split='train',
        mode='CMP'
    )
test_loader = get_dataset_loader(
    opt,
    batch_size=opt.batch_size,
    split='test',
    mode='CMP'
)


 Loading CMP mode HumanML3D dataset ...


  0%|          | 0/6960 [00:00<?, ?it/s]

Completing loading t2m dataset

 Loading CMP mode HumanML3D dataset ...


  0%|          | 0/1305 [00:00<?, ?it/s]

Completing loading t2m dataset


## load clip model for stage 1 training

In [4]:

clip_model = CLIPModel.from_pretrained(opt.clip_model_name)
clip_tokenizer = CLIPTokenizer.from_pretrained(opt.clip_model_name)

# 初始阶段：冻结整个 CLIP 文本编码器（stage 1）
for name, param in clip_model.named_parameters():
    if "text_model" in name:
        param.requires_grad = False


motion_encoder = MotionEncoder(
    input_dim=opt.input_dim,
    embed_dim=opt.embed_dim,
    num_heads=8,
    num_layers=4,         
    dim_feedforward=2048,
    dropout=0.2,
    max_seq_length=opt.max_seq_length
)
model = ClipMotionAlignModel(
    clip_model=clip_model,
    motion_encoder=motion_encoder,
    temperature=0.07
).to(opt.device)

In [5]:
optimizer = optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=opt.lr,
        weight_decay=opt.weight_decay
    )

best_test_loss = float("inf")
no_improve_count = 0
max_no_improve = 3  # 连续3次验证无改进则早停

## Start training

In [6]:
for epoch in range(opt.num_epochs):
    
    if epoch + 1 == opt.pretrain_epochs + 1:
        
        for param in clip_model.text_model.encoder.layers[-1].parameters():
            param.requires_grad = True
        for param in clip_model.text_model.final_layer_norm.parameters():
            param.requires_grad = True
        
        optimizer = optim.AdamW(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=opt.lr_finetune,
            weight_decay=opt.weight_decay
        )
        print("Stage 2: Fine-tuning CLIP text encoder's last layer (and final_layer_norm) with lower lr.")

    model.train()
    total_loss = 0.0
    count = 0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{opt.num_epochs}")



    for step, batch_data in enumerate(pbar):
        caption, motion, m_length = batch_data

        
        caption = [c.lower() for c in caption]
        text_enc = clip_tokenizer(
            caption,
            padding=True,
            truncation=True,
            max_length=opt.max_length,
            return_tensors="pt"
        )
        input_ids = text_enc["input_ids"].to(opt.device)
        attention_mask = text_enc["attention_mask"].to(opt.device)

       
        if isinstance(motion, list):
            motion = torch.stack([torch.tensor(m, dtype=torch.float32) for m in motion], dim=0)
        else:
            motion = motion.float()
        motion = motion.to(opt.device)
        m_length = m_length.to(opt.device)

        
        motion_emb, text_emb = model(motion, m_length, input_ids, attention_mask)
        loss = clip_contrastive_loss(motion_emb, text_emb, model.logit_scale)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        count += 1
        pbar.set_postfix({"loss": f"{loss.item():.4f}"})

    avg_loss = total_loss / max(count, 1)
    print(f"Epoch [{epoch+1}/{opt.num_epochs}] - Train Average Loss: {avg_loss:.4f}")

   
    print(f"[Validate at epoch {epoch+1}] ...")
    test_loss = evaluate_model(model, test_loader, clip_tokenizer, opt, desc=f"Epoch_{epoch+1}_Test")
    model_path = f"clip_motion_align_epoch_CMP_{epoch+1}.pt"
    torch.save(model.state_dict(), model_path)
    print(f"Model saved: {model_path}")

   
    if test_loss < best_test_loss:
        best_test_loss = test_loss
        no_improve_count = 0
    else:
        no_improve_count += 1
        if no_improve_count >= max_no_improve:
            print("Early stopping triggered!")
            break





print("Training completed!")

Epoch 1/300: 100%|██████████| 177/177 [00:18<00:00,  9.46it/s, loss=2.2014]


Epoch [1/300] - Train Average Loss: 2.7224
[Validate at epoch 1] ...


Epoch_1_Test: 100%|██████████| 32/32 [00:02<00:00, 13.42it/s]


Epoch_1_Test Average Contrastive Loss: 2.2290
Epoch_1_Test M->T Retrieval (per 32 samples): R@1=0.295, R@2=0.474, R@3=0.592
Epoch_1_Test T->M Retrieval (per 32 samples): R@1=0.354, R@2=0.507, R@3=0.600
Model saved: clip_motion_align_epoch_CMP_1.pt


Epoch 2/300: 100%|██████████| 177/177 [00:18<00:00,  9.81it/s, loss=1.6996]


Epoch [2/300] - Train Average Loss: 1.9050
[Validate at epoch 2] ...


Epoch_2_Test: 100%|██████████| 32/32 [00:02<00:00, 15.72it/s]


Epoch_2_Test Average Contrastive Loss: 1.8109
Epoch_2_Test M->T Retrieval (per 32 samples): R@1=0.403, R@2=0.605, R@3=0.705
Epoch_2_Test T->M Retrieval (per 32 samples): R@1=0.424, R@2=0.624, R@3=0.721
Model saved: clip_motion_align_epoch_CMP_2.pt


Epoch 3/300: 100%|██████████| 177/177 [00:17<00:00,  9.84it/s, loss=1.5163]


Epoch [3/300] - Train Average Loss: 1.4865
[Validate at epoch 3] ...


Epoch_3_Test: 100%|██████████| 32/32 [00:02<00:00, 15.78it/s]


Epoch_3_Test Average Contrastive Loss: 1.4372
Epoch_3_Test M->T Retrieval (per 32 samples): R@1=0.541, R@2=0.732, R@3=0.805
Epoch_3_Test T->M Retrieval (per 32 samples): R@1=0.549, R@2=0.733, R@3=0.835
Model saved: clip_motion_align_epoch_CMP_3.pt


Epoch 4/300: 100%|██████████| 177/177 [00:18<00:00,  9.56it/s, loss=1.3489]


Epoch [4/300] - Train Average Loss: 1.2181
[Validate at epoch 4] ...


Epoch_4_Test: 100%|██████████| 32/32 [00:02<00:00, 15.56it/s]


Epoch_4_Test Average Contrastive Loss: 1.3728
Epoch_4_Test M->T Retrieval (per 32 samples): R@1=0.560, R@2=0.749, R@3=0.830
Epoch_4_Test T->M Retrieval (per 32 samples): R@1=0.574, R@2=0.752, R@3=0.845
Model saved: clip_motion_align_epoch_CMP_4.pt


Epoch 5/300: 100%|██████████| 177/177 [00:18<00:00,  9.53it/s, loss=1.0210]


Epoch [5/300] - Train Average Loss: 1.0197
[Validate at epoch 5] ...


Epoch_5_Test: 100%|██████████| 32/32 [00:02<00:00, 15.30it/s]


Epoch_5_Test Average Contrastive Loss: 1.2117
Epoch_5_Test M->T Retrieval (per 32 samples): R@1=0.608, R@2=0.782, R@3=0.856
Epoch_5_Test T->M Retrieval (per 32 samples): R@1=0.626, R@2=0.790, R@3=0.858
Model saved: clip_motion_align_epoch_CMP_5.pt
Stage 2: Fine-tuning CLIP text encoder's last layer (and final_layer_norm) with lower lr.


Epoch 6/300: 100%|██████████| 177/177 [00:19<00:00,  9.11it/s, loss=0.8055]


Epoch [6/300] - Train Average Loss: 0.7359
[Validate at epoch 6] ...


Epoch_6_Test: 100%|██████████| 32/32 [00:02<00:00, 15.57it/s]


Epoch_6_Test Average Contrastive Loss: 1.0161
Epoch_6_Test M->T Retrieval (per 32 samples): R@1=0.664, R@2=0.821, R@3=0.896
Epoch_6_Test T->M Retrieval (per 32 samples): R@1=0.665, R@2=0.842, R@3=0.901
Model saved: clip_motion_align_epoch_CMP_6.pt


Epoch 7/300: 100%|██████████| 177/177 [00:19<00:00,  9.11it/s, loss=0.6336]


Epoch [7/300] - Train Average Loss: 0.6391
[Validate at epoch 7] ...


Epoch_7_Test: 100%|██████████| 32/32 [00:02<00:00, 15.92it/s]


Epoch_7_Test Average Contrastive Loss: 1.0283
Epoch_7_Test M->T Retrieval (per 32 samples): R@1=0.656, R@2=0.830, R@3=0.901
Epoch_7_Test T->M Retrieval (per 32 samples): R@1=0.669, R@2=0.849, R@3=0.903
Model saved: clip_motion_align_epoch_CMP_7.pt


Epoch 8/300: 100%|██████████| 177/177 [00:19<00:00,  9.16it/s, loss=0.5799]


Epoch [8/300] - Train Average Loss: 0.6104
[Validate at epoch 8] ...


Epoch_8_Test: 100%|██████████| 32/32 [00:02<00:00, 15.57it/s]


Epoch_8_Test Average Contrastive Loss: 0.9902
Epoch_8_Test M->T Retrieval (per 32 samples): R@1=0.673, R@2=0.831, R@3=0.897
Epoch_8_Test T->M Retrieval (per 32 samples): R@1=0.668, R@2=0.830, R@3=0.902
Model saved: clip_motion_align_epoch_CMP_8.pt


Epoch 9/300: 100%|██████████| 177/177 [00:19<00:00,  9.02it/s, loss=0.5914]


Epoch [9/300] - Train Average Loss: 0.5607
[Validate at epoch 9] ...


Epoch_9_Test: 100%|██████████| 32/32 [00:02<00:00, 15.69it/s]


Epoch_9_Test Average Contrastive Loss: 0.9535
Epoch_9_Test M->T Retrieval (per 32 samples): R@1=0.689, R@2=0.843, R@3=0.908
Epoch_9_Test T->M Retrieval (per 32 samples): R@1=0.685, R@2=0.853, R@3=0.914
Model saved: clip_motion_align_epoch_CMP_9.pt


Epoch 10/300: 100%|██████████| 177/177 [00:19<00:00,  9.13it/s, loss=0.5823]


Epoch [10/300] - Train Average Loss: 0.5178
[Validate at epoch 10] ...


Epoch_10_Test: 100%|██████████| 32/32 [00:02<00:00, 15.56it/s]


Epoch_10_Test Average Contrastive Loss: 0.9004
Epoch_10_Test M->T Retrieval (per 32 samples): R@1=0.703, R@2=0.851, R@3=0.904
Epoch_10_Test T->M Retrieval (per 32 samples): R@1=0.713, R@2=0.857, R@3=0.920
Model saved: clip_motion_align_epoch_CMP_10.pt


Epoch 11/300: 100%|██████████| 177/177 [00:19<00:00,  9.07it/s, loss=0.4455]


Epoch [11/300] - Train Average Loss: 0.4967
[Validate at epoch 11] ...


Epoch_11_Test: 100%|██████████| 32/32 [00:02<00:00, 15.60it/s]


Epoch_11_Test Average Contrastive Loss: 0.8948
Epoch_11_Test M->T Retrieval (per 32 samples): R@1=0.703, R@2=0.861, R@3=0.906
Epoch_11_Test T->M Retrieval (per 32 samples): R@1=0.707, R@2=0.866, R@3=0.926
Model saved: clip_motion_align_epoch_CMP_11.pt


Epoch 12/300: 100%|██████████| 177/177 [00:19<00:00,  9.15it/s, loss=0.3894]


Epoch [12/300] - Train Average Loss: 0.4691
[Validate at epoch 12] ...


Epoch_12_Test: 100%|██████████| 32/32 [00:02<00:00, 15.59it/s]


Epoch_12_Test Average Contrastive Loss: 0.8508
Epoch_12_Test M->T Retrieval (per 32 samples): R@1=0.713, R@2=0.866, R@3=0.918
Epoch_12_Test T->M Retrieval (per 32 samples): R@1=0.714, R@2=0.871, R@3=0.926
Model saved: clip_motion_align_epoch_CMP_12.pt


Epoch 13/300: 100%|██████████| 177/177 [00:19<00:00,  9.16it/s, loss=0.6457]


Epoch [13/300] - Train Average Loss: 0.4353
[Validate at epoch 13] ...


Epoch_13_Test: 100%|██████████| 32/32 [00:02<00:00, 15.57it/s]


Epoch_13_Test Average Contrastive Loss: 0.8605
Epoch_13_Test M->T Retrieval (per 32 samples): R@1=0.724, R@2=0.874, R@3=0.921
Epoch_13_Test T->M Retrieval (per 32 samples): R@1=0.736, R@2=0.873, R@3=0.917
Model saved: clip_motion_align_epoch_CMP_13.pt


Epoch 14/300: 100%|██████████| 177/177 [00:19<00:00,  9.16it/s, loss=0.4258]


Epoch [14/300] - Train Average Loss: 0.4168
[Validate at epoch 14] ...


Epoch_14_Test: 100%|██████████| 32/32 [00:01<00:00, 16.39it/s]


Epoch_14_Test Average Contrastive Loss: 0.8106
Epoch_14_Test M->T Retrieval (per 32 samples): R@1=0.730, R@2=0.876, R@3=0.928
Epoch_14_Test T->M Retrieval (per 32 samples): R@1=0.723, R@2=0.871, R@3=0.931
Model saved: clip_motion_align_epoch_CMP_14.pt


Epoch 15/300: 100%|██████████| 177/177 [00:19<00:00,  9.29it/s, loss=0.3557]


Epoch [15/300] - Train Average Loss: 0.4022
[Validate at epoch 15] ...


Epoch_15_Test: 100%|██████████| 32/32 [00:02<00:00, 15.68it/s]


Epoch_15_Test Average Contrastive Loss: 0.8424
Epoch_15_Test M->T Retrieval (per 32 samples): R@1=0.724, R@2=0.867, R@3=0.922
Epoch_15_Test T->M Retrieval (per 32 samples): R@1=0.726, R@2=0.866, R@3=0.929
Model saved: clip_motion_align_epoch_CMP_15.pt


Epoch 16/300: 100%|██████████| 177/177 [00:19<00:00,  9.13it/s, loss=0.2716]


Epoch [16/300] - Train Average Loss: 0.3798
[Validate at epoch 16] ...


Epoch_16_Test: 100%|██████████| 32/32 [00:02<00:00, 15.81it/s]


Epoch_16_Test Average Contrastive Loss: 0.7899
Epoch_16_Test M->T Retrieval (per 32 samples): R@1=0.726, R@2=0.882, R@3=0.930
Epoch_16_Test T->M Retrieval (per 32 samples): R@1=0.754, R@2=0.883, R@3=0.930
Model saved: clip_motion_align_epoch_CMP_16.pt


Epoch 17/300: 100%|██████████| 177/177 [00:19<00:00,  9.18it/s, loss=0.2762]


Epoch [17/300] - Train Average Loss: 0.3740
[Validate at epoch 17] ...


Epoch_17_Test: 100%|██████████| 32/32 [00:02<00:00, 15.86it/s]


Epoch_17_Test Average Contrastive Loss: 0.8065
Epoch_17_Test M->T Retrieval (per 32 samples): R@1=0.740, R@2=0.884, R@3=0.926
Epoch_17_Test T->M Retrieval (per 32 samples): R@1=0.745, R@2=0.876, R@3=0.926
Model saved: clip_motion_align_epoch_CMP_17.pt


Epoch 18/300: 100%|██████████| 177/177 [00:19<00:00,  9.14it/s, loss=0.4987]


Epoch [18/300] - Train Average Loss: 0.3485
[Validate at epoch 18] ...


Epoch_18_Test: 100%|██████████| 32/32 [00:02<00:00, 15.91it/s]


Epoch_18_Test Average Contrastive Loss: 0.7774
Epoch_18_Test M->T Retrieval (per 32 samples): R@1=0.747, R@2=0.885, R@3=0.926
Epoch_18_Test T->M Retrieval (per 32 samples): R@1=0.747, R@2=0.885, R@3=0.932
Model saved: clip_motion_align_epoch_CMP_18.pt


Epoch 19/300: 100%|██████████| 177/177 [00:19<00:00,  9.15it/s, loss=0.3709]


Epoch [19/300] - Train Average Loss: 0.3459
[Validate at epoch 19] ...


Epoch_19_Test: 100%|██████████| 32/32 [00:02<00:00, 15.72it/s]


Epoch_19_Test Average Contrastive Loss: 0.7510
Epoch_19_Test M->T Retrieval (per 32 samples): R@1=0.750, R@2=0.889, R@3=0.936
Epoch_19_Test T->M Retrieval (per 32 samples): R@1=0.759, R@2=0.898, R@3=0.938
Model saved: clip_motion_align_epoch_CMP_19.pt


Epoch 20/300: 100%|██████████| 177/177 [00:19<00:00,  9.17it/s, loss=0.3626]


Epoch [20/300] - Train Average Loss: 0.3309
[Validate at epoch 20] ...


Epoch_20_Test: 100%|██████████| 32/32 [00:02<00:00, 15.90it/s]


Epoch_20_Test Average Contrastive Loss: 0.7458
Epoch_20_Test M->T Retrieval (per 32 samples): R@1=0.739, R@2=0.888, R@3=0.936
Epoch_20_Test T->M Retrieval (per 32 samples): R@1=0.740, R@2=0.892, R@3=0.938
Model saved: clip_motion_align_epoch_CMP_20.pt


Epoch 21/300: 100%|██████████| 177/177 [00:19<00:00,  9.17it/s, loss=0.2003]


Epoch [21/300] - Train Average Loss: 0.3204
[Validate at epoch 21] ...


Epoch_21_Test: 100%|██████████| 32/32 [00:01<00:00, 16.07it/s]


Epoch_21_Test Average Contrastive Loss: 0.7669
Epoch_21_Test M->T Retrieval (per 32 samples): R@1=0.747, R@2=0.879, R@3=0.940
Epoch_21_Test T->M Retrieval (per 32 samples): R@1=0.740, R@2=0.875, R@3=0.933
Model saved: clip_motion_align_epoch_CMP_21.pt


Epoch 22/300: 100%|██████████| 177/177 [00:19<00:00,  9.24it/s, loss=0.2798]


Epoch [22/300] - Train Average Loss: 0.3007
[Validate at epoch 22] ...


Epoch_22_Test: 100%|██████████| 32/32 [00:02<00:00, 15.80it/s]


Epoch_22_Test Average Contrastive Loss: 0.7133
Epoch_22_Test M->T Retrieval (per 32 samples): R@1=0.779, R@2=0.897, R@3=0.935
Epoch_22_Test T->M Retrieval (per 32 samples): R@1=0.768, R@2=0.903, R@3=0.941
Model saved: clip_motion_align_epoch_CMP_22.pt


Epoch 23/300: 100%|██████████| 177/177 [00:19<00:00,  9.19it/s, loss=0.2971]


Epoch [23/300] - Train Average Loss: 0.2940
[Validate at epoch 23] ...


Epoch_23_Test: 100%|██████████| 32/32 [00:02<00:00, 15.55it/s]


Epoch_23_Test Average Contrastive Loss: 0.7289
Epoch_23_Test M->T Retrieval (per 32 samples): R@1=0.759, R@2=0.899, R@3=0.937
Epoch_23_Test T->M Retrieval (per 32 samples): R@1=0.753, R@2=0.898, R@3=0.938
Model saved: clip_motion_align_epoch_CMP_23.pt


Epoch 24/300: 100%|██████████| 177/177 [00:18<00:00,  9.45it/s, loss=0.2312]


Epoch [24/300] - Train Average Loss: 0.2813
[Validate at epoch 24] ...


Epoch_24_Test: 100%|██████████| 32/32 [00:02<00:00, 15.94it/s]


Epoch_24_Test Average Contrastive Loss: 0.6954
Epoch_24_Test M->T Retrieval (per 32 samples): R@1=0.777, R@2=0.891, R@3=0.939
Epoch_24_Test T->M Retrieval (per 32 samples): R@1=0.780, R@2=0.899, R@3=0.945
Model saved: clip_motion_align_epoch_CMP_24.pt


Epoch 25/300: 100%|██████████| 177/177 [00:18<00:00,  9.71it/s, loss=0.1204]


Epoch [25/300] - Train Average Loss: 0.2795
[Validate at epoch 25] ...


Epoch_25_Test: 100%|██████████| 32/32 [00:02<00:00, 15.75it/s]


Epoch_25_Test Average Contrastive Loss: 0.6870
Epoch_25_Test M->T Retrieval (per 32 samples): R@1=0.775, R@2=0.891, R@3=0.946
Epoch_25_Test T->M Retrieval (per 32 samples): R@1=0.775, R@2=0.904, R@3=0.939
Model saved: clip_motion_align_epoch_CMP_25.pt


Epoch 26/300: 100%|██████████| 177/177 [00:19<00:00,  9.06it/s, loss=0.3347]


Epoch [26/300] - Train Average Loss: 0.2589
[Validate at epoch 26] ...


Epoch_26_Test: 100%|██████████| 32/32 [00:02<00:00, 15.76it/s]


Epoch_26_Test Average Contrastive Loss: 0.7497
Epoch_26_Test M->T Retrieval (per 32 samples): R@1=0.763, R@2=0.878, R@3=0.930
Epoch_26_Test T->M Retrieval (per 32 samples): R@1=0.751, R@2=0.889, R@3=0.945
Model saved: clip_motion_align_epoch_CMP_26.pt


Epoch 27/300: 100%|██████████| 177/177 [00:18<00:00,  9.36it/s, loss=0.2329]


Epoch [27/300] - Train Average Loss: 0.2525
[Validate at epoch 27] ...


Epoch_27_Test: 100%|██████████| 32/32 [00:01<00:00, 16.54it/s]


Epoch_27_Test Average Contrastive Loss: 0.7312
Epoch_27_Test M->T Retrieval (per 32 samples): R@1=0.754, R@2=0.885, R@3=0.943
Epoch_27_Test T->M Retrieval (per 32 samples): R@1=0.764, R@2=0.884, R@3=0.944
Model saved: clip_motion_align_epoch_CMP_27.pt


Epoch 28/300: 100%|██████████| 177/177 [00:18<00:00,  9.39it/s, loss=0.2670]


Epoch [28/300] - Train Average Loss: 0.2463
[Validate at epoch 28] ...


Epoch_28_Test: 100%|██████████| 32/32 [00:01<00:00, 16.12it/s]


Epoch_28_Test Average Contrastive Loss: 0.7094
Epoch_28_Test M->T Retrieval (per 32 samples): R@1=0.758, R@2=0.888, R@3=0.943
Epoch_28_Test T->M Retrieval (per 32 samples): R@1=0.758, R@2=0.891, R@3=0.942
Model saved: clip_motion_align_epoch_CMP_28.pt
Early stopping triggered!
Training completed!
