# 7.Img2Emb

In [1]:
import os
import sys
import pickle
import json
import glob
import gc
import random
import time
import unicodedata
import traceback
import datetime
import copy

import numpy as np
import pandas as pd


from matplotlib import pyplot as plt 
from tqdm.notebook import tqdm
from pathlib import Path
from scipy.spatial import distance
from collections import defaultdict
from PIL import Image

from sklearn.model_selection import train_test_split

import torch
import torchvision
from torchvision.models import (
    vit_b_16, ViT_B_16_Weights, 
    vit_l_16, ViT_L_16_Weights,
    vit_h_14, ViT_H_14_Weights,
    regnet_y_32gf, RegNet_Y_32GF_Weights,
    regnet_y_128gf, RegNet_Y_128GF_Weights,
    regnet_y_16gf, RegNet_Y_16GF_Weights,
    efficientnet_v2_l, EfficientNet_V2_L_Weights,
    efficientnet_v2_m, EfficientNet_V2_M_Weights,
    convnext_large, ConvNeXt_Large_Weights,
    swin_v2_b, Swin_V2_B_Weights
)
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import (
    StepLR, MultiStepLR, 
    ConstantLR, LinearLR, 
    ExponentialLR, PolynomialLR, 
    CosineAnnealingLR, CosineAnnealingWarmRestarts, 
    CyclicLR, OneCycleLR, 
    ReduceLROnPlateau
)

sys.path.append('../input/sentence-transformers-222/sentence-transformers')
from sentence_transformers import SentenceTransformer, models

os.environ["TOKENIZERS_PARALLELISM"] = "false"

