In [None]:
import torch
torch.cuda.empty_cache()
torch.cuda.ipc_collect()


: 

In [3]:
import platform
print(platform.system()) 

Linux


In [4]:
import torch

# PyTorch의 CUDA 버전 확인
if torch.cuda.is_available():
    print(f"CUDA Version: {torch.version.cuda}")
else:
    print("CUDA는 사용할 수 없습니다.")

CUDA Version: 12.4


In [5]:
import os
folder_path = './open/train/Weathered_Rock/'
total_file_count = 0

for root, dirs, files in os.walk(folder_path):
    total_file_count += len(files)

print(f"전체 파일 개수 (서브폴더 포함): {total_file_count}")

전체 파일 개수 (서브폴더 포함): 33452


In [6]:
import timm
import torch

# 별도로 해당 모델에 맞는 입력 전처리 방법을 확인하고 적용해야 합니다.
# 위 코드의 resolve_data_config와 create_transform 부분이 바로 이 역할을 합니다. 
# timm을 사용할 때는 이 전처리 부분을 놓치지 않는 것이 매우 중요합니다.
available_swin_models = timm.list_models('*convnext*', pretrained=True)
print(available_swin_models)

['convnext_atto.d2_in1k', 'convnext_atto_ols.a2_in1k', 'convnext_base.clip_laion2b', 'convnext_base.clip_laion2b_augreg', 'convnext_base.clip_laion2b_augreg_ft_in1k', 'convnext_base.clip_laion2b_augreg_ft_in12k', 'convnext_base.clip_laion2b_augreg_ft_in12k_in1k', 'convnext_base.clip_laion2b_augreg_ft_in12k_in1k_384', 'convnext_base.clip_laiona', 'convnext_base.clip_laiona_320', 'convnext_base.clip_laiona_augreg_320', 'convnext_base.clip_laiona_augreg_ft_in1k_384', 'convnext_base.fb_in1k', 'convnext_base.fb_in22k', 'convnext_base.fb_in22k_ft_in1k', 'convnext_base.fb_in22k_ft_in1k_384', 'convnext_femto.d1_in1k', 'convnext_femto_ols.d1_in1k', 'convnext_large.fb_in1k', 'convnext_large.fb_in22k', 'convnext_large.fb_in22k_ft_in1k', 'convnext_large.fb_in22k_ft_in1k_384', 'convnext_large_mlp.clip_laion2b_augreg', 'convnext_large_mlp.clip_laion2b_augreg_ft_in1k', 'convnext_large_mlp.clip_laion2b_augreg_ft_in1k_384', 'convnext_large_mlp.clip_laion2b_augreg_ft_in12k_384', 'convnext_large_mlp.clip

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
model_name = "convnext_large.fb_in22k_ft_in1k_384"
transfer_model_convnext = timm.create_model(model_name, pretrained=True)

In [8]:
print(next(transfer_model_convnext.parameters()).device)

cpu


In [9]:
transfer_model_convnext

ConvNeXt(
  (stem): Sequential(
    (0): Conv2d(3, 192, kernel_size=(4, 4), stride=(4, 4))
    (1): LayerNorm2d((192,), eps=1e-06, elementwise_affine=True)
  )
  (stages): Sequential(
    (0): ConvNeXtStage(
      (downsample): Identity()
      (blocks): Sequential(
        (0): ConvNeXtBlock(
          (conv_dw): Conv2d(192, 192, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=192)
          (norm): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=192, out_features=768, bias=True)
            (act): GELU()
            (drop1): Dropout(p=0.0, inplace=False)
            (norm): Identity()
            (fc2): Linear(in_features=768, out_features=192, bias=True)
            (drop2): Dropout(p=0.0, inplace=False)
          )
          (shortcut): Identity()
          (drop_path): Identity()
        )
        (1): ConvNeXtBlock(
          (conv_dw): Conv2d(192, 192, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), g

In [10]:
num_classes = 7 # 암석 종류 7가지
in_features = transfer_model_convnext.head.fc.in_features # transfer_model_swin.head로 transfer_model_swin.head 인지 .fc를 붙이는 지 차이가 있다.
transfer_model_convnext.head.fc = torch.nn.Linear(in_features, num_classes)  # fc 레이어 교체

In [12]:
transfer_model_convnext.head

NormMlpClassifierHead(
  (global_pool): SelectAdaptivePool2d(pool_type=avg, flatten=Identity())
  (norm): LayerNorm2d((1536,), eps=1e-06, elementwise_affine=True)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (pre_logits): Identity()
  (drop): Dropout(p=0.0, inplace=False)
  (fc): Linear(in_features=1536, out_features=7, bias=True)
)

In [None]:
# convnext.py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms # torchvision.transforms 사용
import torch.nn.functional as F # FocalLoss 등에서 필요
import timm
import torchmetrics
from tqdm import tqdm
import time
import os
import numpy as np # 필요시 사용

# --- 1. Focal Loss 클래스 정의 (이전 코드에서 복사) ---
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2., reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma) * ce_loss

        if self.alpha is not None:
            # alpha 처리 로직 (이전 코드 참조)
            if isinstance(self.alpha, (float, int)):
                 alpha = torch.tensor([self.alpha] * inputs.shape[1], device=inputs.device)
            elif isinstance(self.alpha, list):
                 alpha = torch.tensor(self.alpha, device=inputs.device, dtype=torch.float32)
            elif torch.is_tensor(self.alpha):
                 alpha = self.alpha.to(device=inputs.device, dtype=torch.float32)
            else:
                 raise TypeError("alpha must be float, list or torch.Tensor")

            if alpha.shape[0] != inputs.shape[1]:
                 raise ValueError(f"alpha size {alpha.shape[0]} does not match C {inputs.shape[1]}")

            alpha_t = alpha.gather(0, targets)
            focal_loss = alpha_t * focal_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else: # 'none'
            return focal_loss

# --- 2. 학습 및 검증 함수 정의 (이전 코드에서 복사) ---
# def train_and_validate_best_f1(...):
#     ... (이전 코드 전체 복사) ...
# --- 여기에 train_and_validate_best_f1 함수 코드를 붙여넣으세요 ---
def train_and_validate_best_f1(model: nn.Module,
                               train_loader: DataLoader,
                               val_loader: DataLoader,
                               optimizer: optim.Optimizer, # torch.optim 임포트 사용
                               criterion: nn.Module, # Loss function
                               epochs: int,
                               device: torch.device,
                               num_classes: int,
                               save_dir: str, # 디렉토리로 변경
                               model_name_base: str, # 모델 파일 이름용
                               top_k: int = 5, # 저장할 상위 모델 개수 (기본값 5)
                               gradient_clipping: float = None,
                               lr_scheduler = None,
                               warmup_epochs: int = 0,
                               base_lr: float = 1e-5
                              ):
    
    history = {'train_losses': [], 'val_losses': [], 'val_macro_f1_scores': []}
    max_val_f1 = 0.0  # 여전히 전체 최고 점수 추적 (조기 종료용)
    best_epoch = -1 # 최고 점수 달성 에폭
    top_k_checkpoints = [] # (f1_score, file_path) 튜플을 저장할 리스트

    f1_metric = torchmetrics.F1Score(task='multiclass', num_classes=num_classes, average='macro').to(device)
    patience = 5
    epochs_no_improve = 0 # 조기 종료 카운터는 여전히 max_val_f1 기준

    # save_dir 존재 확인 및 생성 (함수 호출 전에 해도 되지만 여기서도 확인)
    os.makedirs(save_dir, exist_ok=True)

    print(f"학습 시작: 총 {epochs} 에폭, Device: {device}")
    print(f"Top-{top_k} 모델 저장 디렉토리: {save_dir}") # 경로 출력 수정
    print(f"평가 기준: Validation Macro F1 Score")

    for epoch in range(epochs):
        epoch_start_time = time.time()
        current_lr = optimizer.param_groups[0]['lr']

        # --- Warmup 및 LR 스케줄링 ---
        if epoch < warmup_epochs:
            warmup_factor = (epoch + 1) / warmup_epochs
            for param_group in optimizer.param_groups:
                param_group['lr'] = base_lr * warmup_factor
            current_lr = optimizer.param_groups[0]['lr']
        elif lr_scheduler is not None:
             if epoch == warmup_epochs:
                 for param_group in optimizer.param_groups:
                     param_group['lr'] = base_lr
                 lr_scheduler.step()
             else:
                 lr_scheduler.step()
             current_lr = optimizer.param_groups[0]['lr']

        # --- 학습 단계 ---
        model.train()
        running_train_loss = 0.0
        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train] LR: {current_lr:.1e}", leave=False)
        for images, labels in train_pbar:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            if gradient_clipping is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
            optimizer.step()
            running_train_loss += loss.item()
            train_pbar.set_postfix(loss=f"{loss.item():.4f}")
        epoch_train_loss = running_train_loss / len(train_loader)
        history['train_losses'].append(epoch_train_loss)

        # --- 검증 단계 ---
        model.eval()
        running_val_loss = 0.0
        f1_metric.reset()
        val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} [Val] ", leave=False)
        with torch.no_grad():
            for images, labels in val_pbar:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                running_val_loss += loss.item()
                f1_metric.update(outputs, labels)
                val_pbar.set_postfix(loss=f"{loss.item():.4f}")
        epoch_val_loss = running_val_loss / len(val_loader)
        history['val_losses'].append(epoch_val_loss)
        epoch_val_f1 = f1_metric.compute().item()
        history['val_macro_f1_scores'].append(epoch_val_f1)

        epoch_end_time = time.time()
        epoch_duration = epoch_end_time - epoch_start_time

        print(f"Epoch [{epoch+1}/{epochs}] ({epoch_duration:.2f}s) - "
              f"Train Loss: {epoch_train_loss:.4f}, "
              f"Val Loss: {epoch_val_loss:.4f}, "
              f"Val Macro F1: {epoch_val_f1:.4f}")

        # --- 모델 저장 및 조기 종료 ---
        is_top_k = len(top_k_checkpoints) < top_k or epoch_val_f1 > top_k_checkpoints[-1][0]

        if is_top_k:
            # 새 모델 저장 경로 생성 (에폭과 F1 점수 포함)
            checkpoint_filename = f"{model_name_base}_epoch{epoch+1}_f1_{epoch_val_f1:.4f}.pth"
            checkpoint_path = os.path.join(save_dir, checkpoint_filename)

            try:
                torch.save(model.state_dict(), checkpoint_path)
                print(f"  Checkpoint saved to {checkpoint_path}")

                # Top-K 리스트 업데이트
                top_k_checkpoints.append((epoch_val_f1, checkpoint_path))
                # F1 점수 기준 내림차순 정렬
                top_k_checkpoints.sort(key=lambda x: x[0], reverse=True)

                # 리스트 크기가 K개를 초과하면 가장 낮은 점수 모델 제거
                if len(top_k_checkpoints) > top_k:
                    score_to_remove, path_to_remove = top_k_checkpoints.pop() # 마지막 항목 제거
                    print(f"  Removing checkpoint {os.path.basename(path_to_remove)} (score: {score_to_remove:.4f}) as it's no longer in top-{top_k}")
                    if os.path.exists(path_to_remove):
                        try:
                            os.remove(path_to_remove)
                        except Exception as e_rem:
                            print(f"    Error removing file {path_to_remove}: {e_rem}")
                    else:
                        print(f"    Warning: File to remove not found: {path_to_remove}")

            except Exception as e_save:
                print(f"  Error saving checkpoint: {e_save}")

        # 조기 종료는 여전히 '최고 점수' 갱신 여부 기준
        if epoch_val_f1 > max_val_f1:
            print(f"  Validation Macro F1 improved ({max_val_f1:.4f} --> {epoch_val_f1:.4f}).")
            max_val_f1 = epoch_val_f1
            best_epoch = epoch # 최고 점수 에폭 업데이트
            epochs_no_improve = 0 # 카운터 리셋
        else:
            epochs_no_improve += 1 # 최고 점수 갱신 안됨, 카운터 증가
            print(f"  Validation Macro F1 did not improve from the best ({max_val_f1:.4f}). ({epochs_no_improve}/{patience})")

        # 조기 종료 조건 확인
        if epochs_no_improve >= patience:
            print(f"Early stopping triggered at epoch {epoch+1} after {patience} epochs without improvement from the best F1 score.")
            break # 학습 루프 종료

    print(f"\n학습 완료.")
    print(f"최종 Top-{top_k} 모델 성능 및 경로:")
    if not top_k_checkpoints:
        print("  No models were saved.")
    else:
        for i, (score, path) in enumerate(top_k_checkpoints):
            print(f"  Top {i+1}: Score={score:.4f}, Path={path}")
        # 최고 기록 자체는 여전히 max_val_f1과 best_epoch으로 알 수 있음
        print(f"\nOverall Best Epoch: {best_epoch+1}, Overall Best Validation Macro F1: {max_val_f1:.4f}")

    return history


# --- 3. 설정 변수 정의 ---
MODEL_NAME = 'convnext_large.fb_in22k_ft_in1k_384' # 사용할 모델 이름
NUM_CLASSES = 7        # 분류할 암석 종류 개수
IMG_SIZE = 384         # 모델 입력 이미지 크기

# --- 데이터 경로 설정 ---
# TODO: 실제 데이터 경로로 수정하세요
TRAIN_DATA_DIR = r"/home/metaai2/workspace/limseunghwan/open/train"
VAL_DATA_DIR = r"/home/metaai2/workspace/limseunghwan/open/val"

# --- 학습 하이퍼파라미터 ---
EPOCHS = 20          # 총 학습 에폭 수 (조절 가능)
BATCH_SIZE = 8        # 배치 크기 (GPU 메모리에 맞게 조절)
BASE_LR = 1e-5         # 기본 학습률 (ConvNeXt fine-tuning에 적합한 값으로 시작, 튜닝 필요)
WEIGHT_DECAY = 1e-2    # 가중치 감쇠 (AdamW와 함께 사용)
WARMUP_EPOCHS = 5      # Warmup 에폭 수
GRADIENT_CLIPPING = 1.0 # Gradient Clipping 값 (사용하지 않으려면 None)

# --- 시스템 설정 ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_WORKERS = os.cpu_count() // 2  # 사용 가능한 CPU 코어의 절반 정도 (조절 가능)

# --- 저장 경로 설정 ---
SAVE_DIR = './saved_models' # 모델 저장 디렉토리
os.makedirs(SAVE_DIR, exist_ok=True) # 디렉토리 생성
MODEL_SAVE_PATH = os.path.join(SAVE_DIR, f'{MODEL_NAME.split(".")[0]}_best_f1.pth') # 모델 파일 경로

# --- 4. 모델 로드 및 수정 ---
print(f"Loading model: {MODEL_NAME}")
# timm을 사용하여 사전 학습된 ConvNeXt 모델 로드
model = timm.create_model(MODEL_NAME, pretrained=True)

# 모델의 분류기(head) 부분을 새로운 클래스 수에 맞게 교체
# num_ftrs = model.head.in_features # ConvNeXt는 보통 'head' 속성 사용
# model.head = nn.Linear(num_ftrs, NUM_CLASSES)
model.reset_classifier(num_classes=NUM_CLASSES)
print(f"Model head replaced for {NUM_CLASSES} classes.")

# 모델을 지정된 장치로 이동
model = model.to(DEVICE)

# --- 5. 데이터 변환 정의 ---
# 학습 데이터 변환 (Data Augmentation 포함 - 이전 예시 기반)
# TODO: 필요시 scale, rotation, colorjitter 등 세부 파라미터 조절
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(size=(IMG_SIZE, IMG_SIZE), scale=(0.5, 1.0), ratio=(0.75, 1.3333), interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15, interpolation=transforms.InterpolationMode.BILINEAR, fill=0),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet 기본값 사용
])

