In [1]:
import zipfile
import os
import re
from collections import Counter
from sklearn.model_selection import train_test_split
import numpy as np
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torch
import torch.nn as nn
from torchvision import models
import torch.optim as optim

In [2]:
from google.colab import files
uploaded = files.upload()

KeyboardInterrupt: 

In [3]:
zip_file_name = 'diatom_dataset.zip'

extract_path = '/content/dataset'

zip_file_path = f'/content/{zip_file_name}'

if not os.path.exists(extract_path):
    os.makedirs(extract_path)

with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
  zip_ref.extractall(extract_path)
print("압축 해제 완료")

압축 해제 완료


In [4]:
!rm -rf /content/dataset/__MACOSX # Mac 압축시 생기는 메타데이터 제거

In [5]:
DATASET_PATH = '/content/dataset/'

all_image_paths = []
for root, _, files in os.walk(DATASET_PATH):
    for file in files:
        if file.lower().endswith(('.jpg', '.jpeg')):
            all_image_paths.append(os.path.join(root, file))

# labeling
def extract_label_from_filename(filepath):
    filename = os.path.basename(filepath)
    label = filename.split('_')[0]
    return label

labels = [extract_label_from_filename(p) for p in all_image_paths]

print(f"이미지 파일: {len(all_image_paths)}")
print(f"클래수(종): {len(set(labels))}")
print("\n")
for label, count in Counter(labels).most_common():
    print(f"클래스별 이미지 개수- {label}: {count}개")


이미지 파일: 674
클래수(종): 94


클래스별 이미지 개수- Fragilariforma nitzschioides: 12개
클래스별 이미지 개수- Diatoma tenuis: 11개
클래스별 이미지 개수- Fragilaria recapitellata: 11개
클래스별 이미지 개수- Achnanthes inflata: 10개
클래스별 이미지 개수- Geissleria cascadensis: 10개
클래스별 이미지 개수- Fragilaria crotonensis: 10개
클래스별 이미지 개수- Lindavia intermedia: 10개
클래스별 이미지 개수- Fragilaria synegrotesca: 10개
클래스별 이미지 개수- Chamaepinnularia gandrupii: 9개
클래스별 이미지 개수- Fragilariforma bicapitata: 9개
클래스별 이미지 개수- Geissleria acceptata: 9개
클래스별 이미지 개수- Chamaepinnularia krasskei: 9개
클래스별 이미지 개수- Geissleria decussis: 9개
클래스별 이미지 개수- Fragilariforma virescens: 9개
클래스별 이미지 개수- Fragilaria vaucheriae: 9개
클래스별 이미지 개수- Epithemia sorex: 8개
클래스별 이미지 개수- Geissleria kriegeri: 8개
클래스별 이미지 개수- Amphora copulata: 8개
클래스별 이미지 개수- Epithemia adnata: 8개
클래스별 이미지 개수- Lindavia antiqua: 8개
클래스별 이미지 개수- Chamaepinnularia krookii: 8개
클래스별 이미지 개수- Amphora bicapitata: 8개
클래스별 이미지 개수- Lindavia praetermissa: 8개
클래스별 이미지 개수- Chamaepinnularia witkowskii: 8개
클래스별 이미지 개수- Caloneis lewisii: 8개

In [6]:
sorted_labels = sorted(list(set(labels)))
label_to_int = {label: i for i, label in enumerate(sorted_labels)}
int_to_label = {i: label for i, label in enumerate(sorted_labels)}

int_labels = [label_to_int[label] for label in labels]


train_paths, val_paths, train_labels, val_labels = train_test_split(
    all_image_paths,
    int_labels,
    test_size=0.2,
    random_state=42,
    stratify=int_labels
)

print(f"학습용 데이터 개수: {len(train_paths)}개")
print(f"검증용 데이터 개수: {len(val_paths)}개")

학습용 데이터 개수: 539개
검증용 데이터 개수: 135개


In [13]:
transform_train = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

transform_val = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

class DiatomDataset(Dataset):
    def __init__(self, file_paths, labels, transform=None):
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        image = Image.open(self.file_paths[idx]).convert("RGB")
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)

        return image, label

train_dataset = DiatomDataset(train_paths, train_labels, transform=transform_train)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

val_dataset = DiatomDataset(val_paths, val_labels, transform=transform_val)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

In [18]:
model = models.resnet18(pretrained=True)

num_features = model.fc.in_features
num_classes = len(sorted_labels)
model.fc = nn.Linear(num_features, num_classes)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [19]:
criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=0.001)

NUM_EPOCHS = 20

for epoch in range(NUM_EPOCHS):
    # 훈련 모드
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)

    epoch_loss = running_loss / len(train_loader.dataset)

    # 평가 모드
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    epoch_acc = 100 * correct / total

    print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] | Loss: {epoch_loss:.4f} | Val Accuracy: {epoch_acc:.2f}%")

print("\n 학습 완료!")

Epoch [1/20] | Loss: 4.7410 | Val Accuracy: 8.89%
Epoch [2/20] | Loss: 3.2234 | Val Accuracy: 20.74%
Epoch [3/20] | Loss: 2.4365 | Val Accuracy: 38.52%
Epoch [4/20] | Loss: 1.8568 | Val Accuracy: 41.48%
Epoch [5/20] | Loss: 1.5480 | Val Accuracy: 31.85%
Epoch [6/20] | Loss: 1.2677 | Val Accuracy: 47.41%
Epoch [7/20] | Loss: 0.9869 | Val Accuracy: 50.37%
Epoch [8/20] | Loss: 0.7672 | Val Accuracy: 54.07%
Epoch [9/20] | Loss: 0.6593 | Val Accuracy: 55.56%
Epoch [10/20] | Loss: 0.5814 | Val Accuracy: 59.26%
Epoch [11/20] | Loss: 0.5198 | Val Accuracy: 56.30%
Epoch [12/20] | Loss: 0.4251 | Val Accuracy: 62.22%
Epoch [13/20] | Loss: 0.3270 | Val Accuracy: 68.89%
Epoch [14/20] | Loss: 0.2654 | Val Accuracy: 67.41%
Epoch [15/20] | Loss: 0.3077 | Val Accuracy: 65.19%
Epoch [16/20] | Loss: 0.2534 | Val Accuracy: 68.89%
Epoch [17/20] | Loss: 0.2277 | Val Accuracy: 62.96%
Epoch [18/20] | Loss: 0.2203 | Val Accuracy: 65.93%
Epoch [19/20] | Loss: 0.1812 | Val Accuracy: 68.15%
Epoch [20/20] | Loss: 