def set_seed(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

2023-05-01 13:16:17.883907: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-05-01 13:16:18.013009: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


## CFG 

In [2]:
batch_size_config = {
    "vit_b_16": {
        True: 256,
        False: 16
    },
    "vit_b_16_linear": {
        True: 256,
        False: 48
    },
    "vit_l_16": {
        True: 16,
        False: 1
    }, 
    "--vit_h_14": {
        True: 4,
        False: 1
    },
    "regnet_y_16gf": {
        True: 64,
        False: 26
    },
    "regnet_y_16gf_linear": {
        True: 64,
        False: 26
    },
    "regnet_y_32gf": {
        True: 16,
        False: 6
    },
    "regnet_y_32gf_linear": {
        True: 16,
        False: 20
    },
    "--regnet_y_128gf": {
        True: 4,
        False: 1
    },
    "--regnet_y_128gf_linear": {
        True: 4,
        False: 1
    },
    "--efficientnet_v2_l": {
        True: 64,
        False: 3
    },
    "--efficientnet_v2_m": {
        True: 64,
        False: 6
    },
    "--convnext_large": {
        True: 64,
        False: 12
    },
    "--swin_v2_b": {
        True: 256,
        False: 16
    }

}

In [3]:
class CFG:
    seed = 42
    text_emb_size = 384
    is_kaggle = (os.environ.get('PWD') == '/kaggle/working')
    
    train_files_dir = "img2emb-data"
    save_model = True 
    
    # DATASET FILTERS
    img_size_min = 256
    img_size_max = 1280
    img_max_ratio_diff = 2
    prompt_words_min = 5
    prompt_words_max = 100
    prompt_is_english = True
    
    drop_duplicates_by_head = True
    drop_duplicates_by_tail = False
    
    drop_duplicates_char_len_2m = 20
    drop_duplicates_char_len_sd2v1 = 30
    drop_duplicates_char_len_sd2v2 = 30
    drop_duplicates_char_len_sd3 = 30

    add_sd2_v1 = False
    add_sd2_v2 = True
    add_sd3 = True
    
    drop_duplicates_char_len_sd2v1_name = drop_duplicates_char_len_sd2v1 if add_sd2_v1 else False
    drop_duplicates_char_len_sd2v2_name = drop_duplicates_char_len_sd2v2 if add_sd2_v2 else False
    drop_duplicates_char_len_sd3_name = drop_duplicates_char_len_sd3 if add_sd3 else False
    
    
    img_dataset_name = f"img_{img_size_min}_{img_size_max}_ratio_{img_max_ratio_diff}".replace(".", "_")
    prompt_dataset_name = f"prompt_{prompt_words_min}_{prompt_words_max}"
    dupl_dataset_name = f"dupl_2m_{drop_duplicates_char_len_2m}_sd2v1_{drop_duplicates_char_len_sd2v1_name}_sd2v2_{drop_duplicates_char_len_sd2v2_name}_sd3_{drop_duplicates_char_len_sd3_name}"
    dataset_name = f"{img_dataset_name}_{prompt_dataset_name}_{dupl_dataset_name}"
    
    # TRAIN TEST SPLIT
    img_model_test_size = 0.05
    
    # MODEL
    # "vit_b_16", "vit_b_16_linear", "regnet_y_16gf", "regnet_y_32gf", "regnet_y_32gf_linear", "regnet_y_16gf_linear"
    img_model_name = "regnet_y_32gf_linear" 
    loss_name = "cosine"
    lr = 1e-5
    # "None", "StepLR", "ExponentialLR", "CosineAnnealingLR", "CosineAnnealingWarmRestartsLR", "CyclicLR" 
    lr_scheduler_name = "CyclicLR" 
    train_only_head = False
    
    # AUGMENTATIONS
    train_aug = True
    test_flip = True
    aug_name = f"flip_{int(train_aug)}"
    
    model_name = f"model_{img_model_name}_lr_{lr:.0e}_sch_{lr_scheduler_name}".replace("-", "_")
    
    # RESOURCES
    batch_size = batch_size_config[img_model_name][is_kaggle]
    num_workers = batch_size if not is_kaggle else 2
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # TRAIN CONFIG
    full_train_epoch_num = 100
    max_epoch_num = full_train_epoch_num * 10
    full_val_epoch_num = 5
    early_stopping_patience = full_val_epoch_num * 5
    
    # PATHS
    
    train_name = f"{dataset_name}_{model_name}"

set_seed(CFG.seed)
CFG.train_name

'img_256_1280_ratio_2_prompt_5_100_dupl_2m_20_sd2v1_False_sd2v2_30_sd3_30_model_regnet_y_32gf_linear_lr_1e_05_sch_CyclicLR'

## Val functions 

In [4]:
def get_img_model(img_model_name):
    if img_model_name == "regnet_y_16gf":
        if CFG.is_kaggle:
            model = regnet_y_16gf()
        else:
            weights = RegNet_Y_16GF_Weights.IMAGENET1K_SWAG_E2E_V1
            model = regnet_y_16gf(weights=weights)
        model.fc = torch.nn.Linear(3024, CFG.text_emb_size)
        
        preprocess = transforms.Compose([
            transforms.Resize(224, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                 std=[0.229, 0.224, 0.225]),
        ])
    if img_model_name == "regnet_y_16gf_linear":
        if CFG.is_kaggle:
            model = regnet_y_16gf()
        else:
            weights = RegNet_Y_16GF_Weights.IMAGENET1K_SWAG_LINEAR_V1
            model = regnet_y_16gf(weights=weights)
        model.fc = torch.nn.Linear(3024, CFG.text_emb_size)
        
        preprocess = transforms.Compose([
            transforms.Resize(224, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                 std=[0.229, 0.224, 0.225]),
        ])
    elif img_model_name == "regnet_y_32gf":
        if CFG.is_kaggle:
            model = regnet_y_32gf()
        else:
            weights = RegNet_Y_32GF_Weights.IMAGENET1K_SWAG_E2E_V1
            model = regnet_y_32gf(weights=weights)
        model.fc = torch.nn.Linear(3712, CFG.text_emb_size)
        
        preprocess = transforms.Compose([
            transforms.Resize(384, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.CenterCrop(384),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                 std=[0.229, 0.224, 0.225]),
        ])
        
    elif img_model_name == "regnet_y_32gf_linear":
        if CFG.is_kaggle:
            model = regnet_y_32gf()
        else:
            weights = RegNet_Y_32GF_Weights.IMAGENET1K_SWAG_LINEAR_V1
            model = regnet_y_32gf(weights=weights)
        model.fc = torch.nn.Linear(3712, CFG.text_emb_size)
        
        preprocess = transforms.Compose([
            transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                 std=[0.229, 0.224, 0.225]),
        ])
    
    elif img_model_name == "regnet_y_128gf":
        if CFG.is_kaggle:
            model = regnet_y_128gf()
        else:
            weights = RegNet_Y_128GF_Weights.IMAGENET1K_SWAG_E2E_V1
            model = regnet_y_128gf(weights=weights)
        model.fc = torch.nn.Linear(7392, CFG.text_emb_size)
        
        preprocess = transforms.Compose([
            transforms.Resize(384, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.CenterCrop(384),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                 std=[0.229, 0.224, 0.225]),
        ])
    elif img_model_name == "regnet_y_128gf_linear":
        if CFG.is_kaggle:
            model = regnet_y_128gf()
        else:
            weights = RegNet_Y_128GF_Weights.IMAGENET1K_SWAG_LINEAR_V1
            model = regnet_y_128gf(weights=weights)
        model.fc = torch.nn.Linear(7392, CFG.text_emb_size)
        
        preprocess = transforms.Compose([
            transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                 std=[0.229, 0.224, 0.225]),
        ])
        
    elif img_model_name == "vit_b_16":
        if CFG.is_kaggle:
            model = vit_b_16(image_size=384)
        else:
            weights = ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1
            model = vit_b_16(weights=weights)
        model.heads.head = torch.nn.Linear(768, CFG.text_emb_size)
        
        preprocess = transforms.Compose([
            transforms.Resize(384, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.CenterCrop(384),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                 std=[0.229, 0.224, 0.225]),
        ])
    
    elif img_model_name == "vit_b_16_linear":
        if CFG.is_kaggle:
            model = vit_b_16(image_size=224)
        else:
            weights = ViT_B_16_Weights.IMAGENET1K_SWAG_LINEAR_V1
            model = vit_b_16(weights=weights)
        model.heads.head = torch.nn.Linear(768, CFG.text_emb_size)
        
        preprocess = transforms.Compose([
            transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                 std=[0.229, 0.224, 0.225]),
        ])
    
    elif img_model_name == "vit_l_16":
        if CFG.is_kaggle:
            model = vit_l_16(image_size=512)
        else:
            weights = ViT_L_16_Weights.IMAGENET1K_SWAG_E2E_V1
            model = vit_l_16(weights=weights)

        model.heads.head = torch.nn.Linear(1024, CFG.text_emb_size)
    
        preprocess = transforms.Compose([
            transforms.Resize(512, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.CenterCrop(512),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                 std=[0.229, 0.224, 0.225]),
        ])
    elif img_model_name == "vit_h_14":
        if CFG.is_kaggle:
            model = vit_h_14(image_size=512)
        else:
            weights = ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1
            model = vit_h_14(weights=weights)

        model.heads.head = torch.nn.Linear(1280, CFG.text_emb_size)
    
        preprocess = transforms.Compose([
            transforms.Resize(518, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.CenterCrop(518),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                 std=[0.229, 0.224, 0.225]),
        ])
    elif img_model_name == "efficientnet_v2_l":
        if CFG.is_kaggle:
            model = efficientnet_v2_l(image_size=480)
        else:
            weights = EfficientNet_V2_L_Weights.IMAGENET1K_V1
            model = efficientnet_v2_l(weights=weights)

        model.classifier[1] = torch.nn.Linear(1280, CFG.text_emb_size)
    
        preprocess = transforms.Compose([
            transforms.Resize(480, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.CenterCrop(480),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], 
                                 std=[0.5, 0.5, 0.5]),
        ])

    elif img_model_name == "efficientnet_v2_m":
        if CFG.is_kaggle:
            model = efficientnet_v2_m(image_size=480)
        else:
            weights = EfficientNet_V2_M_Weights.IMAGENET1K_V1
            model = efficientnet_v2_m(weights=weights)

        model.classifier[1] = torch.nn.Linear(1280, CFG.text_emb_size)
    
        preprocess = transforms.Compose([
            transforms.Resize(480, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(480),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                 std=[0.229, 0.224, 0.225]),
        ])
    elif img_model_name == "convnext_large":
        if CFG.is_kaggle:
            model = convnext_large(image_size=224)
        else:
            weights = ConvNeXt_Large_Weights.IMAGENET1K_V1
            model = convnext_large(weights=weights)

        model.classifier[2] = torch.nn.Linear(1536, CFG.text_emb_size)
    
        preprocess = transforms.Compose([
            transforms.Resize(224, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                 std=[0.229, 0.224, 0.225]),
        ])
    elif img_model_name == "swin_v2_b":
        if CFG.is_kaggle:
            model = swin_v2_b(image_size=256)
        else:
            weights = Swin_V2_B_Weights.IMAGENET1K_V1
            model = swin_v2_b(weights=weights)

        model.head = torch.nn.Linear(1024, CFG.text_emb_size)
    
        preprocess = transforms.Compose([
            transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.CenterCrop(256),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                 std=[0.229, 0.224, 0.225]),
        ])
    
    model.to(CFG.device)
    return model, preprocess

