In [14]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import math
from transformers import CLIPModel, CLIPTokenizer
import sys
import os
sys.path.append(os.path.abspath('/home/user/dxc/motion/StableMoFusion/'))
from motion_loader import get_dataset_loader  
from tqdm import tqdm
import random
import yaml
from argparse import Namespace
from model import *


In [15]:

with open('config.yaml', 'r') as f:
    config = yaml.safe_load(f)

opt = Namespace(**config)


print(opt.batch_size)  
print(opt.lr)          
print(opt.device)      

32
0.0001
cuda


## load dataset

In [16]:
sys.path.append(os.path.abspath('/home/user/dxc/motion/StableMoFusion'))
train_loader = get_dataset_loader(
        opt,
        batch_size=opt.batch_size,
        split='train',
        mode='train'
    )
test_loader = get_dataset_loader(
    opt,
    batch_size=opt.batch_size,
    split='test',
    mode='gt_eval'
)


 Loading train mode HumanML3D dataset ...
11111111111111


  0%|          | 0/23384 [00:00<?, ?it/s]

Completing loading t2m dataset

 Loading gt_eval mode HumanML3D dataset ...
11111111111111


  0%|          | 0/4384 [00:00<?, ?it/s]

Completing loading t2m dataset


## load clip model for stage 1 training

In [17]:

clip_model = CLIPModel.from_pretrained(opt.clip_model_name)
clip_tokenizer = CLIPTokenizer.from_pretrained(opt.clip_model_name)

# 初始阶段：冻结整个 CLIP 文本编码器（stage 1）
for name, param in clip_model.named_parameters():
    if "text_model" in name:
        param.requires_grad = False


motion_encoder = MotionEncoder(
    input_dim=opt.input_dim,
    embed_dim=opt.embed_dim,
    num_heads=8,
    num_layers=4,         
    dim_feedforward=2048,
    dropout=0.2,
    max_seq_length=opt.max_seq_length
)
model = ClipMotionAlignModel(
    clip_model=clip_model,
    motion_encoder=motion_encoder,
    temperature=0.07
).to(opt.device)

In [18]:
optimizer = optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=opt.lr,
        weight_decay=opt.weight_decay
    )

best_test_loss = float("inf")
no_improve_count = 0
max_no_improve = 3  # 连续3次验证无改进则早停

## Start training

In [19]:
for epoch in range(opt.num_epochs):
    
    if epoch + 1 == opt.pretrain_epochs + 1:
        
        for param in clip_model.text_model.encoder.layers[-1].parameters():
            param.requires_grad = True
        for param in clip_model.text_model.final_layer_norm.parameters():
            param.requires_grad = True
        
        optimizer = optim.AdamW(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=opt.lr_finetune,
            weight_decay=opt.weight_decay
        )
        print("Stage 2: Fine-tuning CLIP text encoder's last layer (and final_layer_norm) with lower lr.")

    model.train()
    total_loss = 0.0
    count = 0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{opt.num_epochs}")



    for step, batch_data in enumerate(pbar):
        caption, motion, m_length = batch_data

        
        caption = [c.lower() for c in caption]
        text_enc = clip_tokenizer(
            caption,
            padding=True,
            truncation=True,
            max_length=opt.max_length,
            return_tensors="pt"
        )
        input_ids = text_enc["input_ids"].to(opt.device)
        attention_mask = text_enc["attention_mask"].to(opt.device)

       
        if isinstance(motion, list):
            motion = torch.stack([torch.tensor(m, dtype=torch.float32) for m in motion], dim=0)
        else:
            motion = motion.float()
        motion = motion.to(opt.device)
        m_length = m_length.to(opt.device)

        
        motion_emb, text_emb = model(motion, m_length, input_ids, attention_mask)
        loss = clip_contrastive_loss(motion_emb, text_emb, model.logit_scale)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        count += 1
        pbar.set_postfix({"loss": f"{loss.item():.4f}"})

    avg_loss = total_loss / max(count, 1)
    print(f"Epoch [{epoch+1}/{opt.num_epochs}] - Train Average Loss: {avg_loss:.4f}")

   
    print(f"[Validate at epoch {epoch+1}] ...")
    test_loss = evaluate_model(model, test_loader, clip_tokenizer, opt, desc=f"Epoch_{epoch+1}_Test")
    model_path = f"clip_motion_align_epoch_{epoch+1}.pt"
    torch.save(model.state_dict(), model_path)
    print(f"Model saved: {model_path}")

   
    if test_loss < best_test_loss:
        best_test_loss = test_loss
        no_improve_count = 0
    else:
        no_improve_count += 1
        if no_improve_count >= max_no_improve:
            print("Early stopping triggered!")
            break





