# Library

In [None]:
# MUST USE GRADSCALER AND AUTOCAST OR TRAINING TAKES 3X AS LONG

import os
import random
import numpy as np
import pandas as pd
from PIL import Image
from tqdm.notebook import tqdm
from scipy import spatial
from sklearn.model_selection import train_test_split
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision import transforms
import timm
from timm.utils import AverageMeter
import sys
from sentence_transformers import SentenceTransformer
from torch.utils.data.distributed import DistributedSampler
import warnings
import torch.distributed as dist
from torch.cuda.amp import GradScaler, autocast
from torch.multiprocessing import Process
from torch.nn.parallel import DistributedDataParallel as DDP

warnings.filterwarnings('ignore')

# Config

In [None]:
class CFG:
    model_name = 'vit_large_patch16_224'
    input_size = 224
    batch_size = 128
    num_epochs = 25
    lr = 5e-4
    seed = 21

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(CFG.seed)

# Dataset

In [None]:
class DiffusionDataset(Dataset):
    def __init__(self, df, transform):
        self.df = df
        self.transform = transform
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image = Image.open(row['image_path'])
        image = self.transform(image)
        prompt = row['prompt']
        return image, prompt


class DiffusionCollator:
    def __init__(self):
        self.st_model = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
    
    def __call__(self, batch):
        images, prompts = zip(*batch)
        images = torch.stack(images)
        prompt_embeddings = self.st_model.encode(prompts, show_progress_bar=False, convert_to_tensor=True)
        return images, prompt_embeddings
    

def get_dataloaders(trn_df, val_df, input_size, batch_size, rank, world_size, seed=21):
    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=10),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    
    val_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

    trn_dataset = DiffusionDataset(trn_df, train_transform)
    val_dataset = DiffusionDataset(val_df, val_transform)
    collator = DiffusionCollator()

    train_sampler = DistributedSampler(trn_dataset, num_replicas=world_size, rank=rank, seed=seed, shuffle=True)
    val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank, seed=seed, shuffle=False)

    dataloaders = {}
    dataloaders['train'] = DataLoader(
        dataset=trn_dataset,
        batch_size=batch_size,
        pin_memory=True,
        num_workers=12,
        drop_last=True,
        collate_fn=collator,
        sampler=train_sampler)
        
    dataloaders['val'] = DataLoader(
        dataset=val_dataset,
        batch_size=batch_size,
        pin_memory=True,
        num_workers=12,
        drop_last=False,
        collate_fn=collator,
        sampler=val_sampler)

    return dataloaders

# Train

In [None]:
def cosine_similarity(y_trues, y_preds):
    return np.mean([
        1 - spatial.distance.cosine(y_true, y_pred) 
        for y_true, y_pred in zip(y_trues, y_preds)])

In [None]:
def train(rank, world_size, trn_df, val_df, model_name, input_size, batch_size, num_epochs, lr):
    
    # Setup the distributed process group
    dist.init_process_group('nccl', rank=rank, world_size=world_size)
    dist.barrier()
    print("all processes setup")

    # Modify the batch size for distributed training
    batch_size = batch_size // world_size

    dataloaders = get_dataloaders(trn_df, val_df, input_size, batch_size, rank, world_size)

    model = timm.create_model(model_name, pretrained=False, num_classes=384)
    state_dict = torch.load("vit_large_patch16_224.pth")
    model.load_state_dict(state_dict)

    device = torch.device(f'cuda:{rank}')
    model = torch.compile(model)
    model.to(device)
    model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=True)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, fused=True)

    scaler = GradScaler()

    ttl_iters = num_epochs * len(dataloaders['train'])
    scheduler = CosineAnnealingLR(optimizer, T_max=ttl_iters, eta_min=1e-6)
    criterion = nn.CosineEmbeddingLoss()

    best_score = -1.0

    for epoch in range(num_epochs):
        train_meters = {
            'loss': AverageMeter(),
            'cos': AverageMeter()}

        model.train()
        for X, y in tqdm(dataloaders['train'], leave=False):
            X, y = X.to(device), y.to(device)

            optimizer.zero_grad()
            with autocast():
                X_out = model(X)
                target = torch.ones(X.size(0)).to(device)
                loss = criterion(X_out, y, target)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

            trn_loss = loss.item()
            trn_cos = cosine_similarity(
                X_out.detach().cpu().numpy(), 
                y.detach().cpu().numpy())

            train_meters['loss'].update(trn_loss, n=X.size(0))
            train_meters['cos'].update(trn_cos, n=X.size(0))

        print('Epoch {:d} / trn/loss={:.4f}, trn/cos={:.4f}'.format(
            epoch + 1,
            train_meters['loss'].avg,
            train_meters['cos'].avg))

        val_meters = {
            'loss': AverageMeter(),
            'cos': AverageMeter()}

        model.eval()
        for X, y in tqdm(dataloaders['val'], leave=False):
            X, y = X.to(device), y.to(device)

            with torch.no_grad():
                with autocast():
                    X_out = model(X)
                    target = torch.ones(X.size(0)).to(device)
                    loss = criterion(X_out, y, target)

                val_loss = loss.item()
                val_cos = cosine_similarity(
                    X_out.detach().cpu().numpy(), 
                    y.detach().cpu().numpy())

            val_meters['loss'].update(val_loss, n=X.size(0))
            val_meters['cos'].update(val_cos, n=X.size(0))

        print('Epoch {:d} / val/loss={:.4f}, val/cos={:.4f}'.format(
            epoch + 1,
            val_meters['loss'].avg,
            val_meters['cos'].avg))
        
        if val_meters['cos'].avg > best_score:
            if rank == 0: # ONLY SAVE IF GPU 0
                torch.save(model.module.state_dict(), f'{model_name}.pth')
                torch.save(optimizer.state_dict(), f"{model_name}_optimizer.pth")

In [None]:
def main():
    df = pd.read_csv('filtered_image_data.csv')
    trn_df, val_df = train_test_split(df, test_size=0.1, random_state=CFG.seed)
    world_size = 8
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12345'

    processes = []
    for rank in range(world_size):
        p = Process(target=train, args=(rank, world_size, trn_df, val_df, CFG.model_name, CFG.input_size, CFG.batch_size, CFG.num_epochs, CFG.lr))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

if __name__ == '__main__':
    main()