def create_submission(pred_arr, img_names):
    imgIds = [i.split('.')[0] for i in img_names]

    EMBEDDING_LENGTH = CFG.text_emb_size
    eIds = list(range(EMBEDDING_LENGTH))

    imgId_eId = [
        '_'.join(map(str, i)) for i in zip(
            np.repeat(imgIds, EMBEDDING_LENGTH),
            np.tile(range(EMBEDDING_LENGTH), len(imgIds)))]
    
    submission = pd.DataFrame(
                    index=imgId_eId,
                    data=np.array(pred_arr).flatten(),
                    columns=['val']).rename_axis('imgId_eId')
    return submission

class CustomDataSet(Dataset):
    def __init__(self, data_dir, img2prompt, img_preprocess):
        self.data_dir = data_dir
        self.img_names = list(img2prompt.keys())
        self.img2prompt = img2prompt
        self.img_preprocess = img_preprocess

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, idx):
        img_name = self.img_names[idx]
        img_path = os.path.join(self.data_dir, img_name)
        img = Image.open(img_path)
        img_emb = self.img_preprocess(img)
        
        prompt = str(self.img2prompt[img_name])
        
        return img_name, img_emb, prompt

## Train functions 

In [5]:
def get_train_config(train_size, val_size):
    max_batches_per_epoch_train = train_size // CFG.batch_size // CFG.full_train_epoch_num
    max_batches_per_epoch_val = val_size // CFG.batch_size // CFG.full_val_epoch_num
    return max_batches_per_epoch_train, max_batches_per_epoch_val

