In [2]:
from src.train import SphereClassifier, WhaleDataModule
from src.dataset import load_df
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
from pathlib import Path
import pandas as pd

from config.config import Config, load_config

# cuda = torch.device("cuda:0")

## Load model

In [3]:
# model = SphereClassifier.load_from_checkpoint(
#     checkpoint_path="/app/sandbox/happy_whale/kaggle-happywhale-1st-place/result/b6_bottleneck_feature_fix_nb/1/last-v4.ckpt"
# )
# model.to(cuda)
# model.eval()


# image = torch.rand(1, 3, 528, 528).to(cuda)
# logits_ids, logits_species = model(image)

In [4]:
# from torchviz import make_dot

# image = torch.rand(1, 3, 528, 528).to(cuda)
# yhat = model(image)
# make_dot(yhat, params=dict(list(model.named_parameters()))).render(
#     "b6_torchviz", "b6.png"
# )

## Compute embs on train ds

In [6]:
cfg = load_config("config/efficientnet_b6_new.yaml", "config/default.yaml");

used default config lr_backbone: 0.0016
used default config lr_head: 0.016
used default config lr_decay_scale: 0.01
used default config num_classes: 15587
used default config num_species_classes: 26
used default config pretrained: True
used default config val_bbox: fullbody
used default config test_bboxes: ['fullbody', 'fullbody_charm']
used default config bboxes: {'fullbody_charm': 0.15, 'fullbody': 0.6, 'backfin': 0.15, 'detic': 0.05, 'none': 0.05}
used default config bbox_conf_threshold: 0.01
used default config n_data: -1
used default config global_pool: {'arch': 'GeM', 'p': 3, 'train': False}
used default config normalization: batchnorm
used default config optimizer: AdamW
used default config loss_fn: CrossEntropy
used default config loss_id_ratio: 0.437338
used default config margin_coef_id: 0.27126
used default config margin_coef_species: 0.226253
used default config margin_power_id: -0.364399
used default config margin_power_species: -0.720133
used default config s_id: 20.9588


In [7]:
df = load_df("input", cfg, "train.csv", True)
data_module = WhaleDataModule(
    df,
    cfg,
    f"input/train_images",
    cfg.val_bbox,
    -1,
)

detic low conf: 0 / 51033
fullbody low conf: 0 / 51033
fullbody_charm low conf: 10 / 51033
backfin low conf: 1587 / 51033


In [9]:
train_dataset = data_module.get_dataset(df, False)
train_loader = DataLoader(
    train_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=32,
    pin_memory=True,
    drop_last=False,
)

In [10]:
# predictions = []
# model.eval()

# for batch in tqdm(train_loader):
#     images = batch["image"].to(cuda)
#     out = model(images)
#     bottleneck_feat = model.get_bottleneck_feature(images)
#     feats = model.backbone_head_bn(model.backbone_head(bottleneck_feat))
#     feats = F.normalize(feats, p=2.0, dim=1)
#     predictions.append(feats.detach().cpu())
#     break
# embs = torch.cat(predictions, axis=0).numpy()
# np.savez(f"whale_train_emb.npz", embs=embs)
# out[0].max(), out[1].sort()

## Create Whale Train OSFR protocol

In [12]:
# check scf embs
# a = np.load("whale_train_emb.npz")
# b = np.load("/app/cache/features/scf_embs_whale.npz")
# a["embs"], b["embs"]

## Create whale validation and test sets

In [13]:
validation_set_id_num = 5000
rng = np.random.default_rng(1)

unique_ids, count_ids = np.unique(train_dataset.ids, return_counts=True)
validation_id_idx = rng.choice(
    unique_ids.shape[0], validation_set_id_num, replace=False
)
test_id_idx = np.array(
    sorted(list(set(range(unique_ids.shape[0])) - set(validation_id_idx)))
)
assert validation_id_idx.shape[0] + test_id_idx.shape[0] == unique_ids.shape[0]
assert set(validation_id_idx).intersection(test_id_idx) == set()

validation_unique_ids, validation_count_ids = (
    unique_ids[validation_id_idx],
    count_ids[validation_id_idx],
)
test_unique_ids, test_count_ids = unique_ids[test_id_idx], count_ids[test_id_idx]

In [43]:
from tqdm import tqdm