print("Training completed!")

Epoch 1/300:   0%|          | 0/767 [00:00<?, ?it/s]

Epoch 1/300: 100%|██████████| 767/767 [01:47<00:00,  7.11it/s, loss=1.3735]


Epoch [1/300] - Train Average Loss: 1.8478
[Validate at epoch 1] ...


Epoch_1_Test: 100%|██████████| 145/145 [00:13<00:00, 10.87it/s]


Epoch_1_Test Average Contrastive Loss: 1.5263
Epoch_1_Test M->T Retrieval (per 32 samples): R@1=0.502, R@2=0.681, R@3=0.780
Epoch_1_Test T->M Retrieval (per 32 samples): R@1=0.526, R@2=0.714, R@3=0.798
Model saved: clip_motion_align_epoch_1.pt


Epoch 2/300: 100%|██████████| 767/767 [01:40<00:00,  7.64it/s, loss=1.1631]


Epoch [2/300] - Train Average Loss: 1.3659
[Validate at epoch 2] ...


Epoch_2_Test: 100%|██████████| 145/145 [00:10<00:00, 14.14it/s]


Epoch_2_Test Average Contrastive Loss: 1.3374
Epoch_2_Test M->T Retrieval (per 32 samples): R@1=0.553, R@2=0.739, R@3=0.827
Epoch_2_Test T->M Retrieval (per 32 samples): R@1=0.583, R@2=0.756, R@3=0.834
Model saved: clip_motion_align_epoch_2.pt


Epoch 3/300: 100%|██████████| 767/767 [01:42<00:00,  7.51it/s, loss=0.9696]


Epoch [3/300] - Train Average Loss: 1.1982
[Validate at epoch 3] ...


Epoch_3_Test: 100%|██████████| 145/145 [00:10<00:00, 14.40it/s]


Epoch_3_Test Average Contrastive Loss: 1.2602
Epoch_3_Test M->T Retrieval (per 32 samples): R@1=0.578, R@2=0.764, R@3=0.839
Epoch_3_Test T->M Retrieval (per 32 samples): R@1=0.599, R@2=0.772, R@3=0.847
Model saved: clip_motion_align_epoch_3.pt


Epoch 4/300: 100%|██████████| 767/767 [02:06<00:00,  6.04it/s, loss=0.9263]


Epoch [4/300] - Train Average Loss: 1.0762
[Validate at epoch 4] ...


Epoch_4_Test: 100%|██████████| 145/145 [00:10<00:00, 14.01it/s]


Epoch_4_Test Average Contrastive Loss: 1.2273
Epoch_4_Test M->T Retrieval (per 32 samples): R@1=0.598, R@2=0.784, R@3=0.861
Epoch_4_Test T->M Retrieval (per 32 samples): R@1=0.615, R@2=0.787, R@3=0.865
Model saved: clip_motion_align_epoch_4.pt


Epoch 5/300: 100%|██████████| 767/767 [01:37<00:00,  7.90it/s, loss=1.2109]


Epoch [5/300] - Train Average Loss: 1.0004
[Validate at epoch 5] ...


Epoch_5_Test: 100%|██████████| 145/145 [00:12<00:00, 11.85it/s]