def filter_metadata(df, 
                    img_size_min, img_size_max, 
                    img_max_ratio_diff, 
                    prompt_words_min, prompt_words_max, 
                    prompt_is_english,
                    drop_duplicates_by_head, 
                    drop_duplicates_by_tail, 
                    drop_duplicates_char_len):
    def is_english_only(string):
        for s in string:
            cat = unicodedata.category(s)         
            if (cat not in ['Ll', 'Lu', 'Nd', 'Po', 'Pd', 'Zs']) or (not cat.isascii()):
                return False
        
        return True
    
    df = df.copy()
    df["prompt"] = df["prompt"].astype(str)
    
    df["size_ratio"] = df["height"] / df["width"]
    df['prompt'] = df['prompt'].str.strip()
    df["num_words"] = df['prompt'].str.split(" ").apply(len)
    df["is_english"] = df["prompt"].apply(is_english_only)
    
    img_hw_cond = (
        df["width"].between(img_size_min, img_size_max) & 
        df["height"].between(img_size_min, img_size_max)
    )
    img_ratio_cond = df["size_ratio"].between(1/img_max_ratio_diff, img_max_ratio_diff)
    prompt_empty_cond = (df["prompt"] != "")
    prompt_num_words_cond = df["num_words"].between(prompt_words_min, prompt_words_max)

    if prompt_is_english:
        df = df[df["is_english"]]
    
    if drop_duplicates_by_head:
        df['head'] = df['prompt'].str[:drop_duplicates_char_len]
        df.drop_duplicates(subset='head', inplace=True)
    
    if drop_duplicates_by_tail:
        df['tail'] = df['prompt'].str[-drop_duplicates_char_len:]
        df.drop_duplicates(subset='tail', inplace=True)
    
    
    df = df[
        img_hw_cond &
        img_ratio_cond &
        prompt_empty_cond &
        prompt_num_words_cond
    ][["image_name", "prompt"]]
    
    df = df.reset_index(drop=True)
    return df

def get_loss(loss_name):
    if loss_name == "cosine":
        loss_fn = torch.nn.CosineEmbeddingLoss()
        return lambda pred, true: loss_fn(pred, true, torch.ones(pred.size(0)).to(CFG.device))
    
