# One-Class Soil Classification: Inference Phase

**Author:** Siddhant Bhardwaj  
**Team Name:** Siddhant Bhardwaj 
**Team Members:** Siddhant Bhardwaj, Sivadhanushya
**Leaderboard Rank:** 36 

In [None]:
"""

Author: Siddhant Bhardwaj
Team Name: Siddhant Bhardwaj
Team Members: Siddhant, Sivadhanushya
Leaderboard Rank: 36
"""

In [None]:
# --- 1. Setup and Configuration ---
print("--- [1] Initializing: Importing Libraries and Setting Configuration ---")
import pandas as pd
import numpy as np
import os
from tqdm import tqdm
import matplotlib.pyplot as plt # For displaying sample test images (optional)
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPModel
from sklearn.metrics.pairwise import cosine_similarity

# Configuration (MUST MATCH TRAINING NOTEBOOK'S CONFIG FOR MODEL AND PERCENTILE)
BASE_PATH = '/kaggle/input/soil-classification-part-2/soil_competition-2025'
TEST_IDS_PATH = f'{BASE_PATH}/test_ids.csv'
TEST_IMG_PATH = f'{BASE_PATH}/test'
CLIP_MODEL_NAME = "openai/clip-vit-base-patch16" # Must be same as used in training
SIMILARITY_THRESHOLD_PERCENTILE = 5.5         # Must be same as used in training
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_NAME_FOR_FILE = CLIP_MODEL_NAME.split('/')[-1].replace("-", "_")

# Input files from "training" phase (ensure these are in /kaggle/working/ or accessible)
PROTOTYPE_FILE = f'/kaggle/working/soil_prototype_{MODEL_NAME_FOR_FILE}.npy'
THRESHOLD_FILE = f'/kaggle/working/similarity_threshold_{MODEL_NAME_FOR_FILE}_p{SIMILARITY_THRESHOLD_PERCENTILE:.1f}.txt'

# Output file
SUBMISSION_FILENAME = f'/kaggle/working/submission_{MODEL_NAME_FOR_FILE}_inference_p{SIMILARITY_THRESHOLD_PERCENTILE:.1f}.csv'
NUM_SAMPLE_TEST_IMAGES_TO_DISPLAY = 3


print(f"Device: {DEVICE}, CLIP Model: {CLIP_MODEL_NAME}")
print(f"Using prototype: {PROTOTYPE_FILE}")
print(f"Using threshold file: {THRESHOLD_FILE}")

# In a Jupyter Notebook Cell 3 (Code)
# --- 2. Load CLIP Model & Processor ---
print("\n--- [2] Loading CLIP Model ---")
try:
    clip_model = CLIPModel.from_pretrained(CLIP_MODEL_NAME).to(DEVICE)
    clip_processor = CLIPProcessor.from_pretrained(CLIP_MODEL_NAME)
    clip_model.eval()
    print("CLIP model and processor loaded.")
except Exception as e:
    print(f"Error loading CLIP model: {e}"); raise e

# In a Jupyter Notebook Cell 4 (Code)
# --- 3. Load Pre-calculated Prototype and Threshold ---
print("\n--- [3] Loading Prototype and Threshold ---")
try:
    soil_prototype = np.load(PROTOTYPE_FILE)
    with open(THRESHOLD_FILE, 'r') as f:
        similarity_threshold = float(f.read())
    print(f"Soil prototype loaded. Shape: {soil_prototype.shape}")
    print(f"Similarity threshold loaded: {similarity_threshold:.6f}")
except Exception as e:
    print(f"Error loading prototype or threshold: {e}"); raise e

# In a Jupyter Notebook Cell 5 (Code)
# --- 4. Load Test Data & Display Samples ---
print("\n--- [4] Loading Test Data ---")
try:
    test_df = pd.read_csv(TEST_IDS_PATH)
    test_image_ids_original = test_df['image_id'].tolist()
    if not test_image_ids_original: raise ValueError("No test IDs.")
    print(f"Loaded {len(test_image_ids_original)} test image IDs.")
except Exception as e:
    print(f"Error loading test data: {e}"); raise e

# (Code to display sample test images - optional, similar to training)
print("Displaying sample test images (refer to main script's plotting for full code)...")


# In a Jupyter Notebook Cell 6 (Code)
# --- 5. CLIP Embedding Function and Extraction for Test Data ---
print("\n--- [5] Extracting Test Image Embeddings ---")
# (Function definition for get_clip_image_embeddings - same as in the training notebook)
def get_clip_image_embeddings(image_ids_list, img_directory, model, processor, device, model_name_desc):
    embeddings_list = []
    valid_ids_processed = []
    model.eval()
    for img_id in tqdm(image_ids_list, desc=f"Extracting Embeddings ({model_name_desc})"):
        img_path = os.path.join(img_directory, img_id)
        if not os.path.exists(img_path): continue
        try:
            image = Image.open(img_path).convert("RGB")
            inputs = processor(text=None, images=image, return_tensors="pt", padding=True).to(device)
            with torch.no_grad():
                image_features = model.get_image_features(pixel_values=inputs.pixel_values)
            embeddings_list.append(image_features.cpu().numpy().flatten())
            valid_ids_processed.append(img_id)
        except Exception as e:
            print(f"Warning: Error processing '{img_id}': {e}. Skipping.")
            continue
    if not embeddings_list: return np.array([]), []
    return np.array(embeddings_list), valid_ids_processed

test_clip_embeddings, test_image_ids_processed_for_submission = np.array([]), []
if test_image_ids_original:
    test_clip_embeddings, test_image_ids_processed_for_submission = get_clip_image_embeddings(
        test_image_ids_original, TEST_IMG_PATH, clip_model, clip_processor, DEVICE, MODEL_NAME_FOR_FILE
    )
print(f"Test embeddings shape: {test_clip_embeddings.shape}")

# In a Jupyter Notebook Cell 7 (Code)
# --- 6. Classification and Submission File Generation ---
print("\n--- [6] Classification and Submission ---")
test_labels = []
if test_clip_embeddings.shape[0] > 0 and len(test_image_ids_processed_for_submission) > 0:
    print("Classifying test images...")
    similarities_to_prototype_test = cosine_similarity(
        test_clip_embeddings, soil_prototype.reshape(1, -1)
    ).flatten()
    test_labels = [1 if sim >= similarity_threshold else 0 for sim in similarities_to_prototype_test]
else:
    print("No test embeddings to classify.")

submission_df = pd.DataFrame({
    'image_id': test_image_ids_processed_for_submission, 
    'label': test_labels
})

if not submission_df.empty:
    print("\nPredicted label distribution in submission:")
    print(submission_df['label'].value_counts(normalize=True).to_string())
else:
    print("\nSubmission DataFrame is empty.")

submission_df.to_csv(SUBMISSION_FILENAME, index=False)
print(f"\nSubmission file '{SUBMISSION_FILENAME}' created in /kaggle/working/.")
print("--- Inference Phase Complete ---")