In [1]:
import os
from glob import glob

def get_file_names(path):
    file_names = glob(os.path.join(path, *['*'] * 3, '*.jpg'))
    return file_names

def convert_to_image_ids(fnames):
    image_ids = []
    for fname in fnames:
        image_id = os.path.splitext(os.path.basename(fname))[0]
        image_ids.append(image_id)
    return image_ids

def get_image_ids(path):
    file_names = get_file_names(path)
    image_ids = convert_to_image_ids(file_names)
    return image_ids    

In [2]:
import csv
import numpy as np
from dataclasses import dataclass
from typing import List

@dataclass
class RetrievalResult:
    test_id: str
    chosen_ids: List[str]


def generate_random_results(test_ids, index_ids):
    for test_id in test_ids:
        chosen_ids = np.random.choice(index_ids, 100, replace=False)
        yield RetrievalResult(
            test_id = test_id,
            chosen_ids = chosen_ids
        )

def write_submission(output_fname, results: List[RetrievalResult]):
    with open(output_fname, 'w') as f:
        writer = csv.DictWriter(f, fieldnames=('id', 'images'))
        writer.writeheader()
        for result in results:
            writer.writerow(
                {
                    'id': result.test_id,
                    'images':' '.join(result.chosen_ids)
                }
            )


In [18]:
from pathlib import Path

data_path = Path('/shared/lorenzo/data-gld')
index_path = data_path/'index'
test_path = data_path/'test'

index_files = get_file_names(index_path)
test_files = get_file_names(test_path)

index_ids = get_image_ids(index_path)
test_ids = get_image_ids(test_path)

In [6]:
len(index_ids), len(test_ids)

(76176, 1129)

In [15]:
import shutil
import torch
from PIL import Image, ImageDraw

In [21]:
query_img_name = test_files[0]

In [None]:
from model.cgd import CGD
model = CGD(config.backbone_name, config.gd_config, config.feature_dim, config.num_classes)