In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import math
from transformers import CLIPModel, CLIPTokenizer
import sys
import os
sys.path.append(os.path.abspath('/home/user/dxc/motion/StableMoFusion/'))
from motion_loader import get_dataset_loader  
from tqdm import tqdm
import random
import yaml
from argparse import Namespace
from model import *


In [2]:

with open('config.yaml', 'r') as f:
    config = yaml.safe_load(f)

opt = Namespace(**config)


print(opt.batch_size)  
print(opt.lr)          
print(opt.device)      

32
0.0001
cuda


## load dataset

In [3]:
sys.path.append(os.path.abspath('/home/user/dxc/motion/StableMoFusion'))
train_loader = get_dataset_loader(
        opt,
        batch_size=opt.batch_size,
        split='train',
        mode='train'
    )
test_loader = get_dataset_loader(
    opt,
    batch_size=opt.batch_size,
    split='test',
    mode='gt_eval'
)


 Loading train mode HumanML3D dataset ...
11111111111111


  0%|          | 0/23384 [00:00<?, ?it/s]

Completing loading t2m dataset

 Loading gt_eval mode HumanML3D dataset ...
11111111111111


  0%|          | 0/4384 [00:00<?, ?it/s]

Completing loading t2m dataset


## load clip model for stage 1 training

In [4]:

clip_model = CLIPModel.from_pretrained(opt.clip_model_name)
clip_tokenizer = CLIPTokenizer.from_pretrained(opt.clip_model_name)

# 初始阶段：冻结整个 CLIP 文本编码器（stage 1）
for name, param in clip_model.named_parameters():
    if "text_model" in name:
        param.requires_grad = False


motion_encoder = MotionEncoder(
    input_dim=opt.input_dim,
    embed_dim=opt.embed_dim,
    num_heads=8,
    num_layers=4,         
    dim_feedforward=2048,
    dropout=0.2,
    max_seq_length=opt.max_seq_length
)
model = ClipMotionAlignModel(
    clip_model=clip_model,
    motion_encoder=motion_encoder,
    temperature=0.07
).to(opt.device)

In [5]:
optimizer = optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=opt.lr,
        weight_decay=opt.weight_decay
    )

best_test_loss = float("inf")
no_improve_count = 0
max_no_improve = 3  # 连续3次验证无改进则早停

## Start training

In [6]:
for epoch in range(opt.num_epochs):
    
    if epoch + 1 == opt.pretrain_epochs + 1:
        
        for param in clip_model.text_model.encoder.layers[-1].parameters():
            param.requires_grad = True
        for param in clip_model.text_model.final_layer_norm.parameters():
            param.requires_grad = True
        
        optimizer = optim.AdamW(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=opt.lr_finetune,
            weight_decay=opt.weight_decay
        )
        print("Stage 2: Fine-tuning CLIP text encoder's last layer (and final_layer_norm) with lower lr.")

    model.train()
    total_loss = 0.0
    count = 0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{opt.num_epochs}")



    for step, batch_data in enumerate(pbar):
        caption, motion, m_length = batch_data

        
        caption = [c.lower() for c in caption]
        text_enc = clip_tokenizer(
            caption,
            padding=True,
            truncation=True,
            max_length=opt.max_length,
            return_tensors="pt"
        )
        input_ids = text_enc["input_ids"].to(opt.device)
        attention_mask = text_enc["attention_mask"].to(opt.device)

       
        if isinstance(motion, list):
            motion = torch.stack([torch.tensor(m, dtype=torch.float32) for m in motion], dim=0)
        else:
            motion = motion.float()
        motion = motion.to(opt.device)
        m_length = m_length.to(opt.device)

        
        motion_emb, text_emb = model(motion, m_length, input_ids, attention_mask)
        loss = clip_contrastive_loss(motion_emb, text_emb, model.logit_scale)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        count += 1
        pbar.set_postfix({"loss": f"{loss.item():.4f}"})

    avg_loss = total_loss / max(count, 1)
    print(f"Epoch [{epoch+1}/{opt.num_epochs}] - Train Average Loss: {avg_loss:.4f}")

   
    print(f"[Validate at epoch {epoch+1}] ...")
    test_loss = evaluate_model(model, test_loader, clip_tokenizer, opt, desc=f"Epoch_{epoch+1}_Test")
    model_path = f"clip_motion_align_epoch_{epoch+1}.pt"
    torch.save(model.state_dict(), model_path)
    print(f"Model saved: {model_path}")

   
    if test_loss < best_test_loss:
        best_test_loss = test_loss
        no_improve_count = 0
    else:
        no_improve_count += 1
        if no_improve_count >= max_no_improve:
            print("Early stopping triggered!")
            break





