In [2]:
import argparse
from pathlib import Path
from typing import List

import torchvision
import pandas as pd
import torch
import tqdm
import torch.nn.functional as F
from PIL import Image

from transformers import ViTImageProcessor, ViTModel


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from lightly.transforms.dino_transform import DINOTransform
from lightly.data import LightlyDataset

from src_lightly.models.dino import DINO

In [11]:
def load_image(path: str) -> Image.Image:
    return Image.open(path).convert("RGB")


def load_dataset(
    suspect_dir: str,
    output_dir: str,
    load_image,
    transform,
    batch_size = 64,
    num_workers = 8
):
    suspect_dataset = torchvision.datasets.DatasetFolder(
        suspect_dir,
        load_image,
        extensions=[".png"],
#         transform=transform
    )
    output_dataset = torchvision.datasets.DatasetFolder(
        output_dir,
        load_image,
        extensions=[".png"],
#         transform=transform
    )
    
    suspect_dataloader = torch.utils.data.DataLoader(
        LightlyDataset.from_torch_dataset(suspect_dataset, transform=transform),
        batch_size=batch_size,
        num_workers=num_workers,
#         collate_fn = lambda x: tuple(zip(*x)),        
    )
    
    output_dataloader = torch.utils.data.DataLoader(
        LightlyDataset.from_torch_dataset(output_dataset, transform=transform),
        batch_size=batch_size,
        num_workers=num_workers,
#         collate_fn = lambda x: tuple(zip(*x)),        
    )

    
    return suspect_dataloader, output_dataloader
    
    
def load_model(config: dict):
    print("loading model")
    processor = ViTImageProcessor.from_pretrained('facebook/dino-vitb8')
    model = ViTModel.from_pretrained('facebook/dino-vitb8')
    print("model loaded")
    return model, processor

def generate_batch(lst, batch_size):
    """  Yields batch of specified size """
    for i in range(0, len(lst), batch_size):
        yield lst[i : i + batch_size]


def encode_image(model, preprocess, img_paths: List[str], batch_size: int):
    print("num image found:", len(img_paths))
    with torch.inference_mode(), torch.cuda.amp.autocast():
        vectors = []
        for img_paths in tqdm.tqdm(generate_batch(img_paths, batch_size)):
            inputs = preprocess(
                images=[load_image(img_path) for img_path in img_paths],
                return_tensors="pt"
            )
            outputs = model(**inputs)
            vector = outputs.last_hidden_state
            vectors.append(vector.reshape(batch_size, -1).to("cpu"))
        vectors = torch.cat(vectors, dim=0)
        # normalize vectors prior
        vectors = F.normalize(vectors, dim=1)
    return vectors

def encode_image_dataloader(model, dataloader, preprocess):
    vectors = []
    with torch.inference_mode(), torch.no_grad():
        for imgs, labels in tqdm.tqdm(dataloader):
            inputs = preprocess(images=imgs, return_tensors="pt")
            outputs = model(**inputs)
            vector = outputs.last_hidden_state
            vectors.append(vector.reshape(len(imgs), -1).to("cpu"))
        vectors = torch.cat(vectors, dim=0)
        # normalize vectors prior
        vectors = F.normalize(vectors, dim=1)
    return vectors

def encode_lightly_dataloader(model, dataloader):
    vectors = []
    with torch.inference_mode(), torch.no_grad():
        for imgs, labels, _ in tqdm.tqdm(dataloader):
            vector = model(imgs[0].cuda())
            vectors.append(vector.reshape(len(imgs[0]), -1).to("cpu"))
        vectors = torch.cat(vectors, dim=0)
        # normalize vectors prior
        vectors = F.normalize(vectors, dim=1)
    return vectors


In [12]:

import torchvision.transforms as T

suspect_dir = Path("data/images/suspects")
cropped_dir = Path("test-detected/output/test")
output_csv = Path("test.csv")

# load csv
df = pd.read_csv("test-detected/output/test/image_infos.csv")
df["Image_Path"] = df["Image_Name"].apply(lambda x: str(cropped_dir / x))

# load model
# model, preprocess = load_model(
#     dict(
#         model_name="ViT-H-14",
#         pretrained="laion2b_s32b_b79k",
#         #     jit=True,
#         device=DEVICE,
#     )
# )


suspect_dataloader, output_dataloader = load_dataset(
    suspect_dir,
    cropped_dir,
    load_image,
    DINOTransform(
            hf_prob=0,
            cj_prob=0,
            random_gray_scale=0,
            gaussian_blur=(0, 0, 0),
            solarization_prob=0,
        ),
    64
)

model = DINO.load_from_checkpoint("checkpoints/dino/epoch=79-step=15120.ckpt",
      dataloader_suspect=suspect_dataloader,
        suspect_labels=[1,2,3],
        num_classes=10,
        knn_k=10,
        knn_t=0.1,
 )


# encode suspect images
print("Encoding suspect images")
# suspect_vectors = encode_image(
#     model, preprocess, list(suspect_dir.glob("*.png")), 64
# )
# suspect_vectors = encode_image_dataloader(model, suspect_dataloader, preprocess)

suspect_vectors = encode_lightly_dataloader(model, suspect_dataloader)

Using cache found in /home/p2026309/.cache/torch/hub/facebookresearch_dino_main


Encoding suspect images


100%|██████████| 25/25 [00:03<00:00,  7.87it/s]


In [13]:
# encode cropped images
print("Encoding cropped images")
# cropped_vectors = encode_image_dataloader(model, output_dataloader, preprocess)
cropped_vectors = encode_lightly_dataloader(model, output_dataloader)


Encoding cropped images


100%|██████████| 54/54 [00:05<00:00,  9.76it/s]


In [14]:
# compute cosine similarity
# shape=(num_cropped, num_suspect)
similarity_matrix = cropped_vectors.float() @ suspect_vectors.T.float()

# find the most similar image
max_similarity = torch.max(similarity_matrix, dim=1).values

# # check if similarity is greater than threshold
# is_suspect = torch.where(max_similarity > args.similarity_threshold, 1, 0)

# # save to csv
# df["class"] = is_suspect.tolist()
# df.to_csv(output_csv, index=False)


In [16]:
max_similarity.mean()

tensor(0.9863)