# Library

In [1]:
import os
import random
import numpy as np
import pandas as pd
from PIL import Image
from tqdm.notebook import tqdm
from scipy import spatial
from sklearn.model_selection import train_test_split
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision import transforms
import timm
from timm.utils import AverageMeter
import sys
from sentence_transformers import SentenceTransformer
import warnings
from torchvision import transforms
import torch.cuda.amp as amp

warnings.filterwarnings('ignore')
device = torch.device("cuda")

  warn(f"Failed to load image Python extension: {e}")


# Config

In [4]:
class CFG:
    model_name = 'vit_large_patch16_224'
    input_size = 224
    batch_size = 128
    num_epochs = 25
    lr = 5e-4
    seed = 21

# Dataset

In [5]:
class DiffusionDataset(Dataset):
    def __init__(self, df, transform):
        self.df = df
        self.transform = transform
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image = Image.open(row['image_path'])
        image = self.transform(image)
        prompt = row['prompt']
        return image, prompt


class DiffusionCollator:
    def __init__(self):
        self.st_model = SentenceTransformer('all-MiniLM-L6-v2', device='cpu') # crashes cuda if all embeddings on GPU and lots of images
    
    def __call__(self, batch):
        images, prompts = zip(*batch)
        images = torch.stack(images)
        prompt_embeddings = self.st_model.encode(prompts, show_progress_bar=False, convert_to_tensor=True)
        return images, prompt_embeddings
    
    
def get_dataloaders(trn_df, val_df, input_size, batch_size):
    
    train_transform = transforms.Compose([
        transforms.Resize(CFG.input_size),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=10),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    
    val_transform = transforms.Compose([
        transforms.Resize(CFG.input_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

    trn_dataset = DiffusionDataset(trn_df, train_transform)
    val_dataset = DiffusionDataset(val_df, val_transform)
    collator = DiffusionCollator()
    
    dataloaders = {}
    dataloaders['train'] = DataLoader(
        dataset=trn_dataset,
        shuffle=True,
        batch_size=batch_size,
        pin_memory=True,
        num_workers=12,
        drop_last=True,
        collate_fn=collator
    )
    dataloaders['val'] = DataLoader(
        dataset=val_dataset,
        shuffle=False,
        batch_size=batch_size,
        pin_memory=True,
        num_workers=12,
        drop_last=False,
        collate_fn=collator
    )
    return dataloaders

# Train

In [6]:
def cosine_similarity(y_trues, y_preds):
    return np.mean([
        1 - spatial.distance.cosine(y_true, y_pred) 
        for y_true, y_pred in zip(y_trues, y_preds)])

In [7]:
def train(trn_df, val_df, model_name, input_size, batch_size, num_epochs, lr):

    dataloaders = get_dataloaders(trn_df, val_df, input_size, batch_size)

    model = timm.create_model(model_name, pretrained=False, num_classes=384)
    state_dict = torch.load("vit_large_patch16_224.pth")
    model.load_state_dict(state_dict)
    model.set_grad_checkpointing()
    
    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    scaler = amp.GradScaler()
    ttl_iters = num_epochs * len(dataloaders['train'])
    scheduler = CosineAnnealingLR(optimizer, T_max=ttl_iters, eta_min=1e-6)
    criterion = nn.CosineEmbeddingLoss()
    
    best_score = -1.0

    for epoch in range(num_epochs):
        train_meters = {
            'loss': AverageMeter(),
            'cos': AverageMeter()}

        model.train()
        for X, y in tqdm(dataloaders['train'], leave=False):
            X, y = X.to(device), y.to(device)

            optimizer.zero_grad()
            with amp.autocast():
                X_out = model(X)
                target = torch.ones(X.size(0)).to(device)
                loss = criterion(X_out, y, target)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

            trn_loss = loss.item()
            trn_cos = cosine_similarity(
                X_out.detach().cpu().numpy(), 
                y.detach().cpu().numpy())

            train_meters['loss'].update(trn_loss, n=X.size(0))
            train_meters['cos'].update(trn_cos, n=X.size(0))

        print('Epoch {:d} / trn/loss={:.4f}, trn/cos={:.4f}'.format(
            epoch + 1,
            train_meters['loss'].avg,
            train_meters['cos'].avg))

        val_meters = {
            'loss': AverageMeter(),
            'cos': AverageMeter()}

        model.eval()
        for X, y in tqdm(dataloaders['val'], leave=False):
            X, y = X.to(device), y.to(device)

            with torch.no_grad():
                with amp.autocast():
                    X_out = model(X)
                    target = torch.ones(X.size(0)).to(device)
                    loss = criterion(X_out, y, target)

                val_loss = loss.item()
                val_cos = cosine_similarity(
                    X_out.detach().cpu().numpy(), 
                    y.detach().cpu().numpy())

            val_meters['loss'].update(val_loss, n=X.size(0))
            val_meters['cos'].update(val_cos, n=X.size(0))

        print('Epoch {:d} / val/loss={:.4f}, val/cos={:.4f}'.format(
            epoch + 1,
            val_meters['loss'].avg,
            val_meters['cos'].avg))
        
        if val_meters['cos'].avg > best_score:
            best_score = val_meters['cos'].avg
            torch.save(model.state_dict(), f'{model_name}.pth')
            torch.save(optimizer.state_dict(), f"{model_name}_optimizer.pth")

In [8]:
df = pd.read_csv('midjourney_ea_al_data.csv')
trn_df, val_df = train_test_split(df, test_size=0.1, random_state=CFG.seed)

In [None]:
train(trn_df, val_df, CFG.model_name, CFG.input_size, CFG.batch_size, CFG.num_epochs, CFG.lr)

  0%|          | 0/2334 [00:00<?, ?it/s]

Epoch 1 / trn/loss=0.4169, trn/cos=0.5831


  0%|          | 0/260 [00:00<?, ?it/s]

Epoch 1 / val/loss=0.4050, val/cos=0.5949


  0%|          | 0/2334 [00:00<?, ?it/s]

Epoch 2 / trn/loss=0.3920, trn/cos=0.6080


  0%|          | 0/260 [00:00<?, ?it/s]

Epoch 2 / val/loss=0.3829, val/cos=0.6171


  0%|          | 0/2334 [00:00<?, ?it/s]

Epoch 3 / trn/loss=0.3714, trn/cos=0.6286


  0%|          | 0/260 [00:00<?, ?it/s]

Epoch 3 / val/loss=0.3696, val/cos=0.6304


  0%|          | 0/2334 [00:00<?, ?it/s]