def get_scheduler(scheduler_name):
    overfiting_epoch = int(CFG.full_train_epoch_num)
    lr_start = CFG.lr
    
    if scheduler_name == "None":
        scheduler = None
    elif scheduler_name == "StepLR":
        scheduler = lambda optimizer: StepLR(
            optimizer, step_size = overfiting_epoch * 0.5, gamma = 0.5
        )
    elif scheduler_name == "ExponentialLR":
        scheduler = lambda optimizer: ExponentialLR(
            optimizer, gamma = 1.05
        )
    elif scheduler_name == "CosineAnnealingLR":
        scheduler = lambda optimizer: CosineAnnealingLR(
            optimizer, T_max=overfiting_epoch, eta_min=lr_start * 0.1
        )
    elif scheduler_name == "CosineAnnealingWarmRestartsLR":
        scheduler = lambda optimizer: CosineAnnealingWarmRestarts(
            optimizer, T_0=overfiting_epoch, T_mult=1, eta_min=lr_start * 0.01
        )
    elif scheduler_name == "CyclicLR":
        scheduler = lambda optimizer: CyclicLR(
            optimizer, 
            base_lr =lr_start / 10, max_lr = lr_start * 10, 
            step_size_up = overfiting_epoch * 0.5, 
            mode = "triangular2", 
            cycle_momentum=False
        )
    return scheduler

## Data 

In [6]:
train_data_dir = Path("../input/")
metadata_2m = pd.read_parquet(train_data_dir / "DiffusionDB_2M/metadata.parquet")
metadata_2m["image_name"] = "DiffusionDB_2M/" + metadata_2m["image_name"]
metadata_2m = filter_metadata(
    metadata_2m, 
    img_size_min=CFG.img_size_min, 
    img_size_max=CFG.img_size_max, 
    img_max_ratio_diff=CFG.img_max_ratio_diff, 
    prompt_words_min=CFG.prompt_words_min, 
    prompt_words_max=CFG.prompt_words_max,
    prompt_is_english=CFG.prompt_is_english,
    drop_duplicates_by_head=CFG.drop_duplicates_by_head,
    drop_duplicates_by_tail=CFG.drop_duplicates_by_tail,
    drop_duplicates_char_len=CFG.drop_duplicates_char_len_2m
)

metadata = metadata_2m

if CFG.add_sd2_v1:
    metadata_sd2_v1 = pd.read_parquet(train_data_dir / "gustavosta-sd2-v1/metadata.parquet")
    metadata_sd2_v1["image_name"] = "gustavosta-sd2-v1/" + metadata_sd2_v1["image_name"]
    metadata_sd2_v1["height"] = 512
    metadata_sd2_v1["width"] = 512
    metadata_sd2_v1 = filter_metadata(
        metadata_sd2_v1, 
        img_size_min=CFG.img_size_min, 
        img_size_max=CFG.img_size_max, 
        img_max_ratio_diff=CFG.img_max_ratio_diff, 
        prompt_words_min=CFG.prompt_words_min, 
        prompt_words_max=CFG.prompt_words_max,
        prompt_is_english=CFG.prompt_is_english,
        drop_duplicates_by_head=CFG.drop_duplicates_by_head,
        drop_duplicates_by_tail=CFG.drop_duplicates_by_tail,
        drop_duplicates_char_len=CFG.drop_duplicates_char_len_sd2v1
    )
    
    metadata = pd.concat([metadata, metadata_sd2_v1], ignore_index=True)

if CFG.add_sd2_v2:
    metadata_sd2_v2 = pd.read_parquet(train_data_dir / "gustavosta-sd2-v2/metadata.parquet")
    metadata_sd2_v2["image_name"] = "gustavosta-sd2-v2/" + metadata_sd2_v2["image_name"]
    metadata_sd2_v2["height"] = 512
    metadata_sd2_v2["width"] = 512
    metadata_sd2_v2 = filter_metadata(
        metadata_sd2_v2, 
        img_size_min=CFG.img_size_min, 
        img_size_max=CFG.img_size_max, 
        img_max_ratio_diff=CFG.img_max_ratio_diff, 
        prompt_words_min=CFG.prompt_words_min, 
        prompt_words_max=CFG.prompt_words_max,
        prompt_is_english=CFG.prompt_is_english,
        drop_duplicates_by_head=CFG.drop_duplicates_by_head,
        drop_duplicates_by_tail=CFG.drop_duplicates_by_tail,
        drop_duplicates_char_len=CFG.drop_duplicates_char_len_sd2v2
    )
    
    metadata = pd.concat([metadata, metadata_sd2_v2], ignore_index=True)
    
