In [14]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import os, glob, shutil
import torch, torch.nn as nn, torch.nn.functional as F
import cv2 as cv, numpy as np
from collections import deque

In [9]:
# 당신의 데이터 최상위 폴더(여기 아래 어딘가에 클래스 폴더 6개가 있음)
RAW_ROOT = r"trashnet/data/dataset-resized"

# 1) 맥 압축 잔여물 제거
def clean_mac_artifacts(root: str):
    for d in glob.glob(os.path.join(root, "**", "__MACOSX"), recursive=True):
        shutil.rmtree(d, ignore_errors=True)
    for p in glob.glob(os.path.join(root, "**", "._*"), recursive=True):
        try: os.remove(p)
        except: pass
    for p in glob.glob(os.path.join(root, "**", ".DS_Store"), recursive=True):
        try: os.remove(p)
        except: pass

clean_mac_artifacts(RAW_ROOT)

# 2) 클래스 폴더 깊이 자동 탐지 (cardboard/glass/metal/paper/plastic/trash 찾기)
def find_class_root(root: str):
    expected = {"cardboard","glass","metal","paper","plastic","trash"}
    names = {n for n in os.listdir(root) if os.path.isdir(os.path.join(root,n))}
    # 현재 깊이에 클래스 폴더가 보이면 OK
    if len(expected & names) >= 4:
        return root
    # 한 단계 더 들어가서 확인
    for n in names:
        sub = os.path.join(root, n)
        if not os.path.isdir(sub): 
            continue
        sub_names = {m for m in os.listdir(sub) if os.path.isdir(os.path.join(sub,m))}
        if len(expected & sub_names) >= 4:
            return sub
    return root  # 못 찾으면 원래 root 반환(추후 프린트로 확인)

DATA_DIR = find_class_root(RAW_ROOT)
print("사용할 DATA_DIR:", os.path.abspath(DATA_DIR))
print("하위 폴더 샘플:", os.listdir(DATA_DIR)[:10])

사용할 DATA_DIR: c:\Users\jk316\practice\trashnet\data\dataset-resized\dataset-resized
하위 폴더 샘플: ['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']


In [10]:
valid_ext = {".jpg",".jpeg",".png",".bmp"}
def is_ok(p: str):
    name = os.path.basename(p)
    ext  = os.path.splitext(name)[1].lower()
    return (ext in valid_ext) and (not name.startswith("._")) and ("__MACOSX" not in p)

IMG_SIZE = 128

tf_train = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.2,0.2,0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])
tf_eval = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

full_ds = datasets.ImageFolder(DATA_DIR, transform=tf_train, is_valid_file=is_ok)
print("클래스:", full_ds.classes, "총 이미지:", len(full_ds))

# 무조건 6개 클래스가 떠야 정상!
assert set(full_ds.classes) >= {"cardboard","glass","metal","paper","plastic","trash"}, "클래스 폴더 인식 실패"

n = len(full_ds)
n_tr = int(0.8*n)
train_ds, val_ds = random_split(full_ds, [n_tr, n-n_tr])
val_ds.dataset.transform = tf_eval  # 검증은 증강 제거

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True,  num_workers=0)
val_loader   = DataLoader(val_ds,   batch_size=32, shuffle=False, num_workers=0)


클래스: ['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash'] 총 이미지: 2528


In [11]:
class TrashClassifier(nn.Module):
    def __init__(self, num_classes: int):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3,16,3,padding=1), nn.ReLU(), nn.MaxPool2d(2),  # 128→64
            nn.Conv2d(16,32,3,padding=1), nn.ReLU(), nn.MaxPool2d(2), # 64→32
            nn.Conv2d(32,64,3,padding=1), nn.ReLU(), nn.MaxPool2d(2), # 32→16
        )
        self.gap = nn.AdaptiveAvgPool2d((1,1))
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64,128), nn.ReLU(),
            nn.Linear(128, num_classes)
        )
    def forward(self, x):
        x = self.features(x)
        x = self.gap(x)
        return self.classifier(x)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model  = TrashClassifier(num_classes=len(full_ds.classes)).to(device)


In [12]:
EPOCHS = 20
LR = 1e-3

opt  = torch.optim.Adam(model.parameters(), lr=LR)
crit = nn.CrossEntropyLoss()

# 과거 잘못된 산출물 제거
for f in ["model.pth","classes.txt"]:
    if os.path.exists(f):
        os.remove(f)

