In [None]:
import zipfile
import os
import re
from collections import Counter
from sklearn.model_selection import train_test_split
import numpy as np
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torch
import torch.nn as nn
from torchvision import models
import torch.optim as optim

In [None]:
from google.colab import files
uploaded = files.upload()

In [None]:
zip_file_name = 'diatom_dataset.zip'

extract_path = '/content/dataset'

zip_file_path = f'/content/{zip_file_name}'

if not os.path.exists(extract_path):
    os.makedirs(extract_path)

with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
  zip_ref.extractall(extract_path)
print("압축 해제 완료")

압축 해제 완료


In [None]:
!rm -rf /content/dataset/__MACOSX # Mac 압축시 생기는 메타데이터 제거

In [None]:
DATASET_PATH = '/content/dataset/'

all_image_paths = []
for root, _, files in os.walk(DATASET_PATH):
    for file in files:
        if file.lower().endswith(('.jpg', '.jpeg')):
            all_image_paths.append(os.path.join(root, file))

# labeling
def extract_label_from_filename(filepath):
    filename = os.path.basename(filepath)
    label = filename.split('_')[0]
    return label

labels = [extract_label_from_filename(p) for p in all_image_paths]

print(f"이미지 파일: {len(all_image_paths)}")
print(f"클래수(종): {len(set(labels))}")
print("\n")
for label, count in Counter(labels).most_common():
    print(f"클래스별 이미지 개수- {label}: {count}개")


이미지 파일: 264
클래수(종): 38


클래스별 이미지 개수- Diatoma tenuis: 11개
클래스별 이미지 개수- Achnanthes inflata: 10개
클래스별 이미지 개수- Chamaepinnularia krasskei: 9개
클래스별 이미지 개수- Chamaepinnularia gandrupii: 9개
클래스별 이미지 개수- Amphora bicapitata: 8개
클래스별 이미지 개수- Caloneis lewisii: 8개
클래스별 이미지 개수- Amphora copulata: 8개
클래스별 이미지 개수- Chamaepinnularia krookii: 8개
클래스별 이미지 개수- Chamaepinnularia witkowskii: 8개
클래스별 이미지 개수- Achnanthes tumescens: 7개
클래스별 이미지 개수- Amphora pediculus: 7개
클래스별 이미지 개수- Diatoma moniliformis: 7개
클래스별 이미지 개수- Amphora ovalis: 7개
클래스별 이미지 개수- Achnanthes felinophila: 7개
클래스별 이미지 개수- Diatoma vulgaris: 7개
클래스별 이미지 개수- Achnanthes longboardia: 7개
클래스별 이미지 개수- Chamaepinnularia mediocris: 7개
클래스별 이미지 개수- Achnanthes undulorostrata: 7개
클래스별 이미지 개수- Caloneis fusus: 7개
클래스별 이미지 개수- Anomoeoneis sphaerophora f. rostrata: 7개
클래스별 이미지 개수- Anomoeoneis monoensis: 7개
클래스별 이미지 개수- Achnanthes mauiensis: 7개
클래스별 이미지 개수- Anomoeoneis sculpta: 7개
클래스별 이미지 개수- Diatoma problematica: 7개
클래스별 이미지 개수- Anomoeoneis sphaerophora: 7개
클래스

In [None]:
sorted_labels = sorted(list(set(labels)))
label_to_int = {label: i for i, label in enumerate(sorted_labels)}
int_to_label = {i: label for i, label in enumerate(sorted_labels)}

int_labels = [label_to_int[label] for label in labels]


train_paths, val_paths, train_labels, val_labels = train_test_split(
    all_image_paths,
    int_labels,
    test_size=0.2,
    random_state=42,
    stratify=int_labels
)

print(f"학습용 데이터 개수: {len(train_paths)}개")
print(f"검증용 데이터 개수: {len(val_paths)}개")

학습용 데이터 개수: 211개
검증용 데이터 개수: 53개


In [None]:
transform_train = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

transform_val = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

class DiatomDataset(Dataset):
    def __init__(self, file_paths, labels, transform=None):
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        image = Image.open(self.file_paths[idx]).convert("RGB")
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)

        return image, label

train_dataset = DiatomDataset(train_paths, train_labels, transform=transform_train)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

val_dataset = DiatomDataset(val_paths, val_labels, transform=transform_val)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

In [None]:
model = models.resnet18(pretrained=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

for param in model.parameters():
    param.requires_grad = False

for param in model.layer4.parameters():
    param.requires_grad = True

num_features = model.fc.in_features
num_classes = len(sorted_labels)
model.fc = nn.Linear(num_features, num_classes)

model.to(device)

print("수정된 마지막 레이어")
print(model.fc)

수정된 마지막 레이어
Linear(in_features=512, out_features=38, bias=True)


In [None]:
criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam([
    {'params': model.fc.parameters(), 'lr': 0.001},      # 머리는 비교적 빠르게 학습
    {'params': model.layer4.parameters(), 'lr': 0.0001}  # 몸통은 매우 조심스럽게 학습
], lr=0.001)

NUM_EPOCHS = 25

for epoch in range(NUM_EPOCHS):
    # 훈련 모드
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)

    epoch_loss = running_loss / len(train_loader.dataset)

    # 평가 모드
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    epoch_acc = 100 * correct / total

    print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] | Loss: {epoch_loss:.4f} | Val Accuracy: {epoch_acc:.2f}%")

print("\n 학습 완료!")

Epoch [1/25] | Loss: 3.8107 | Val Accuracy: 24.53%
Epoch [2/25] | Loss: 2.4566 | Val Accuracy: 33.96%
Epoch [3/25] | Loss: 1.7150 | Val Accuracy: 49.06%
Epoch [4/25] | Loss: 1.0654 | Val Accuracy: 60.38%
Epoch [5/25] | Loss: 0.7992 | Val Accuracy: 67.92%
Epoch [6/25] | Loss: 0.5499 | Val Accuracy: 66.04%
Epoch [7/25] | Loss: 0.3958 | Val Accuracy: 71.70%
Epoch [8/25] | Loss: 0.2626 | Val Accuracy: 71.70%
Epoch [9/25] | Loss: 0.2716 | Val Accuracy: 75.47%
Epoch [10/25] | Loss: 0.2688 | Val Accuracy: 73.58%
Epoch [11/25] | Loss: 0.1668 | Val Accuracy: 81.13%
Epoch [12/25] | Loss: 0.1506 | Val Accuracy: 75.47%
Epoch [13/25] | Loss: 0.1339 | Val Accuracy: 73.58%
Epoch [14/25] | Loss: 0.1140 | Val Accuracy: 73.58%
Epoch [15/25] | Loss: 0.0988 | Val Accuracy: 73.58%
Epoch [16/25] | Loss: 0.1014 | Val Accuracy: 77.36%
Epoch [17/25] | Loss: 0.0976 | Val Accuracy: 77.36%
Epoch [18/25] | Loss: 0.0676 | Val Accuracy: 73.58%
Epoch [19/25] | Loss: 0.0693 | Val Accuracy: 77.36%
Epoch [20/25] | Loss: