# Resnet

In [4]:
# used for data loading
from utils.data_loader import DataLoader

# device configuration
from utils.device import try_gpu

# resnet18 frame feature extractor
from models.resnet18 import FrameFeatureExtractor

# tensor library
import torch

In [5]:
data_loader = DataLoader()

train_iter = data_loader.load_dataiter('train', modality='rgb', frame_ext='jpg')

In [6]:
# inspect a batch
for batch in train_iter:
    if isinstance(batch, dict):
        print('frames shape:', batch['frames'].shape)
        print('labels:', batch['label'])
        print('lengths:', batch['length'])
    else:
        print('batch type:', type(batch))
        try:
            for i, elem in enumerate(batch):
                if hasattr(elem, 'shape'):
                    print(f'elem[{i}] shape:', elem.shape)
                else:
                    print(f'elem[{i}] type:', type(elem))
        except Exception as e:
            print('Error:', e)
    break

frames shape: torch.Size([8, 128, 3, 224, 224])
labels: tensor([18,  2,  9, 15, 15,  2, 15,  8])
lengths: tensor([128, 128, 128, 128, 128, 128, 128, 128])


In [7]:
# test model forward pass
device = try_gpu()
B, T, H, W = 2, 8, 224, 224

# 1) 测试 RGB (C=3)
model_rgb = FrameFeatureExtractor(modality='rgb', pretrained=False).to(device)
x_rgb = torch.randn(B, T, 3, H, W, device=device)
out_rgb = model_rgb(x_rgb)
print('RGB out shape:', out_rgb.shape)  # 期望 (B, T, 512)

RGB out shape: torch.Size([2, 8, 512])


In [8]:
# 2) 测试 depth/infrared 单通道 (C=1)
model_depth = FrameFeatureExtractor(modality='depth', pretrained=False).to(device)
x_depth = torch.randn(B, T, 1, H, W, device=device)
out_depth = model_depth(x_depth)
print('Depth out shape:', out_depth.shape)  # 期望 (B, T, 512)

Depth out shape: torch.Size([2, 8, 512])
