In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import zipfile
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm

In [2]:
from google.colab import files
uploaded = files.upload()

In [5]:
zip_file_name = 'diatom_dataset.zip'

extract_path = '/content/dataset'

zip_file_path = f'/content/{zip_file_name}'

if not os.path.exists(extract_path):
    os.makedirs(extract_path)

with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
  zip_ref.extractall(extract_path)
print("압축 해제 완료")

압축 해제 완료


In [6]:
!rm -rf /content/dataset/__MACOSX

In [7]:
DATASET_PATH = '/content/dataset/diatom_dataset'

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

full_dataset = datasets.ImageFolder(root=DATASET_PATH, transform=transform)

print(f"전체 이미지 개수: {len(full_dataset)}")
print(f"전체 클래스: {len(full_dataset.classes)}개")

전체 이미지 개수: 687
전체 클래스: 62개


In [8]:
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [9]:
class DiatomClassifier(nn.Module):
    def __init__(self, num_classes):
        super(DiatomClassifier, self).__init__()
        self.backbone = models.resnet18(pretrained=True)
        in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Linear(in_features, num_classes)

    def forward(self, x):
        return self.backbone(x)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = len(full_dataset.classes)
model = DiatomClassifier(num_classes).to(device)



Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 198MB/s]


In [10]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [11]:
def train_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss, correct, total = 0, 0, 0
    for images, labels in tqdm(loader):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    return total_loss / len(loader), 100 * correct / total


def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss, correct, total = 0, 0, 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            total_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    return total_loss / len(loader), 100 * correct / total


EPOCHS = 15
for epoch in range(EPOCHS):
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)
    print(f"[{epoch+1}/{EPOCHS}] Train Acc: {train_acc:.2f}% | Test Acc: {test_acc:.2f}%")


100%|██████████| 18/18 [00:09<00:00,  1.86it/s]


[1/15] Train Acc: 23.32% | Test Acc: 42.75%


100%|██████████| 18/18 [00:07<00:00,  2.26it/s]


[2/15] Train Acc: 72.86% | Test Acc: 52.90%


100%|██████████| 18/18 [00:07<00:00,  2.46it/s]


[3/15] Train Acc: 93.08% | Test Acc: 68.12%


100%|██████████| 18/18 [00:07<00:00,  2.42it/s]


[4/15] Train Acc: 98.18% | Test Acc: 72.46%


100%|██████████| 18/18 [00:07<00:00,  2.28it/s]


[5/15] Train Acc: 99.82% | Test Acc: 73.91%


100%|██████████| 18/18 [00:07<00:00,  2.27it/s]


[6/15] Train Acc: 100.00% | Test Acc: 76.09%


100%|██████████| 18/18 [00:08<00:00,  2.17it/s]


[7/15] Train Acc: 99.64% | Test Acc: 76.09%


100%|██████████| 18/18 [00:07<00:00,  2.45it/s]


[8/15] Train Acc: 100.00% | Test Acc: 76.81%


100%|██████████| 18/18 [00:08<00:00,  2.24it/s]


[9/15] Train Acc: 100.00% | Test Acc: 76.09%


100%|██████████| 18/18 [00:08<00:00,  2.20it/s]


[10/15] Train Acc: 100.00% | Test Acc: 75.36%


100%|██████████| 18/18 [00:07<00:00,  2.26it/s]


[11/15] Train Acc: 100.00% | Test Acc: 76.81%


100%|██████████| 18/18 [00:07<00:00,  2.44it/s]


[12/15] Train Acc: 100.00% | Test Acc: 77.54%


100%|██████████| 18/18 [00:07<00:00,  2.40it/s]


[13/15] Train Acc: 100.00% | Test Acc: 77.54%


100%|██████████| 18/18 [00:08<00:00,  2.24it/s]


[14/15] Train Acc: 100.00% | Test Acc: 76.81%


100%|██████████| 18/18 [00:07<00:00,  2.25it/s]


[15/15] Train Acc: 100.00% | Test Acc: 77.54%


In [13]:
from collections import defaultdict

# 예측 결과 수집
y_true, y_pred = [], []
model.eval()
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        preds = torch.argmax(outputs, dim=1)
        y_true.extend(labels.cpu().numpy())
        y_pred.extend(preds.cpu().numpy())

# 종 이름 → 속 이름 매핑
def get_genus(species_name):
    return species_name.split(' ')[0]

species_to_genus = {i: get_genus(name) for i, name in enumerate(full_dataset.classes)}
genus_correct, genus_total = defaultdict(int), defaultdict(int)

for i in range(len(y_true)):
    genus = species_to_genus[y_true[i]]
    genus_total[genus] += 1
    if y_true[i] == y_pred[i]:
        genus_correct[genus] += 1

print("\n[속별 정확도]")
for genus in genus_total:
    acc = 100 * genus_correct[genus] / genus_total[genus]
    print(f"{genus}: {acc:.2f}%")


[속별 정확도]
Achnanthes: 75.00%
Fragilariforma: 25.00%
Fragilaria: 52.94%
Epithemia: 92.86%
Geissleria: 85.71%
Caloneis: 100.00%
Anomoeoneis: 75.00%
Amphora: 60.00%
Chamaepinnularia: 100.00%


In [15]:
from collections import Counter
import os

base_path = "/content/dataset/diatom_dataset"
genus_counts = Counter()

for folder in os.listdir(base_path):
    if not os.path.isdir(os.path.join(base_path, folder)):
        continue
    genus = folder.split(' ')[0]  # 띄어쓰기 기준
    genus_counts[genus] += len(os.listdir(os.path.join(base_path, folder)))

print("\n[속별 이미지 개수]")
for g, c in genus_counts.items():
    print(f"{g}: {c}")



[속별 이미지 개수]
Amphora: 51
Fragilariforma: 63
Geissleria: 61
Fragilaria: 81
Epithemia: 77
Anomoeoneis: 46
Achnanthes: 50
Caloneis: 49
Chamaepinnularia: 8