# 검증/테스트 데이터 변환 (Data Augmentation 없음)
# timm 기본 설정을 사용하거나 직접 정의
try:
    # timm 설정을 우선 사용 시도
    config = timm.data.resolve_data_config({}, model=model)
    val_transform = timm.data.create_transform(**config, is_training=False)
    print("Using timm's default validation transform.")
except Exception as e:
    # timm 설정 로드 실패 시 직접 정의
    print(f"Failed to get timm config ({e}), defining validation transform manually.")
    val_transform = transforms.Compose([
        transforms.Resize(IMG_SIZE, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.CenterCrop(IMG_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

# --- 6. 데이터셋 및 데이터 로더 준비 ---
# TODO: 실제 데이터셋 클래스를 사용하세요. (예: torchvision.datasets.ImageFolder)
# ImageFolder 사용 예시:
from torchvision.datasets import ImageFolder

print(f"Loading datasets from: {TRAIN_DATA_DIR} and {VAL_DATA_DIR}")
try:
    train_dataset = ImageFolder(root=TRAIN_DATA_DIR, transform=train_transform)
    val_dataset = ImageFolder(root=VAL_DATA_DIR, transform=val_transform)

    print(f"Found {len(train_dataset)} training images and {len(val_dataset)} validation images.")
    print(f"Classes: {train_dataset.classes}") # 클래스 이름 확인

    # Focal Loss의 alpha 계산 (선택적) - 클래스 빈도 기반
    class_counts = np.bincount([s[1] for s in train_dataset.samples])
    if len(class_counts) != NUM_CLASSES:
        print(f"Warning: Number of found classes ({len(class_counts)}) does not match NUM_CLASSES ({NUM_CLASSES}). Adjust NUM_CLASSES or check dataset.")
        # focal_loss_alpha = None
    else:
        total_samples = sum(class_counts)
        class_weights = [total_samples / count if count > 0 else 0 for count in class_counts]
        max_weight = max(class_weights) if any(w > 0 for w in class_weights) else 1 # 0으로 나누기 방지
        class_weights = [w / max_weight for w in class_weights]
        focal_loss_alpha = torch.tensor(class_weights, device=DEVICE, dtype=torch.float32)
        print(f"Calculated Focal Loss alpha (normalized): {focal_loss_alpha.cpu().numpy()}")
    # focal_loss_alpha = None # 우선 None으로 설정

except FileNotFoundError:
    print(f"Error: Data directory not found. Please check TRAIN_DATA_DIR and VAL_DATA_DIR.")
    exit()
except Exception as e:
    print(f"Error loading dataset: {e}")
    exit()

# 데이터 로더 생성
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

# --- 7. 손실 함수, 옵티마이저, 스케줄러 정의 ---
# 손실 함수 (Focal Loss 또는 CrossEntropyLoss)
criterion = FocalLoss(alpha=focal_loss_alpha, gamma=2.0).to(DEVICE)
# criterion = nn.CrossEntropyLoss().to(DEVICE) # CrossEntropy 사용 시

# 옵티마이저 (AdamW 추천)
optimizer = optim.AdamW(model.parameters(), lr=BASE_LR, weight_decay=WEIGHT_DECAY)

# LR 스케줄러 (Cosine Annealing with Warmup)
lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS - WARMUP_EPOCHS, eta_min=BASE_LR * 0.01)

# --- 8. 학습 및 검증 실행 ---
if __name__ == "__main__": # 스크립트로 실행될 때만 학습 시작
    print("\nStarting training process...")
    model_name_base = MODEL_NAME.split('/')[-1].split('.')[0] if '/' in MODEL_NAME else MODEL_NAME.split('.')[0]
    history = train_and_validate_best_f1(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        optimizer=optimizer,
        criterion=criterion,
        epochs=EPOCHS,
        device=DEVICE,
        num_classes=NUM_CLASSES,
        save_dir=SAVE_DIR,             # 디렉토리 전달
        model_name_base=model_name_base, # 모델 이름 기반 전달
        top_k=5,                       # 저장할 개수 지정 (예: 5)
        gradient_clipping=GRADIENT_CLIPPING,
        lr_scheduler=lr_scheduler,
        warmup_epochs=WARMUP_EPOCHS,
        base_lr=BASE_LR
    )

    print("\nTraining finished.")
    # 최종 결과 요약 출력 (함수 내에서도 출력됨)
    if history['val_macro_f1_scores']: # 점수가 기록되었는지 확인
        print("Training History Summary:")
        best_f1_overall = max(history['val_macro_f1_scores'])
        best_epoch_overall = history['val_macro_f1_scores'].index(best_f1_overall) + 1
        print(f"  Overall Best Validation Macro F1 in history: {best_f1_overall:.4f}")
        print(f"  Achieved at Epoch: {best_epoch_overall}")
    else:
        print("No validation scores recorded in history.")

Loading model: convnext_large.fb_in22k_ft_in1k_384
Model head replaced for 7 classes.
Using timm's default validation transform.
Loading datasets from: /home/metaai2/workspace/limseunghwan/open/train and /home/metaai2/workspace/limseunghwan/open/val
Found 342015 training images and 38005 validation images.
Classes: ['Andesite', 'Basalt', 'Etc', 'Gneiss', 'Granite', 'Mud_Sandstone', 'Weathered_Rock']
Calculated Focal Loss alpha (normalized): [0.36379087 0.59434706 1.         0.21558282 0.17148152 0.17810482
 0.4287038 ]

Starting training process...
학습 시작: 총 25 에폭, Device: cuda
Top-5 모델 저장 디렉토리: ./saved_models
평가 기준: Validation Macro F1 Score


                                                                                                 

KeyboardInterrupt: 

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import timm
import pandas as pd
from PIL import Image
import os
from tqdm import tqdm
import numpy as np # For class names later if needed

# --- 0. Configuration (Adjust these paths and parameters) ---
# Model and Data Parameters (should match training)
MODEL_NAME = 'convnext_large.fb_in22k_ft_in1k_384' # Same as training
NUM_CLASSES = 7        # Same as training
IMG_SIZE = 384         # Same as training

# Paths
# TODO: IMPORTANT! Update this to the path of your BEST trained model
SAVED_MODEL_PATH = './saved_models/convnext_large_epoch20_f1_0.9130.pth'
TEST_CSV_PATH = './open/test.csv'    # Path to your test.csv
IMAGE_BASE_DIR = './open'             # Base directory where images listed in test.csv are (e.g., if csv says 'test/img1.jpg', full path is IMAGE_BASE_DIR/test/img1.jpg)
OUTPUT_CSV_PATH = './submission_convnext_epoch20.csv' # Where to save predictions

# Inference Parameters
BATCH_SIZE_INFERENCE = 16 # Adjust based on your GPU memory
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_WORKERS_INFERENCE = os.cpu_count() // 2

# TODO: IMPORTANT! Define your class names in the order your model was trained.
# You can get this from `train_dataset.classes` from your training script.
# Example: CLASS_NAMES = ['ClassA', 'ClassB', 'ClassC', 'ClassD', 'ClassE', 'ClassF', 'ClassG']
# If you ran the training script and printed `train_dataset.classes`, use that output here.
# For example, if train_dataset.classes was ['andesite', 'gneiss', ...], then:
CLASS_NAMES = ['Andesite', 'Basalt', 'Etc', 'Gneiss', 'Granite', 'Mud_Sandstone', 'Weathered_Rock'] # Replace with your actual class names in order

if len(CLASS_NAMES) != NUM_CLASSES:
    raise ValueError(f"Number of CLASS_NAMES ({len(CLASS_NAMES)}) does not match NUM_CLASSES ({NUM_CLASSES})")

# --- 1. Model Loading ---
print(f"Loading model: {MODEL_NAME}")
model = timm.create_model(MODEL_NAME, pretrained=False, num_classes=NUM_CLASSES) # pretrained=False as we load our weights
try:
    model.load_state_dict(torch.load(SAVED_MODEL_PATH, map_location=DEVICE))
    print(f"Successfully loaded model weights from: {SAVED_MODEL_PATH}")
except FileNotFoundError:
    print(f"ERROR: Model file not found at {SAVED_MODEL_PATH}. Please check the path.")
    exit()
except Exception as e:
    print(f"ERROR: Could not load model weights: {e}")
    exit()

model = model.to(DEVICE)
model.eval()
print(f"Model loaded on {DEVICE} and set to evaluation mode.")

# --- 2. Data Transformations for Inference (use validation transform from training) ---
# Re-create the validation transform used during training.
# If you used timm's default, try to recreate it:
try:
    config = timm.data.resolve_data_config({}, model=model, use_test_size=True) # use_test_size for consistency
    # Override img_size if needed, though convnext_large_..._384 implies 384
    config['input_size'] = (3, IMG_SIZE, IMG_SIZE) # Ensure this is (C, H, W)
    inference_transform = timm.data.create_transform(**config, is_training=False)
    print("Using timm's default validation transform for inference.")
except Exception as e:
    print(f"Failed to get timm config for inference ({e}), defining inference transform manually.")
    inference_transform = transforms.Compose([
        transforms.Resize(IMG_SIZE, interpolation=transforms.InterpolationMode.BICUBIC), # Or (IMG_SIZE, IMG_SIZE)
        transforms.CenterCrop(IMG_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
print(f"Inference Transform: {inference_transform}")

# --- 3. Custom Dataset for Test Images ---
class TestImageDataset(Dataset):
    def __init__(self, csv_path, img_dir_root, transform=None):
        """
        Args:
            csv_path (string): Path to the csv file with image paths.
            img_dir_root (string): Root directory for images. Paths in CSV are relative to this.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.data_frame = pd.read_csv(csv_path)
        self.img_dir_root = img_dir_root
        self.transform = transform
        # Assuming the CSV has a column named 'img_path' or similar for the image file paths
        # If your column name is different, change 'img_path' below
        if 'img_path' not in self.data_frame.columns:
            # Fallback if 'img_path' is not present, try the first column.
            print(f"Warning: 'img_path' column not found in {csv_path}. Assuming first column ('{self.data_frame.columns[0]}') contains image paths.")
            self.img_path_column = self.data_frame.columns[0]
        else:
            self.img_path_column = 'img_path'


    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        # img_name is the path relative to img_dir_root, e.g., "test/test_00000.jpg"
        relative_img_path = self.data_frame.loc[idx, self.img_path_column]
        # full_img_path = os.path.join(self.img_dir_root, relative_img_path)
        # If relative_img_path already includes a base like "test/" and img_dir_root is "./open"
        # then full path is "./open/test/test_00000.jpg"
        # However, if test.csv paths are absolute or relative to project root, img_dir_root might be "" or "."
        # Given IMAGE_BASE_DIR = './open' and CSV has 'test/test_00000.jpg',
        # we form the path as os.path.join(IMAGE_BASE_DIR, value_from_csv_img_path_column)
        # Let's assume paths in CSV are like 'test/image.png'
        # and IMAGE_BASE_DIR = './open'
        # Then the image is at './open/test/image.png'
        # So, we use img_name = relative_img_path, because it might already be 'test/image.png'
        # And the full path will be os.path.join(self.img_dir_root, img_name)

        full_img_path = os.path.join(self.img_dir_root, relative_img_path)


        try:
            image = Image.open(full_img_path).convert('RGB')
        except FileNotFoundError:
            print(f"ERROR: Image not found at {full_img_path}. Check IMAGE_BASE_DIR and CSV paths.")
            # Return a placeholder or raise error
            # For robustness in a batch, could return a black image and handle later
            # but for now, let's make it fail loudly if one image is missing during setup.
            # Better: during iteration, catch and log.
            # For __getitem__, it's often better to ensure data exists or provide a fallback.
            # Here we let it raise if an image is critically missing.
            raise
        except Exception as e:
            print(f"ERROR: Could not open image {full_img_path}: {e}")
            raise

        if self.transform:
            image = self.transform(image)

        # Return the image and its original filename (or path from CSV) for mapping later
        return image, relative_img_path

# --- 4. DataLoader for Test Set ---
print(f"Loading test data from: {TEST_CSV_PATH} with image base: {IMAGE_BASE_DIR}")
try:
    test_dataset = TestImageDataset(csv_path=TEST_CSV_PATH,
                                    img_dir_root=IMAGE_BASE_DIR,
                                    transform=inference_transform)
    test_loader = DataLoader(test_dataset,
                             batch_size=BATCH_SIZE_INFERENCE,
                             shuffle=False,
                             num_workers=NUM_WORKERS_INFERENCE,
                             pin_memory=True)
    print(f"Found {len(test_dataset)} images for testing.")
except FileNotFoundError:
    print(f"Error: Test CSV file not found at {TEST_CSV_PATH}. Please check the path.")
    exit()
except Exception as e:
    print(f"Error creating test dataset/loader: {e}")
    exit()


# --- 5. Prediction Function ---
def predict_on_test_data(model, loader, device, class_names):
    model.eval()
    all_preds_indices = []
    all_filenames = [] # To store the original filenames/paths from CSV

    with torch.no_grad():
        for inputs, filenames_batch in tqdm(loader, desc="Predicting"):
            inputs = inputs.to(device)
            outputs = model(inputs)
            _, predicted_indices = torch.max(outputs, 1)

            all_preds_indices.extend(predicted_indices.cpu().numpy())
            all_filenames.extend(list(filenames_batch)) # filenames_batch is a tuple of strings

    # Convert predicted indices to class names
    predicted_class_names = [class_names[idx] for idx in all_preds_indices]
    return all_filenames, predicted_class_names

# --- 6. Run Inference and Save Results ---
if __name__ == "__main__":
    print("\nStarting inference...")
    original_img_paths_from_loader, predictions = predict_on_test_data(model, test_loader, DEVICE, CLASS_NAMES)
    original_test_df = pd.read_csv(TEST_CSV_PATH)

    # TestImageDataset에서 사용된 img_path 컬럼 이름 가져오기
    csv_img_path_col = test_dataset.img_path_column # 예: 'img_path'

    # Loader에서 반환된 경로(original_img_paths_from_loader)와 예측값을 매핑하는 딕셔너리 생성
    # 경로 정규화를 통해 매칭 신뢰도 향상 (os.path.normpath)
    prediction_map = {
        os.path.normpath(p): label
        for p, label in zip(original_img_paths_from_loader, predictions)
    }

    # 원본 test_df의 img_path 컬럼에 대해 예측값 매핑
    # os.path.normpath를 사용하여 CSV의 경로와 loader의 경로를 일관되게 비교
    mapped_labels = original_test_df[csv_img_path_col].apply(
        lambda x: prediction_map.get(os.path.normpath(x))
    )

    # 새로운 submission DataFrame 생성
    submission_df_final = pd.DataFrame()

    # img_path 컬럼: 원본 CSV의 경로에서 파일명만 추출 (확장자 제외)
    # 예: './test/TEST_00000.jpg' -> 'TEST_00000'
    # 예: 'test/TEST_00000.jpg' -> 'TEST_00000'
    # 예: 'TEST_00000.jpg' -> 'TEST_00000' (만약 CSV에 이렇게만 있다면)
    submission_df_final['ID'] = original_test_df[csv_img_path_col].apply(
        lambda x: os.path.splitext(os.path.basename(x))[0]
    )
    submission_df_final['rock_type'] = mapped_labels

    # 혹시 매핑되지 않은 예측값이 있는지 확인 (디버깅용)
    if submission_df_final['rock_type'].isnull().any():
        print("Warning: Some images in the CSV did not get a prediction. Check for mismatches or missing images.")
        # 필요한 경우 NaN 값을 기본값으로 채울 수 있습니다.
        # 예: submission_df_final['label'].fillna(CLASS_NAMES[0], inplace=True)

    # 최종 결과를 CSV 파일로 저장
    submission_df_final.to_csv(OUTPUT_CSV_PATH, index=False)
    print(f"\nInference complete. Predictions saved to: {OUTPUT_CSV_PATH}")
    print(f"Sample predictions (img_path modified):\n{submission_df_final.head()}")

  from .autonotebook import tqdm as notebook_tqdm


Loading model: convnext_large.fb_in22k_ft_in1k_384
Successfully loaded model weights from: ./saved_models/convnext_large_epoch20_f1_0.9130.pth
Model loaded on cuda and set to evaluation mode.
Using timm's default validation transform for inference.
Inference Transform: Compose(
    Resize(size=(384, 384), interpolation=bicubic, max_size=None, antialias=True)
    CenterCrop(size=(384, 384))
    MaybeToTensor()
    Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
)
Loading test data from: ./open/test.csv with image base: ./open
Found 95006 images for testing.

Starting inference...


Predicting: 100%|██████████| 5938/5938 [10:45<00:00,  9.20it/s]



Inference complete. Predictions saved to: ./submission_convnext_epoch20.csv
Sample predictions (img_path modified):
           ID      rock_type
0  TEST_00000  Mud_Sandstone
1  TEST_00001  Mud_Sandstone
2  TEST_00002  Mud_Sandstone
3  TEST_00003        Granite
4  TEST_00004        Granite


In [4]:
torch.device("cuda" if torch.cuda.is_available() else "cpu")

device(type='cpu')

In [6]:
torch.cuda.is_available()

False