if CFG.add_sd3:
    metadata_sd3 = pd.read_csv(train_data_dir / "sd3/metadata.csv")
    metadata_sd3["image_name"] = "sd3/" + metadata_sd3["image_path"]
    del metadata_sd3["image_path"]
    metadata_sd3["height"] = 512
    metadata_sd3["width"] = 512
    metadata_sd3 = filter_metadata(
        metadata_sd3, 
        img_size_min=CFG.img_size_min, 
        img_size_max=CFG.img_size_max, 
        img_max_ratio_diff=CFG.img_max_ratio_diff, 
        prompt_words_min=CFG.prompt_words_min, 
        prompt_words_max=CFG.prompt_words_max,
        prompt_is_english=CFG.prompt_is_english,
        drop_duplicates_by_head=CFG.drop_duplicates_by_head,
        drop_duplicates_by_tail=CFG.drop_duplicates_by_tail,
        drop_duplicates_char_len=CFG.drop_duplicates_char_len_sd3
    )
    
    metadata = pd.concat([metadata, metadata_sd3], ignore_index=True)
    
metadata

  df = df[
  df = df[
  df = df[


Unnamed: 0,image_name,prompt
0,DiffusionDB_2M/2217ccbd-a1c6-47ac-9a2d-7964972...,"a portrait of a female robot made from code, v..."
1,DiffusionDB_2M/dc71658a-5e4b-4dca-861a-e153551...,"only memories remain, trending on artstation"
2,DiffusionDB_2M/48eb7e17-a3cf-4eb8-96a9-d8e3e23...,dream swimming pool with nobody
3,DiffusionDB_2M/601d9792-eccd-4850-97a7-edbe91d...,a dog doing weights. epic oil painting.
4,DiffusionDB_2M/3c586acb-14dc-43df-8900-954c336...,a dog doing weights on fire. epic oil painting.
...,...,...
689334,sd3/artifacts/sd-img-to-prompts:v29/00995.png,"portrait of modern darna, sonequa martin - gre..."
689335,sd3/artifacts/sd-img-to-prompts:v29/00996.png,a greenhouse with deep green and purple glowin...
689336,sd3/artifacts/sd-img-to-prompts:v29/00998.png,1 9 2 0 s color spirit photography 0 9 1 1 2 1...
689337,sd3/artifacts/sd-img-to-prompts:v29/00999.png,gary busey doing a sweet skateboard trick off ...


In [7]:
metadata["image_name"].str.split("/").str[0].value_counts()

DiffusionDB_2M       625544
gustavosta-sd2-v2     40371
sd3                   23424
Name: image_name, dtype: int64

In [8]:
full_prompt = metadata[["image_name", "prompt"]].values
train_prompt, val_prompt = train_test_split(
    full_prompt, 
    test_size=CFG.img_model_test_size, 
    random_state=CFG.seed,
    shuffle=True
)

CFG.dataset_train_size = len(train_prompt)
CFG.dataset_val_size = len(val_prompt)
CFG.max_batches_per_epoch_train, CFG.max_batches_per_epoch_val = get_train_config(CFG.dataset_train_size, 
                                                                                  CFG.dataset_val_size)

print(CFG.dataset_train_size, CFG.dataset_val_size)
print(CFG.max_batches_per_epoch_train, CFG.max_batches_per_epoch_val)
print(CFG.max_batches_per_epoch_train * CFG.batch_size, CFG.max_batches_per_epoch_val * CFG.batch_size)

train_prompt_dict = {img_name: prompt for img_name, prompt in train_prompt}
val_prompt_dict = {img_name: prompt for img_name, prompt in val_prompt}

654872 34467
327 344
6540 6880


## Model 

In [9]:
st_model = SentenceTransformer('../input/sentence-transformers-222/all-MiniLM-L6-v2/')
img_model, img_preprocess = get_img_model(img_model_name=CFG.img_model_name)

if CFG.train_aug:
    train_img_preprocess = transforms.Compose([
        transforms.RandomHorizontalFlip(p=0.5),
        img_preprocess,
    ])
else:
    train_img_preprocess = img_preprocess

train_dataset = CustomDataSet(
    data_dir=train_data_dir, 
    img2prompt=train_prompt_dict, 
    img_preprocess=train_img_preprocess,
)
val_dataset = CustomDataSet(
    data_dir=train_data_dir, 
    img2prompt=val_prompt_dict, 
    img_preprocess=img_preprocess,
)

train_dataloader = DataLoader(train_dataset, batch_size=CFG.batch_size, shuffle=True,
                                    num_workers=CFG.num_workers)
val_dataloader = DataLoader(val_dataset, batch_size=CFG.batch_size, shuffle=True,
                                  num_workers=CFG.num_workers)

## Train  

In [10]:
def create_summary_writer():
    current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    writer = SummaryWriter(log_dir=f"../logs/{CFG.dataset_name}/{CFG.model_name}_{current_time}")
    return writer

def train_epoch(img_model, st_model, train_dataloader, loss_f, optimizer, disable_tqdm=True):
    if CFG.train_only_head:
        img_model.eval()
        img_model.fc.train()
    else:
        img_model.train()
        
    mean_train_loss = 0
    train_batches_n = 0
    for batch_i, (img_names, img_embs, prompts) in enumerate(pbar := tqdm(train_dataloader, disable=disable_tqdm)):
        if batch_i > CFG.max_batches_per_epoch_train:
            break
        
        pred = img_model(img_embs.to(CFG.device))
        prompts_emb = torch.Tensor(st_model.encode(prompts)).to(CFG.device)
        
        loss = loss_f(pred, prompts_emb)

        img_model.zero_grad()
        loss.backward()
        optimizer.step()

        mean_train_loss += float(loss)
        train_batches_n += 1

    mean_train_loss /= train_batches_n
    return mean_train_loss

def val_epoch(img_model, st_model, val_dataloader, loss_f, optimizer, disable_tqdm=True):
    img_model.eval()
    
    mean_val_loss = 0
    val_batches_n = 0
    with torch.no_grad():
        for batch_i, (img_names, img_embs, prompts) in enumerate(pbar := tqdm(val_dataloader, disable=disable_tqdm)):
            if batch_i > CFG.max_batches_per_epoch_val:
                break
                
            img_embs = img_embs.to(CFG.device)
                
            pred = img_model(img_embs)
            
            if CFG.test_flip:
                img_embs_flip = transforms.functional.hflip(img_embs)
                pred_flip = img_model(img_embs_flip)
                pred = (pred + pred_flip) / 2
            
            prompts_emb = torch.Tensor(st_model.encode(prompts)).to(CFG.device)

            loss = loss_f(pred, prompts_emb)

            mean_val_loss += float(loss)
            val_batches_n += 1
            
    mean_val_loss /= val_batches_n
    return mean_val_loss

In [11]:
writer = create_summary_writer()

loss_f = get_loss(loss_name=CFG.loss_name)
optimizer = torch.optim.Adam(img_model.parameters(), lr=CFG.lr)

lr_scheduler = get_scheduler(CFG.lr_scheduler_name)
if lr_scheduler:
    lr_scheduler = lr_scheduler(optimizer)

img_model.to(CFG.device)
st_model.to(CFG.device)

best_val_loss = float('inf')
best_epoch_i = 0

for epoch_i in range(CFG.max_epoch_num):
    mean_train_loss = train_epoch(img_model, st_model, train_dataloader, loss_f, optimizer, disable_tqdm=True)
    mean_val_loss = val_epoch(img_model, st_model, val_dataloader, loss_f, optimizer, disable_tqdm=True)
    
    if lr_scheduler:
        lr_scheduler.step()
    
    train_sim = round(1 - mean_train_loss, 3)
    val_sim = round(1 - mean_val_loss, 3)
    
    ### SAVE BEST MODEL ###
    print(f"Epoch {epoch_i + 1}, lr={optimizer.param_groups[0]['lr']:.2e}: train = {train_sim}, val = {val_sim}", end="; ")
    
    if mean_val_loss < best_val_loss:
        best_epoch_i = epoch_i
        best_val_loss = mean_val_loss
        
        if CFG.save_model:
            torch.save(
                img_model.state_dict(), f"../input/{CFG.train_files_dir}/{CFG.train_name}.torch"
            )
        print(f'new best model')
    elif epoch_i - best_epoch_i > CFG.early_stopping_patience:
        print(f'early stopping')
        break
    else:
        print("continue")
        
    ### HISTORY ###
    writer.add_scalars(
        "Similarity",
        {"train": 1 - mean_train_loss, "val": 1 - mean_val_loss}, 
        global_step=epoch_i + 1
    )
    writer.flush()

Epoch 1, lr=2.98e-06: train = 0.123, val = 0.216; new best model
Epoch 2, lr=4.96e-06: train = 0.386, val = 0.442; new best model
Epoch 3, lr=6.94e-06: train = 0.473, val = 0.487; new best model
Epoch 4, lr=8.92e-06: train = 0.5, val = 0.515; new best model
Epoch 5, lr=1.09e-05: train = 0.523, val = 0.532; new best model
Epoch 6, lr=1.29e-05: train = 0.54, val = 0.548; new best model
Epoch 7, lr=1.49e-05: train = 0.552, val = 0.56; new best model
Epoch 8, lr=1.68e-05: train = 0.564, val = 0.571; new best model
Epoch 9, lr=1.88e-05: train = 0.569, val = 0.576; new best model
Epoch 10, lr=2.08e-05: train = 0.578, val = 0.581; new best model
Epoch 11, lr=2.28e-05: train = 0.585, val = 0.587; new best model
Epoch 12, lr=2.48e-05: train = 0.583, val = 0.592; new best model
Epoch 13, lr=2.67e-05: train = 0.591, val = 0.597; new best model
Epoch 14, lr=2.87e-05: train = 0.595, val = 0.602; new best model
Epoch 15, lr=3.07e-05: train = 0.6, val = 0.601; continue
Epoch 16, lr=3.27e-05: train = 

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4912fdbd00>
Traceback (most recent call last):
  File "/home/rv/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1479, in __del__
    self._shutdown_workers()
  File "/home/rv/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1443, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/usr/lib/python3.10/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
  File "/usr/lib/python3.10/multiprocessing/popen_fork.py", line 40, in wait
    if not wait([self.sentinel], timeout):
  File "/usr/lib/python3.10/multiprocessing/connection.py", line 931, in wait
    ready = selector.select(timeout)
  File "/usr/lib/python3.10/selectors.py", line 416, in select
    fd_event_list = self._selector.poll(timeout)
KeyboardInterrupt: 
ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this i

Traceback (most recent call last):
  File "/usr/lib/python3/dist-packages/IPython/core/interactiveshell.py", line 3457, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_236048/1361756759.py", line 17, in <module>
    mean_train_loss = train_epoch(img_model, st_model, train_dataloader, loss_f, optimizer, disable_tqdm=True)
  File "/tmp/ipykernel_236048/861239834.py", line 28, in train_epoch
    mean_train_loss += float(loss)
KeyboardInterrupt

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/lib/python3/dist-packages/IPython/core/interactiveshell.py", line 2077, in showtraceback
    stb = value._render_traceback_()
AttributeError: 'KeyboardInterrupt' object has no attribute '_render_traceback_'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/lib/python3/dist-packages/IPython/core/ultratb.py", line 1101, in get_r

TypeError: object of type 'NoneType' has no len()

## Inference 

In [14]:
img_model, img_preprocess = get_img_model(img_model_name=CFG.img_model_name)
img_model.load_state_dict(torch.load(f"../input/{CFG.train_files_dir}/{CFG.train_name}.torch"))
CFG.train_name

'img_256_1280_ratio_1_prompt_5_100_dupl_20_model_regnet_y_16gf_cosine_lr_1e_05_sch_None'

In [11]:
test_data_dir = Path("../input/stable-diffusion-image-to-prompts/images/")
test_image_names = sorted(os.listdir(test_data_dir))
test_prompt_dict = {img_name: "" for img_name in test_image_names}

test_dataset = CustomDataSet(   
    data_dir=test_data_dir, 
    img2prompt=test_prompt_dict, 
    img_preprocess=img_preprocess
)
test_dataloader = DataLoader(
    test_dataset, 
    batch_size=CFG.batch_size, 
    shuffle=False, 
    num_workers=CFG.num_workers
)

In [12]:
img_model.eval()
pred_arr = []
with torch.no_grad():
    for img_names, img_embs, prompts in tqdm(test_dataloader):
        prompts_emb = img_model(img_embs.to(CFG.device))
        pred_arr.extend(prompts_emb.cpu().detach().numpy())
pred_arr = np.array(pred_arr)

  0%|          | 0/1 [00:00<?, ?it/s]

In [13]:
submission = create_submission(pred_arr, test_image_names)
submission

Unnamed: 0_level_0,val
imgId_eId,Unnamed: 1_level_1
20057f34d_0,0.087353
20057f34d_1,0.822619
20057f34d_2,1.290800
20057f34d_3,-0.150833
20057f34d_4,1.020432
...,...
f27825b2c_379,0.155445
f27825b2c_380,0.874851
f27825b2c_381,-0.456304
f27825b2c_382,-0.542236
