In [2]:
import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Dataset
import timm
import numpy as np
import pandas as pd
import os
import sys
import faiss
from tqdm import tqdm
from collections import Counter
from torch.utils.data import Subset
from collections import defaultdict
from sklearn.metrics import silhouette_score
import csv
from pathlib import Path

from matplotlib import pyplot as plt
import random
from PIL import Image


In [17]:
N = 3   # N x N tiling
TRT_SIZE = 224   # 224 x 224 image size for ResNet50 input

project_root = os.path.abspath('.')
if project_root not in sys.path:
    sys.path.append(project_root) 

training_data_path = os.path.join(project_root, "PlantCLEF2025_data/images_max_side_800")
inference_data_path = os.path.join(project_root, "PlantCLEF2025_data/test_images/images")

transform = transforms.Compose([
        transforms.Resize((TRT_SIZE,TRT_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=(0.485, 0.456, 0.406),
            std=(0.229, 0.224, 0.225))
        ])

def tile_image_nxn(img, tiles_per_side=3, target_size=224):
    w, h = img.size

    tile_w = w / tiles_per_side
    tile_h = h / tiles_per_side

    tiles = []

    for row in range(tiles_per_side):
        for col in range(tiles_per_side):
            left = int(col * tile_w)
            top = int(row * tile_h)
            right = int((col + 1) * tile_w)
            bottom = int((row + 1) * tile_h)

            tile = img.crop((left, top, right, bottom))
            tile = tile.resize((target_size, target_size), Image.BICUBIC)
            tiles.append(tile)

    return tiles


class QuadratNxNDataset():
    def __init__(self, datafolder, transform, tiles_per_side=3, target_size=224):
        self.paths = sorted([str(p) for p in Path(datafolder).glob("*.*")])  # list of absolut paths for all images in the folder
        self.transform = transform
        self.tiles_per_side = tiles_per_side
        self.target_size = target_size

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        path = self.paths[idx]
        img = Image.open(path).convert("RGB")

        tiles = tile_image_nxn(
            img,
            tiles_per_side=self.tiles_per_side,
            target_size=self.target_size
        )

        tiles = [self.transform(t) for t in tiles]
        tiles = torch.stack(tiles)  # shape [N*N, 3, 224, 224]

        return tiles, path


In [None]:
# Check the tile_image_nxn function
# img = Image.open('2024-CEV3-20240602.jpg')
# plt.imshow(img)

# tiles = tile_image_nxn(img, tiles_per_side=3, target_size=518)
# fig, axes = plt.subplots(3, 3)
# axes = axes.flatten()
# for i, tile in enumerate(tiles):
#     axes[i].imshow(tile)

In [18]:
# Preprocess the test/quadrat data
inference_loader = DataLoader(
        QuadratNxNDataset(
            inference_data_path,
            transform,
            tiles_per_side=N,
            target_size=TRT_SIZE
        ),
        batch_size=1,  # since we are processing 9 tiles for each sample/quadrat image (1 image = 1 batch)
        shuffle=False,   # critical for inference to keep tiles in order
        num_workers=4
    )

In [23]:
# Check batch shape
data_batch, labels_batch = next(iter(inference_loader))

print(f"Batch data shape: {data_batch.shape}")   # (Batch_size x num_Tiles x num_Channels x Width x Height)
print(f"Batch label - path for quadrat image: {labels_batch}")   # (1 image = 1 batch)

Batch data shape: torch.Size([1, 9, 3, 224, 224])
Batch label - path for quadrat image: ('/sfs/weka/scratch/hl9h/PlantCLEF2025_data/test_images/images/2024-CEV3-20240602.jpg',)


In [None]:
# Get the labels and indices form the single-plant training data set
print(f"Loading data from: {training_data_path}")
    try:
        # Create the full training dataset
        full_training = datasets.ImageFolder(
            training_data_path,
            transform=transform # Use validation transform for initial loading
        )

        # Check if dataset is empty
        if not full_dataset.samples:
            print(f"ERROR: No images found in {self.data_dir}.")
            return

        classes = full_training.classes   # list of all class names (sorted alphabetically)
        mapping_dict = full_training.class_to_idx   # the mapping dictionary from class name to index
        
    except Exception as e:
        print(f"An unexpected error occurred loading data: {e}")
        return

# Reverse Mapping (Index to Class Name)
def idx_to_class(idx, mapping_dict)
    idx_to_class = {v: k for k, v in mapping_dict.items()}
    return idx_to_class[idx]

# Get top K indices from an array
def get_topk(k=3, array):
    topk_idx = array.argsort()[-k:][::-1]

    

In [8]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load the fine tuned ResNet50 parameters
from resnet50 import resnet50
resnet50_finetuned = resnet50.get_resnet50_pretrained(num_classes=NUM_CLASS, fine_tune=True)
state_dict = torch.load('resnet50/resnet50_finetuned_plantCLEF_ep30.pth', weights_only=True)  # weights_only=True is recommended for security
resnet50_finetuned.load_state_dict(state_dict)

resnet50_finetuned.eval().to(device)
for p in resnet50_finetuned.parameters():
    p.requires_grad = False
for batch, path in tqdm(inference_loader, desc="Testing"):  # each batch is one quadrat image
    path = path[0]
    batch = batch.to(device)
    logits = model(batch)
    





        imgs = imgs.to(device)
        feats = model.forward_features(imgs)
        if feats.ndim == 3:
            cls_embs = feats[:, 0, :]  # [CLS] token
        else:
            cls_embs = feats
        all_embs.append(cls_embs.cpu().numpy())
        all_labels.append(labels.numpy())
    return np.concatenate(all_embs), np.concatenate(all_labels)


def aggregate_predictions(tile_preds, tile_paths, tiles_per_quadrat=9, min_votes=1):
    quadrat_to_species = defaultdict(list)
    for i in range(0, len(tile_preds), tiles_per_quadrat):
        preds_for_quadrat = tile_preds[i:i+tiles_per_quadrat]
        path = tile_paths[i]
        quadrat_id = Path(path).stem
        # Flatten tile predictions: each tile has k neighbors
        flat_preds = preds_for_quadrat.flatten()
        # Voting
        counter = Counter(flat_preds)
        species = [sp for sp, count in counter.items() if count >= min_votes]
        quadrat_to_species[quadrat_id] = species
    return quadrat_to_species

def write_submission(pred_dict, out_csv="submission.csv"):
    with open(out_csv, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["quadrat_id", "species_ids"])
        for quad, species in pred_dict.items():
            s = "[" + ", ".join(str(x) for x in species) + "]"
            writer.writerow([quad, s])


Initializing inference dataset for a 3x3 grid.


NameError: name 'glob' is not defined

In [None]:


    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = timm.create_model("timm/vit_base_patch14_reg4_dinov2.lvd142m", pretrained=True)
    #model = timm.create_model("vit_base_patch16_224", pretrained=True)
    checkpoint_path = '/home/jme3qd/Downloads/model_best.pth.tar'
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) # Load to CPU first

    if 'state_dict' in checkpoint:
        state_dict = checkpoint['state_dict']
    else:
        state_dict = checkpoint # Assume the checkpoint itself is the state_dict

    # 3. Load the state dictionary into the model
    model.load_state_dict(state_dict,strict=False)

    model.eval().to(device)
    for p in model.parameters():
        p.requires_grad = False

