# Resnet

In [None]:
# used for data loading
from utils.data_loader import DataLoader

# device configuration
from utils.device import try_gpu

# resnet18 frame feature extractor
from models.resnet18 import FrameFeatureExtractor

# video recognition model
from models.lstm import VideoRecognitionModel

# train & predict functions
from train import train_one_epoch, save_checkpoint
from predict import predict, load_model

# tensor library
import torch
import os
import time
from torch import nn
from torch import optim
from torch.cuda import amp

In [None]:
# try gpu
device = try_gpu()

# decide which to test
# 1: test model forward pass; 2: inspect multi-modal data loader ; 3: continue training; 4: predict using trained model
phase = 3

In [12]:
data_loader = DataLoader(num_samples=162)

if phase == 1:
    train_iter = data_loader.load_dataiter('train', modality='depth', frame_ext='png')

    # inspect a batch
    for batch in train_iter:
        if isinstance(batch, dict):
            print('frames shape:', batch['frames'].shape)
            print('labels:', batch['label'])
            print('lengths:', batch['length'])
        else:
            print('batch type:', type(batch))
            try:
                for i, elem in enumerate(batch):
                    if hasattr(elem, 'shape'):
                        print(f'elem[{i}] shape:', elem.shape)
                    else:
                        print(f'elem[{i}] type:', type(elem))
            except Exception as e:
                print('Error:', e)
        break

In [13]:
if phase == 1:
    # test model forward pass
    B, T, H, W = 2, 8, 224, 224

    # 1) 测试 RGB (C=3)
    model_rgb = FrameFeatureExtractor(modality='rgb', pretrained=False).to(device)
    x_rgb = torch.randn(B, T, 3, H, W, device=device)
    out_rgb = model_rgb(x_rgb)
    print('RGB out shape:', out_rgb.shape)  # 期望 (B, T, 512)

In [14]:
if phase == 1:
    # 2) 测试 depth/infrared 单通道 (C=1)
    model_depth = FrameFeatureExtractor(modality='depth', pretrained=False).to(device)
    x_depth = torch.randn(B, T, 1, H, W, device=device)
    out_depth = model_depth(x_depth)
    print('Depth out shape:', out_depth.shape)  # 期望 (B, T, 512)

In [15]:
if phase == 2:
    # inspect three modalities together
    train_iter = data_loader.load_multi_modal_dataiter(set='train')
    batch = next(iter(train_iter))
    print('rgb shape:', batch['rgb'].shape)        # (B, T, C_rgb, H, W)
    print('depth shape:', batch['depth'].shape)    # (B, T, C_depth, H, W)
    print('infrared shape:', batch['infrared'].shape)
    print('lengths:', batch['lengths'])
    print('labels:', batch['labels'])

rgb shape: torch.Size([4, 128, 3, 224, 224])
depth shape: torch.Size([4, 128, 1, 224, 224])
infrared shape: torch.Size([4, 128, 1, 224, 224])
lengths: tensor([128, 128, 128, 128])
labels: tensor([10, 10,  0, 17])


In [17]:
if phase == 2:
    # try gpu
    rgb = batch['rgb'].to(device)
    depth = batch['depth'].to(device)
    infrared = batch['infrared'].to(device)
    lengths = batch['lengths']   # pack_padded_sequence 需要在 CPU 上（collate 已返回 CPU tensor）
    labels = batch['labels'].to(device)

    # test model forward pass for multi-modal inputs
    model = VideoRecognitionModel(num_classes=20, num_frames=8).to(device)
    logits = model(rgb, depth, infrared, lengths=lengths)  # 如果你的 forward 支持 lengths，请传 lengths=lengths
    print('logits shape:', logits.shape)  # 期望 (B, num_classes)

logits shape: torch.Size([4, 20])


In [None]:
# continue training
if phase == 3:
    epochs = 10
    ckpt_dir = 'checkpoints'

    device = try_gpu()
    print("Using device:", device)

    # 数据迭代器
    dl = DataLoader()

    train_loader = dl.load_multi_modal_dataiter(set='all',
                                                frames_per_clip=64,
                                                batch_size=8,
                                                shuffle=True,
                                                num_workers=4)

    # 1) 模型
    model = VideoRecognitionModel(num_classes=20,
                                  num_frames=64,
                                  lstm_hidden_size=256,
                                  lstm_num_layers=1)
    model = model.to(device)

    # 2) 加载 checkpoint（支持 CPU/GPU）
    ckpt_path = "checkpoints/best.pth"
    checkpoint = torch.load(ckpt_path, map_location=device)

    # 3) 根据保存格式加载权重
    if isinstance(checkpoint, dict) and 'model_state' in checkpoint:
        state_dict = checkpoint['model_state']
    else:
        state_dict = checkpoint  # 可能直接是 state_dict

    model.load_state_dict(state_dict)

    # 损失/优化器/调度
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.5)

    optimizer.load_state_dict(checkpoint['optim_state'])
    scheduler.load_state_dict(checkpoint['scheduler_state'])
    start_epoch = checkpoint['epoch'] + 1
    best_val = checkpoint.get('best_val_acc', 0.0)

    # 按训练 loss（越小越好）判定最佳模型
    best_metric = float('inf')  # 用训练 loss，越小越好

    for epoch in range(1, epochs + 1):
        t0 = time.time()
        train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device, None, grad_clip=5.0)
        
        val_loss, val_acc = None, None

        scheduler.step()

        elapsed = time.time() - t0
        print(f"Epoch {epoch}/{epochs} | Time {elapsed:.1f}s | Train loss {train_loss:.4f} acc {train_acc:.4f}")

        # 保存 checkpoint：按 train_loss 判断（越小越好）
        is_best = train_loss < best_metric
        if is_best:
            best_metric = train_loss

        ckpt = {
            'epoch': epoch,
            'model_state': model.state_dict(),
            'optim_state': optimizer.state_dict(),
            'scheduler_state': scheduler.state_dict(),
            'best_metric': best_metric
        }
        save_checkpoint(ckpt, os.path.join(ckpt_dir, f"ckpt_epoch{epoch}.pth"))
        if is_best:
            save_checkpoint(ckpt, os.path.join(ckpt_dir, "best.pth"))


    print("Training finished. Best train loss:", best_metric)


In [None]:
if phase == 4:
    model = load_model('checkpoints/best.pth')

    predict()