<a href="https://colab.research.google.com/github/Song20011219/song/blob/main/dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**dataset**

In [None]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms

class ASD_Isolated(Dataset):
    def __init__(self, data_path, transform=None):
        super(ASD_Isolated, self).__init__()
        self.data_path = data_path
        self.transform = transform
        self.frames = 30  # 每个视频样本的帧数
        self.data_info = self._get_data_info()

    def _get_data_info(self):
        data_info = []
        for label in ("arm_flapping", "hand_flapping"):
            label_path = os.path.join(self.data_path, label)
            for video_folder in os.listdir(label_path):
                video_folder_path = os.path.join(label_path, video_folder)
                if os.path.isdir(video_folder_path):
                    data_info.append((video_folder_path, label))
        return data_info

    def read_images(self, folder_path):
        image_files = sorted([os.path.join(folder_path, file) for file in os.listdir(folder_path) if file.endswith('.jpg')])
        assert len(image_files) == self.frames, f"Expected {self.frames} images, but found {len(image_files)} in folder {folder_path}"
        images = [Image.open(file) for file in image_files]
        if self.transform is not None:
            images = [self.transform(image) for image in images]
        images = torch.stack(images, dim=0)
        images = images.permute(1, 0, 2, 3)  # Adjust dimensions for CNN
        return images

    def __len__(self):
        return len(self.data_info)

    def __getitem__(self, idx):
        folder_path, label = self.data_info[idx]
        images = self.read_images(folder_path)
        label_tensor = torch.tensor(0 if label == "arm_flapping" else 1, dtype=torch.long)
        return {'data': images, 'label': label_tensor}

# 测试
if __name__ == '__main__':
    transform = transforms.Compose([
        transforms.Resize([128, 128]),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    dataset = ASD_Isolated(data_path="/content/drive/MyDrive/output_frames", transform=transform)
    print(f"Dataset size: {len(dataset)}")
    sample = dataset[0]
    print(f"Sample image shape: {sample['data'].shape}, Label: {sample['label']}")


Dataset size: 163
Sample image shape: torch.Size([3, 30, 128, 128]), Label: 0


**train**

In [None]:
# 导入所需的库
import os
import sys
from datetime import datetime
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as transforms
from models.Conv3D import CNN3D, resnet18, resnet34, resnet50, resnet101, r2plus1d_18
from dataset import ASD_Isolated
from train import train_epoch
from validation import val_epoch

# 路径设置
data_path = "YOUR_DATASET_PATH"  # 你将填写数据集路径
label_path = "YOUR_LABEL_PATH"  # 你将填写标签文件路径
model_path = "/home/haodong/Data/cnn3d_models"
log_path = "log/cnn3d_{:%Y-%m-%d_%H-%M-%S}.log".format(datetime.now())
sum_path = "runs/slr_cnn3d_{:%Y-%m-%d_%H-%M-%S}".format(datetime.now())

# 日志文件和Tensorboard writer的设置
logging.basicConfig(level=logging.INFO, format='%(message)s', handlers=[logging.FileHandler(log_path), logging.StreamHandler()])
logger = logging.getLogger('SLR')
logger.info('Logging to file...')
writer = SummaryWriter(sum_path)

# 设置特定的GPU
os.environ["CUDA_VISIBLE_DEVICES"]="2"
# 设备设置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 超参数设置
num_classes = 2
epochs = 100
batch_size = 16
learning_rate = 1e-5
log_interval = 20
sample_size = 128
sample_duration = 30
attention = False
drop_p = 0.0
hidden1, hidden2 = 512, 256

# 使用3DCNN进行训练
if __name__ == '__main__':
    # Load data
    # 注意：这里需要替换为处理ASD数据集的自定义数据加载器
    train_set = ASD_Isolated(data_path=data_path, label_path=label_path, frames=sample_duration,
        num_classes=num_classes, train=True, transform=transform)
    val_set = ASD_Isolated(data_path=data_path, label_path=label_path, frames=sample_duration,
        num_classes=num_classes, train=False, transform=transform)
    logger.info("数据集样本数: {}".format(len(train_set)+len(val_set)))
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True)
    # 创建模型
    # model = CNN3D(sample_size=sample_size, sample_duration=sample_duration, drop_p=drop_p,
    #             hidden1=hidden1, hidden2=hidden2, num_classes=num_classes).to(device)
    model = resnet18(pretrained=True, progress=True, sample_size=sample_size, sample_duration=sample_duration,
                    attention=attention, num_classes=num_classes).to(device)
    # model = r2plus1d_18(pretrained=True, num_classes=num_classes).to(device)
    # 并行运行模型
    if torch.cuda.device_count() > 1:
        logger.info("使用{}个GPU".format(torch.cuda.device_count()))
        model = nn.DataParallel(model)
    # 创建损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # 开始训练
    logger.info("开始训练".center(60, '#'))
    for epoch in range(epochs):
        # 训练模型
        train_epoch(model, criterion, optimizer, train_loader, device, epoch, logger, log_interval, writer)

        # 验证模型
        val_epoch(model, criterion, val_loader, device, epoch, logger, writer)

        # 保存模型
        torch.save(model.state_dict(), os.path.join(model_path, "slr_cnn3d_epoch{:03d}.pth".format(epoch+1)))
        logger.info("第{}轮训练完成的模型已保存".format(epoch+1).center(60, '#'))

    logger.info("训练完成".center(60, '#'))


In [None]:
from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


**test**