In [2]:
backbone = timm.create_model('hf-hub:BVRA/MegaDescriptor-T-224', num_classes=0, pretrained=True)
for name, param in backbone.named_parameters():
    print(f"{name}: requires_grad = {param.requires_grad}")

patch_embed.proj.weight: requires_grad = True
patch_embed.proj.bias: requires_grad = True
patch_embed.norm.weight: requires_grad = True
patch_embed.norm.bias: requires_grad = True
layers.0.blocks.0.norm1.weight: requires_grad = True
layers.0.blocks.0.norm1.bias: requires_grad = True
layers.0.blocks.0.attn.relative_position_bias_table: requires_grad = True
layers.0.blocks.0.attn.qkv.weight: requires_grad = True
layers.0.blocks.0.attn.qkv.bias: requires_grad = True
layers.0.blocks.0.attn.proj.weight: requires_grad = True
layers.0.blocks.0.attn.proj.bias: requires_grad = True
layers.0.blocks.0.norm2.weight: requires_grad = True
layers.0.blocks.0.norm2.bias: requires_grad = True
layers.0.blocks.0.mlp.fc1.weight: requires_grad = True
layers.0.blocks.0.mlp.fc1.bias: requires_grad = True
layers.0.blocks.0.mlp.fc2.weight: requires_grad = True
layers.0.blocks.0.mlp.fc2.bias: requires_grad = True
layers.0.blocks.1.norm1.weight: requires_grad = True
layers.0.blocks.1.norm1.bias: requires_grad = T

In [8]:
import sys
sys.path.append('..')
import argparse
import shutil
import os
import yaml
import timm
import torch
import torch.nn as nn
from pytorch_lightning import LightningModule
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_lightning import Trainer
import numpy as np
from PIL import Image
import wandb
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

In [4]:
from wildlife_tools.features import DeepFeatures

backbone = timm.create_model('hf-hub:BVRA/MegaDescriptor-T-224', num_classes=0, pretrained=True)


In [5]:
backbone = backbone.eval()

# Generate a dummy batch of images with a shape of (batch_size, channels, height, width)
# Typical input size for Swin Transformer models might be (1, 3, 224, 224)
dummy_batch = torch.randn(1, 3, 224, 224)  # Change 224 to your model’s input size if different

with torch.no_grad():
    output = backbone(dummy_batch)

# Output shape will be (1, dim_embedding)
print("Embedding size:", output.shape[1])

Embedding size: 768


In [6]:
from models.triplet_loss_model import TripletModel

from utils.re_ranking import re_ranking
from data.data_utils import calculate_num_channels
from utils.metrics import compute_distance_matrix
from utils.metrics import evaluate_map, compute_average_precision

# model = TripletModel(backbone_model_name=backbone)

# Embedder (to project features into the desired embedding space)
# embedder = nn.Linear(backbone.feature_info[-1]["num_chs"], 768)

re_ranking = True
distance_matrix = 'euclidean'
query_embeddings = []
query_labels = []
gallery_embeddings = []
gallery_labels = []

def validation_step(batch, batch_idx, dataloader_idx=0):
    x, target = batch
    embeddings = backbone(x)
    if dataloader_idx == 0:
        # Query data
        query_embeddings.append(embeddings)
        query_labels.append(target)
    else:
        # Gallery data
        gallery_embeddings.append(embeddings)
        gallery_labels.append(target)

def on_validation_epoch_end():
    # Concatenate all embeddings and labels
    query_embeddings = torch.cat(query_embeddings)
    query_labels = torch.cat(query_labels)
    gallery_embeddings = torch.cat(gallery_embeddings)
    gallery_labels = torch.cat(gallery_labels)

    # Compute distance matrix
    if re_ranking:
        distmat = re_ranking(query_embeddings, gallery_embeddings, k1=20, k2=6, lambda_value=0.3)
    else:
        distmat = compute_distance_matrix(distance_matrix, query_embeddings, gallery_embeddings, wildlife=True)

    # Compute mAP
    # mAP = torchreid.metrics.evaluate_rank(distmat, query_labels.cpu().numpy(), gallery_labels.cpu().numpy(), use_cython=False)[0]['mAP']
    mAP1 = evaluate_map(distmat, query_labels, gallery_labels, top_k=1)
    mAP5 = evaluate_map(distmat, query_labels, gallery_labels, top_k=5)
    print(mAP5)




  @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)


In [37]:
import wandb
import timm
import torch
import torch.nn as nn