print("Training completed!")

Epoch 1/300:   0%|          | 0/767 [00:00<?, ?it/s]

Epoch 1/300: 100%|██████████| 767/767 [01:45<00:00,  7.30it/s, loss=1.3750]


Epoch [1/300] - Train Average Loss: 1.8424
[Validate at epoch 1] ...


Epoch_1_Test: 100%|██████████| 145/145 [00:11<00:00, 12.90it/s]


Epoch_1_Test Average Contrastive Loss: 1.5630
Epoch_1_Test M->T Retrieval (per 32 samples): R@1=0.494, R@2=0.675, R@3=0.770
Epoch_1_Test T->M Retrieval (per 32 samples): R@1=0.514, R@2=0.698, R@3=0.790
Model saved: clip_motion_align_epoch_1.pt


Epoch 2/300: 100%|██████████| 767/767 [01:44<00:00,  7.36it/s, loss=1.2818]


Epoch [2/300] - Train Average Loss: 1.3815
[Validate at epoch 2] ...


Epoch_2_Test: 100%|██████████| 145/145 [00:11<00:00, 12.96it/s]


Epoch_2_Test Average Contrastive Loss: 1.3516
Epoch_2_Test M->T Retrieval (per 32 samples): R@1=0.557, R@2=0.739, R@3=0.818
Epoch_2_Test T->M Retrieval (per 32 samples): R@1=0.571, R@2=0.759, R@3=0.834
Model saved: clip_motion_align_epoch_2.pt


Epoch 3/300: 100%|██████████| 767/767 [01:45<00:00,  7.30it/s, loss=1.4553]


Epoch [3/300] - Train Average Loss: 1.2163
[Validate at epoch 3] ...


Epoch_3_Test: 100%|██████████| 145/145 [00:10<00:00, 13.57it/s]


Epoch_3_Test Average Contrastive Loss: 1.2700
Epoch_3_Test M->T Retrieval (per 32 samples): R@1=0.580, R@2=0.763, R@3=0.844
Epoch_3_Test T->M Retrieval (per 32 samples): R@1=0.594, R@2=0.772, R@3=0.850
Model saved: clip_motion_align_epoch_3.pt


Epoch 4/300: 100%|██████████| 767/767 [01:59<00:00,  6.39it/s, loss=1.1510]


Epoch [4/300] - Train Average Loss: 1.0818
[Validate at epoch 4] ...


Epoch_4_Test: 100%|██████████| 145/145 [00:10<00:00, 13.76it/s]


Epoch_4_Test Average Contrastive Loss: 1.2076
Epoch_4_Test M->T Retrieval (per 32 samples): R@1=0.606, R@2=0.776, R@3=0.858
Epoch_4_Test T->M Retrieval (per 32 samples): R@1=0.617, R@2=0.785, R@3=0.861
Model saved: clip_motion_align_epoch_4.pt


Epoch 5/300: 100%|██████████| 767/767 [01:46<00:00,  7.22it/s, loss=0.8676]


Epoch [5/300] - Train Average Loss: 1.0001
[Validate at epoch 5] ...


Epoch_5_Test: 100%|██████████| 145/145 [00:10<00:00, 13.38it/s]


Epoch_5_Test Average Contrastive Loss: 1.1730
Epoch_5_Test M->T Retrieval (per 32 samples): R@1=0.610, R@2=0.790, R@3=0.866
Epoch_5_Test T->M Retrieval (per 32 samples): R@1=0.634, R@2=0.801, R@3=0.869
Model saved: clip_motion_align_epoch_5.pt


TypeError: '<=' not supported between instances of 'float' and 'str'