def create_whale_dataset(train_dataset, unique_ids, count_ids, full_ds_embs, ds_name):

    out_of_gallery_ids = unique_ids[count_ids == 1]  # single image ids
    in_gallery_ids = unique_ids[count_ids > 1]
    print(out_of_gallery_ids.shape, in_gallery_ids.shape)

    # construct gallery and probe temlates
    image_path_to_template_id = {}
    image_path_to_subject_id = {}
    image_path_to_emb = {}
    image_path_to_unc = {}

    # select embeddings
    img_names = []  # train_dataset.x_paths
    for subject in tqdm(unique_ids):
        local_idx = train_dataset.ids == subject
        subject_images_paths = train_dataset.x_paths[local_idx]
        subject_embs = full_ds_embs["embs"][local_idx]
        subject_unc = full_ds_embs["unc"][local_idx]
        for image_path, emb, unc in zip(
            subject_images_paths, subject_embs, subject_unc
        ):
            img_names.append(image_path)
            image_path_to_emb[image_path] = emb[np.newaxis, :]
            image_path_to_unc[image_path] = unc[np.newaxis, :]

    gallery_templates = []
    known_probe_templates = []

    subject_id = 0
    gallery_template_id = 0
    probe_template_id = 10000
    for subject in tqdm(in_gallery_ids):
        subject_images_paths = train_dataset.x_paths[train_dataset.ids == subject]

        image_count = len(subject_images_paths)
        for i, image_path in enumerate(subject_images_paths):
            image_path_to_subject_id[image_path] = subject_id
            if i < image_count // 2:
                image_path_to_template_id[image_path] = gallery_template_id
            if i >= image_count // 2:
                image_path_to_template_id[image_path] = probe_template_id

        gallery_templates.append(
            (subject_images_paths[: image_count // 2], gallery_template_id, subject_id)
        )
        known_probe_templates.append(
            (subject_images_paths[image_count // 2 :], probe_template_id, subject_id)
        )
        gallery_template_id += 1
        probe_template_id += 1
        subject_id += 1

    assert gallery_template_id < 10000
    unknown_probe_templates = []

    for probe_subject in tqdm(out_of_gallery_ids):
        probe_images_paths = train_dataset.x_paths[train_dataset.ids == probe_subject]
        for image_path in probe_images_paths:
            image_path = str(image_path)
            image_path_to_subject_id[image_path] = subject_id
            image_path_to_template_id[image_path] = probe_template_id
        unknown_probe_templates.append(
            (probe_images_paths, probe_template_id, subject_id)
        )
        probe_template_id += 1
        subject_id += 1

    # assert len(image_path_to_template_id) == len(train_dataset.x_paths)
    # assert len(image_path_to_subject_id) == len(train_dataset.x_paths)
    assert len(set(image_path_to_subject_id.values())) == len(unique_ids)
    assert len(set(image_path_to_template_id.values())) == len(unique_ids) + len(
        in_gallery_ids
    )
    print(
        len(gallery_templates), len(known_probe_templates), len(unknown_probe_templates)
    )
    print(
        len(unknown_probe_templates)
        / (len(known_probe_templates) + len(unknown_probe_templates))
    )
    print(len(known_probe_templates) + len(unknown_probe_templates))

    # create meta files
    # tid mid
    identification_ds_path = Path(f"/app/datasets/{ds_name}")
    identification_ds_path.mkdir(exist_ok=True)
    meta_path = identification_ds_path / "meta"
    embeddings_path = identification_ds_path / "embeddings"
    embeddings_path.mkdir(exist_ok=True)
    meta_path.mkdir(exist_ok=True)

    mids = np.arange(len(img_names))
    tids = []
    sids = []

    for image_path in img_names:
        tids.append(image_path_to_template_id[image_path])
        sids.append(image_path_to_subject_id[image_path])

    out_file_tid_mid = meta_path / Path(f"{ds_name}_face_tid_mid.txt")
    with open(out_file_tid_mid, "w") as fd:
        for name, tid, sid, mid in zip(img_names, tids, sids, mids):
            fd.write(f"{name} {tid} {mid} {sid}\n")

    out_file_probe = meta_path / Path(f"{ds_name}_1N_probe_mixed.csv")
    out_file_gallery = meta_path / Path(f"{ds_name}_1N_gallery_G1.csv")

    tids_probe = []
    sids_probe = []
    names_probe = []
    for probe_meta in known_probe_templates + unknown_probe_templates:
        tids_probe.extend([probe_meta[1]] * len(probe_meta[0]))
        sids_probe.extend([probe_meta[2]] * len(probe_meta[0]))
        names_probe.extend([x.split("/")[-1] for x in probe_meta[0]])

    tids_gallery = []
    sids_gallery = []
    names_gallery = []

    for gallery_meta in gallery_templates:
        tids_gallery.extend([gallery_meta[1]] * len(gallery_meta[0]))
        sids_gallery.extend([gallery_meta[2]] * len(gallery_meta[0]))
        names_gallery.extend([x.split("/")[-1] for x in gallery_meta[0]])

    assert len(tids_gallery) + len(tids_probe) == len(img_names)
    probe = pd.DataFrame(
        {
            "TEMPLATE_ID": tids_probe,
            "SUBJECT_ID": sids_probe,
            "FILENAME": names_probe,
        }
    )
    gallery = pd.DataFrame(
        {
            "TEMPLATE_ID": tids_gallery,
            "SUBJECT_ID": sids_gallery,
            "FILENAME": names_gallery,
        }
    )

    probe.to_csv(out_file_probe, sep=",", index=False)
    gallery.to_csv(out_file_gallery, sep=",", index=False)

    # save embedding
    embs = []
    uncs = []
    for image_name in image_path_to_emb.keys():
        embs.append(image_path_to_emb[image_name])
        uncs.append(image_path_to_unc[image_name])
    embs = np.concatenate(embs, axis=0)
    uncs = np.concatenate(uncs, axis=0)
    print(embs.shape, uncs.shape)
    np.savez(embeddings_path / f"scf_embs_{ds_name}.npz", embs=embs, unc=uncs)

In [48]:
full_ds_embs = np.load("/app/cache/features/scf_embs_whale.npz")
create_whale_dataset(
    train_dataset,
    validation_unique_ids,
    validation_count_ids,
    full_ds_embs,
    "whale_val",
)

  0%|          | 2/5000 [00:00<04:54, 16.95it/s]

(2985,) (2015,)


100%|██████████| 5000/5000 [04:11<00:00, 19.88it/s]
100%|██████████| 2015/2015 [00:00<00:00, 37515.80it/s]
100%|██████████| 2985/2985 [00:00<00:00, 44559.75it/s]


2015 2015 2985
0.597
5000
(17271, 512) (17271, 1)


In [49]:
create_whale_dataset(
    train_dataset, test_unique_ids, test_count_ids, full_ds_embs, "whale"
)

  0%|          | 2/10587 [00:00<09:16, 19.01it/s]

(6273,) (4314,)


100%|██████████| 10587/10587 [08:54<00:00, 19.83it/s]
100%|██████████| 4314/4314 [00:00<00:00, 38457.28it/s]
100%|██████████| 6273/6273 [00:00<00:00, 44744.40it/s]


4314 4314 6273
0.5925191272315103
10587
(33762, 512) (33762, 1)


In [45]:
a = np.load("/app/datasets/whale_val/embeddings/scf_embs_whale.npz")
a["embs"][0] @ a["embs"][1], a["embs"][0] @ a["embs"][2]

(0.62920386, -0.009191476)

In [46]:
a["embs"][4] @ a["embs"][5], a["embs"][4] @ a["embs"][6], a["embs"][4] @ a["embs"][
    9
], a["embs"][4] @ a["embs"][10], a["embs"][4] @ a["embs"][16]

(0.80075425, 0.8756635, 0.8109231, 0.709551, 0.12658957)

In [47]:
full_ds_embs["embs"].shape, full_ds_embs["unc"].shape

((51033, 512), (51033, 1))

In [None]:
create_whale_dataset(
    train_dataset, test_unique_ids, test_count_ids, full_ds_embs, "whale"
)

In [None]:
np.unique(train_dataset.x_paths).shape

In [None]:
emb_dir = identification_ds_path / "embeddings"
emb_dir.mkdir(exist_ok=True)
np.savez(emb_dir / "b6_embs_whale.npz", **a, unc=np.ones((a["embs"].shape[0], 1)) * 30)

In [None]:
np.ones((a["embs"].shape[0], 1)).shape