In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import math
from transformers import CLIPModel, CLIPTokenizer
import sys
import os
import random
import numpy as np  

seed = 42  
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


from motion_loader import get_dataset_loader  
from tqdm import tqdm
import yaml
from argparse import Namespace
from model import *

In [5]:

with open('config_KIT.yaml', 'r') as f:
    config = yaml.safe_load(f)

opt = Namespace(**config)


print(opt.batch_size)  
print(opt.lr)          
print(opt.device)      

32
0.0001
cuda


## load dataset

In [6]:

train_loader = get_dataset_loader(
        opt,
        batch_size=opt.batch_size,
        split='train',
        mode='gt_eval'
    )
test_loader = get_dataset_loader(
    opt,
    batch_size=opt.batch_size,
    split='test',
    mode='gt_eval'
)

Loading dataset...

 Loading gt_eval mode KIT dataset ...
/home/user/dxc/motion/StableMoFusion/data/kit_std.npy


FileNotFoundError: [Errno 2] No such file or directory: './data/KIT-ML/Mean.npy'

## load clip model for stage 1 training

In [None]:

clip_model = CLIPModel.from_pretrained(opt.clip_model_name)
clip_tokenizer = CLIPTokenizer.from_pretrained(opt.clip_model_name)


for name, param in clip_model.named_parameters():
    if "text_model" in name:
        param.requires_grad = False


motion_encoder = MotionEncoder(
    input_dim=opt.input_dim,
    embed_dim=opt.embed_dim,
    num_heads=8,
    num_layers=4,         
    dim_feedforward=2048,
    dropout=0.2,
    max_seq_length=opt.max_seq_length
)
model = ClipMotionAlignModel(
    clip_model=clip_model,
    motion_encoder=motion_encoder,
    temperature=0.07
).to(opt.device)

In [None]:
optimizer = optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=opt.lr,
        weight_decay=opt.weight_decay
    )

best_test_loss = float("inf")
no_improve_count = 0
max_no_improve = 3  

## Start training

In [None]:
for epoch in range(opt.num_epochs):
    
    if epoch + 1 == opt.pretrain_epochs + 1:
        
        for param in clip_model.text_model.encoder.layers[-1].parameters():
            param.requires_grad = True
        for param in clip_model.text_model.final_layer_norm.parameters():
            param.requires_grad = True
        
        optimizer = optim.AdamW(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=opt.lr_finetune,
            weight_decay=opt.weight_decay
        )
        print("Stage 2: Fine-tuning CLIP text encoder's last layer (and final_layer_norm) with lower lr.")

    model.train()
    total_loss = 0.0
    count = 0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{opt.num_epochs}")



    for step, batch_data in enumerate(pbar):
        caption, motion, m_length = batch_data

        
        caption = [c.lower() for c in caption]
        text_enc = clip_tokenizer(
            caption,
            padding=True,
            truncation=True,
            max_length=opt.max_length,
            return_tensors="pt"
        )
        input_ids = text_enc["input_ids"].to(opt.device)
        attention_mask = text_enc["attention_mask"].to(opt.device)

       
        if isinstance(motion, list):
            motion = torch.stack([torch.tensor(m, dtype=torch.float32) for m in motion], dim=0)
        else:
            motion = motion.float()
        motion = motion.to(opt.device)
        m_length = m_length.to(opt.device)

        
        motion_emb, text_emb = model(motion, m_length, input_ids, attention_mask)
        loss = clip_contrastive_loss(motion_emb, text_emb, model.logit_scale)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        count += 1
        pbar.set_postfix({"loss": f"{loss.item():.4f}"})

    avg_loss = total_loss / max(count, 1)
    print(f"Epoch [{epoch+1}/{opt.num_epochs}] - Train Average Loss: {avg_loss:.4f}")

   
    print(f"[Validate at epoch {epoch+1}] ...")
    test_loss = evaluate_model(model, test_loader, clip_tokenizer, opt, desc=f"Epoch_{epoch+1}_Test")
    model_path = f"clip_motion_align_epoch_KIT_{epoch+1}.pt"
    torch.save(model.state_dict(), model_path)
    print(f"Model saved: {model_path}")

   
    if test_loss < best_test_loss:
        best_test_loss = test_loss
        no_improve_count = 0
    else:
        no_improve_count += 1
        if no_improve_count >= max_no_improve:
            print("Early stopping triggered!")
            break





print("Training completed!")

Epoch 1/300:   0%|          | 0/30 [00:00<?, ?it/s]

