In [2]:
!pwd

/root/kaggle/google-retrieval


In [3]:
from albumentations.pytorch.transforms import ToTensorV2
from torch.utils.data import DataLoader
from torchvision.datasets.folder import default_loader
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from tqdm.auto import tqdm

import os
import glob
import torch
import timm

import albumentations as A
import torchvision.datasets as datasets
import pandas as pd
import numpy as np
import pytorch_lightning as pl
import torch.nn.functional as F
import torch.nn as nn

In [4]:
class GLDConfig:
    seed = 42
    num_workers = 8
    wandb_logger = False
    project = 'GLD2021'

    backbone_name = "tf_efficientnet_b4_ns"
    hidden_dim = 512

    data_root = '/shared/lorenzo/data-gld'
    split_df_root = '/root/kaggle/google-retrieval/split'
    img_size = 384
    num_classes = 81313
    aug = True
    fold_no = 0
    
    max_epochs = 20
    batch_size = 64
    lr = 3e-4
    lr_scheduler = False
    
    es_patience = 5
    checkpoint_dir = '/root/kaggle/google-retrieval/pl_output'

In [5]:
pd.read_csv(os.path.join(GLDConfig.data_root, "train.csv"))

Unnamed: 0,id,landmark_id
0,17660ef415d37059,1
1,92b6290d571448f6,1
2,cd41bf948edc0340,1
3,fb09f1e98c6d2f70,1
4,25c9dfc7ea69838d,7
...,...,...
1580465,72c3b1c367e3d559,203092
1580466,7a6a2d9ea92684a6,203092
1580467,9401fad4c497e1f9,203092
1580468,aacc960c9a228b5f,203092


In [6]:
len(pd.read_csv(os.path.join(GLDConfig.data_root, "train.csv"))["landmark_id"].unique())

81313

In [7]:
label_map = dict(map(lambda x: (x[1], x[0]), enumerate(pd.read_csv(
            os.path.join(GLDConfig.data_root, "train.csv"))["landmark_id"].unique())))

label_map

