In [1]:
import torch

if torch.cuda.is_available():
    print("GPU 메모리 사용량:")
    print(f"사용 중: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
    print(f"예약 중: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")

GPU 메모리 사용량:
사용 중: 0.00 MB
예약 중: 0.00 MB


  from .autonotebook import tqdm as notebook_tqdm


Loading model: maxvit_xlarge_tf_384.in21k_ft_in1k
Model classifier replaced for 7 classes.
Using timm's default validation transform for maxvit_xlarge_tf_384.in21k_ft_in1k.
Timm default input size for maxvit_xlarge_tf_384.in21k_ft_in1k: (3, 384, 384)
Timm default mean: (0.5, 0.5, 0.5), std: (0.5, 0.5, 0.5)
Overriding to use specified IMG_SIZE: 384 for validation transform consistency.
Loading datasets from: /home/metaai2/workspace/limseunghwan/open/train and /home/metaai2/workspace/limseunghwan/open/val
Found 342015 training images and 38005 validation images.
Classes: ['Andesite', 'Basalt', 'Etc', 'Gneiss', 'Granite', 'Mud_Sandstone', 'Weathered_Rock']
Calculated Focal Loss alpha (normalized): [0.36379087 0.59434706 1.         0.21558282 0.17148152 0.17810482
 0.4287038 ]
Using Focal Loss with calculated alpha.

Starting training process with maxvit_xlarge_tf_384.in21k_ft_in1k...
Batch Size: 2. If you encounter OOM errors, try reducing it further.
학습 시작: 총 20 에폭, Device: cuda
Top-5 모델

                                                                                                   

KeyboardInterrupt: 

: 

In [1]:
# MaxViT.py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset # Added Dataset for TestImageDataset
from torchvision import transforms
import torch.nn.functional as F
import timm
import torchmetrics
from tqdm import tqdm
import time
import os
import numpy as np
import pandas as pd # For inference CSV handling
from PIL import Image # For inference image loading

# --- 1. Focal Loss 클래스 정의 (이전 코드에서 복사) ---
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2., reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma) * ce_loss

        if self.alpha is not None:
            if isinstance(self.alpha, (float, int)):
                 alpha = torch.tensor([self.alpha] * inputs.shape[1], device=inputs.device)
            elif isinstance(self.alpha, list):
                 alpha = torch.tensor(self.alpha, device=inputs.device, dtype=torch.float32)
            elif torch.is_tensor(self.alpha):
                 alpha = self.alpha.to(device=inputs.device, dtype=torch.float32)
            else:
                 raise TypeError("alpha must be float, list or torch.Tensor")

            if alpha.shape[0] != inputs.shape[1]:
                 raise ValueError(f"alpha size {alpha.shape[0]} does not match C {inputs.shape[1]}")

            alpha_t = alpha.gather(0, targets)
            focal_loss = alpha_t * focal_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else: # 'none'
            return focal_loss