Epoch_5_Test Average Contrastive Loss: 1.1726
Epoch_5_Test M->T Retrieval (per 32 samples): R@1=0.619, R@2=0.791, R@3=0.864
Epoch_5_Test T->M Retrieval (per 32 samples): R@1=0.630, R@2=0.801, R@3=0.866
Model saved: clip_motion_align_epoch_5.pt
Stage 2: Fine-tuning CLIP text encoder's last layer (and final_layer_norm) with lower lr.


Epoch 6/300: 100%|██████████| 767/767 [01:44<00:00,  7.37it/s, loss=0.6682]


Epoch [6/300] - Train Average Loss: 0.8040
[Validate at epoch 6] ...


Epoch_6_Test: 100%|██████████| 145/145 [00:10<00:00, 13.95it/s]


Epoch_6_Test Average Contrastive Loss: 1.0273
Epoch_6_Test M->T Retrieval (per 32 samples): R@1=0.659, R@2=0.825, R@3=0.895
Epoch_6_Test T->M Retrieval (per 32 samples): R@1=0.664, R@2=0.828, R@3=0.894
Model saved: clip_motion_align_epoch_6.pt


Epoch 7/300: 100%|██████████| 767/767 [01:44<00:00,  7.34it/s, loss=1.1500]


Epoch [7/300] - Train Average Loss: 0.7119
[Validate at epoch 7] ...


Epoch_7_Test: 100%|██████████| 145/145 [00:20<00:00,  7.21it/s]


Epoch_7_Test Average Contrastive Loss: 1.0261
Epoch_7_Test M->T Retrieval (per 32 samples): R@1=0.664, R@2=0.828, R@3=0.895
Epoch_7_Test T->M Retrieval (per 32 samples): R@1=0.674, R@2=0.836, R@3=0.896
Model saved: clip_motion_align_epoch_7.pt


Epoch 8/300: 100%|██████████| 767/767 [01:51<00:00,  6.90it/s, loss=0.6629]


Epoch [8/300] - Train Average Loss: 0.6702
[Validate at epoch 8] ...


Epoch_8_Test: 100%|██████████| 145/145 [00:10<00:00, 13.30it/s]


Epoch_8_Test Average Contrastive Loss: 1.0127
Epoch_8_Test M->T Retrieval (per 32 samples): R@1=0.669, R@2=0.826, R@3=0.896
Epoch_8_Test T->M Retrieval (per 32 samples): R@1=0.680, R@2=0.836, R@3=0.899
Model saved: clip_motion_align_epoch_8.pt


Epoch 9/300: 100%|██████████| 767/767 [01:42<00:00,  7.48it/s, loss=0.3763]


Epoch [9/300] - Train Average Loss: 0.6254
[Validate at epoch 9] ...


Epoch_9_Test: 100%|██████████| 145/145 [00:10<00:00, 13.39it/s]


Epoch_9_Test Average Contrastive Loss: 0.9889
Epoch_9_Test M->T Retrieval (per 32 samples): R@1=0.676, R@2=0.837, R@3=0.895
Epoch_9_Test T->M Retrieval (per 32 samples): R@1=0.691, R@2=0.850, R@3=0.901
Model saved: clip_motion_align_epoch_9.pt


Epoch 10/300: 100%|██████████| 767/767 [01:47<00:00,  7.14it/s, loss=0.5926]


Epoch [10/300] - Train Average Loss: 0.5808
[Validate at epoch 10] ...


Epoch_10_Test: 100%|██████████| 145/145 [00:10<00:00, 14.48it/s]


Epoch_10_Test Average Contrastive Loss: 0.9776
Epoch_10_Test M->T Retrieval (per 32 samples): R@1=0.675, R@2=0.843, R@3=0.902
Epoch_10_Test T->M Retrieval (per 32 samples): R@1=0.686, R@2=0.844, R@3=0.905
Model saved: clip_motion_align_epoch_10.pt


Epoch 11/300: 100%|██████████| 767/767 [01:48<00:00,  7.06it/s, loss=0.5308]


Epoch [11/300] - Train Average Loss: 0.5597
[Validate at epoch 11] ...


