# One-Class Soil Classification: Training Phase (CLIP Prototype Method)

**Author:** Siddhant Bhardwaj
**Team Name:** Siddhant Bhardwaj
**Team Members:** Siddhant Bhardwaj, Sivadhanushya 
**Leaderboard Rank:** 36 

In [None]:
"""

Author: Siddhant Bhardwaj
Team Name: Siddhant Bhardwaj
Team Members: Siddhant, Sivadhanushya
Leaderboard Rank: 36
"""

In [None]:
# In a Jupyter Notebook Cell 2 (Code)
# --- 1. Setup and Configuration ---
print("--- [1] Initializing: Importing Libraries and Setting Configuration ---")
import pandas as pd
import numpy as np
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPModel
from sklearn.metrics.pairwise import cosine_similarity
from IPython.display import display, HTML

# Configuration (should match your other files)
BASE_PATH = '/kaggle/input/soil-classification-part-2/soil_competition-2025'
TRAIN_LABELS_PATH = f'{BASE_PATH}/train_labels.csv'
TRAIN_IMG_PATH = f'{BASE_PATH}/train'
CLIP_MODEL_NAME = "openai/clip-vit-base-patch16"
SIMILARITY_THRESHOLD_PERCENTILE = 5.5
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_NAME_FOR_FILE = CLIP_MODEL_NAME.split('/')[-1].replace("-", "_")

# Output files from this "training" phase
PROTOTYPE_FILE = f'soil_prototype_{MODEL_NAME_FOR_FILE}.npy'
THRESHOLD_FILE = f'similarity_threshold_{MODEL_NAME_FOR_FILE}_p{SIMILARITY_THRESHOLD_PERCENTILE:.1f}.txt'
TRAINING_SIMILARITIES_PLOT_FILE = f'training_similarities_dist_{MODEL_NAME_FOR_FILE}_p{SIMILARITY_THRESHOLD_PERCENTILE:.1f}.png'

print(f"Device: {DEVICE}, CLIP Model: {CLIP_MODEL_NAME}, Percentile: {SIMILARITY_THRESHOLD_PERCENTILE}%")
print(f"Will save prototype to: {PROTOTYPE_FILE}")
print(f"Will save threshold to: {THRESHOLD_FILE}")

# In a Jupyter Notebook Cell 3 (Code)
# --- 2. Load CLIP Model & Processor ---
print("\n--- [2] Loading CLIP Model ---")
try:
    clip_model = CLIPModel.from_pretrained(CLIP_MODEL_NAME).to(DEVICE)
    clip_processor = CLIPProcessor.from_pretrained(CLIP_MODEL_NAME)
    clip_model.eval()
    print("CLIP model and processor loaded.")
except Exception as e:
    print(f"Error loading CLIP model: {e}"); raise e

# In a Jupyter Notebook Cell 4 (Code)
# --- 3. Load Training Data & Display Samples ---
print("\n--- [3] Loading Training Data ---")
try:
    train_df = pd.read_csv(TRAIN_LABELS_PATH)
    train_image_ids_original = train_df['image_id'].tolist()
    if not train_image_ids_original: raise ValueError("No training IDs.")
    print(f"Loaded {len(train_image_ids_original)} training image IDs.")
except Exception as e:
    print(f"Error loading training data: {e}"); raise e

# (Code to display sample training images - same as in the main script, for brevity I'll skip pasting it here again)
# ... Ensure you include the sample image display code ...
print("Displaying sample training images (refer to main script's plotting for full code)...")
# For actual notebook, copy the image display cell from the main script here.

# In a Jupyter Notebook Cell 5 (Code)
# --- 4. CLIP Embedding Function and Extraction for Training Data ---
print("\n--- [4] Extracting Training Image Embeddings ---")
# (Function definition for get_clip_image_embeddings - same as in the main script)
def get_clip_image_embeddings(image_ids_list, img_directory, model, processor, device, model_name_desc):
    embeddings_list = []
    valid_ids_processed = []
    model.eval()
    for img_id in tqdm(image_ids_list, desc=f"Extracting Embeddings ({model_name_desc})"):
        img_path = os.path.join(img_directory, img_id)
        if not os.path.exists(img_path): continue
        try:
            image = Image.open(img_path).convert("RGB")
            inputs = processor(text=None, images=image, return_tensors="pt", padding=True).to(device)
            with torch.no_grad():
                image_features = model.get_image_features(pixel_values=inputs.pixel_values)
            embeddings_list.append(image_features.cpu().numpy().flatten())
            valid_ids_processed.append(img_id)
        except Exception as e:
            print(f"Warning: Error processing '{img_id}': {e}. Skipping.")
            continue
    if not embeddings_list: return np.array([]), []
    return np.array(embeddings_list), valid_ids_processed

train_clip_embeddings, _ = get_clip_image_embeddings(
    train_image_ids_original, TRAIN_IMG_PATH, clip_model, clip_processor, DEVICE, MODEL_NAME_FOR_FILE
)
if train_clip_embeddings.shape[0] == 0: 
    print("Error: No training embeddings extracted."); raise ValueError("No training embeddings")
print(f"Training embeddings shape: {train_clip_embeddings.shape}")

# In a Jupyter Notebook Cell 6 (Code)
# --- 5. Calculate Soil Prototype & Save ---
print("\n--- [5] Calculating and Saving Soil Prototype ---")
soil_prototype = np.mean(train_clip_embeddings, axis=0)
np.save(PROTOTYPE_FILE, soil_prototype)
print(f"Soil prototype calculated. Shape: {soil_prototype.shape}. Saved to {PROTOTYPE_FILE}")

# In a Jupyter Notebook Cell 7 (Code)
# --- 6. Determine & Save Similarity Threshold ---
print("\n--- [6] Determining and Saving Similarity Threshold ---")
similarities_to_prototype_train = cosine_similarity(
    train_clip_embeddings, soil_prototype.reshape(1, -1)
).flatten()

similarity_threshold = np.percentile(similarities_to_prototype_train, SIMILARITY_THRESHOLD_PERCENTILE)
with open(THRESHOLD_FILE, 'w') as f:
    f.write(str(similarity_threshold))
print(f"Similarity Threshold ({SIMILARITY_THRESHOLD_PERCENTILE}th percentile): {similarity_threshold:.6f}. Saved to {THRESHOLD_FILE}")

# Plot and save distribution
plt.figure(figsize=(10, 6))
sns.histplot(similarities_to_prototype_train, kde=True, bins=50, color="forestgreen")
plt.title(f'Training Similarities to {MODEL_NAME_FOR_FILE} Prototype (Pctl: {SIMILARITY_THRESHOLD_PERCENTILE}%)', fontsize=12)
plt.xlabel('Cosine Similarity'); plt.ylabel('Frequency')
plt.axvline(similarity_threshold, color='r', linestyle='dashed', label=f'Threshold: {similarity_threshold:.4f}')
plt.legend(); plt.grid(True, linestyle='--', alpha=0.6);
plt.savefig(TRAINING_SIMILARITIES_PLOT_FILE)
plt.show()
print(f"Training similarities distribution plot saved to {TRAINING_SIMILARITIES_PLOT_FILE}")
print("--- 'Training' Phase Complete: Prototype and Threshold are ready. ---")