# --- 2. 학습 및 검증 함수 정의 (이전 코드에서 복사) ---
def train_and_validate_best_f1(model: nn.Module,
                               train_loader: DataLoader,
                               val_loader: DataLoader,
                               optimizer: optim.Optimizer,
                               criterion: nn.Module,
                               epochs: int,
                               device: torch.device,
                               num_classes: int,
                               save_dir: str,
                               model_name_base: str,
                               top_k: int = 5,
                               gradient_clipping: float = None,
                               lr_scheduler = None,
                               warmup_epochs: int = 0,
                               base_lr: float = 1e-5
                              ):

    history = {'train_losses': [], 'val_losses': [], 'val_macro_f1_scores': [], 'best_model_path': None} # Added best_model_path
    max_val_f1 = 0.0
    best_epoch = -1
    top_k_checkpoints = []

    f1_metric = torchmetrics.F1Score(task='multiclass', num_classes=num_classes, average='macro').to(device)
    patience = 10 # 조기 종료를 위한 인내 에폭 수 (조금 늘려봄)
    epochs_no_improve = 0

    os.makedirs(save_dir, exist_ok=True)

    print(f"학습 시작: 총 {epochs} 에폭, Device: {device}")
    print(f"Top-{top_k} 모델 저장 디렉토리: {save_dir}")
    print(f"평가 기준: Validation Macro F1 Score")

    for epoch in range(epochs):
        epoch_start_time = time.time()
        current_lr = optimizer.param_groups[0]['lr']

        if epoch < warmup_epochs:
            warmup_factor = (epoch + 1) / warmup_epochs
            for param_group in optimizer.param_groups:
                param_group['lr'] = base_lr * warmup_factor
            current_lr = optimizer.param_groups[0]['lr']
        elif lr_scheduler is not None:
             if epoch == warmup_epochs:
                 for param_group in optimizer.param_groups:
                     param_group['lr'] = base_lr
                 current_lr = optimizer.param_groups[0]['lr']
             lr_scheduler.step()
             current_lr = optimizer.param_groups[0]['lr']


        model.train()
        running_train_loss = 0.0
        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train] LR: {current_lr:.1e}", leave=False)
        for images, labels in train_pbar:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            if gradient_clipping is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
            optimizer.step()
            running_train_loss += loss.item()
            train_pbar.set_postfix(loss=f"{loss.item():.4f}")
        epoch_train_loss = running_train_loss / len(train_loader)
        history['train_losses'].append(epoch_train_loss)

        model.eval()
        running_val_loss = 0.0
        f1_metric.reset()
        val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} [Val] ", leave=False)
        with torch.no_grad():
            for images, labels in val_pbar:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                running_val_loss += loss.item()
                f1_metric.update(outputs, labels)
                val_pbar.set_postfix(loss=f"{loss.item():.4f}")
        epoch_val_loss = running_val_loss / len(val_loader)
        history['val_losses'].append(epoch_val_loss)
        epoch_val_f1 = f1_metric.compute().item()
        history['val_macro_f1_scores'].append(epoch_val_f1)

        epoch_end_time = time.time()
        epoch_duration = epoch_end_time - epoch_start_time

        print(f"Epoch [{epoch+1}/{epochs}] ({epoch_duration:.2f}s) - "
              f"Train Loss: {epoch_train_loss:.4f}, "
              f"Val Loss: {epoch_val_loss:.4f}, "
              f"Val Macro F1: {epoch_val_f1:.4f}")

        is_top_k = len(top_k_checkpoints) < top_k or epoch_val_f1 > top_k_checkpoints[-1][0]

        if is_top_k:
            checkpoint_filename = f"{model_name_base}_epoch{epoch+1:03d}_f1_{epoch_val_f1:.4f}.pth" #epoch zero-padding
            checkpoint_path = os.path.join(save_dir, checkpoint_filename)
            try:
                torch.save(model.state_dict(), checkpoint_path)
                print(f"  Checkpoint saved to {checkpoint_path}")
                top_k_checkpoints.append((epoch_val_f1, checkpoint_path))
                top_k_checkpoints.sort(key=lambda x: x[0], reverse=True)
                if len(top_k_checkpoints) > top_k:
                    score_to_remove, path_to_remove = top_k_checkpoints.pop()
                    print(f"  Removing checkpoint {os.path.basename(path_to_remove)} (score: {score_to_remove:.4f}) as it's no longer in top-{top_k}")
                    if os.path.exists(path_to_remove):
                        try:
                            os.remove(path_to_remove)
                        except Exception as e_rem:
                            print(f"    Error removing file {path_to_remove}: {e_rem}")
                    else:
                        print(f"    Warning: File to remove not found: {path_to_remove}")
            except Exception as e_save:
                print(f"  Error saving checkpoint: {e_save}")

        if epoch_val_f1 > max_val_f1:
            print(f"  Validation Macro F1 improved ({max_val_f1:.4f} --> {epoch_val_f1:.4f}).")
            max_val_f1 = epoch_val_f1
            best_epoch = epoch
            if top_k_checkpoints: # Best model is the first in sorted top_k
                history['best_model_path'] = top_k_checkpoints[0][1]
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            print(f"  Validation Macro F1 did not improve from the best ({max_val_f1:.4f}). ({epochs_no_improve}/{patience})")

        if epochs_no_improve >= patience and epoch >= warmup_epochs + patience : # Ensure warmup isn't cut short, and some training happened
            print(f"Early stopping triggered at epoch {epoch+1} after {patience} epochs without improvement from the best F1 score.")
            break

    print(f"\n학습 완료.")
    print(f"최종 Top-{top_k} 모델 성능 및 경로:")
    if not top_k_checkpoints:
        print("  No models were saved.")
    else:
        for i, (score, path) in enumerate(top_k_checkpoints):
            print(f"  Top {i+1}: Score={score:.4f}, Path={path}")
        if not history['best_model_path'] and top_k_checkpoints: # Fallback if not set during improvement
             history['best_model_path'] = top_k_checkpoints[0][1]

    print(f"\nOverall Best Epoch: {best_epoch+1 if best_epoch != -1 else 'N/A'}, Overall Best Validation Macro F1: {max_val_f1:.4f}")
    if history['best_model_path']:
        print(f"Path to best model: {history['best_model_path']}")
    else:
        print("No best model path recorded (possibly no improvement or no models saved).")


    return history

# --- 3. 설정 변수 정의 ---
MODEL_NAME = 'maxvit_xlarge_tf_384.in21k_ft_in1k' # MaxViT XLarge
NUM_CLASSES = 7
IMG_SIZE = 384

# --- 데이터 경로 설정 ---
# TODO: 실제 데이터 경로로 수정하세요
TRAIN_DATA_DIR = r"/home/metaai2/workspace/limseunghwan/open/train"
VAL_DATA_DIR = r"/home/metaai2/workspace/limseunghwan/open/val"
# --- 테스트 데이터 및 결과 경로 설정 (추가) ---
TEST_CSV_PATH = r'/home/metaai2/workspace/limseunghwan/open/test.csv'    # Path to your test.csv
IMAGE_BASE_DIR = r'/home/metaai2/workspace/limseunghwan/open'           # Base directory for images in test.csv
# OUTPUT_CSV_PATH will be generated dynamically based on model name and performance

