# DiffusionDB prompt capture training

## Imports

In [3]:
from datasets import load_dataset
import torchvision.transforms as transforms
import torch
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence
import json
import matplotlib.pyplot as plt
import time
from nltk.translate.bleu_score import corpus_bleu

In [4]:
from DecoderWithAttention import DecoderWithAttention
from Encoder import Encoder
from DiffusionDBDataLoader import DiffusionDBDataLoader
from Memory import Memory
from checkpoint_utils import save_checkpoint, load_checkpoint

## System / training parameters

In [5]:
class Parameters:
    start_epoch = 0
    epochs = 50 
    epochs_since_improvement = 0 
    batch_size = 32
    
    device = "cpu"
    max_img_width = 720
    max_img_height = 720

    encoded_image_size=14
    embding_dimension = 512 
    attention_dimension = 512 
    decoder_dimension = 512 
    dropout_fraction = 0.5

    encoder_lr = 1e-4
    decoder_lr = 4e-4 
    
    is_encoder_pretrained = True
    fine_tune_encoder = False 

    grad_clip = 5.  
    alpha_c = 1. 

    print_freq = 2  

    top_5 = 0. 
    
    max_prompt_len = 15
    
    remove_unk = True

In [6]:
# Set device to GPU (cuda) if available
Parameters.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Import data

### Import word-map

In [7]:
with open("word_map_nostop5.json", "r") as j:
    word_map_dict = json.load(j)

### Import images and prompts

In [None]:
dataset = load_dataset('poloclub/diffusiondb', '2m_random_1k')["train"]
images_dataset = dataset["image"]
prompts_dataset = dataset["prompt"]

### Train- and validation-set split 

In [9]:
train_images_dataset = images_dataset[:512]
train_prompts_dataset = prompts_dataset[:512]

validation_images_dataset = images_dataset[512:640]
validation_prompts_dataset = prompts_dataset[512:640]

## Train system

In [10]:
def accuracy(scores, targets, k):
    batch_size = targets.size(0)
    _, ind = scores.topk(k, 1, True, True)
    correct = ind.eq(targets.view(-1, 1).expand_as(ind))
    correct_total = correct.view(-1).float().sum()  
    return correct_total.item() * (100.0 / batch_size)

### Training loop function

In [11]:
def clip_gradient(optimizer, grad_clip):
    for group in optimizer.param_groups:
        for param in group['params']:
            if param.grad is not None:
                param.grad.data.clamp_(-grad_clip, grad_clip)
                
def train(train_loader, encoder, decoder, criterion, encoder_optimizer, decoder_optimizer, epoch):
    decoder.train() 
    encoder.train()

    batch_time = Memory()
    losses = Memory()
    top5accs = Memory()

    start = time.time()
    k = 0

    for i, (imgs, caps, caplens) in enumerate(train_loader):

        imgs = imgs.to(Parameters.device)
        caps = caps.to(Parameters.device)
        caplens = caplens.to(Parameters.device)
        
        imgs = encoder(imgs)
        scores, prompts_sorted, decode_lengths, alphas = decoder(imgs, caps, caplens)

        targets = prompts_sorted[:, 1:]

        scores = pack_padded_sequence(scores, decode_lengths, batch_first=True, enforce_sorted=False).data
        targets = pack_padded_sequence(targets, decode_lengths, batch_first=True, enforce_sorted=False).data

        loss = criterion(scores, targets)
        loss += Parameters.alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()

        decoder_optimizer.zero_grad()
        if encoder_optimizer is not None:
            encoder_optimizer.zero_grad()
        loss.backward()

        if Parameters.grad_clip is not None:
            clip_gradient(decoder_optimizer, Parameters.grad_clip)
            if encoder_optimizer is not None:
                clip_gradient(encoder_optimizer, Parameters.grad_clip)

        decoder_optimizer.step()
        if encoder_optimizer is not None:
            encoder_optimizer.step()

        top5 = accuracy(scores, targets, 5)
        losses.update(loss.item(), sum(decode_lengths))
        top5accs.update(top5, sum(decode_lengths))
        batch_time.update(time.time() - start)

        start = time.time()

        if i % Parameters.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})'
                  .format(epoch, i, len(train_loader), batch_time=batch_time, loss=losses, top5=top5accs))
        
    return top5accs.avg

### Validation loop function

