# Library

In [1]:
import os
import random
import numpy as np
import pandas as pd
from PIL import Image
from tqdm.notebook import tqdm
from scipy import spatial
from sklearn.model_selection import train_test_split
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision import transforms
import timm    #并行节点操作
from timm.utils import AverageMeter
from torch.cuda.amp import autocast,GradScaler
import sys
sys.path.append('../input/sentence-transformers-222/sentence-transformers')
from sentence_transformers import SentenceTransformer
import warnings
warnings.filterwarnings('ignore')
import unicodedata

In [2]:

from pathlib import Path
from PIL import Image
from tqdm.notebook import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import timm
from torchvision import transforms
import pickle
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from glob import glob
from torch import nn
from PIL import Image
from pathlib import Path
from transformers import AutoModel, AutoProcessor
import cv2
import os
import random
import numpy as np
import pandas as pd
from pathlib import Path
from PIL import Image
from tqdm.notebook import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import timm
from timm.models.vision_transformer import Mlp ,PatchEmbed,_cfg

BATCHSIZE=128
SAVE_OPT_CKP = True
SAVE_MODEL_CKP = True
UNFREEZE_START = 18 # set it to lower number when significantly more samples are included.

torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn

# run_name = f'open-clip224-l14'
import gc

In [3]:
wheels_path = "/kaggle/input/open-clip-wheels/open_clip_wheels"
open_clip_whl_path = f"{wheels_path}/open_clip_torch-2.14.0-py3-none-any.whl"
!pip install --no-index --find-links $wheels_path $open_clip_whl_path -q
import open_clip

