In [1]:
import os
import pickle
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import torch
from torch import optim
from tqdm.notebook import tqdm
import torch.nn.functional as F
import itertools
from model import *
from dataset_loader import create_data_loader
from loss_func import *

In [2]:
lr = 0.001 # 学习率
epoches = 2000 #训练次数
batch_size = 256 # 每一个训练批次数量
start_epoch = 0
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [3]:
poseS1 = PoseS1().to(device)
poseS2 = PoseS2().to(device)
poseS3 = PoseS3().to(device)
transB1 = TransB1().to(device)
transB2 = TransB2().to(device)

Ls1 = poseLoss # 叶关节点位置的loss
Ls2 = poseLoss # 除胯外的关节位置的loss
Ls3 = poseLoss # 除胯外的关节6d旋转的loss
Lb1 = crossEntropy # 接触脚的概率
Lb2 = ver_n_loss # 连续1,3,9,27帧的位移loss

optimizer = optim.Adam(filter(lambda p: p.requires_grad,  itertools.chain(poseS1.parameters(),
                                       poseS2.parameters(),
                                       poseS3.parameters(),
                                       transB1.parameters(),
                                       transB2.parameters())), lr=lr)

In [4]:
def load_checkpoint(filepath):
    checkpoint = torch.load(filepath)
    global start_epoch
    start_epoch = checkpoint['epoch']
    models = checkpoint['modes'] # 提取网络结构
    models_state_dict = checkpoint['models_state_dict']  # 提取网络结构
    for model, state_dict in zip([poseS1,  poseS2, poseS3, transB1, transB2], models_state_dict):
        model.load_state_dict(state_dict)  # 加载网络权重参数
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])  # 加载优化器参数
    
    for parameter in model.parameters():
        parameter.requires_grad = True
    model.train()
    
    return models
models = load_checkpoint("./checkpoint/checkpoint_best_transB1.pth")
for p in poseS1.parameters():
    p.requires_grad=False
for p in poseS2.parameters():
    p.requires_grad=False
for p in poseS3.parameters():
    p.requires_grad=False
for p in transB1.parameters():
    p.requires_grad=False



In [5]:
train_loader, test_loader, valid_loader = create_data_loader(batch_size)

训练集大小36858， 验证集大小9215， 测试集大小46073