{1: 0,
 7: 1,
 9: 2,
 11: 3,
 12: 4,
 17: 5,
 22: 6,
 23: 7,
 24: 8,
 27: 9,
 29: 10,
 30: 11,
 32: 12,
 34: 13,
 36: 14,
 37: 15,
 41: 16,
 43: 17,
 48: 18,
 50: 19,
 51: 20,
 56: 21,
 58: 22,
 60: 23,
 63: 24,
 65: 25,
 66: 26,
 72: 27,
 79: 28,
 81: 29,
 82: 30,
 83: 31,
 84: 32,
 87: 33,
 89: 34,
 90: 35,
 99: 36,
 100: 37,
 101: 38,
 103: 39,
 104: 40,
 110: 41,
 111: 42,
 115: 43,
 117: 44,
 118: 45,
 123: 46,
 124: 47,
 131: 48,
 134: 49,
 135: 50,
 136: 51,
 139: 52,
 141: 53,
 143: 54,
 145: 55,
 146: 56,
 149: 57,
 151: 58,
 155: 59,
 166: 60,
 168: 61,
 172: 62,
 173: 63,
 174: 64,
 175: 65,
 177: 66,
 178: 67,
 183: 68,
 185: 69,
 187: 70,
 189: 71,
 191: 72,
 192: 73,
 199: 74,
 202: 75,
 209: 76,
 212: 77,
 213: 78,
 215: 79,
 216: 80,
 219: 81,
 221: 82,
 223: 83,
 225: 84,
 226: 85,
 228: 86,
 230: 87,
 232: 88,
 240: 89,
 242: 90,
 243: 91,
 244: 92,
 245: 93,
 247: 94,
 250: 95,
 259: 96,
 260: 97,
 262: 98,
 263: 99,
 264: 100,
 270: 101,
 272: 102,
 274: 103,
 277: 

In [8]:
df = pd.read_csv(os.path.join(GLDConfig.split_df_root, f"train_df_fold0.csv"))
df['landmark_id'].map(label_map)

0              0
1              0
2              0
3              1
4              1
           ...  
1264371    81312
1264372    81312
1264373    81312
1264374    81312
1264375    81312
Name: landmark_id, Length: 1264376, dtype: int64

In [9]:
data = list(zip(df['id'], df['landmark_id'].map(label_map)))
data[-1]

('d9e338c530dca106', 81312)

In [10]:
glob.glob(os.path.join(GLDConfig.data_root, f"test/*/*/*/*.jpg"))

['/shared/lorenzo/data-gld/test/5/f/5/5f5a7b3e3f2c17ca.jpg',
 '/shared/lorenzo/data-gld/test/5/f/2/5f27abd2ea15f147.jpg',
 '/shared/lorenzo/data-gld/test/5/f/6/5f6238137076d11b.jpg',
 '/shared/lorenzo/data-gld/test/5/f/f/5ffd474f62502596.jpg',
 '/shared/lorenzo/data-gld/test/5/f/1/5f14c1bbc07f34e6.jpg',
 '/shared/lorenzo/data-gld/test/5/f/1/5f1c08a01ec523d1.jpg',
 '/shared/lorenzo/data-gld/test/5/f/c/5fc41fa23cda8d6f.jpg',
 '/shared/lorenzo/data-gld/test/5/f/c/5fc68cb1c569dfa2.jpg',
 '/shared/lorenzo/data-gld/test/5/6/a/56aaed126d7cb956.jpg',
 '/shared/lorenzo/data-gld/test/5/6/d/56d525bfd1da5e15.jpg',
 '/shared/lorenzo/data-gld/test/5/1/f/51f3f1fddd1393ef.jpg',
 '/shared/lorenzo/data-gld/test/5/1/a/51ae88ec101d2221.jpg',
 '/shared/lorenzo/data-gld/test/5/1/c/51ca94b1da9a713b.jpg',
 '/shared/lorenzo/data-gld/test/5/1/3/513caf882acf8af6.jpg',
 '/shared/lorenzo/data-gld/test/5/1/d/51deb719fbe6de33.jpg',
 '/shared/lorenzo/data-gld/test/5/a/7/5a7f391ec3d00fa9.jpg',
 '/shared/lorenzo/data-g

In [11]:
class GLDataset(datasets.VisionDataset):
    def __init__(self, root, split_df_root, fold_no=0, seed=42, split='train', transform=None):
        super().__init__(root, transform=transform)
        assert split in ['train', 'val', 'test', 'index']
        self.loader = default_loader
        self.split = split
        self.seed = seed

        label_map = dict(map(lambda x: (x[1], x[0]), enumerate(pd.read_csv(
            os.path.join(root, "train.csv"))["landmark_id"].unique())))

        if split == 'train':
            df = pd.read_csv(os.path.join(
                split_df_root, f"train_df_fold{fold_no}.csv"))
            self.data = list(zip(df['id'], df['landmark_id'].map(label_map)))
        elif split == "val":
            df = pd.read_csv(os.path.join(
                split_df_root, f"val_df_fold{fold_no}.csv"))
            self.data = list(zip(df["id"], df["landmark_id"].map(label_map)))
        else:
            self.data = glob.glob(os.path.join(root, f"{split}/*/*/*/*.jpg"))

    def __getitem__(self, index):
        if self.split in ['train', 'val']:
            path, target = self.data[index]
            path = os.path.join(self.root, 'train', 
                                path[0], path[1], path[2], f"{path}.jpg")
        else:
            path = self.data[index]

        img = self.loader(path)
        img = np.array(img)

        if self.transform is not None:
            if type(self.transform) == A.Compose:
                img = self.transform(image=img)['image']
            else:
                img = self.transform(img)

        if self.split in ['train', 'val']:
            return img, target
        else:
            return img

    def __len__(self):
        return len(self.data)     

In [13]:
dataset = GLDataset(root=GLDConfig.data_root, split_df_root=GLDConfig.split_df_root, split='val')
img = dataset[0]
len(dataset)

316094