# Train CoAtNet (ImageFolder)


In [18]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms

from feature_extraction import CoAtNet


In [None]:
# Config
DATA_ROOT = "out/test/1_AULA_F87_Pro"
VAL_RATIO = 0.1
BATCH_SIZE = 32
EPOCHS = 20
LR = 1e-4


In [20]:
if not os.path.isdir(DATA_ROOT):
    raise FileNotFoundError(f"Dataset not found: {DATA_ROOT}")

if torch.backends.mps.is_available(): # macOS에서 GPU 사용 가능 여부 확인(mps)
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(device)



FileNotFoundError: Dataset not found: out/1_AULA_F87_Pro

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

dataset = ImageFolder(root=DATA_ROOT, transform=transform) # 각 클래스는 폴더로 구분되어 있어야 함 (폴더 이름을 라벨로 처리)
num_classes = len(dataset.classes)
print(f"Classes: {num_classes} -> {dataset.classes}")


Classes: 46 -> [',', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'Backspace', 'Ctrl', 'Enter', 'Shift', 'Spacebar', 'a', 'b', 'c', 'd', 'e', 'endpoint', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'minus', 'n', 'o', 'p', 'plus', 'q', 'questionmark', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']


In [None]:
val_size = int(len(dataset) * VAL_RATIO)
train_size = len(dataset) - val_size
train_ds, val_ds = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)


In [None]:
backbone = CoAtNet().to(device)
classifier = nn.Linear(768, num_classes).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(list(backbone.parameters()) + list(classifier.parameters()), lr=LR)


In [None]:
best_val = 0.0

for epoch in range(1, EPOCHS + 1):
    backbone.train()
    classifier.train()
    total_loss = 0.0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()

        feats = backbone(images)  # [B, 768]
        logits = classifier(feats)
        loss = criterion(logits, labels)

        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    # Validation
    backbone.eval()
    classifier.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            feats = backbone(images)
            logits = classifier(feats)
            preds = logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    val_acc = correct / total if total else 0
    print(f"Epoch {epoch:02d} | Loss: {total_loss/len(train_loader):.4f} | Val Acc: {val_acc:.4f}")

    if val_acc > best_val:
        best_val = val_acc
        torch.save({
            "backbone": backbone.state_dict(),
            "classifier": classifier.state_dict(),
            "classes": dataset.classes
        }, "coatnet_classifier.pth")

print("Training done. Saved: coatnet_classifier.pth")


In [None]:
# --- Prediction on a few samples ---
import random
import torch
from PIL import Image
from torchvision.datasets import ImageFolder

# reload dataset (same transform)
dataset = ImageFolder(root=DATA_ROOT, transform=transform)

# load best checkpoint
ckpt = torch.load('coatnet_classifier.pth', map_location=device)
backbone = CoAtNet().to(device)
classifier = nn.Linear(768, len(ckpt['classes'])).to(device)
backbone.load_state_dict(ckpt['backbone'])
classifier.load_state_dict(ckpt['classifier'])
backbone.eval()
classifier.eval()

# pick random indices
indices = random.sample(range(len(dataset)), k=min(8, len(dataset)))

with torch.no_grad():
    for idx in indices:
        img, label = dataset[idx]
        img = img.unsqueeze(0).to(device)
        feat = backbone(img)
        logits = classifier(feat)
        pred = logits.argmax(dim=1).item()
        true_label = dataset.classes[label]
        pred_label = ckpt['classes'][pred]
        print(f"idx={idx} true={true_label} pred={pred_label}")


idx=2439 true=a pred=a
idx=1961 true=Shift pred=Shift
idx=4944 true=q pred=q
idx=3794 true=i pred=i
idx=1959 true=Shift pred=Shift
idx=5657 true=v pred=v
idx=906 true=6 pred=6
idx=4282 true=m pred=m


In [None]:
# --- Predict per-keystroke using existing mel images in out/test (top-k) ---
from pathlib import Path
from PIL import Image
import torch.nn.functional as F

TOPK = None  # set to None to print all classes

out_root = Path('out/test')
files = [
    'Hello World 2.wav',
    'Hello World.wav',
    'test.wav',
]

with torch.no_grad():
    for name in files:
        stem = Path(name).stem
        # find folder anywhere under out/test matching stem
        candidates = [p for p in out_root.rglob('*') if p.is_dir() and p.name == stem]
        if not candidates:
            print('Missing image folder for:', name)
            continue
        # use the first match
        img_dir = candidates[0]

        imgs = sorted(img_dir.glob('keystroke_*.png'))
        if not imgs:
            print('No images in:', img_dir)
            continue

        print(f"== {name} ==")
        print('img_dir:', img_dir)
        print("strokes:", len(imgs))

        for i, img_path in enumerate(imgs):
            img = Image.open(img_path).convert('RGB')
            img_t = transform(img).unsqueeze(0).to(device)
            feat = backbone(img_t)
            logits = classifier(feat)
            probs = F.softmax(logits, dim=1).squeeze(0)
            k = probs.numel() if TOPK is None else min(TOPK, probs.numel())
            topk = torch.topk(probs, k=k)

            print(f"#{i:03d} img={img_path.name}")
            for p, idx in zip(topk.values, topk.indices):
                print(f"  {ckpt['classes'][idx]}: {float(p):.4f}")


== Hello World 2.wav ==
img_dir: out/test/test/Hello World 2
strokes: 13
#000 img=keystroke_0000.png
  c: 0.6635
  a: 0.2281
  Spacebar: 0.0972
  b: 0.0079
  f: 0.0019
  Shift: 0.0011
  g: 0.0001
  plus: 0.0001
  w: 0.0000
  m: 0.0000
  questionmark: 0.0000
  Backspace: 0.0000
  d: 0.0000
  Enter: 0.0000
  endpoint: 0.0000
  e: 0.0000
  z: 0.0000
  5: 0.0000
  minus: 0.0000
  Ctrl: 0.0000
  2: 0.0000
  k: 0.0000
  i: 0.0000
  8: 0.0000
  j: 0.0000
  7: 0.0000
  x: 0.0000
  1: 0.0000
  l: 0.0000
  3: 0.0000
  p: 0.0000
  t: 0.0000
  9: 0.0000
  4: 0.0000
  s: 0.0000
  n: 0.0000
  0: 0.0000
  h: 0.0000
  ,: 0.0000
  6: 0.0000
  o: 0.0000
  v: 0.0000
  y: 0.0000
  r: 0.0000
  q: 0.0000
  u: 0.0000
#001 img=keystroke_0001.png
  a: 0.8300
  c: 0.1198
  Spacebar: 0.0357
  f: 0.0090
  b: 0.0046
  Shift: 0.0005
  w: 0.0002
  g: 0.0001
  d: 0.0000
  Backspace: 0.0000
  endpoint: 0.0000
  questionmark: 0.0000
  2: 0.0000
  plus: 0.0000
  m: 0.0000
  Enter: 0.0000
  3: 0.0000
  minus: 0.0000
  z:

In [None]:
from PIL import Image
import numpy as np

train_img = Image.open('out/1_AULA_F87_Pro/0/keystroke_0000.png').convert('RGB')
test_img = Image.open('out/test/Hello World/keystroke_0000.png').convert('RGB')

print('train min/max:', np.array(train_img).min(), np.array(train_img).max())
print('test min/max:', np.array(test_img).min(), np.array(test_img).max())
print('train size:', train_img.size, 'test size:', test_img.size)


train min/max: 0 255
test min/max: 0 255
train size: (900, 600) test size: (900, 600)