In [7]:
for epoch in range(start_epoch, epoches):
    i = 0
    loss_sum = 0
    loss_dict = {"loss_ls1":0, "loss_ls2":0, "loss_ls3":0, "loss_lb1":0, "loss_lb2":0, "foot_acc":0}
    bar = tqdm(enumerate(iter(train_loader)),postfix="training", total =len(train_loader))

    for idx, (seq_len, x0, p_leaf_gt, p_all_gt, pose_6d_gt, support_leg_gt, root_velocity_gt, root_ori, mask) in bar:
        poseS1.train()
        poseS2.train()
        poseS3.train()
        transB1.train()
        transB2.train()

        x0 = x0.to(device)
        p_leaf_gt = p_leaf_gt.to(device)
        p_all_gt = p_all_gt.to(device)
        pose_6d_gt = pose_6d_gt.to(device)
        support_leg_gt = support_leg_gt.to(device)
        root_velocity_gt = root_velocity_gt.to(device)
        root_ori = root_ori.to(device)
        mask = mask.to(device)
        

        # --------pose 1
        p_leaf = poseS1(x0, seq_len)
        loss_ls1 = Ls1(p_leaf, p_leaf_gt.to(device) + torch.normal(mean=0, std=0.04, size=p_leaf_gt.shape).to(device), mask)
        loss_dict["loss_ls1"] += loss_ls1.item()
        # --------pose 2
        x1 = torch.cat([p_leaf, x0], dim=-1)
        p_all = poseS2(x1, seq_len)
        loss_ls2 = Ls2(p_all, p_all_gt.to(device) + torch.normal(mean=0, std=0.025, size=p_all_gt.shape).to(device), mask)
        loss_dict["loss_ls2"] += loss_ls2.item()
        # --------pose 3
        x2 = torch.cat([p_all, x0], dim=-1)
        r6d_all = poseS3(x2, seq_len)
        loss_ls3 = Ls3(r6d_all, pose_6d_gt, mask)
        loss_dict["loss_ls3"] += loss_ls3.item()
        # --------transB1
        support_leg_prob = transB1(x1, seq_len)

        loss_lb1 = Lb1(support_leg_prob, support_leg_gt, mask)
        foot_acc = foot_accuracy(support_leg_prob, support_leg_gt, mask)
        loss_dict["loss_lb1"] += loss_lb1.item()
        loss_dict["foot_acc"] += foot_acc.item() * 100
        # --------transB2
        ve_hat = transB2(x2 + torch.normal(mean=0, std=0.025, size=x2.shape).to(device), seq_len)
        loss_lb2 = Lb2(ve_hat , root_velocity_gt , mask)
        loss_dict["loss_lb2"] += loss_lb2.item()
        # --------end
        
        # cal angle error
        #avg_error = compute_angle_dif(r6d_all.detach().cpu(), pose_6d_gt.detach().cpu(), mask.detach().cpu())
        avg_error = 0
        loss_total = loss_lb2
        
        optimizer.zero_grad()
        loss_sum += loss_total.item()
        loss_total.backward()
        optimizer.step()
        i+=1
        bar.set_description(f"[Epoch {epoch+1}/ {epoches}]")
        bar.set_postfix(total_loss= loss_sum / i, avg_angle_error=avg_error,**{key: value / i for key, value in loss_dict.items()})

        if idx%30 == 0:
            checkpoint = {'modes': [PoseS1(), PoseS2(),PoseS3(),TransB1(),TransB2()],  
    'models_state_dict': [poseS1.state_dict(),poseS2.state_dict(), poseS3.state_dict(),transB1.state_dict(),transB2.state_dict()],
    'optimizer_state_dict': optimizer.state_dict(),'epoch': epoch}
            torch.save(checkpoint, f'checkpoint/checkpoint_latest.pth')
            
            # get sample
            sample_idx = 0
            v_gt = root_velocity_gt[sample_idx][mask[sample_idx].int().bool()].reshape(-1, 3) / 60
            v_pre = ve_hat[sample_idx][mask[sample_idx].int().bool()].reshape(-1, 3)
            foot_pre = support_leg_prob[sample_idx][mask[sample_idx].int().bool()].reshape(-1, 2)
            foot_gt = support_leg_gt[sample_idx][mask[sample_idx].int().bool()].reshape(-1, 2)

        # print(v_gt)
            sample_dict = {"pose": r6d_all[sample_idx][mask[sample_idx].int().bool()].detach().cpu().numpy(), 
                                            'pose_gt':pose_6d_gt[sample_idx][mask[sample_idx].int().bool()].detach().cpu().numpy(), 
                                            'leg': support_leg_prob[sample_idx][mask[sample_idx].int().bool()].detach().cpu().numpy(),
                                            'leg_gt': support_leg_gt[sample_idx][mask[sample_idx].int().bool()].detach().cpu().numpy(),
                                        'foot_pre': foot_pre.detach().cpu().numpy(),
                                        'foot_gt': foot_gt.detach().cpu().numpy(),
                                        'v': v_pre.detach().cpu().numpy(),
                                            'v_gt': v_gt.detach().cpu().numpy(),
                                            'root_ori':root_ori[sample_idx][mask[sample_idx].int().bool()].detach().cpu().numpy()
                                            }
            pickle.dump(sample_dict, open(f'sample/sample_data_{epoch}_{idx}.pkl', 'wb'))


    checkpoint = {'modes': [PoseS1(), PoseS2(),PoseS3(),TransB1(),TransB2()],  
    'models_state_dict': [poseS1.state_dict(),poseS2.state_dict(), poseS3.state_dict(),transB1.state_dict(),transB2.state_dict()],
    'optimizer_state_dict': optimizer.state_dict(),'epoch': epoch}
    torch.save(checkpoint, f'checkpoint/checkpoint_{epoch+1}_pose1.pth')


  0%|          | 0/143 [00:00<?, ?it/s, training]