In [12]:
def validate(val_loader, encoder, decoder, criterion):
    decoder.eval() 
    if encoder is not None:
        encoder.eval()

    losses = Memory()
    top5accs = Memory()

    start = time.time()

    with torch.no_grad():
        for i, (imgs, caps, caplens) in enumerate(val_loader):
            imgs = imgs.to(Parameters.device)
            caps = caps.to(Parameters.device)
            caplens = caplens.to(Parameters.device)
            
            if encoder is not None:
                imgs = encoder(imgs)
            scores, prompts_sorted, decode_lengths, alphas = decoder(imgs, caps, caplens)
            
            targets = prompts_sorted[:, 1:]

            scores_copy = scores.clone()
            
            scores = pack_padded_sequence(scores, decode_lengths, batch_first=True, enforce_sorted=False).data
            targets = pack_padded_sequence(targets, decode_lengths, batch_first=True, enforce_sorted=False).data

            loss = criterion(scores, targets)
            loss += Parameters.alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()

            losses.update(loss.item(), sum(decode_lengths))
            top5 = accuracy(scores, targets, 5)
            top5accs.update(top5, sum(decode_lengths))

            start = time.time()

            if i % Parameters.print_freq == 0:
                print('Validation: [{0}/{1}]\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})\t'.format(i, len(val_loader), loss=losses, top5=top5accs))

        print('\n * LOSS - {loss.avg:.3f}, TOP-5 ACCURACY - {top5.avg:.3f}'.format(loss=losses, top5=top5accs))

    return top5accs.avg


### Define dataloaders, loss functions, and optimizers

In [None]:
decoder = DecoderWithAttention(attention_dimension=Parameters.attention_dimension,
                                embedding_dimension=Parameters.embding_dimension,
                                decoder_dimension=Parameters.decoder_dimension,
                                vocab_size=len(word_map_dict),
                                dropout_fraction=Parameters.dropout_fraction,
                                device=Parameters.device)
                                
decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.parameters()), lr=Parameters.decoder_lr)
encoder = Encoder(Parameters.encoded_image_size, Parameters.is_encoder_pretrained)
encoder.fine_tune(Parameters.fine_tune_encoder)
encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()), lr=Parameters.encoder_lr) if Parameters.fine_tune_encoder else None

decoder = decoder.to(Parameters.device)
encoder = encoder.to(Parameters.device)

criterion = nn.CrossEntropyLoss().to(Parameters.device)

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])
                                    
train_loader = DiffusionDBDataLoader(train_images_dataset, 
                                    train_prompts_dataset, 
                                    (Parameters.max_img_width, Parameters.max_img_height),
                                    word_map_dict, 
                                    Parameters.batch_size,
                                    transform=normalize,
                                    max_length=Parameters.max_prompt_len,
                                    remove_unk=True)

val_loader = DiffusionDBDataLoader(validation_images_dataset, 
                                    validation_prompts_dataset, 
                                    (Parameters.max_img_width, Parameters.max_img_height),
                                    word_map_dict, 
                                    Parameters.batch_size,
                                    transform=normalize,
                                    max_length=Parameters.max_prompt_len,
                                    remove_unk=True)

### Load checkpoint if exists

In [14]:
# Epochs
#encoder, decoder, encoder_optimizer, decoder_optimizer, epoch, epoch_since_improvment, history = load_checkpoint("resnet50", encoder, decoder, encoder_optimizer, decoder_optimizer, best=True)

#print("Prev epoch:", epoch)
#Parameters.start_epoch = 0
#Parameters.epochs_since_improvement = epoch_since_improvment
#load_history = history

load_history = []

### Main function

In [None]:
def adjust_learning_rate(optimizer, factor):
    for param_group in optimizer.param_groups:
        param_group['lr'] *= factor

for epoch in range(Parameters.start_epoch, Parameters.epochs):
    print("epoch nr", epoch)
    if Parameters.epochs_since_improvement == 20:
        break
    if Parameters.epochs_since_improvement > 0 and Parameters.epochs_since_improvement % 8 == 0:
        adjust_learning_rate(decoder_optimizer, 0.8)
        if Parameters.fine_tune_encoder:
            adjust_learning_rate(encoder_optimizer, 0.8)

    top5_avg = train(train_loader=train_loader,
            encoder=encoder,
            decoder=decoder,
            criterion=criterion,
            encoder_optimizer=encoder_optimizer,
            decoder_optimizer=decoder_optimizer,
            epoch=epoch)
    
    recent_top5 = validate(val_loader=val_loader,
                            encoder=encoder,
                            decoder=decoder,
                            criterion=criterion)

    load_history.append((top5_avg, recent_top5))
    
    is_best = top5_avg > Parameters.top_5
    Parameters.top_5 = max(top5_avg, Parameters.top_5)
    if not is_best:
        Parameters.epochs_since_improvement += 1
        print("\nEpochs since last improvement: %d\n" % (Parameters.epochs_since_improvement))
    else:
        Parameters.epochs_since_improvement = 0

    save_checkpoint("resnet50_no-unk_15_no_stop", epoch, Parameters.epochs_since_improvement, encoder, decoder, encoder_optimizer,
                    decoder_optimizer, top5_avg, load_history, is_best)
    
    print("Epoch:", epoch, "Top5-Validation", recent_top5)