[0m

# Config

In [4]:
class CFG:
    model_path = '/kaggle/input/stable-diffusion-convnext-baseline-train/convnext_large_d.pth'
    model_name = 'convnext_large_d'  
    input_size = 256
    batch_size = 16
    num_epochs = 1
    lr = 1e-9
    seed = 42


In [5]:
def seed_everything(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    if torch.cuda.is_available(): 
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True    #每次使用相同算法


seed_everything(CFG.seed)

# Dataset

In [6]:
class DiffusionDataset(Dataset):
    def __init__(self, df, transform):
        self.df = df
        self.transform = transform
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image = Image.open(row['filepath'])
        image = self.transform(image)
        prompt = row['prompt']
        return image, prompt


In [7]:
class DiffusionCollator:
    def __init__(self):
        self.st_model = SentenceTransformer(
            '/kaggle/input/sentence-transformers-222/all-MiniLM-L6-v2',
            device='cpu'
        )
    
    def __call__(self, batch):
        images, prompts = zip(*batch)
        images = torch.stack(images)
        prompt_embeddings = self.st_model.encode(
            prompts, 
            show_progress_bar=False, 
            convert_to_tensor=True
        )
        return images, prompt_embeddings
    

In [8]:
def get_dataloaders(
    trn_df,
    val_df,
    input_size,
    batch_size
):
    transform = transforms.Compose([
        transforms.Resize(input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    trn_dataset = DiffusionDataset(trn_df, transform)
    val_dataset = DiffusionDataset(val_df, transform)
    collator = DiffusionCollator()
    
    dataloaders = {}
    dataloaders['train'] = DataLoader(
        dataset=trn_dataset,
        shuffle=True,
        batch_size=batch_size,
        pin_memory=True,
        num_workers=2,
        drop_last=True,
        collate_fn=collator
    )
    dataloaders['val'] = DataLoader(
        dataset=val_dataset,
        shuffle=False,
        batch_size=batch_size,
        pin_memory=True,
        num_workers=2,
        drop_last=False,
        collate_fn=collator
    )
    return dataloaders

# Train

In [9]:
class Net(nn.Module):
    def __init__(self, model):
        super(Net, self).__init__()
        clip = model
        self.vision = clip.visual
        self.fc = nn.Linear(768, 384)

    def forward(self, x):
        out = self.vision(x)
        return self.fc(out)

In [10]:
def cosine_similarity(y_trues, y_preds):
    return np.mean([
        1 - spatial.distance.cosine(y_true, y_pred) 
        for y_true, y_pred in zip(y_trues, y_preds)
    ])



def train(
    trn_df,
    val_df,
    model_name,
    input_size,
    batch_size,
    num_epochs,
    lr,
):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    dataloaders = get_dataloaders(
        trn_df,
        val_df,
        input_size,
        batch_size
    )

#     model = timm.create_model(
#         model_name,
#         pretrained=True,
#         num_classes=384
#     )
    
    model, _, preprocess = open_clip.create_model_and_transforms(model_name)
    model = Net(model)
    transform = preprocess
    state_dict = torch.load(CFG.model_path)
    model.load_state_dict(state_dict)
#     model.set_grad_checkpointing()
    model.to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)

    ttl_iters = num_epochs * len(dataloaders['train'])
    scheduler = CosineAnnealingLR(optimizer, T_max=ttl_iters, eta_min=1e-6)
    criterion = nn.CosineEmbeddingLoss()
    
    best_score = -1.0

    for epoch in range(num_epochs):
        train_meters = {
            'loss': AverageMeter(),
            'cos': AverageMeter(),
        }
        model.train()
        for X, y in tqdm(dataloaders['train'], leave=False):
            X, y = X.to(device), y.to(device)

            optimizer.zero_grad()
            X_out = model(X)
            target = torch.ones(X.size(0)).to(device)
            loss = criterion(X_out, y, target)
            loss.backward()

            optimizer.step()
            scheduler.step()

            trn_loss = loss.item()
            trn_cos = cosine_similarity(
                X_out.detach().cpu().numpy(), 
                y.detach().cpu().numpy()
            )

            train_meters['loss'].update(trn_loss, n=X.size(0))
            train_meters['cos'].update(trn_cos, n=X.size(0))

        print('Epoch {:d} / trn/loss={:.4f}, trn/cos={:.4f}'.format(
            epoch + 1,
            train_meters['loss'].avg,
            train_meters['cos'].avg))

        val_meters = {
            'loss': AverageMeter(),
            'cos': AverageMeter(),
        }
        model.eval()
        for X, y in tqdm(dataloaders['val'], leave=False):
            X, y = X.to(device), y.to(device)

            with torch.no_grad():
                X_out = model(X)
                target = torch.ones(X.size(0)).to(device)
                loss = criterion(X_out, y, target)

                val_loss = loss.item()
                val_cos = cosine_similarity(
                    X_out.detach().cpu().numpy(), 
                    y.detach().cpu().numpy()
                )

            val_meters['loss'].update(val_loss, n=X.size(0))
            val_meters['cos'].update(val_cos, n=X.size(0))

        print('Epoch {:d} / val/loss={:.4f}, val/cos={:.4f}'.format(
            epoch + 1,
            val_meters['loss'].avg,
            val_meters['cos'].avg))
        
        if val_meters['cos'].avg > best_score:
            best_score = val_meters['cos'].avg
            torch.save(model.state_dict(), f'{model_name}.pth')

In [11]:
df = pd.read_csv('/kaggle/input/k/shoheiazuma/diffusiondb-data-cleansing/diffusiondb.csv')
trn_df, val_df = train_test_split(df, test_size=0.1, random_state=CFG.seed)

Unnamed: 0,filepath,prompt
0,/kaggle/input/diffusiondb-2m-part-0001-to-0100...,"a portrait of a female robot made from code, v..."
1,/kaggle/input/diffusiondb-2m-part-0001-to-0100...,dream swimming pool with nobody
2,/kaggle/input/diffusiondb-2m-part-0001-to-0100...,a beautiful paint of cultists dancing surround...
3,/kaggle/input/diffusiondb-2m-part-0001-to-0100...,"frontal portrait of ragged, worried twin women..."
4,/kaggle/input/diffusiondb-2m-part-0001-to-0100...,a stunning portrait of an asian samurai with l...
...,...,...
154315,/kaggle/input/diffusiondb-2m-part-1901-to-2000...,"obama transformed into a penguin, a combinatio..."
154316,/kaggle/input/diffusiondb-2m-part-1901-to-2000...,"new york invaded by nazis, concept art"
154317,/kaggle/input/diffusiondb-2m-part-1901-to-2000...,"a owlish, aquiline picture of an owl sitting o..."
154318,/kaggle/input/diffusiondb-2m-part-1901-to-2000...,"a owlish, elaborate painting of an owl sitting..."


In [13]:
train(trn_df, val_df, CFG.model_name, CFG.input_size, CFG.batch_size, CFG.num_epochs, CFG.lr)

  0%|          | 0/8680 [00:00<?, ?it/s]

Epoch 1 / trn/loss=0.3347, trn/cos=0.6653


  0%|          | 0/965 [00:00<?, ?it/s]

Epoch 1 / val/loss=0.3280, val/cos=0.6720
