In [2]:
import torch
from torchvision import transforms
import numpy as np
from glob import glob
from PIL import Image
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

files = [(path, path.replace("data_original\\", "").replace(".npy", "")) for path in glob("data_original/*")]

# 분류 정보 처리
categories = [category for _, category in files]
np.save('category.npy', np.array(categories))

# 이미지 변환용 transform 정의
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.Grayscale(num_output_channels=3),  # 흑백을 3채널로 변경
    transforms.ToTensor(),  # 0~1 float tensor (3,224,224)
])

for path, _ in files:
    print('-' * 10, path, '변환 시작', '-' * 10)
    data = np.load(path)
    data = data[:15000].reshape(-1, 28, 28)
    total = data.shape[0]
    
    resized_data = torch.zeros((total, 3, 224, 224), dtype=torch.uint8, device=device)
    
    for i in range(total):
        img_np = data[i]
        # numpy -> tensor 변환, transform 적용
        img_tensor = torch.from_numpy(img_np).to(device)
        img_rgb = transform(img_tensor).to(device)  # float tensor (3,224,224)
        img_rgb = (img_rgb * 255).to(torch.uint8)
        resized_data[i] = img_rgb

    # CPU로 이동 후 numpy 변환하여 저장
    resized_data_np = resized_data.cpu().numpy()

    np.save(path.replace("_original", ""), resized_data_np[:13000])
    np.save(path.replace("_original", "_val"), resized_data_np[13000:])
    
    print('-' * 10, path, '변환 종료', '-' * 10)

# 변환 이미지 테스트
items = np.load("data/aircraft carrier.npy")  # (13000, 3, 224, 224)

# matplotlib는 (H,W,C) uint8 배열을 요구하므로 차원 변환
img = items[0].transpose(1, 2, 0)  # (224,224,3)
plt.imshow(img)
plt.axis('off')
plt.show()


---------- data_original\aircraft carrier.npy 변환 시작 ----------
---------- data_original\aircraft carrier.npy 변환 종료 ----------
---------- data_original\airplane.npy 변환 시작 ----------
---------- data_original\airplane.npy 변환 종료 ----------
---------- data_original\alarm clock.npy 변환 시작 ----------
---------- data_original\alarm clock.npy 변환 종료 ----------
---------- data_original\ambulance.npy 변환 시작 ----------
---------- data_original\ambulance.npy 변환 종료 ----------
---------- data_original\angel.npy 변환 시작 ----------
---------- data_original\angel.npy 변환 종료 ----------
---------- data_original\animal migration.npy 변환 시작 ----------
---------- data_original\animal migration.npy 변환 종료 ----------
---------- data_original\ant.npy 변환 시작 ----------
---------- data_original\ant.npy 변환 종료 ----------
---------- data_original\anvil.npy 변환 시작 ----------
---------- data_original\anvil.npy 변환 종료 ----------
---------- data_original\apple.npy 변환 시작 ----------
---------- data_original\apple.npy 변환 종료 ---------

KeyboardInterrupt: 