# --- 학습 하이퍼파라미터 ---
EPOCHS = 50 # 에폭 수 (조절)
BATCH_SIZE = 2 # MaxViT-XLarge는 매우 많은 메모리를 사용
BASE_LR = 1e-5
WEIGHT_DECAY = 1e-2
WARMUP_EPOCHS = 5
GRADIENT_CLIPPING = 1.0

# --- 시스템 설정 ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_WORKERS = os.cpu_count() // 2 if os.cpu_count() else 4

# --- 저장 경로 설정 ---
MODEL_ARTIFACT_NAME_BASE = MODEL_NAME.split('/')[-1].split('.')[0] if '/' in MODEL_NAME else MODEL_NAME.split('.')[0]
SAVE_DIR = f'./saved_models_{MODEL_ARTIFACT_NAME_BASE}_{IMG_SIZE}' # 모델 저장 디렉토리
os.makedirs(SAVE_DIR, exist_ok=True)

# --- 클래스 이름 정의 (중요!) ---
# TODO: ***매우 중요*** 당신의 데이터셋에 맞는 실제 클래스 이름을 순서대로 정의하세요.
#       train_dataset.classes 에서 가져올 수 있습니다.
#       예시: CLASS_NAMES = ['andesite', 'gneiss', 'granite', 'mudstone', 'quartzite', 'rhyolite', 'sandstone']
CLASS_NAMES = ['Andesite', 'Basalt', 'Etc', 'Gneiss', 'Granite', 'Mud_Sandstone', 'Weathered_Rock'] # 예시 값, 반드시 수정!
# CLASS_NAMES = [] # 아래에서 train_dataset.classes로 채우도록 시도합니다.

# --- 4. 모델 로드 및 수정 ---
print(f"Loading model: {MODEL_NAME}")
# 모델은 학습 루프 시작 시 또는 추론 시점에 로드/생성됩니다.
# 여기서는 전역으로 선언하지 않고, 필요할 때 생성하도록 변경 가능.
# 우선은 기존 방식 유지
model_train = timm.create_model(MODEL_NAME, pretrained=True, num_classes=NUM_CLASSES, img_size=IMG_SIZE)
print(f"Model classifier adapted for {NUM_CLASSES} classes.")
model_train = model_train.to(DEVICE)


# --- 5. 데이터 변환 정의 ---
# 학습 변환
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(size=(IMG_SIZE, IMG_SIZE), scale=(0.6, 1.0), ratio=(0.75, 1.3333), interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5), # 추가
    transforms.RandomRotation(degrees=30, interpolation=transforms.InterpolationMode.BILINEAR, fill=0), # 각도 증가
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1), # hue jitter 추가
    transforms.ToTensor(),
])

