In [2]:
from pathlib import Path
import torch
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

from data_loader import build_loaders
from train_ddp   import LitTimm, seed_everything

IMG_SIZE = 224
BATCH_SIZE = 512
MODEL_NAME = "vit_xsmall_patch16_clip_224"
CKPT_PATH  = Path("model_weights/0423_1455_epoch=39_val_f1=0.8880.ckpt")   # ★수정★

seed_everything(777)

In [4]:
_, val_loader, test_loader, le = build_loaders(
    get_test=True,
    data_dir="./data/train",
    test_size=0.2,
    img_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    num_workers=0, # Window + 노트북 환경에서는 0으로
    prefetch_factor=None, # Window + 노트북 환경에서는 None으로
    seed=777,
    mean=(0.38771973, 0.39787053, 0.40713646),
    std =(0.2130759 , 0.21581389, 0.22090413),
)

NUM_CLASSES = len(le.classes_)
print("✓ classes:", le.classes_)

✓ classes: ['Andesite' 'Basalt' 'Etc' 'Gneiss' 'Granite' 'Mud_Sandstone'
 'Weathered_Rock']


In [5]:
ckpt = torch.load(CKPT_PATH, map_location="cpu")
state = ckpt["state_dict"]

model = LitTimm(model_name=MODEL_NAME, num_classes=NUM_CLASSES,
                lr=1e-4, class_weights=None)          # lr/weights는 의미 X
model.load_state_dict(ckpt["state_dict"], strict=True)
model.eval().cuda()
print("✓ checkpoint loaded")

✓ checkpoint loaded


In [10]:
# ====================== 4) Validation 성능 계산 ======================
from sklearn.metrics import f1_score
import numpy as np, torch
from tqdm.auto import tqdm

def validation(model, criterion, val_loader, device):
    model.eval()
    val_loss = []
    preds, true_labels = [], []

    with torch.no_grad():
        for imgs, labels in tqdm(val_loader, desc="Val"):
            imgs = imgs.float().to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            logits = model(imgs)
            loss   = criterion(logits, labels)

            preds += logits.argmax(1).cpu().numpy().tolist()
            true_labels += labels.cpu().numpy().tolist()
            val_loss.append(loss.item())

    _val_loss  = np.mean(val_loss)
    _val_score = f1_score(true_labels, preds, average='macro')
    return _val_loss, _val_score


device = torch.device("cuda")
criterion = torch.nn.CrossEntropyLoss().to(device)

val_loss, val_f1 = validation(model, criterion, val_loader, device)
print(f"⚡ Validation  |  loss: {val_loss:.4f}  |  Macro-F1: {val_f1:.4f}")

Val:   0%|          | 0/149 [00:00<?, ?it/s]

⚡ Validation  |  loss: 0.0791  |  Macro-F1: 0.9698


In [6]:
pred_labels = []
pred_probs = []

with torch.no_grad():
    for imgs in tqdm(test_loader, desc="Test"):  # imgs, _  → imgs
        imgs = imgs.cuda(non_blocking=True)
        logits = model(imgs)
        probs = torch.softmax(logits, dim=1)
        pred_labels.extend(logits.argmax(1).cpu().numpy())
        pred_probs.extend(probs.cpu().numpy())

# 숫자 → 클래스명 
pred_classes = le.inverse_transform(np.array(pred_labels))
pred_probs = np.array(pred_probs)

submission = pd.read_csv('./data/sample_submission.csv')
submission['rock_type'] = pred_classes

Test:   0%|          | 0/186 [00:00<?, ?it/s]


Prediction probabilities for first few samples:

Sample 0:
Andesite: 0.0002
Basalt: 0.0000
Etc: 0.0000
Gneiss: 0.0000
Granite: 0.0000
Mud_Sandstone: 0.9998
Weathered_Rock: 0.0000

Sample 1:
Andesite: 0.0013
Basalt: 0.0099
Etc: 0.0004
Gneiss: 0.0001
Granite: 0.0000
Mud_Sandstone: 0.9882
Weathered_Rock: 0.0000

Sample 2:
Andesite: 0.0001
Basalt: 0.0000
Etc: 0.0000
Gneiss: 0.0000
Granite: 0.0000
Mud_Sandstone: 0.9999
Weathered_Rock: 0.0000

Sample 3:
Andesite: 0.0000
Basalt: 0.0000
Etc: 0.0016
Gneiss: 0.0000
Granite: 0.9984
Mud_Sandstone: 0.0000
Weathered_Rock: 0.0000

Sample 4:
Andesite: 0.0000
Basalt: 0.0000
Etc: 0.0000
Gneiss: 0.0000
Granite: 0.9999
Mud_Sandstone: 0.0000
Weathered_Rock: 0.0000


In [7]:
SUBMIT_NAME = "./Submission/submission.csv"
submission.to_csv(SUBMIT_NAME, index=False)
print("🎉 saved:", SUBMIT_NAME)

🎉 saved: ./Submission/submission.csv


In [9]:
probs = torch.softmax(logits, dim=1).cpu().numpy()
probs

array([[1.51008744e-05, 1.54826157e-05, 2.39382473e-07, ...,
        2.99387084e-05, 6.71633679e-05, 4.03645208e-06],
       [7.25029804e-06, 1.71697284e-05, 2.72593006e-05, ...,
        1.25071565e-05, 3.11056239e-04, 4.43403860e-06],
       [5.66039489e-07, 1.33902927e-06, 1.66006830e-05, ...,
        2.50340581e-05, 8.90584197e-05, 1.33376557e-06],
       ...,
       [3.19618353e-04, 1.42469522e-04, 4.01976079e-01, ...,
        4.17127414e-03, 1.23140035e-05, 3.67581117e-04],
       [4.73372756e-06, 1.96558426e-06, 1.62325728e-06, ...,
        4.18798882e-05, 1.72887048e-05, 3.18055868e-07],
       [7.17210128e-07, 4.11534211e-06, 1.11504574e-03, ...,
        3.98348639e-05, 4.46218677e-04, 1.13047599e-05]],
      shape=(286, 7), dtype=float32)

In [8]:
submission

Unnamed: 0,ID,rock_type
0,TEST_00000,Mud_Sandstone
1,TEST_00001,Mud_Sandstone
2,TEST_00002,Mud_Sandstone
3,TEST_00003,Granite
4,TEST_00004,Granite
...,...,...
95001,TEST_95001,Gneiss
95002,TEST_95002,Gneiss
95003,TEST_95003,Gneiss
95004,TEST_95004,Gneiss