Epoch_11_Test: 100%|██████████| 145/145 [00:13<00:00, 10.95it/s]


Epoch_11_Test Average Contrastive Loss: 0.9677
Epoch_11_Test M->T Retrieval (per 32 samples): R@1=0.686, R@2=0.849, R@3=0.906
Epoch_11_Test T->M Retrieval (per 32 samples): R@1=0.684, R@2=0.850, R@3=0.908
Model saved: clip_motion_align_epoch_11.pt


Epoch 12/300: 100%|██████████| 767/767 [01:41<00:00,  7.57it/s, loss=0.6088]


Epoch [12/300] - Train Average Loss: 0.5258
[Validate at epoch 12] ...


Epoch_12_Test: 100%|██████████| 145/145 [00:10<00:00, 14.45it/s]


Epoch_12_Test Average Contrastive Loss: 0.9872
Epoch_12_Test M->T Retrieval (per 32 samples): R@1=0.685, R@2=0.842, R@3=0.900
Epoch_12_Test T->M Retrieval (per 32 samples): R@1=0.693, R@2=0.845, R@3=0.905
Model saved: clip_motion_align_epoch_12.pt


Epoch 13/300: 100%|██████████| 767/767 [01:40<00:00,  7.65it/s, loss=0.5552]


Epoch [13/300] - Train Average Loss: 0.4936
[Validate at epoch 13] ...


Epoch_13_Test: 100%|██████████| 145/145 [00:09<00:00, 14.53it/s]


Epoch_13_Test Average Contrastive Loss: 0.9213
Epoch_13_Test M->T Retrieval (per 32 samples): R@1=0.693, R@2=0.858, R@3=0.913
Epoch_13_Test T->M Retrieval (per 32 samples): R@1=0.708, R@2=0.859, R@3=0.912
Model saved: clip_motion_align_epoch_13.pt


Epoch 14/300: 100%|██████████| 767/767 [01:47<00:00,  7.13it/s, loss=0.5086]


Epoch [14/300] - Train Average Loss: 0.4660
[Validate at epoch 14] ...


Epoch_14_Test: 100%|██████████| 145/145 [00:10<00:00, 13.64it/s]


Epoch_14_Test Average Contrastive Loss: 0.9463
Epoch_14_Test M->T Retrieval (per 32 samples): R@1=0.690, R@2=0.851, R@3=0.908
Epoch_14_Test T->M Retrieval (per 32 samples): R@1=0.707, R@2=0.852, R@3=0.908
Model saved: clip_motion_align_epoch_14.pt


Epoch 15/300: 100%|██████████| 767/767 [01:38<00:00,  7.76it/s, loss=0.4169]


Epoch [15/300] - Train Average Loss: 0.4450
[Validate at epoch 15] ...


Epoch_15_Test: 100%|██████████| 145/145 [00:12<00:00, 11.88it/s]


Epoch_15_Test Average Contrastive Loss: 0.9549
Epoch_15_Test M->T Retrieval (per 32 samples): R@1=0.687, R@2=0.845, R@3=0.907
Epoch_15_Test T->M Retrieval (per 32 samples): R@1=0.700, R@2=0.852, R@3=0.905
Model saved: clip_motion_align_epoch_15.pt


Epoch 16/300: 100%|██████████| 767/767 [01:39<00:00,  7.68it/s, loss=0.4101]


Epoch [16/300] - Train Average Loss: 0.4231
[Validate at epoch 16] ...


Epoch_16_Test: 100%|██████████| 145/145 [00:12<00:00, 11.45it/s]


Epoch_16_Test Average Contrastive Loss: 0.9447
Epoch_16_Test M->T Retrieval (per 32 samples): R@1=0.703, R@2=0.857, R@3=0.913
Epoch_16_Test T->M Retrieval (per 32 samples): R@1=0.705, R@2=0.856, R@3=0.913
Model saved: clip_motion_align_epoch_16.pt
Early stopping triggered!
Training completed!