# 검증/추론 변환 (일관성 유지)
try:
    # 학습에 사용된 모델과 동일한 구조의 모델 인스턴스를 사용하여 config를 가져옵니다.
    # model_train이 이미 생성되어 있으므로 이를 사용합니다.
    config = timm.data.resolve_data_config({}, model=model_train)
    print(f"Using timm's default validation transform config for {MODEL_NAME}.")
    print(f"Timm default input size: {config['input_size']}")
    print(f"Timm default mean: {config['mean']}, std: {config['std']}")

    val_inference_transform = transforms.Compose([
        transforms.Resize(IMG_SIZE, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.CenterCrop(IMG_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(mean=config['mean'], std=config['std'])
    ])
    train_transform.transforms.append(transforms.Normalize(mean=config['mean'], std=config['std']))

except Exception as e:
    print(f"Failed to get timm config ({e}), defining transforms manually (using ImageNet defaults).")
    _imagenet_mean = [0.485, 0.456, 0.406]
    _imagenet_std = [0.229, 0.224, 0.225]
    val_inference_transform = transforms.Compose([
        transforms.Resize(IMG_SIZE, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.CenterCrop(IMG_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(mean=_imagenet_mean, std=_imagenet_std)
    ])
    train_transform.transforms.append(transforms.Normalize(mean=_imagenet_mean, std=_imagenet_std))

# --- 6. 데이터셋 및 데이터 로더 준비 ---
from torchvision.datasets import ImageFolder

print(f"Loading datasets from: {TRAIN_DATA_DIR} and {VAL_DATA_DIR}")
try:
    train_dataset = ImageFolder(root=TRAIN_DATA_DIR, transform=train_transform)
    val_dataset = ImageFolder(root=VAL_DATA_DIR, transform=val_inference_transform) # 검증은 val_inference_transform 사용

    print(f"Found {len(train_dataset)} training images and {len(val_dataset)} validation images.")
    
    # CLASS_NAMES 업데이트 시도
    if not CLASS_NAMES or len(CLASS_NAMES) != NUM_CLASSES:
        print("CLASS_NAMES not predefined or mismatched. Attempting to use 'train_dataset.classes'.")
        if hasattr(train_dataset, 'classes') and len(train_dataset.classes) == NUM_CLASSES:
            CLASS_NAMES = train_dataset.classes
            print(f"Successfully set CLASS_NAMES from train_dataset: {CLASS_NAMES}")
        else:
            print(f"ERROR: Could not automatically determine CLASS_NAMES. Please define it manually and ensure it matches NUM_CLASSES={NUM_CLASSES}.")
            if hasattr(train_dataset, 'classes'):
                print(f"Found classes in dataset: {train_dataset.classes} (count: {len(train_dataset.classes)})")
            exit()
    elif len(CLASS_NAMES) != NUM_CLASSES:
        print(f"ERROR: Predefined CLASS_NAMES length ({len(CLASS_NAMES)}) does not match NUM_CLASSES ({NUM_CLASSES}).")
        exit()
    else:
        print(f"Using predefined CLASS_NAMES: {CLASS_NAMES}")


    class_counts = np.bincount([s[1] for s in train_dataset.samples])
    focal_loss_alpha = None
    if len(class_counts) == NUM_CLASSES:
        total_samples = sum(class_counts)
        class_weights_raw = [total_samples / count if count > 0 else 0 for count in class_counts]
        max_weight = max(class_weights_raw) if any(w > 0 for w in class_weights_raw) else 1
        if max_weight > 0:
            class_weights_normalized = [w / max_weight for w in class_weights_raw]
            focal_loss_alpha = torch.tensor(class_weights_normalized, device=DEVICE, dtype=torch.float32)
            print(f"Calculated Focal Loss alpha (normalized): {focal_loss_alpha.cpu().numpy()}")
        else:
            print("Warning: All class counts are zero. Cannot calculate Focal Loss alpha.")
    else:
        print(f"Warning: Number of found classes ({len(class_counts)}) in dataset does not match NUM_CLASSES ({NUM_CLASSES}). Focal Loss alpha set to None.")


except FileNotFoundError:
    print(f"Error: Data directory not found. Please check TRAIN_DATA_DIR and VAL_DATA_DIR.")
    exit()
except Exception as e:
    print(f"Error loading dataset: {e}")
    exit()

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True) # BATCH_SIZE는 학습과 동일하게 설정

# --- 7. 손실 함수, 옵티마이저, 스케줄러 정의 ---
if focal_loss_alpha is not None:
    criterion = FocalLoss(alpha=focal_loss_alpha, gamma=2.0).to(DEVICE)
    print("Using Focal Loss with calculated alpha.")
else:
    criterion = FocalLoss(gamma=2.0).to(DEVICE) # 또는 nn.CrossEntropyLoss().to(DEVICE)
    print("Using Focal Loss without alpha (or CrossEntropyLoss if preferred).")

optimizer = optim.AdamW(model_train.parameters(), lr=BASE_LR, weight_decay=WEIGHT_DECAY)
lr_scheduler = None
if EPOCHS > WARMUP_EPOCHS:
    lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS - WARMUP_EPOCHS, eta_min=BASE_LR * 0.01)
else: # WARMUP_EPOCHS >= EPOCHS 인 경우 스케줄러 없음
    print("Warning: WARMUP_EPOCHS >= EPOCHS. No LR scheduler will be used after warmup.")


# --- 8. Custom Dataset for Test Images (추가된 부분) ---
class TestImageDataset(Dataset):
    def __init__(self, csv_path, img_dir_root, transform=None):
        self.data_frame = pd.read_csv(csv_path)
        self.img_dir_root = img_dir_root
        self.transform = transform
        if 'img_path' not in self.data_frame.columns:
            print(f"Warning: 'img_path' column not found in {csv_path}. Assuming first column ('{self.data_frame.columns[0]}') contains image paths.")
            self.img_path_column = self.data_frame.columns[0]
        else:
            self.img_path_column = 'img_path'

    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        relative_img_path = self.data_frame.loc[idx, self.img_path_column]
        full_img_path = os.path.join(self.img_dir_root, relative_img_path)
        try:
            image = Image.open(full_img_path).convert('RGB')
        except FileNotFoundError:
            print(f"ERROR: Image not found at {full_img_path}. Check IMAGE_BASE_DIR and CSV paths.")
            # To avoid crashing the whole batch, return a dummy tensor and None path
            # This requires handling in the prediction loop if it occurs.
            # For now, we let it raise during dataset iteration if critical.
            raise
        except Exception as e:
            print(f"ERROR: Could not open image {full_img_path}: {e}")
            raise

        if self.transform:
            image = self.transform(image)
        return image, relative_img_path


# --- 9. Inference Function (추가된 부분) ---
def run_inference(
    model_path: str,
    model_architecture_name: str, # e.g., 'maxvit_xlarge_tf_384.in21k_ft_in1k'
    num_classes_inf: int,
    img_size_inf: int,
    device_inf: torch.device,
    test_csv_path_inf: str,
    image_base_dir_inf: str,
    output_csv_path_inf: str,
    class_names_inf: list,
    batch_size_inf: int = 16, # Can be different from training batch size
    num_workers_inf: int = 4
):
    print(f"\n--- Starting Inference ---")
    print(f"Loading model for inference: {model_architecture_name} from {model_path}")

    # Load the model structure
    model_inf = timm.create_model(model_architecture_name, pretrained=False, num_classes=num_classes_inf, img_size=img_size_inf)
    try:
        model_inf.load_state_dict(torch.load(model_path, map_location=device_inf))
        print(f"Successfully loaded model weights from: {model_path}")
    except FileNotFoundError:
        print(f"ERROR: Model file not found at {model_path}. Cannot run inference.")
        return
    except Exception as e:
        print(f"ERROR: Could not load model weights for inference: {e}")
        return

    model_inf = model_inf.to(device_inf)
    model_inf.eval()

    # Re-create the inference transform (should be same as val_transform)
    # We use the loaded model_inf to resolve config for robustness
    try:
        config_inf = timm.data.resolve_data_config({}, model=model_inf)
        current_inference_transform = transforms.Compose([
            transforms.Resize(img_size_inf, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.CenterCrop(img_size_inf),
            transforms.ToTensor(),
            transforms.Normalize(mean=config_inf['mean'], std=config_inf['std'])
        ])
        print("Using timm's default transform for inference based on loaded model.")
    except Exception as e:
        print(f"Failed to get timm config for inference model ({e}), defining manually (ImageNet defaults).")
        current_inference_transform = transforms.Compose([
            transforms.Resize(img_size_inf, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.CenterCrop(img_size_inf),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    print(f"Inference Transform: {current_inference_transform}")

    # DataLoader for Test Set
    print(f"Loading test data from: {test_csv_path_inf} with image base: {image_base_dir_inf}")
    try:
        test_dataset_inf = TestImageDataset(csv_path=test_csv_path_inf,
                                        img_dir_root=image_base_dir_inf,
                                        transform=current_inference_transform)
        test_loader_inf = DataLoader(test_dataset_inf,
                                 batch_size=batch_size_inf,
                                 shuffle=False,
                                 num_workers=num_workers_inf,
                                 pin_memory=True)
        print(f"Found {len(test_dataset_inf)} images for testing.")
    except FileNotFoundError:
        print(f"Error: Test CSV file not found at {test_csv_path_inf}. Cannot run inference.")
        return
    except Exception as e:
        print(f"Error creating test dataset/loader for inference: {e}")
        return

    # Prediction
    all_preds_indices = []
    all_filenames = []
    with torch.no_grad():
        for inputs, filenames_batch in tqdm(test_loader_inf, desc="Predicting"):
            inputs = inputs.to(device_inf)
            outputs = model_inf(inputs)
            _, predicted_indices = torch.max(outputs, 1)
            all_preds_indices.extend(predicted_indices.cpu().numpy())
            all_filenames.extend(list(filenames_batch))

    predicted_class_names_inf = [class_names_inf[idx] for idx in all_preds_indices]

    # Create submission DataFrame
    original_test_df = pd.read_csv(test_csv_path_inf)
    csv_img_path_col = test_dataset_inf.img_path_column

    prediction_map = {
        os.path.normpath(p): label
        for p, label in zip(all_filenames, predicted_class_names_inf)
    }
    mapped_labels = original_test_df[csv_img_path_col].apply(
        lambda x: prediction_map.get(os.path.normpath(x))
    )

    submission_df_final = pd.DataFrame()
    submission_df_final['ID'] = original_test_df[csv_img_path_col].apply(
        lambda x: os.path.splitext(os.path.basename(x))[0]
    )
    submission_df_final['rock_type'] = mapped_labels

    if submission_df_final['rock_type'].isnull().any():
        num_null = submission_df_final['rock_type'].isnull().sum()
        print(f"Warning: {num_null} images in the CSV did not get a prediction. Check for mismatches or missing images.")
        # Example: Fill with a default class if needed (e.g., the most frequent one or 'Etc')
        # default_class_for_nan = class_names_inf[0] # Or any other logic
        # submission_df_final['rock_type'].fillna(default_class_for_nan, inplace=True)
        # print(f"Filled {num_null} NaNs with '{default_class_for_nan}'.")


    os.makedirs(os.path.dirname(output_csv_path_inf), exist_ok=True)
    submission_df_final.to_csv(output_csv_path_inf, index=False)
    print(f"Inference complete. Predictions saved to: {output_csv_path_inf}")
    print(f"Sample predictions:\n{submission_df_final.head()}")
    print(f"--- Inference Finished ---")


# --- 10. 학습 및 검증 실행 (그리고 추론) ---
if __name__ == "__main__":
    print(f"\nStarting training process with {MODEL_NAME}...")
    print(f"Batch Size: {BATCH_SIZE}. If OOM errors, reduce it.")
    
    history = train_and_validate_best_f1(
        model=model_train, # Use the globally defined model for training
        train_loader=train_loader,
        val_loader=val_loader,
        optimizer=optimizer,
        criterion=criterion,
        epochs=EPOCHS,
        device=DEVICE,
        num_classes=NUM_CLASSES,
        save_dir=SAVE_DIR,
        model_name_base=MODEL_ARTIFACT_NAME_BASE,
        top_k=5,
        gradient_clipping=GRADIENT_CLIPPING,
        lr_scheduler=lr_scheduler,
        warmup_epochs=WARMUP_EPOCHS,
        base_lr=BASE_LR
    )

    print("\nTraining finished.")
    if history['val_macro_f1_scores']:
        best_f1_overall = max(history['val_macro_f1_scores'])
        best_epoch_overall = history['val_macro_f1_scores'].index(best_f1_overall) + 1
        print(f"  Overall Best Validation Macro F1 in history: {best_f1_overall:.4f}")
        print(f"  Achieved at Epoch: {best_epoch_overall}")
    else:
        print("No validation scores recorded in history.")

    # --- 자동 추론 실행 ---
    if history.get('best_model_path') and os.path.exists(history['best_model_path']):
        best_model_for_inference = history['best_model_path']
        # Extract F1 score from filename for output CSV name
        try:
            f1_score_from_filename = float(os.path.basename(best_model_for_inference).split('_f1_')[1].replace('.pth',''))
            output_csv_filename = f"submission_{MODEL_ARTIFACT_NAME_BASE}_f1_{f1_score_from_filename:.4f}.csv"
        except: # Fallback if filename parsing fails
            output_csv_filename = f"submission_{MODEL_ARTIFACT_NAME_BASE}_best.csv"
        
        final_output_csv_path = os.path.join(SAVE_DIR, output_csv_filename) # Save submission in the model save directory

        if not CLASS_NAMES:
             print("ERROR: CLASS_NAMES are not set. Cannot run inference. Please define them correctly.")
        else:
            run_inference(
                model_path=best_model_for_inference,
                model_architecture_name=MODEL_NAME, # Use the same architecture as training
                num_classes_inf=NUM_CLASSES,
                img_size_inf=IMG_SIZE,
                device_inf=DEVICE,
                test_csv_path_inf=TEST_CSV_PATH,
                image_base_dir_inf=IMAGE_BASE_DIR,
                output_csv_path_inf=final_output_csv_path,
                class_names_inf=CLASS_NAMES,
                batch_size_inf=BATCH_SIZE * 2, # Inference can often use larger batch size
                num_workers_inf=NUM_WORKERS
            )
    else:
        print("\nNo best model path found or model file does not exist. Skipping inference.")

  from .autonotebook import tqdm as notebook_tqdm


Loading model: maxvit_xlarge_tf_384.in21k_ft_in1k
Model classifier adapted for 7 classes.
Using timm's default validation transform config for maxvit_xlarge_tf_384.in21k_ft_in1k.
Timm default input size: (3, 384, 384)
Timm default mean: (0.5, 0.5, 0.5), std: (0.5, 0.5, 0.5)
Loading datasets from: /home/metaai2/workspace/limseunghwan/open/train and /home/metaai2/workspace/limseunghwan/open/val
Found 342015 training images and 38005 validation images.
Using predefined CLASS_NAMES: ['Andesite', 'Basalt', 'Etc', 'Gneiss', 'Granite', 'Mud_Sandstone', 'Weathered_Rock']
Calculated Focal Loss alpha (normalized): [0.36379087 0.59434706 1.         0.21558282 0.17148152 0.17810482
 0.4287038 ]
Using Focal Loss with calculated alpha.

Starting training process with maxvit_xlarge_tf_384.in21k_ft_in1k...
Batch Size: 2. If OOM errors, reduce it.
학습 시작: 총 50 에폭, Device: cuda
Top-5 모델 저장 디렉토리: ./saved_models_maxvit_xlarge_tf_384_384
평가 기준: Validation Macro F1 Score


                                                                                                   

KeyboardInterrupt: 

In [None]:
# MaxViT_inference.py
import torch
import torch.nn as nn # Not strictly needed for inference if model is loaded, but good practice
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import timm
import pandas as pd
from PIL import Image
import os
from tqdm import tqdm
import numpy as np # For potential use, e.g. if needing to map class indices to names manually

# --- 1. Configuration (Adjust these paths and parameters) ---
# Model and Data Parameters (should match how the loaded model was trained)
MODEL_ARCHITECTURE_NAME = 'maxvit_xlarge_tf_384.in21k_ft_in1k' # Architecture of your saved model
NUM_CLASSES = 7        # Number of classes your model was trained on
IMG_SIZE = 384         # Image size your model was trained with

# --- Paths ---
# TODO: IMPORTANT! Update this to the path of YOUR TRAINED MaxViT model .pth file
SAVED_MODEL_PATH = './saved_models_maxvit_xlarge_384/maxvit_xlarge_tf_384_epoch15_f1_0.9192.pth' # EXAMPLE PATH!
TEST_CSV_PATH = r'/home/metaai2/workspace/limseunghwan/open/test.csv'    # Path to your test.csv
IMAGE_BASE_DIR = r'/home/metaai2/workspace/limseunghwan/open'           # Base directory for images in test.csv
OUTPUT_CSV_DIR = './submissions_maxvit/maxvit_xlarge_tf_384_epoch15_f1_0.9192' # Directory to save the output CSV
# OUTPUT_CSV_PATH will be generated dynamically based on model name.

# --- Inference Parameters ---
BATCH_SIZE_INFERENCE = 4 # Adjust based on your GPU memory (MaxViT XLarge is demanding)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_WORKERS_INFERENCE = os.cpu_count() // 2 if os.cpu_count() else 4

# --- Class Names (CRITICAL!) ---
# TODO: ***매우 중요*** 당신의 데이터셋에 맞는 실제 클래스 이름을 순서대로 정의하세요.
#       이 순서는 모델이 학습될 때 클래스에 할당된 인덱스와 일치해야 합니다.
#       (예: train_dataset.classes 에서 가져온 순서)
CLASS_NAMES = ['Andesite', 'Basalt', 'Etc', 'Gneiss', 'Granite', 'Mud_Sandstone', 'Weathered_Rock'] # 예시 값, 반드시 수정!

if len(CLASS_NAMES) != NUM_CLASSES:
    raise ValueError(f"Number of CLASS_NAMES ({len(CLASS_NAMES)}) does not match NUM_CLASSES ({NUM_CLASSES}). Please check your configuration.")

# --- 2. Model Loading ---
print(f"Loading model architecture: {MODEL_ARCHITECTURE_NAME}")
# Create the model structure (pretrained=False because we load our own weights)
model = timm.create_model(
    MODEL_ARCHITECTURE_NAME,
    pretrained=False,
    num_classes=NUM_CLASSES,
    img_size=IMG_SIZE  # Ensure img_size is passed if model supports/requires it at creation
)

print(f"Loading trained weights from: {SAVED_MODEL_PATH}")
if not os.path.exists(SAVED_MODEL_PATH):
    print(f"ERROR: Model weights file not found at {SAVED_MODEL_PATH}. Please check the path.")
    exit()

try:
    # Load the state dictionary
    model.load_state_dict(torch.load(SAVED_MODEL_PATH, map_location=DEVICE))
    print(f"Successfully loaded model weights onto {DEVICE}.")
except Exception as e:
    print(f"ERROR: Could not load model weights: {e}")
    exit()

model = model.to(DEVICE)
model.eval() # Set the model to evaluation mode
print("Model ready for inference.")

# --- 3. Data Transformations for Inference ---
# Re-create the validation/inference transform used during training.
# It's best if this matches exactly what was used for the validation set.
print("Defining inference transform...")
try:
    # Attempt to use timm's recommended settings for the loaded model
    # Pass the loaded model instance to resolve_data_config
    config = timm.data.resolve_data_config({}, model=model)
    # Override input_size to be sure it matches your training
    config['input_size'] = (3, IMG_SIZE, IMG_SIZE) # (C, H, W)
    
    inference_transform = timm.data.create_transform(**config, is_training=False)
    print(f"Using timm's default transform based on loaded model config: Mean={config['mean']}, Std={config['std']}")
except Exception as e:
    print(f"Failed to get timm config for inference model ({e}). Defining transform manually using ImageNet defaults.")
    inference_transform = transforms.Compose([
        transforms.Resize(IMG_SIZE, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.CenterCrop(IMG_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet defaults
    ])
print(f"Inference Transform: {inference_transform}")


# --- 4. Custom Dataset for Test Images ---
class TestImageDataset(Dataset):
    def __init__(self, csv_path, img_dir_root, transform=None):
        self.data_frame = pd.read_csv(csv_path)
        self.img_dir_root = img_dir_root
        self.transform = transform
        # Determine the image path column name
        if 'img_path' not in self.data_frame.columns:
            print(f"Warning: 'img_path' column not found in {csv_path}. Assuming first column ('{self.data_frame.columns[0]}') contains image paths.")
            self.img_path_column = self.data_frame.columns[0]
        else:
            self.img_path_column = 'img_path'
        print(f"Using column '{self.img_path_column}' from CSV for image paths.")

    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        relative_img_path = self.data_frame.loc[idx, self.img_path_column]
        full_img_path = os.path.join(self.img_dir_root, relative_img_path)

        try:
            image = Image.open(full_img_path).convert('RGB')
        except FileNotFoundError:
            print(f"ERROR: Image not found at {full_img_path}. Check IMAGE_BASE_DIR and CSV paths.")
            # To prevent crashing the DataLoader, one might return a placeholder or skip.
            # For now, raising an error is fine to highlight the issue.
            raise
        except Exception as e:
            print(f"ERROR: Could not open image {full_img_path}: {e}")
            raise

        if self.transform:
            image = self.transform(image)

        return image, relative_img_path # Return image and its original path for mapping

# --- 5. DataLoader for Test Set ---
print(f"Loading test data from CSV: {TEST_CSV_PATH}")
print(f"Image base directory: {IMAGE_BASE_DIR}")
try:
    test_dataset = TestImageDataset(csv_path=TEST_CSV_PATH,
                                    img_dir_root=IMAGE_BASE_DIR,
                                    transform=inference_transform)
    if len(test_dataset) == 0:
        print(f"ERROR: No images found or loaded from {TEST_CSV_PATH}. Please check the CSV and image paths.")
        exit()
        
    test_loader = DataLoader(test_dataset,
                             batch_size=BATCH_SIZE_INFERENCE,
                             shuffle=False, # No need to shuffle for inference
                             num_workers=NUM_WORKERS_INFERENCE,
                             pin_memory=True)
    print(f"Successfully created DataLoader with {len(test_dataset)} images for testing.")
except FileNotFoundError:
    print(f"ERROR: Test CSV file not found at {TEST_CSV_PATH}. Please check the path.")
    exit()
except Exception as e:
    print(f"ERROR: Could not create test dataset or DataLoader: {e}")
    exit()

# --- 6. Prediction Function ---
def predict_on_test_data(model_to_predict, loader, device, class_names_map):
    model_to_predict.eval() # Ensure model is in eval mode
    all_predicted_indices = []
    all_original_filenames = [] # To store the original filenames/paths from CSV

    with torch.no_grad(): # Disable gradient calculations for inference
        for images_batch, filenames_batch in tqdm(loader, desc="Predicting"):
            images_batch = images_batch.to(device)
            
            outputs = model_to_predict(images_batch)
            _, predicted_indices_batch = torch.max(outputs, 1) # Get the index of the max log-probability

            all_predicted_indices.extend(predicted_indices_batch.cpu().numpy())
            all_original_filenames.extend(list(filenames_batch)) # filenames_batch is a tuple of strings

    # Convert predicted indices to actual class names
    predicted_class_names_list = [class_names_map[idx] for idx in all_predicted_indices]
    
    return all_original_filenames, predicted_class_names_list

# --- 7. Run Inference and Save Results ---
if __name__ == "__main__":
    print("\nStarting inference process...")

    # Generate output CSV path
    os.makedirs(OUTPUT_CSV_DIR, exist_ok=True)
    model_filename_base = os.path.splitext(os.path.basename(SAVED_MODEL_PATH))[0]
    output_csv_filename = f"submission_{model_filename_base}.csv"
    final_output_csv_path = os.path.join(OUTPUT_CSV_DIR, output_csv_filename)

    original_paths_from_loader, string_predictions = predict_on_test_data(
        model,
        test_loader,
        DEVICE,
        CLASS_NAMES
    )

    # Create a DataFrame for submission
    # Ensure the order of predictions matches the order of images in the original test.csv
    
    # Load the original test CSV to get the correct order and 'ID' column format if needed
    original_test_df = pd.read_csv(TEST_CSV_PATH)
    
    # Create a mapping from the image paths returned by DataLoader to their predictions
    # Use os.path.normpath to handle potential path separator differences (e.g. / vs \)
    prediction_map = {
        os.path.normpath(path): pred_label 
        for path, pred_label in zip(original_paths_from_loader, string_predictions)
    }

    # Map predictions back to the original CSV's image paths
    # This ensures that even if DataLoader reorders (it shouldn't with shuffle=False),
    # or if there are missing images handled gracefully, the mapping is correct.
    csv_img_path_col = test_dataset.img_path_column # Get the column name used for image paths
    
    mapped_predictions = original_test_df[csv_img_path_col].apply(
        lambda x: prediction_map.get(os.path.normpath(x))
    )

    # Create the submission DataFrame
    submission_df = pd.DataFrame()
    # The 'ID' column in submission usually requires the filename without extension
    submission_df['ID'] = original_test_df[csv_img_path_col].apply(
        lambda x: os.path.splitext(os.path.basename(x))[0]
    )
    submission_df['rock_type'] = mapped_predictions # Use the mapped predictions

    # Check for any images that didn't get a prediction (should not happen if all images load)
    if submission_df['rock_type'].isnull().any():
        num_null = submission_df['rock_type'].isnull().sum()
        print(f"WARNING: {num_null} images from the CSV did not receive a prediction. "
              "This might be due to missing image files or errors during loading. "
              "Consider filling NaNs if appropriate for the submission.")
        # Example: Fill with a default class if needed (e.g., the most frequent one or 'Etc')
        # default_class_for_nan = CLASS_NAMES[0] # Or any other logic
        # submission_df['rock_type'].fillna(default_class_for_nan, inplace=True)
        # print(f"Filled {num_null} NaNs with '{default_class_for_nan}'.")


    # Save the submission file
    submission_df.to_csv(final_output_csv_path, index=False)
    print(f"\nInference complete. Predictions saved to: {final_output_csv_path}")
    print(f"Sample of the submission file:\n{submission_df.head()}")

Loading model architecture: maxvit_xlarge_tf_384.in21k_ft_in1k
Loading trained weights from: ./saved_models_maxvit_xlarge_384/maxvit_xlarge_tf_384_epoch15_f1_0.9192.pth
Successfully loaded model weights onto cuda.
Model ready for inference.
Defining inference transform...
Using timm's default transform based on loaded model config: Mean=(0.5, 0.5, 0.5), Std=(0.5, 0.5, 0.5)
Inference Transform: Compose(
    Resize(size=(384, 384), interpolation=bicubic, max_size=None, antialias=True)
    CenterCrop(size=(384, 384))
    MaybeToTensor()
    Normalize(mean=tensor([0.5000, 0.5000, 0.5000]), std=tensor([0.5000, 0.5000, 0.5000]))
)
Loading test data from CSV: /home/metaai2/workspace/limseunghwan/open/test.csv
Image base directory: /home/metaai2/workspace/limseunghwan/open
Using column 'img_path' from CSV for image paths.
Successfully created DataLoader with 95006 images for testing.

Starting inference process...


Predicting:   8%|▊         | 1841/23752 [03:07<37:15,  9.80it/s]


KeyboardInterrupt: 

: 