In [10]:
import pandas as pd
import torch
import os
import timm
import glob
import numpy as np
import safetensors.torch
from datasets import load_from_disk
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from tqdm import tqdm

TEST_DATA_PATH = "processed_bird_test_data"
BASE_CHECKPOINT_DIR = "old_hyperparam_model_checkpoints"
MODEL_NAME = "coatnet_0_rw_224"
OUTPUT_FILENAME = "submission_old_model_fixed.csv"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Using device: {DEVICE}")

if not os.path.exists(BASE_CHECKPOINT_DIR):
    if os.path.exists(f"../{BASE_CHECKPOINT_DIR}"):
        BASE_CHECKPOINT_DIR = f"../{BASE_CHECKPOINT_DIR}"
    else:
        raise FileNotFoundError(f"Folder {BASE_CHECKPOINT_DIR} not found")

checkpoints = glob.glob(os.path.join(BASE_CHECKPOINT_DIR, "checkpoint-*"))
valid_checkpoints = []

print(f"Searching for weights in {BASE_CHECKPOINT_DIR}...")
for ckpt in checkpoints:
    if os.path.exists(os.path.join(ckpt, "model.safetensors")) or os.path.exists(os.path.join(ckpt, "pytorch_model.bin")):
        try:
            step_num = int(ckpt.split("-")[-1])
            valid_checkpoints.append((step_num, ckpt))
        except ValueError: continue

if not valid_checkpoints:
    raise FileNotFoundError("No valid checkpoints found.")

valid_checkpoints.sort(key=lambda x: x[0], reverse=True)
best_ckpt_step, WEIGHTS_PATH = valid_checkpoints[0]
print(f"Selected checkpoint: {WEIGHTS_PATH}")

print(f"Loading test data form {TEST_DATA_PATH}...")
dataset_raw = load_from_disk(TEST_DATA_PATH)
if isinstance(dataset_raw, dict) and "test" in dataset_raw:
    test_ds = dataset_raw["test"]
else:
    test_ds = dataset_raw

if "id" in test_ds.column_names:
    submission_ids = list(test_ds["id"])
    print(f"Found {len(submission_ids)} IDs in dataset.")
else:
    print("Warning: 'id' column not found. Creating sequential IDs.")
    submission_ids = list(range(len(test_ds)))

normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
val_transform = Compose([
    Resize(256),
    CenterCrop(224),
    ToTensor(),
    normalize,
])

class SimpleTestDataset(Dataset):
    def __init__(self, hf_dataset, transform=None):
        self.hf_dataset = hf_dataset
        self.transform = transform

    def __len__(self):
        return len(self.hf_dataset)

    def __getitem__(self, idx):
        img = self.hf_dataset[idx]["image"]
        img = img.convert("RGB")
        
        if self.transform:
            img = self.transform(img)
        return img

test_loader = DataLoader(
    SimpleTestDataset(test_ds, transform=val_transform),
    batch_size=32,
    shuffle=False,
    num_workers=0
)

print(f"Initializing model {MODEL_NAME}...")
model = timm.create_model(MODEL_NAME, pretrained=False, num_classes=200)

path_safe = os.path.join(WEIGHTS_PATH, "model.safetensors")
path_bin = os.path.join(WEIGHTS_PATH, "pytorch_model.bin")

if os.path.exists(path_safe):
    print("Loading SafeTensors...")
    state_dict = safetensors.torch.load_file(path_safe, device=DEVICE)
else:
    print("Loading PyTorch Bin...")
    state_dict = torch.load(path_bin, map_location=DEVICE)

model.load_state_dict(state_dict)
model.to(DEVICE)
model.eval()

all_preds = []
print("Running prediction...")

with torch.no_grad():
    for imgs in tqdm(test_loader, desc="Predicting"):
        imgs = imgs.to(DEVICE)
        
        logits = model(imgs)
        preds = torch.argmax(logits, dim=1).cpu().numpy()
        all_preds.extend(preds)

final_preds = np.array(all_preds) + 1 

if len(final_preds) != len(submission_ids):
    print(f"ERROR: Mismatch! IDs: {len(submission_ids)} vs Preds: {len(final_preds)}")
else:
    df = pd.DataFrame({
        "id": submission_ids,
        "label": final_preds
    })
    
    df = df.sort_values(by="id")
    
    df.to_csv(OUTPUT_FILENAME, index=False)
    print(f"Saved {OUTPUT_FILENAME}")
    print(df.head())

Using device: cpu
Searching for weights in old_hyperparam_model_checkpoints...
Selected checkpoint: old_hyperparam_model_checkpoints\checkpoint-1785
Loading test data form processed_bird_test_data...
Found 4000 IDs in dataset.
Initializing model coatnet_0_rw_224...
Loading SafeTensors...
Running prediction...


Predicting: 100%|████████████████████████████████████████████████████████████████████| 125/125 [05:55<00:00,  2.85s/it]

Saved submission_old_model_fixed.csv
   id  label
0   1     69
1   2     60
2   3     52
3   4     12
4   5     74



