In [None]:
import torch
from torch.utils.data import DataLoader, random_split
from dataset import VRDataset

# 配置参数
config = {
    'batch_size': 8,
    'num_epochs': 2,
    'lr': 1e-4,
    'lr_step_size': 10,
    'lr_gamma': 0.1,
    'weight_decay': 1e-5,
    'save_dir': './checkpoints',
    'log_interval': 10,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu'
}

# 数据加载
train_dataset = VRDataset(
    video_root='./data/train/video',
    motion_csv_root='./data/train/csv',
    temporal_length=32
)

# 划分训练集和验证集
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

# 创建 DataLoader
train_loader = DataLoader(
    train_dataset,
    batch_size=config['batch_size'],
    shuffle=True,
    num_workers=4
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config['batch_size'],
    num_workers=4
)

# 保存数据：训练集大小、批次数
print(f"训练集样本量: {len(train_dataset)}")
print(f"Batch数目: {len(train_loader)}")


In [None]:
from model import build_model
import torch.nn.functional as F

# 初始化模型
model = build_model(pretrained=True).to(config['device'])

# 四元数损失函数
def quaternion_angle_loss(q_pred, q_true):
    q_pred = q_pred / torch.norm(q_pred, dim=-1, keepdim=True)
    q_true = q_true / torch.norm(q_true, dim=-1, keepdim=True)
    dot_product = torch.sum(q_pred * q_true, dim=-1)
    dot_product_abs = torch.clamp(torch.abs(dot_product), 1e-6, 1.0)
    angle_diff = 2 * torch.acos(dot_product_abs)
    return torch.mean(angle_diff ** 2)

# 保存模型结构信息
print(model)


In [None]:
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR

# 初始化优化器和学习率调度器
optimizer = Adam(
    model.parameters(),
    lr=config['lr'],
    weight_decay=config['weight_decay']
)

scheduler = StepLR(
    optimizer,
    step_size=config['lr_step_size'],
    gamma=config['lr_gamma']
)

# 保存优化器状态
print(f"Optimizer initialized with learning rate: {config['lr']}")


In [None]:
from tqdm import tqdm
from datetime import datetime
import os
# 时间戳和最佳验证损失初始化
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
best_val_loss = float('inf')
os.makedirs(config['save_dir'], exist_ok=True)
# 开始训练和验证
for epoch in range(config['num_epochs']):
    # 训练阶段
    model.train()
    epoch_loss = 0.0
    progress_bar = tqdm(train_loader, desc=f'Epoch {epoch + 1} Training')

    for batch_idx, batch in enumerate(progress_bar):
        # 数据转移到设备
        video = batch['video'].to(config['device'], dtype=torch.float32)
        motion = batch['motion'].to(config['device'], dtype=torch.float32)
        target = batch['target'].to(config['device'], dtype=torch.float32)

        # 梯度清零
        optimizer.zero_grad()

        # 前向传播
        outputs = model(video, motion)

        # 计算损失
        loss = quaternion_angle_loss(outputs, target)

        # 反向传播和优化
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        if batch_idx % config['log_interval'] == 0:
            progress_bar.set_postfix({'loss': loss.item()})
            print("")

    avg_train_loss = epoch_loss / len(train_loader)
    print(f'Epoch {epoch + 1} Train Loss: {avg_train_loss:.4f}')
    
    # 保存每个epoch的训练损失
    with open('./train_loss.txt', 'a') as f:
        f.write(f'Epoch {epoch + 1} Train Loss: {avg_train_loss:.4f}\n')
    '''
    # 验证阶段
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        val_progress = tqdm(val_loader, desc=f'Epoch {epoch + 1} Validation')
        for batch in val_progress:
            video = batch['video'].to(config['device'], dtype=torch.float32)
            motion = batch['motion'].to(config['device'], dtype=torch.float32)
            target = batch['target'].to(config['device'], dtype=torch.float32)

            outputs = model(video, motion)
            loss = quaternion_angle_loss(outputs, target)
            val_loss += loss.item()

    avg_val_loss = val_loss / len(val_loader)
    print(f'Epoch {epoch + 1} Val Loss: {avg_val_loss:.4f}')
    
    # 保存每个epoch的验证损失
    with open('./val_loss.txt', 'a') as f:
        f.write(f'Epoch {epoch + 1} Val Loss: {avg_val_loss:.4f}\n')

    # 保存最佳模型
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        os.makedirs(config['save_dir'], exist_ok=True)
        checkpoint = {
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'loss': best_val_loss
        }
        torch.save(checkpoint, os.path.join(config['save_dir'], f'best_model_{timestamp}.pt'))

    # 保存当前模型和优化器
    torch.save({
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': avg_val_loss,
    }, f'checkpoint_epoch_{epoch + 1}.pt')
    '''
    # 更新学习率
    scheduler.step()