Epoch 1/300: 100%|██████████| 30/30 [00:03<00:00,  7.75it/s, loss=2.5532]


Epoch [1/300] - Train Average Loss: 2.7950
[Validate at epoch 1] ...


Epoch_1_Test: 100%|██████████| 22/22 [00:01<00:00, 13.66it/s]


Epoch_1_Test Average Contrastive Loss: 2.4003
Epoch_1_Test M->T Retrieval (per 32 samples): R@1=0.216, R@2=0.384, R@3=0.503
Epoch_1_Test T->M Retrieval (per 32 samples): R@1=0.224, R@2=0.402, R@3=0.528
Model saved: clip_motion_align_epoch_KIT_1.pt


Epoch 2/300: 100%|██████████| 30/30 [00:02<00:00, 11.04it/s, loss=1.9454]


Epoch [2/300] - Train Average Loss: 2.1803
[Validate at epoch 2] ...


Epoch_2_Test: 100%|██████████| 22/22 [00:01<00:00, 21.72it/s]


Epoch_2_Test Average Contrastive Loss: 2.1095
Epoch_2_Test M->T Retrieval (per 32 samples): R@1=0.283, R@2=0.472, R@3=0.591
Epoch_2_Test T->M Retrieval (per 32 samples): R@1=0.314, R@2=0.499, R@3=0.632
Model saved: clip_motion_align_epoch_KIT_2.pt


Epoch 3/300: 100%|██████████| 30/30 [00:02<00:00, 11.68it/s, loss=1.9369]


Epoch [3/300] - Train Average Loss: 1.8235
[Validate at epoch 3] ...


Epoch_3_Test: 100%|██████████| 22/22 [00:00<00:00, 22.31it/s]


Epoch_3_Test Average Contrastive Loss: 2.0060
Epoch_3_Test M->T Retrieval (per 32 samples): R@1=0.301, R@2=0.483, R@3=0.591
Epoch_3_Test T->M Retrieval (per 32 samples): R@1=0.335, R@2=0.540, R@3=0.686
Model saved: clip_motion_align_epoch_KIT_3.pt


Epoch 4/300: 100%|██████████| 30/30 [00:02<00:00, 11.20it/s, loss=1.5030]


Epoch [4/300] - Train Average Loss: 1.5885
[Validate at epoch 4] ...


Epoch_4_Test: 100%|██████████| 22/22 [00:00<00:00, 22.64it/s]


Epoch_4_Test Average Contrastive Loss: 1.7729
Epoch_4_Test M->T Retrieval (per 32 samples): R@1=0.362, R@2=0.591, R@3=0.706
Epoch_4_Test T->M Retrieval (per 32 samples): R@1=0.393, R@2=0.608, R@3=0.741
Model saved: clip_motion_align_epoch_KIT_4.pt


Epoch 5/300: 100%|██████████| 30/30 [00:02<00:00, 11.14it/s, loss=1.5383]


Epoch [5/300] - Train Average Loss: 1.3343
[Validate at epoch 5] ...


Epoch_5_Test: 100%|██████████| 22/22 [00:00<00:00, 22.14it/s]


Epoch_5_Test Average Contrastive Loss: 1.7477
Epoch_5_Test M->T Retrieval (per 32 samples): R@1=0.385, R@2=0.604, R@3=0.713
Epoch_5_Test T->M Retrieval (per 32 samples): R@1=0.409, R@2=0.621, R@3=0.740
Model saved: clip_motion_align_epoch_KIT_5.pt
Stage 2: Fine-tuning CLIP text encoder's last layer (and final_layer_norm) with lower lr.


Epoch 6/300: 100%|██████████| 30/30 [00:02<00:00, 10.82it/s, loss=1.2861]


Epoch [6/300] - Train Average Loss: 1.1751
[Validate at epoch 6] ...


Epoch_6_Test: 100%|██████████| 22/22 [00:00<00:00, 22.45it/s]


Epoch_6_Test Average Contrastive Loss: 1.7107
Epoch_6_Test M->T Retrieval (per 32 samples): R@1=0.381, R@2=0.591, R@3=0.726
Epoch_6_Test T->M Retrieval (per 32 samples): R@1=0.409, R@2=0.612, R@3=0.749
Model saved: clip_motion_align_epoch_KIT_6.pt


Epoch 7/300: 100%|██████████| 30/30 [00:02<00:00, 10.60it/s, loss=0.9888]


