Getting the dataset together
After doing some data cleaning and validation, the records should be in the same format and complete. As much as was possible with this kind of data.


In [11]:
!python -m pip show open_clip_torch

Name: open_clip_torch
Version: 3.2.0
Summary: Open reproduction of consastive language-image pretraining (CLIP) and related.
Home-page: https://github.com/mlfoundations/open_clip
Author: Gabriel Ilharco, Mitchell Wortsman, Romain Beaumont
Author-email: Ross Wightman <ross@huggingface.co>
License: MIT
Location: c:\Users\Sascha\Desktop\LegaSea Model\fossil_env_311\Lib\site-packages
Requires: ftfy, huggingface-hub, regex, safetensors, timm, torch, torchvision, tqdm
Required-by: 


In [1]:
import pandas as pd
import requests
import os
import re
import csv
from tqdm import tqdm

# === CONFIGURATION ===
EXCEL_FILE = "Dataset\SelectiveData.xlsx"              # Input Excel file
IMAGE_URL_COLUMN = "images0"        # Column with image URLs
METADATA_COLUMNS = ["reviewer_notes", "category", "type_code", "ShortOrLong"]  # Other columns you want in mapping
OUTPUT_FOLDER = "Dataset\images"   # Folder to save images
OUTPUT_MAPPING_FILE = "Dataset\images_mapping.csv"  # Output mapping file
# =====================
# Define a simple function to remove HTML tags
def strip_tags(text):
    if isinstance(text, str):
        return re.sub(r"<[^>]*>", "", text)
    return text

# Create output folder
os.makedirs(OUTPUT_FOLDER, exist_ok=True)

# Read Excel
df = pd.read_excel(EXCEL_FILE, dtype=str)
df = df.applymap(strip_tags)

# i wanna resume after pausing
# Load existing mapping file if it exists
if os.path.exists(OUTPUT_MAPPING_FILE):
    existing_mapping = pd.read_csv(OUTPUT_MAPPING_FILE)
    if "url" in existing_mapping.columns:
        downloaded_urls = set(existing_mapping["url"].dropna().tolist())
    else:
        print("‚ö†Ô∏è 'url' column missing in existing mapping file. Continuing without skip logic.")
        downloaded_urls = set()
    counter = len(existing_mapping) + 1  # Continue numbering
    print(f"‚ÑπÔ∏è Found existing mapping with {len(existing_mapping)} images. Resuming at {counter:04d}.")
else:
    existing_mapping = pd.DataFrame()
    downloaded_urls = set()
    counter = 1
    print("‚ÑπÔ∏è No existing mapping found. Starting fresh.")

records = []


for _, row in tqdm(df.iterrows(), total=len(df)):
    url = row[IMAGE_URL_COLUMN]

# Skip missing or invalid URLs
    if pd.isna(url) or not str(url).startswith("http"):
        print(f"‚ö†Ô∏è Skipping invalid URL: {url}")
        continue
    
    if url in downloaded_urls:
        continue
    
    url = str(url)
    image_id = f"{counter:04d}"
    
    ext = os.path.splitext(url.split("?")[0])[1].lower()
    if ext not in [".jpg", ".jpeg", ".png", ".gif", ".webp"]:
        ext = ".jpg"
    
    filename = f"{image_id}{ext}"
    filepath = os.path.join(OUTPUT_FOLDER, filename)

    try:
        # Download image
        response = requests.get(url, timeout=10)
        response.raise_for_status()
        with open(filepath, "wb") as f:
            f.write(response.content)

        # Create mapping record
        record = {
            "id": image_id,
            "filename": filename,
            "url": url
        }

        # Add metadata columns dynamically
        for col in METADATA_COLUMNS:
            record[col] = row.get(col, None)

        records.append(record)
        counter += 1
       
            
    except Exception as e:
        print(f"‚ö†Ô∏è Failed to download {url}: {e}")

# Save mapping file
# Save mapping file
#mapping_df = pd.DataFrame(records)
#mapping_df.to_csv(OUTPUT_MAPPING_FILE, index=False)
if not existing_mapping.empty:
    mapping_df = pd.concat([existing_mapping, pd.DataFrame(records)], ignore_index=True)
else:
    mapping_df = pd.DataFrame(records)

