In [1]:
pip install torch

Note: you may need to restart the kernel to use updated packages.


In [1]:
pip uninstall torch

Note: you may need to restart the kernel to use updated packages.




In [1]:
pip install torch

Note: you may need to restart the kernel to use updated packages.


In [3]:
pip install torchvision

Collecting torchvision
  Downloading torchvision-0.18.0-cp311-cp311-win_amd64.whl.metadata (6.6 kB)
Downloading torchvision-0.18.0-cp311-cp311-win_amd64.whl (1.2 MB)
   ---------------------------------------- 0.0/1.2 MB ? eta -:--:--
   ---------------------------------------- 0.0/1.2 MB ? eta -:--:--
   -- ------------------------------------- 0.1/1.2 MB 812.7 kB/s eta 0:00:02
   ------------- -------------------------- 0.4/1.2 MB 3.4 MB/s eta 0:00:01
   ----------------------- ---------------- 0.7/1.2 MB 4.4 MB/s eta 0:00:01
   ---------------------------------- ----- 1.0/1.2 MB 4.9 MB/s eta 0:00:01
   ---------------------------------------- 1.2/1.2 MB 5.0 MB/s eta 0:00:00
Installing collected packages: torchvision
Successfully installed torchvision-0.18.0
Note: you may need to restart the kernel to use updated packages.


In [None]:
import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
from torch.utils.data import DataLoader, Dataset
from PIL import Image

# 데이터셋 클래스 정의
class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.data = self.load_data()

    def load_data(self):
        data = []
        class_names = os.listdir(self.root_dir)
        for label in class_names:
            label_path = os.path.join(self.root_dir, label)
            if os.path.isdir(label_path):  # 폴더인 경우
                label_data = None
                json_files = [f for f in os.listdir(label_path) if f.endswith('.json')]
                for json_file in json_files:
                    with open(os.path.join(label_path, json_file), 'r') as f:
                        label_data = json.load(f)
                image_folder_path = os.path.join(label_path, '원천')
                if os.path.isdir(image_folder_path):  # 이미지 폴더가 있는지 확인
                    image_files = [f for f in os.listdir(image_folder_path) if f.endswith('.jpg')]
                    for image_file in image_files:
                        image_data = os.path.join(image_folder_path, image_file)
                        data.append((image_data, label_data))
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name, label_data = self.data[idx]
        image = Image.open(img_name)
        if self.transform:
            image = self.transform(image)
        return image, label_data

# 이미지 전처리
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# 데이터셋 생성
training_path = 'F:/유형별 두피 이미지/Training'
dataset = CustomDataset(training_path, transform=transform)

# 데이터로더 설정
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=4)

# 모델 생성 (예시로 EfficientNet 사용)
model = models.efficientnet_b0(pretrained=True)
num_ftrs = model._fc.in_features
model._fc = nn.Linear(num_ftrs, len(dataset.classes))

# 장치 설정
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

# 손실 함수와 최적화 알고리즘 설정
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 학습 코드
num_epochs = 10
for epoch in range(num_epochs):
    running_loss = 0.0
    for inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
    epoch_loss = running_loss / len(dataset)
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss:.4f}")

# 학습된 모델 저장
torch.save(model.state_dict(), 'efficientnet_model.pth')