best_val = 0.0
for epoch in range(1, EPOCHS+1):
    model.train()
    run_loss = 0.0
    for x,y in train_loader:
        x,y = x.to(device), y.to(device)
        opt.zero_grad()
        out = model(x)
        loss = crit(out, y)
        loss.backward()
        opt.step()
        run_loss += loss.item()

    model.eval()
    correct = total = 0
    with torch.no_grad():
        for x,y in val_loader:
            x,y = x.to(device), y.to(device)
            pred = model(x).argmax(1)
            total += y.size(0)
            correct += (pred==y).sum().item()
    val_acc = correct / max(1,total)
    print(f"[{epoch:02d}] loss={run_loss/len(train_loader):.4f}  val_acc={val_acc*100:.2f}%")

    if val_acc > best_val:
        best_val = val_acc
        torch.save(model.state_dict(), "model.pth")
        with open("classes.txt","w",encoding="utf-8") as f:
            f.write("\n".join(full_ds.classes))
        print("✔ 저장: model.pth / classes.txt")

print("최고 검증 정확도:", f"{best_val*100:.2f}%")


[01] loss=1.7096  val_acc=31.62%
✔ 저장: model.pth / classes.txt
[02] loss=1.5023  val_acc=35.97%
✔ 저장: model.pth / classes.txt
[03] loss=1.4326  val_acc=37.35%
✔ 저장: model.pth / classes.txt
[04] loss=1.3610  val_acc=46.64%
✔ 저장: model.pth / classes.txt
[05] loss=1.3009  val_acc=50.40%
✔ 저장: model.pth / classes.txt
[06] loss=1.2404  val_acc=51.58%
✔ 저장: model.pth / classes.txt
[07] loss=1.1975  val_acc=54.35%
✔ 저장: model.pth / classes.txt
[08] loss=1.1676  val_acc=52.57%
[09] loss=1.1140  val_acc=56.72%
✔ 저장: model.pth / classes.txt
[10] loss=1.1055  val_acc=60.08%
✔ 저장: model.pth / classes.txt
[11] loss=1.0747  val_acc=58.70%
[12] loss=1.0457  val_acc=59.09%
[13] loss=1.0469  val_acc=59.29%
[14] loss=1.0358  val_acc=62.25%
✔ 저장: model.pth / classes.txt
[15] loss=0.9819  val_acc=62.45%
✔ 저장: model.pth / classes.txt
[16] loss=0.9994  val_acc=63.44%
✔ 저장: model.pth / classes.txt
[17] loss=0.9646  val_acc=64.43%
✔ 저장: model.pth / classes.txt
[18] loss=0.9601  val_acc=65.22%
✔ 저장: model.pth 

In [15]:
with open("classes.txt","r",encoding="utf-8") as f:
    CLASSES = [l.strip() for l in f if l.strip()]

# 추론 전용 모델 로드(동일 구조)
infer_model = TrashClassifier(num_classes=len(CLASSES)).to(device)
infer_model.load_state_dict(torch.load("model.pth", map_location=device))
infer_model.eval()

def preprocess(bgr):
    rgb = cv.cvtColor(bgr, cv.COLOR_BGR2RGB)
    rgb = cv.resize(rgb, (IMG_SIZE,IMG_SIZE), interpolation=cv.INTER_AREA)
    x = rgb.astype(np.float32)/255.0
    x = (x-0.5)/0.5
    x = np.transpose(x, (2,0,1))
    return torch.from_numpy(x).unsqueeze(0).to(device)

cap = cv.VideoCapture(0)
assert cap.isOpened(), "웹캠을 열 수 없습니다. 인덱스를 1,2 등으로 바꿔보세요."

smooth = deque(maxlen=5)  # 흔들림 완화

while True:
    ok, frame = cap.read()
    if not ok: break
    with torch.no_grad():
        p = torch.softmax(infer_model(preprocess(frame)), dim=1).cpu().numpy()[0]
    smooth.append(p)
    avg = np.mean(smooth, axis=0)
    idx = int(np.argmax(avg))
    label, conf = CLASSES[idx], float(avg[idx])
    cv.putText(frame, f"{label} {conf*100:.1f}%", (10,40),
                cv.FONT_HERSHEY_SIMPLEX, 1.2, (0,255,0), 3)
    cv.imshow("Trash Classifier", frame)
    if cv.waitKey(1) == ord('q'): 
        break

cap.release()
cv.destroyAllWindows()


  infer_model.load_state_dict(torch.load("model.pth", map_location=device))