mapping_df.to_csv(
    OUTPUT_MAPPING_FILE,
    index=False,
    quoting=csv.QUOTE_ALL,
    escapechar='\\',
    encoding='utf-8-sig'
)
print(f"\n‚úÖ Done! Downloaded {len(records)} images.")
print(f"üìÅ Images saved in: {OUTPUT_FOLDER}")
print(f"üìÑ Mapping saved in: {OUTPUT_MAPPING_FILE}")


  df = df.applymap(strip_tags)


‚ÑπÔ∏è No existing mapping found. Starting fresh.


 46%|‚ñà‚ñà‚ñà‚ñà‚ñå     | 2472/5402 [14:34<17:12,  2.84it/s]

‚ö†Ô∏è Skipping invalid URL: nan


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5402/5402 [31:07<00:00,  2.89it/s]


‚úÖ Done! Downloaded 5401 images.
üìÅ Images saved in: Dataset\images
üìÑ Mapping saved in: Dataset\images_mapping.csv





Imports and Model Setup

In [1]:
from transformers import CLIPProcessor, CLIPModel, Blip2Processor, Blip2ForConditionalGeneration
from datasets import load_dataset
from PIL import Image
import torch
import matplotlib.pyplot as plt
from torchcam.methods import SmoothGradCAMpp
from torchvision.transforms.functional import to_pil_image
import os


Loading my Dataset
Structure:
fossil_dataset/
  train/
    ammonite/
    trilobite/
    coral/
  val/
    ...


In [None]:
from datasets import load_dataset

dataset = load_dataset("imagefolder", data_dir="fossil_dataset")
dataset = dataset["train"].train_test_split(test_size=0.1)
print(dataset)


Load and fine tune CLIP

In [None]:
model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
model = CLIPModel.from_pretrained(model_id)
processor = CLIPProcessor.from_pretrained(model_id)

def preprocess(examples):
    return processor(
        text=examples["label"],
        images=examples["image"],
        padding="max_length",
        truncation=True,
        return_tensors="pt"
    )

proc_dataset = dataset.map(preprocess, batched=True)

from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir="./clip_fossil_finetuned",
    per_device_train_batch_size=8,
    num_train_epochs=5,
    learning_rate=5e-6,
    fp16=True,
    logging_steps=50,
    save_steps=500
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=proc_dataset["train"],
    eval_dataset=proc_dataset["test"]
)

trainer.train()


Add the prediction fucntion
Usage:
labels = ["ammonite", "trilobite", "coral", "crinoid"]
prediction, confidence = classify_image("example_fossil.jpg", labels)
print(f"Predicted fossil: {prediction} (Confidence: {confidence:.2f})")


In [None]:
def classify_image(image_path, labels):
    image = Image.open(image_path)
    inputs = processor(text=labels, images=image, return_tensors="pt", padding=True)
    outputs = model(**inputs)
    probs = outputs.logits_per_image.softmax(dim=1)
    best_idx = probs.argmax(dim=1).item()
    return labels[best_idx], probs[0, best_idx].item()


Add BLIP2 -> fine tune for domain
usage:
explanation = generate_explanation("example_fossil.jpg", prediction)
print("Explanation:", explanation)



In [None]:
blip_id = "Salesforce/blip2-flan-t5-base"
blip_processor = Blip2Processor.from_pretrained(blip_id)
blip_model = Blip2ForConditionalGeneration.from_pretrained(blip_id)

def generate_explanation(image_path, fossil_label):
    image = Image.open(image_path)
    prompt = f"This is likely a {fossil_label}. Explain which visual features indicate this identification."
    inputs = blip_processor(images=image, text=prompt, return_tensors="pt")
    output = blip_model.generate(**inputs, max_new_tokens=100)
    explanation = blip_processor.tokenizer.decode(output[0], skip_special_tokens=True)
    return explanation


GRAD CAM
usage:
visualize_attention("example_fossil.jpg", text_prompt=prediction)


In [None]:
model.eval()
cam_extractor = SmoothGradCAMpp(model.vision_model)

def visualize_attention(image_path, text_prompt="a fossil"):
    image = Image.open(image_path)
    inputs = processor(images=image, text=[text_prompt], return_tensors="pt")
    outputs = model(**inputs)
    cams = cam_extractor(inputs["pixel_values"])
    cam = cams[0][0]
    heatmap = to_pil_image(cam / cam.max())
    plt.figure(figsize=(6,6))
    plt.imshow(image)
    plt.imshow(heatmap, alpha=0.5, cmap="jet")
    plt.axis("off")
    plt.show()