Epoch [7/300] - Train Average Loss: 1.1113
[Validate at epoch 7] ...


Epoch_7_Test: 100%|██████████| 22/22 [00:01<00:00, 21.81it/s]


Epoch_7_Test Average Contrastive Loss: 1.6461
Epoch_7_Test M->T Retrieval (per 32 samples): R@1=0.425, R@2=0.636, R@3=0.751
Epoch_7_Test T->M Retrieval (per 32 samples): R@1=0.433, R@2=0.655, R@3=0.767
Model saved: clip_motion_align_epoch_KIT_7.pt


Epoch 8/300: 100%|██████████| 30/30 [00:02<00:00, 10.67it/s, loss=0.8657]


Epoch [8/300] - Train Average Loss: 1.0205
[Validate at epoch 8] ...


Epoch_8_Test: 100%|██████████| 22/22 [00:00<00:00, 22.42it/s]


Epoch_8_Test Average Contrastive Loss: 1.6461
Epoch_8_Test M->T Retrieval (per 32 samples): R@1=0.391, R@2=0.614, R@3=0.746
Epoch_8_Test T->M Retrieval (per 32 samples): R@1=0.445, R@2=0.659, R@3=0.776
Model saved: clip_motion_align_epoch_KIT_8.pt


Epoch 9/300: 100%|██████████| 30/30 [00:02<00:00, 10.74it/s, loss=0.7568]


Epoch [9/300] - Train Average Loss: 1.0636
[Validate at epoch 9] ...


Epoch_9_Test: 100%|██████████| 22/22 [00:01<00:00, 21.97it/s]


Epoch_9_Test Average Contrastive Loss: 1.6372
Epoch_9_Test M->T Retrieval (per 32 samples): R@1=0.405, R@2=0.619, R@3=0.751
Epoch_9_Test T->M Retrieval (per 32 samples): R@1=0.419, R@2=0.624, R@3=0.754
Model saved: clip_motion_align_epoch_KIT_9.pt


Epoch 10/300: 100%|██████████| 30/30 [00:02<00:00, 10.44it/s, loss=0.7973]


Epoch [10/300] - Train Average Loss: 0.9406
[Validate at epoch 10] ...


Epoch_10_Test: 100%|██████████| 22/22 [00:01<00:00, 21.51it/s]


Epoch_10_Test Average Contrastive Loss: 1.6563
Epoch_10_Test M->T Retrieval (per 32 samples): R@1=0.406, R@2=0.635, R@3=0.761
Epoch_10_Test T->M Retrieval (per 32 samples): R@1=0.402, R@2=0.653, R@3=0.771
Model saved: clip_motion_align_epoch_KIT_10.pt


Epoch 11/300: 100%|██████████| 30/30 [00:02<00:00, 10.71it/s, loss=0.7034]


Epoch [11/300] - Train Average Loss: 0.9351
[Validate at epoch 11] ...


Epoch_11_Test: 100%|██████████| 22/22 [00:00<00:00, 22.31it/s]


Epoch_11_Test Average Contrastive Loss: 1.6523
Epoch_11_Test M->T Retrieval (per 32 samples): R@1=0.413, R@2=0.625, R@3=0.757
Epoch_11_Test T->M Retrieval (per 32 samples): R@1=0.433, R@2=0.643, R@3=0.760
Model saved: clip_motion_align_epoch_KIT_11.pt


Epoch 12/300: 100%|██████████| 30/30 [00:02<00:00, 10.49it/s, loss=0.8813]


Epoch [12/300] - Train Average Loss: 0.9008
[Validate at epoch 12] ...


Epoch_12_Test: 100%|██████████| 22/22 [00:01<00:00, 21.78it/s]


Epoch_12_Test Average Contrastive Loss: 1.6173
Epoch_12_Test M->T Retrieval (per 32 samples): R@1=0.420, R@2=0.635, R@3=0.751
Epoch_12_Test T->M Retrieval (per 32 samples): R@1=0.445, R@2=0.655, R@3=0.774
Model saved: clip_motion_align_epoch_KIT_12.pt


Epoch 13/300: 100%|██████████| 30/30 [00:02<00:00, 10.70it/s, loss=0.6806]


Epoch [13/300] - Train Average Loss: 0.8440
[Validate at epoch 13] ...


Epoch_13_Test: 100%|██████████| 22/22 [00:00<00:00, 22.10it/s]


