In [5]:
import numpy as np
import torch
from glob import glob

files = [(path, path.replace("data\\", "").replace(".npy", "")) for path in glob("data/*")]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

batch_size = 1000  # 한번에 처리할 배치 크기 (조정 가능)
num_samples_per_file = 10000  # 파일당 샘플 수

for file_idx, (path, label) in enumerate(files):
    print(f'Processing {path}...')
    data_np = np.load(path)[:num_samples_per_file]  # (10000, 224, 224, 3) 혹은 (10000, 224, 224)

    # 흑백이면 채널 차원 추가
    if data_np.ndim == 3:
        data_np = data_np[..., None]  # (N, H, W, 1)

    for start in range(0, num_samples_per_file, batch_size):
        end = start + batch_size
        batch_np = data_np[start:end]  # (batch_size, H, W, C)
        
        # numpy -> torch tensor, 차원 재배치
        batch_torch = torch.from_numpy(batch_np).permute(0, 3, 1, 2).to(dtype=torch.uint8).to(device)
        
        # 여기서 batch_torch를 모델 입력 등 필요한 작업 수행
        print(batch_torch.shape)

print("완료")

files = [(path, path.replace("data\\", "").replace(".npy", "")) for path in glob("data/*")]
total = len(files) * 10000

train_data_shape = (total, 3, 224, 224)  # PyTorch는 채널이 앞에
target_data_shape = (total, )

print(train_data_shape, target_data_shape)

Processing data\aircraft carrier.npy...
torch.Size([1000, 224, 3, 224])
torch.Size([1000, 224, 3, 224])
torch.Size([1000, 224, 3, 224])
torch.Size([1000, 224, 3, 224])
torch.Size([1000, 224, 3, 224])
torch.Size([1000, 224, 3, 224])
torch.Size([1000, 224, 3, 224])
torch.Size([1000, 224, 3, 224])
torch.Size([1000, 224, 3, 224])
torch.Size([1000, 224, 3, 224])
Processing data\airplane.npy...
torch.Size([1000, 224, 3, 224])
torch.Size([1000, 224, 3, 224])
torch.Size([1000, 224, 3, 224])
torch.Size([1000, 224, 3, 224])
torch.Size([1000, 224, 3, 224])
torch.Size([1000, 224, 3, 224])
torch.Size([1000, 224, 3, 224])
torch.Size([1000, 224, 3, 224])
torch.Size([1000, 224, 3, 224])
torch.Size([1000, 224, 3, 224])
Processing data\alarm clock.npy...
torch.Size([1000, 224, 3, 224])
torch.Size([1000, 224, 3, 224])
torch.Size([1000, 224, 3, 224])
torch.Size([1000, 224, 3, 224])
torch.Size([1000, 224, 3, 224])
torch.Size([1000, 224, 3, 224])
torch.Size([1000, 224, 3, 224])
torch.Size([1000, 224, 3, 224