In [41]:
import torch
import json
import matplotlib.pyplot as plt
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np

import os
from models import Net
from torchvision.transforms import Lambda, Compose, ToTensor, Normalize
from PIL import Image, ImageOps, ImageFilter

In [42]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cuda


In [43]:

class EmotionImageDataset(Dataset):
    def __init__(self, directory, transform=None):
        self.directory = directory
        self.transform = transform
        self.image_files = [os.path.join(directory, file) for file in os.listdir(directory) if file.endswith(('jpg', 'png'))]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_path = self.image_files[idx]
        image = Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, image_path


In [44]:
# 이미지 전처리
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 데이터 로더 설정
emotion_paths = {
    'Angry': '/home/work/XAI/BITAmin/Cat_yolo/train_cropped/Angry',
    'Disgusted': '/home/work/XAI/BITAmin/Cat_yolo/train_cropped/Disgusted',
    'Happy': '/home/work/XAI/BITAmin/Cat_yolo/train_cropped/Happy',
    'Normal': '/home/work/XAI/BITAmin/Cat_yolo/train_cropped/Normal',
    'Sad': '/home/work/XAI/BITAmin/Cat_yolo/train_cropped/Sad',
    'Scared': '/home/work/XAI/BITAmin/Cat_yolo/train_cropped/Scared',
    'Surprised': '/home/work/XAI/BITAmin/Cat_yolo/train_cropped/Surprised',
}

data_loaders = {
    emotion: DataLoader(EmotionImageDataset(path, transform=transform), batch_size=10, shuffle=False)
    for emotion, path in emotion_paths.items()
}


In [46]:
# 모델 설정 및 로드
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Net().to(device)
model_path = '/home/work/XAI/BITAmin/facial/models/cat_keypoints_model2.pt'
model.eval()

Net(
  (conv1): Conv2d(3, 32, kernel_size=(4, 4), stride=(1, 1))
  (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (dropout1): Dropout(p=0.1, inplace=False)
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (dropout2): Dropout(p=0.2, inplace=False)
  (conv3): Conv2d(64, 128, kernel_size=(2, 2), stride=(1, 1))
  (bn3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (dropout3): Dropout(p=0.3, inplace=False)
  (conv4): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))
  (bn4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
 

In [47]:
# 키포인트 추출 및 시각화
for emotion, loader in data_loaders.items():
    print(f"Processing {emotion} images...")
    for images, image_paths in loader:
        images = images.to(device)
        outputs = model(images)
        outputs = outputs.view(outputs.size(0), -1, 2)  # 조정한 모델 출력 형태

        outputs = outputs.cpu().detach().numpy()  # GPU에서 CPU로 텐서 이동 및 numpy 배열로 변환

        for i, output in enumerate(outputs):
            image_np = images[i].cpu().permute(1, 2, 0).numpy()
            image_np = image_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
            image_np = np.clip(image_np, 0, 1)
            keypoints = output

            plt.imshow(image_np)
            plt.scatter(keypoints[:, 0], keypoints[:, 1], s=20, color='red')
            plt.title(f"{emotion} - {os.path.basename(image_paths[i])}")
            plt.axis('off')
            plt.show()

Processing Angry images...


RuntimeError: stack expects each tensor to be equal size, but got [3, 376, 343] at entry 0 and [3, 349, 358] at entry 1