from utils.triplet_loss_utils import TripletLoss
from utils.optimizer import get_optimizer, get_lr_scheduler_config
from utils.weights_initializer import weights_init_kaiming, weights_init_classifier

import torch
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_metric_learning import losses, miners
from torch import nn

from wildlife_tools.similarity.cosine import CosineSimilarity
from utils.metrics import evaluate_map, compute_average_precision

from utils.re_ranking import re_ranking
from data.data_utils import calculate_num_channels
from utils.metrics import compute_distance_matrix
from utils.triplet_loss_utils import KnnClassifier


class SimpleModel(pl.LightningModule):
    def __init__(self, 
                 backbone_model_name="resnet50", 
                 config=None, pretrained=True, 
                 embedding_size=768, margin=0.2, 
                 mining_type="semihard", 
                 lr=0.001, 
                 preprocess_lvl=0, 
                 re_ranking=True, 
                 outdir="results"):
        super().__init__()
        self.config = config
        self.re_ranking = re_ranking
        self.distance_matrix = 'cosine'
            
        # Backbone (ResNet without the final FC layer)
        self.backbone = timm.create_model(model_name=backbone_model_name, pretrained=pretrained, num_classes=0)

        self.embedder = nn.Linear(self.backbone.feature_info[-1]["num_chs"], embedding_size)

    # Can experiment with different embedders or need to adjust the embedding layer frequently.
    def forward(self, x):
        features = self.backbone(x) # Extract features using the backbone
        return features
        embeddings = self.embedder(features) # Project features into the embedding space
        embeddings = nn.functional.normalize(embeddings, p=2, dim=1)  # L2 normalization
        return embeddings

    def training_step(self, batch, batch_idx):
        images, labels = batch
        embeddings = self(images)
        mined_triplets = self.miner(embeddings, labels)
        loss = self.loss_fn(embeddings, labels, mined_triplets)
        self.log("train/loss", loss,  on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def on_validation_epoch_start(self):
        self.query_embeddings = []
        self.query_labels = []
        self.gallery_embeddings = []
        self.gallery_labels = []

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        x, target = batch
        embeddings = self(x)
        if dataloader_idx == 0:
            # Query data
            self.query_embeddings.append(embeddings)
            self.query_labels.append(target)
        else:
            # Gallery data
            self.gallery_embeddings.append(embeddings)
            self.gallery_labels.append(target)

    def on_validation_epoch_end(self):
        # Concatenate all embeddings and labels
        query_embeddings = torch.cat(self.query_embeddings)
        query_labels = torch.cat(self.query_labels)
        gallery_embeddings = torch.cat(self.gallery_embeddings)
        gallery_labels = torch.cat(self.gallery_labels)

        print(f"size of gallery embeddings: {gallery_embeddings.size()}")
        print(f"size of gallery labels: {gallery_labels.size()}")
        print(f"size of query embeddings: {query_embeddings.size()}")
        print(f"size of query labels: {query_labels.size()}")

        # Compute distance matrix
        if self.re_ranking:
            distmat = re_ranking(query_embeddings, gallery_embeddings, k1=20, k2=6, lambda_value=0.3)
        else:
            distmat = compute_distance_matrix(self.distance_matrix, query_embeddings, gallery_embeddings, wildlife=True)

        # Compute mAP
        # mAP = torchreid.metrics.evaluate_rank(distmat, query_labels.cpu().numpy(), gallery_labels.cpu().numpy(), use_cython=False)[0]['mAP']
        mAP1 = evaluate_map(distmat, query_labels, gallery_labels, top_k=1)
        mAP5 = evaluate_map(distmat, query_labels, gallery_labels, top_k=5)
        self.log('val/mAP1', mAP1)
        self.log('val/mAP5', mAP5)

        similarity_function = CosineSimilarity()
        similarity = similarity_function(query_embeddings, gallery_embeddings)["cosine"]
        print("Similarity matrix: \n", similarity.shape)
        print(similarity)

        # Convert gallery_labels to numpy if necessary
        gallery_labels = gallery_labels.cpu().numpy() if isinstance(gallery_labels, torch.Tensor) else gallery_labels
        print("Gallery labels: \n", gallery_labels)
        print("Query labels: \n", query_labels)

        query_map = [i for i in range(len(query_labels))]

        # Nearest neighbor classifier using KNN with k=1
        classifier = KnnClassifier(k=1)
        preds = classifier(similarity)
        print(f"preds: {preds}")

        # Calculate accuracy
        accuracy = (preds == query_labels.cpu().numpy()).mean()
        print(f"accuracy: {accuracy}")
        self.log('val/accuracy', accuracy)

        # # Calculate Recall@K (choose K=5)
        # K = 5
        # top_k_preds = np.argsort(-similarity, axis=1)[:, :K]  # Get top K indices for each query
        # recall_at_k = 0
        # for i, query_label in enumerate(query_labels):
        #     top_k_labels = gallery_labels[top_k_preds[i]]  # Get top K labels for the query
        #     if query_label in top_k_labels:
        #         recall_at_k += 1

        # # Log Recall@K
        # recall_at_k /= len(query_labels)
        # self.log(f'val/Recall@{K}', recall_at_k)


In [38]:
model = SimpleModel(backbone_model_name='hf-hub:BVRA/MegaDescriptor-T-224',
                    re_ranking=False)

In [39]:
from wildlife_datasets import analysis, datasets, loader
from data.raptors_wildlife import Raptors, WildlifeReidDataModule

root = '/Users/amee/Documents/code/master-thesis/datasets/EDA-whaleshark/'

dataset = datasets.WhaleSharkID(root)
data = WildlifeReidDataModule(data_dir=root, 
                              metadata=dataset.df, 
                              cache_path='/Users/amee/Documents/code/master-thesis/EagleID/dataset/dataframe/cache_whaleshark.csv', 
                              size = 224, 
                              preprocess_lvl=1, 
                              only_cache=True, 
                              batch_size=4)

Split: time-unaware closed-set
Samples: train/test/unassigned/total = 6108/1585/0/7693
Classes: train/test/unassigned/total = 543/512/0/543
Samples: train only/test only        = 31/0
Classes: train only/test only/joint  = 31/0/512

Fraction of train set     = 79.40%
Fraction of test set only = 0.00%
Train set size: 6108
Test set size: 1585


  ckpt = torch.load(file, map_location="cpu")
  ckpt = torch.load(file, map_location="cpu")


No segmentation data found: Empty list or list containing empty list.
Removed 1 rows with invalid segmentation data.
Removed 0 rows with invalid segmentation data.
Training Set
Length: 2090
Number of individuals: 315
Mean images/individual: 6.634920634920635
Min images/individual: 2
Max images/individual: 50
Test Set
Length: 584
Number of individuals: 295
Mean images per individual: 1.9796610169491526
Min images per individual: 1
Max images per individual: 18
length of query dataset: 125
length of gallery dataset: 459


In [40]:
trainer = Trainer(accelerator="cpu")
trainer.validate(model, dataloaders=data.val_dataloader())

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/amee/miniconda3/envs/pytorch_env/lib/python3.12/site-packages/pytorch_lightning/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
/Users/amee/miniconda3/envs/pytorch_env/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:419: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.


Validation: |          | 0/? [00:00<?, ?it/s]

size of gallery embeddings: torch.Size([459, 768])
size of gallery labels: torch.Size([459])
size of query embeddings: torch.Size([125, 768])
size of query labels: torch.Size([125])
Distance matrix type should be np for rerankin: <class 'numpy.ndarray'>
Similarity matrix: 
 (125, 459)
[[    0.48836     0.41789     0.71838 ...     0.37283     0.66797     0.28333]
 [    0.18126     0.26546     0.77685 ...     0.12444     0.73263     0.11917]
 [    0.55198     0.54416     0.53006 ...     0.43508     0.45751     0.22246]
 ...
 [    0.14237     0.16995     0.72094 ...     0.18472     0.56607      0.1169]
 [    0.74552     0.48677     0.28043 ...     0.49063     0.33863     0.27755]
 [    0.61553     0.39169     0.47036 ...     0.72542      0.3144     0.19776]]
Gallery labels: 
 [  0   1   2   2   3   4   5   6   7   8   8   9  10  11  11  11  12  12  13  14  15  16  17  18  19  20  21  21  22  23  24  24  24  24  25  26  26  26  26  26  26  26  27  28  29  30  31  32  32  33  34  35  36  37

  a, b = torch.tensor(a), torch.tensor(b)
  results = pd.DataFrame(results).T.fillna(method="ffill").T


[{'val/mAP1': 0.0, 'val/mAP5': 0.0, 'val/accuracy': 0.00800000037997961},
 {'val/mAP1': 0.0, 'val/mAP5': 0.0, 'val/accuracy': 0.00800000037997961}]

In [20]:
from utils.triplet_loss_utils import KnnClassifier
from wildlife_tools.features import DeepFeatures

query_loader, gallery_loader = data.val_dataloader()
gallery_dataset = gallery_loader.dataset
query_dataset = query_loader.dataset

backbone = timm.create_model('hf-hub:BVRA/MegaDescriptor-T-224', num_classes=0, pretrained=True)
extractor = DeepFeatures(backbone)
query, database = extractor(query_dataset), extractor(gallery_dataset)

print(f'Query features shape: {query.shape}, Database features shape: {database.shape}')
# Cosine similarity between deep features
similarity_function = CosineSimilarity()
similarity = similarity_function(query, database)['cosine']
print("Similarity matrix: \n", similarity.shape)

### Debug: Check indices and sizes
print(f"Database labels map size: {len(gallery_dataset.labels_map)}")
print(f"Sample indices from similarity: {np.argmax(similarity, axis=1)[:10]}")

# Nearest neigbour classifier using the similarity
classifier = KnnClassifier(k=1, database_labels=gallery_dataset.labels)
print(f"size of gallery dataset labels: {len(gallery_dataset.labels)}")
preds = classifier(similarity)
preds = gallery_dataset.labels_map[preds]
print("Prediction \t", preds)
print("Ground truth \t", query_dataset.labels_string)

acc = sum(preds == query_dataset.labels_string) / len(query_dataset.labels_string)
print('\n Accuracy: ', acc)

100%|█████████████████████████████████████████████████████████████████| 1/1 [00:22<00:00, 22.07s/it]
100%|█████████████████████████████████████████████████████████████████| 4/4 [00:46<00:00, 11.50s/it]

Query features shape: (125, 768), Database features shape: (459, 768)
Similarity matrix: 
 (125, 459)
Database labels size: 295
Sample indices from similarity: [156 289 395 281 351 330  16   7  20  21]
size of gallery dataset labels: 459
125
459
Prediction 	 ['4507ee90-84b6-fecb-cd98-2fbddd8707fb' '9cc4a537-c0a2-3d06-3780-6b327d292303' 'de473d56-8c61-13cb-6684-933726b40a3d' '96583bb9-245a-e067-da52-deca5b665631' 'bb27ae2e-1c8a-9a2b-e28f-eb8d05360bda' 'b2bb892c-9e2d-0ceb-a7a6-d81c90b937cf' '0d0160ac-3076-83fc-76ca-7bf0189d86e1' '0749736f-ca90-ed15-a460-1488f8ad9522'
 '0f46e82c-b6a7-5819-9852-003f6861922c' '0fca832e-87ea-029e-4a39-0c1ae5a49416' '0e35135a-e5a5-ac85-1867-78d694cc3d86' '535ee47e-9a8d-156b-439e-00ceb0bc08ec' '95d65d3c-d851-6fdf-2d80-b91ef84459cf' 'a9429b5d-fe9f-08f9-3f79-c6497a00fcfc' 'b1a459de-6869-191f-3650-c345bd1552d4' '77ec3c66-4342-b777-89fd-79f15a5f0c6b'
 'cf744cd1-16d2-45c6-bd09-7149f9143159' '1cf57545-7a91-ab69-4ecc-5ef84fadd959' 'bc57a667-372a-587e-7c48-a3036c1bc7e


  results = pd.DataFrame(results).T.fillna(method="ffill").T


[{'val/mAP1': 0.0, 'val/mAP5': 0.004666666500270367},
 {'val/mAP1': 0.0, 'val/mAP5': 0.004666666500270367}]

In [None]:
config_file_path = yaml.safe_load(args.config)
with open(config_file_path, 'r') as config_file:
    config = yaml.safe_load(config_file)

data = ArtportalenDataModule(data_dir=config['dataset'], preprocess_lvl=config['preprocess_lvl'], batch_size=config['batch_size'], size=config['img_size'], mean=config['transforms']['mean'], std=config['transforms']['std'])
data.prepare_testing_data(config['dataset'])
dataloader = data.test_dataloader()

model = SimpleModel(config=config, pretrained=False, num_classes=data.num_classes)
if args.gpu:
    checkpoint = torch.load(config['checkpoint'])
else:
    checkpoint = torch.load(config['checkpoint'], map_location=torch.device('cpu'), weights_only=True)
model.load_state_dict(checkpoint["state_dict"])
model.to(torch.device('cpu'))


trainer = Trainer(accelerator="cpu")
# trainer.fit(model, data)
trainer.test(model, dataloaders=dataloader, ckpt_path=config['checkpoint'])
trainer.validate(model, dataloaders=data.val_dataloader())