Epoch_13_Test Average Contrastive Loss: 1.6327
Epoch_13_Test M->T Retrieval (per 32 samples): R@1=0.412, R@2=0.624, R@3=0.771
Epoch_13_Test T->M Retrieval (per 32 samples): R@1=0.433, R@2=0.649, R@3=0.786
Model saved: clip_motion_align_epoch_KIT_13.pt


Epoch 14/300: 100%|██████████| 30/30 [00:02<00:00, 11.83it/s, loss=0.9139]


Epoch [14/300] - Train Average Loss: 0.8830
[Validate at epoch 14] ...


Epoch_14_Test: 100%|██████████| 22/22 [00:00<00:00, 22.18it/s]


Epoch_14_Test Average Contrastive Loss: 1.5854
Epoch_14_Test M->T Retrieval (per 32 samples): R@1=0.428, R@2=0.663, R@3=0.780
Epoch_14_Test T->M Retrieval (per 32 samples): R@1=0.466, R@2=0.662, R@3=0.788
Model saved: clip_motion_align_epoch_KIT_14.pt


Epoch 15/300: 100%|██████████| 30/30 [00:02<00:00, 10.56it/s, loss=1.1938]


Epoch [15/300] - Train Average Loss: 0.8690
[Validate at epoch 15] ...


Epoch_15_Test: 100%|██████████| 22/22 [00:01<00:00, 21.91it/s]


Epoch_15_Test Average Contrastive Loss: 1.6273
Epoch_15_Test M->T Retrieval (per 32 samples): R@1=0.443, R@2=0.669, R@3=0.767
Epoch_15_Test T->M Retrieval (per 32 samples): R@1=0.406, R@2=0.646, R@3=0.767
Model saved: clip_motion_align_epoch_KIT_15.pt


Epoch 16/300: 100%|██████████| 30/30 [00:02<00:00, 11.08it/s, loss=1.0129]


Epoch [16/300] - Train Average Loss: 0.8606
[Validate at epoch 16] ...


Epoch_16_Test: 100%|██████████| 22/22 [00:00<00:00, 25.96it/s]


Epoch_16_Test Average Contrastive Loss: 1.5764
Epoch_16_Test M->T Retrieval (per 32 samples): R@1=0.443, R@2=0.662, R@3=0.764
Epoch_16_Test T->M Retrieval (per 32 samples): R@1=0.463, R@2=0.668, R@3=0.781
Model saved: clip_motion_align_epoch_KIT_16.pt


Epoch 17/300: 100%|██████████| 30/30 [00:02<00:00, 10.64it/s, loss=0.8548]


Epoch [17/300] - Train Average Loss: 0.8229
[Validate at epoch 17] ...


Epoch_17_Test: 100%|██████████| 22/22 [00:00<00:00, 22.04it/s]


Epoch_17_Test Average Contrastive Loss: 1.6506
Epoch_17_Test M->T Retrieval (per 32 samples): R@1=0.452, R@2=0.639, R@3=0.756
Epoch_17_Test T->M Retrieval (per 32 samples): R@1=0.433, R@2=0.668, R@3=0.778
Model saved: clip_motion_align_epoch_KIT_17.pt


Epoch 18/300: 100%|██████████| 30/30 [00:02<00:00, 11.75it/s, loss=1.2670]


Epoch [18/300] - Train Average Loss: 0.8090
[Validate at epoch 18] ...


Epoch_18_Test: 100%|██████████| 22/22 [00:00<00:00, 26.33it/s]


Epoch_18_Test Average Contrastive Loss: 1.6195
Epoch_18_Test M->T Retrieval (per 32 samples): R@1=0.425, R@2=0.632, R@3=0.763
Epoch_18_Test T->M Retrieval (per 32 samples): R@1=0.463, R@2=0.655, R@3=0.764
Model saved: clip_motion_align_epoch_KIT_18.pt


Epoch 19/300: 100%|██████████| 30/30 [00:02<00:00, 10.91it/s, loss=0.6178]


Epoch [19/300] - Train Average Loss: 0.7818
[Validate at epoch 19] ...


Epoch_19_Test: 100%|██████████| 22/22 [00:00<00:00, 22.47it/s]


Epoch_19_Test Average Contrastive Loss: 1.6379
Epoch_19_Test M->T Retrieval (per 32 samples): R@1=0.419, R@2=0.648, R@3=0.773
Epoch_19_Test T->M Retrieval (per 32 samples): R@1=0.433, R@2=0.619, R@3=0.753
Model saved: clip_motion_align_epoch_KIT_19.pt
Early stopping triggered!
Training completed!
