<a href="https://colab.research.google.com/github/Song20011219/song/blob/main/dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


**dataset**

In [None]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms

class ASD_Isolated(Dataset):
    def __init__(self, data_path, transform=None):
        super(ASD_Isolated, self).__init__()
        self.data_path = data_path
        self.transform = transform
        self.frames = 30  # 每个视频样本的帧数
        self.data_info = self._get_data_info()

    def _get_data_info(self):
        data_info = []
        for label in ("arm_flapping", "hand_flapping"):
            label_path = os.path.join(self.data_path, label)
            for video_folder in os.listdir(label_path):
                video_folder_path = os.path.join(label_path, video_folder)
                if os.path.isdir(video_folder_path):
                    data_info.append((video_folder_path, label))
        return data_info

    def read_images(self, folder_path):
        image_files = sorted([os.path.join(folder_path, file) for file in os.listdir(folder_path) if file.endswith('.jpg')])
        assert len(image_files) == self.frames, f"Expected {self.frames} images, but found {len(image_files)} in folder {folder_path}"
        images = [Image.open(file) for file in image_files]
        if self.transform is not None:
            images = [self.transform(image) for image in images]
        images = torch.stack(images, dim=0)
        images = images.permute(1, 0, 2, 3)  # Adjust dimensions for CNN
        return images

    def __len__(self):
        return len(self.data_info)

    def __getitem__(self, idx):
        folder_path, label = self.data_info[idx]
        images = self.read_images(folder_path)
        label_tensor = torch.tensor(0 if label == "arm_flapping" else 1, dtype=torch.long)
        return {'data': images, 'label': label_tensor}

# 测试
if __name__ == '__main__':
    transform = transforms.Compose([
        transforms.Resize([128, 128]),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    dataset = ASD_Isolated(data_path="/content/drive/MyDrive/output_frames", transform=transform)
    print(f"Dataset size: {len(dataset)}")
    sample = dataset[0]
    print(f"Sample image shape: {sample['data'].shape}, Label: {sample['label']}")


Dataset size: 163
Sample image shape: torch.Size([3, 30, 128, 128]), Label: 0