Full pipeline
fossil_pipeline("example_fossil.jpg", ["ammonite", "trilobite", "coral", "crinoid"])

In [None]:
def fossil_pipeline(image_path, labels):
    label, conf = classify_image(image_path, labels)
    print(f"ü¶¥ Predicted fossil: {label} ({conf:.2f})")
    explanation = generate_explanation(image_path, label)
    print("\nüí¨ Explanation:", explanation)
    visualize_attention(image_path, text_prompt=label)


In [None]:

#Imports and Setup
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
from datasets import load_dataset
from transformers import CLIPProcessor, CLIPModel
from torchcam.methods import SmoothGradCAMpp
from torchcam.utils import overlay_mask
import matplotlib.pyplot as plt
import pandas as pd
import os

# Check device
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

# ===============================
# Step 1: Load Your Dataset
# ===============================

# Replace this path with your dataset CSV
dataset_csv = "fossil_dataset/annotations.csv"
dataset_dir = "fossil_dataset/images/"

# Load CSV with Hugging Face datasets
dataset = load_dataset("csv", data_files=dataset_csv)

# Example row
print(dataset['train'][0])

# ===============================
# Step 2: Load Pretrained CLIP
# ===============================

model_name = "openai/clip-vit-base-patch32"
processor = CLIPProcessor.from_pretrained(model_name)
model = CLIPModel.from_pretrained(model_name).to(device)

# ===============================
# Step 3: Preprocessing Function
# ===============================

def preprocess(example):
    image_path = os.path.join(dataset_dir, example['image'])
    image = Image.open(image_path).convert("RGB")
    inputs = processor(text=example['caption'], images=image, return_tensors="pt", padding=True)
    return inputs

# Test preprocessing
sample_inputs = preprocess(dataset['train'][0])
print(sample_inputs.keys())

# ===============================
# Step 4: Fine-Tuning Setup
# ===============================

# Simple PyTorch DataLoader
class FossilDataset(torch.utils.data.Dataset):
    def __init__(self, hf_dataset):
        self.dataset = hf_dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        row = self.dataset[idx]
        image_path = os.path.join(dataset_dir, row['image'])
        image = Image.open(image_path).convert("RGB")
        caption = row['caption']
        return image, caption

train_dataset = FossilDataset(dataset['train'])
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=5e-6)

# ===============================
# Step 5: Fine-Tuning Loop (Simplified)
# ===============================

for epoch in range(1):  # adjust epochs
    for images, captions in train_loader:
        # Prepare batch
        inputs = processor(text=list(captions), images=list(images), return_tensors="pt", padding=True).to(device)
        outputs = model(**inputs)
        
        # Simple CLIP contrastive loss
        logits_per_image = outputs.logits_per_image
        logits_per_text = outputs.logits_per_text
        loss = ((logits_per_image - logits_per_text.T)**2).mean()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    print(f"Epoch {epoch+1} done. Loss: {loss.item():.4f}")

# ===============================
# Step 6: Visual Explainability
# ===============================

# Initialize Grad-CAM
cam_extractor = SmoothGradCAMpp(model.visual)

def show_gradcam(image_path):
    image = Image.open(image_path).convert("RGB")
    inputs = processor(images=image, return_tensors="pt").to(device)
    
    outputs = model.get_image_features(**inputs)
    target_class = outputs.argmax(dim=-1).item()
    
    activation_map = cam_extractor(outputs[0].unsqueeze(0), target_class)
    
    # Overlay
    result = overlay_mask(transforms.ToTensor()(image), transforms.ToPILImage()(activation_map[0].cpu()), alpha=0.5)
    plt.imshow(result)
    plt.axis('off')
    plt.show()

# Example
show_gradcam(os.path.join(dataset_dir, dataset['train'][0]['image']))

# ===============================
# Step 7: Textual Explanation
# ===============================

def explain_text(image_path, candidate_captions):
    image = Image.open(image_path).convert("RGB")
    inputs = processor(text=candidate_captions, images=[image]*len(candidate_captions), return_tensors="pt", padding=True).to(device)
    outputs = model(**inputs)
    
    # Similarity scores
    logits = outputs.logits_per_image
    for caption, score in zip(candidate_captions, logits[0]):
        print(f"{caption}: {score.item():.4f}")

# Example
candidate_texts = [
    "Trilobite fragment",
    "Ammonite shell",
    "Unknown fossil fragment"
]
explain_text(os.path.join(dataset_dir, dataset['train'][0]['image']), candidate_texts)
