In [1]:
from datasets import load_dataset
import torchvision.transforms as transforms
import torch
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence
import json
import matplotlib.pyplot as plt
import time
from nltk.translate.bleu_score import corpus_bleu

In [2]:
from DecoderWithAttention import DecoderWithAttention
from Encoder import Encoder
from DiffusionDBDataLoader import DiffusionDBDataLoader
from utils import AverageMeter, clip_gradient, accuracy, adjust_learning_rate, save_checkpoint

In [3]:
class Parameters:
    start_epoch = 0
    epochs = 50  # number of epochs to train for (if early stopping is not triggered)
    epochs_since_improvement = 0 
    batch_size = 32
    
    device = "cpu"
    max_img_width = 1472
    max_img_height = 1024

    emb_dim = 512  # dimension of word embeddings
    attention_dim = 512  # dimension of attention linear layers
    decoder_dim = 512  # dimension of decoder RNN
    dropout = 0.5

    encoder_lr = 1e-4  # learning rate for encoder if fine-tuning
    decoder_lr = 4e-4  # learning rate for decoder
    
    fine_tune_encoder = False  # fine-tune encoder?

    grad_clip = 5.  # clip gradients at an absolute value of
    alpha_c = 1.  # regularization parameter for 'doubly stochastic attention', as in the paper

    print_freq = 100  # print training/validation stats every __ batches

    best_bleu4 = 0.  # BLEU-4 score right now

In [4]:
dataset = load_dataset('poloclub/diffusiondb', '2m_first_1k')["train"]
images_dataset = dataset["image"]
prompts_dataset = dataset["prompt"]
print(len(images_dataset))

Found cached dataset diffusiondb (C:/Users/norbe/.cache/huggingface/datasets/poloclub___diffusiondb/2m_first_1k/0.9.1/547894e3a57aa647ead68c9faf148324098f47f2bc1ab6705d670721de9d89d1)


  0%|          | 0/1 [00:00<?, ?it/s]

1000


In [5]:
Parameters.max_img_width = max(dataset["width"])
Parameters.max_img_height = max(dataset["height"])
Parameters.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Read word map
with open("WORDMAP_coco_5_cap_per_img_5_min_word_freq.json", "r") as j:
    word_map_dict = json.load(j)

In [6]:
def train(train_loader, encoder, decoder, criterion, encoder_optimizer, decoder_optimizer, epoch):
    """
    Performs one epoch's training.
    :param train_loader: DataLoader for training data
    :param encoder: encoder model
    :param decoder: decoder model
    :param criterion: loss layer
    :param encoder_optimizer: optimizer to update encoder's weights (if fine-tuning)
    :param decoder_optimizer: optimizer to update decoder's weights
    :param epoch: epoch number
    """

    decoder.train()  # train mode (dropout and batchnorm is used)
    encoder.train()

    batch_time = AverageMeter()  # forward prop. + back prop. time
    data_time = AverageMeter()  # data loading time
    losses = AverageMeter()  # loss (per word decoded)
    top5accs = AverageMeter()  # top5 accuracy

    start = time.time()

    # Batches
    for i, (imgs, caps, caplens) in enumerate(train_loader):
        #imgs = imgs[None, :]
        #caps = caps[None, :]
        #caplens = caplens[None, :]

        data_time.update(time.time() - start)

        # Move to GPU, if available
        imgs = imgs.to(Parameters.device)
        caps = caps.to(Parameters.device)
        caplens = caplens.to(Parameters.device)
        
        # Forward prop.
        imgs = encoder(imgs)
        scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens)
        
        # Since we decoded starting with <start>, the targets are all words after <start>, up to <end>
        targets = caps_sorted[:, 1:]

        # Remove timesteps that we didn't decode at, or are pads
        # pack_padded_sequence is an easy trick to do this
        
        scores = pack_padded_sequence(scores, decode_lengths, batch_first=True).data 
        targets = pack_padded_sequence(targets, decode_lengths, batch_first=True).data 
        
        # Calculate loss
        loss = criterion(scores, targets)

        # Add doubly stochastic attention regularization
        loss += Parameters.alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()

        # Back prop.
        decoder_optimizer.zero_grad()
        if encoder_optimizer is not None:
            encoder_optimizer.zero_grad()
        loss.backward()

        # Clip gradients
        if Parameters.grad_clip is not None:
            clip_gradient(decoder_optimizer, Parameters.grad_clip)
            if encoder_optimizer is not None:
                clip_gradient(encoder_optimizer, Parameters.grad_clip)

        # Update weights
        decoder_optimizer.step()
        if encoder_optimizer is not None:
            encoder_optimizer.step()

        # Keep track of metrics
        top5 = accuracy(scores, targets, 5)
        losses.update(loss.item(), sum(decode_lengths))
        top5accs.update(top5, sum(decode_lengths))
        batch_time.update(time.time() - start)

        start = time.time()

        # Print status
        if i % Parameters.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data Load Time {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})'.format(epoch, i, len(train_loader),
                                                                          batch_time=batch_time,
                                                                          data_time=data_time, loss=losses,
                                                                          top5=top5accs))
    

def validate(val_loader, encoder, decoder, criterion):
    """
    Performs one epoch's validation.
    :param val_loader: DataLoader for validation data.
    :param encoder: encoder model
    :param decoder: decoder model
    :param criterion: loss layer
    :return: BLEU-4 score
    """

    decoder.eval()  # eval mode (no dropout or batchnorm)
    if encoder is not None:
        encoder.eval()

    batch_time = AverageMeter()
    losses = AverageMeter()
    top5accs = AverageMeter()

    start = time.time()

    references = list()  # references (true captions) for calculating BLEU-4 score
    hypotheses = list()  # hypotheses (predictions)

    # explicitly disable gradient calculation to avoid CUDA memory error
    # solves the issue #57
    with torch.no_grad():
        # Batches
        for i, (imgs, caps, caplens) in enumerate(val_loader):
            #imgs = imgs[None, :]
            #caps = caps[None, :]
            #caplens = caplens[None, :]
            
            # Move to device, if available
            imgs = imgs.to(Parameters.device)
            caps = caps.to(Parameters.device)
            caplens = caplens.to(Parameters.device)
            
            # Forward prop.
            if encoder is not None:
                imgs = encoder(imgs)
            scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens)

            # Since we decoded starting with <start>, the targets are all words after <start>, up to <end>
            targets = caps_sorted[:, 1:]

            # Remove timesteps that we didn't decode at, or are pads
            # pack_padded_sequence is an easy trick to do this
            scores_copy = scores.clone()
            scores = pack_padded_sequence(scores, decode_lengths, batch_first=True).data
            targets = pack_padded_sequence(targets, decode_lengths, batch_first=True).data

            # Calculate loss
            loss = criterion(scores, targets)

            # Add doubly stochastic attention regularization
            loss += Parameters.alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()

            # Keep track of metrics
            losses.update(loss.item(), sum(decode_lengths))
            top5 = accuracy(scores, targets, 5)
            top5accs.update(top5, sum(decode_lengths))
            batch_time.update(time.time() - start)

            start = time.time()

            if i % Parameters.print_freq == 0:
                print('Validation: [{0}/{1}]\t'
                      'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})\t'.format(i, len(val_loader), batch_time=batch_time,
                                                                                loss=losses, top5=top5accs))

            # Store references (true captions), and hypothesis (prediction) for each image
            # If for n images, we have n hypotheses, and references a, b, c... for each image, we need -
            # references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...]

            # References
            """
            allcaps = allcaps[sort_ind]  # because images were sorted in the decoder
            for j in range(allcaps.shape[0]):
                img_caps = allcaps[j].tolist()
                img_captions = list(
                    map(lambda c: [w for w in c if w not in {word_map_dict['<start>'], word_map_dict['<pad>']}],
                        img_caps))  # remove <start> and pads
                references.append(img_captions)
            
            allcaps = allcaps[sort_ind]  # because images were sorted in the decoder
            """

            img_captions = list(
                map(lambda c: [w for w in c if w not in {word_map_dict['<start>'], word_map_dict['<pad>']}],
                    caps))  # remove <start> and pads
            references.append(img_captions)

            # Hypotheses
            _, preds = torch.max(scores_copy, dim=2)
            preds = preds.tolist()
            temp_preds = list()
            for j, p in enumerate(preds):
                temp_preds.append(preds[j][:decode_lengths[j]])  # remove pads
            preds = temp_preds
            hypotheses.extend(preds)

            assert len(references) == len(hypotheses)

        # Calculate BLEU-4 scores
        bleu4 = corpus_bleu(references, hypotheses)

        print(
            '\n * LOSS - {loss.avg:.3f}, TOP-5 ACCURACY - {top5.avg:.3f}, BLEU-4 - {bleu}\n'.format(
                loss=losses,
                top5=top5accs,
                bleu=bleu4))

    return bleu4


In [7]:
decoder = DecoderWithAttention(attention_dim=Parameters.attention_dim,
                                embed_dim=Parameters.emb_dim,
                                decoder_dim=Parameters.decoder_dim,
                                vocab_size=len(word_map_dict),
                                dropout=Parameters.dropout,
                                device=Parameters.device)
#decoder.load_pretrained_embeddings(pretrained_embeddings) # pretrained_embeddings should be of dimensions (len(word_map), emb_dim)
#decoder.fine_tune_embeddings(True)
                                
decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.parameters()),
                                    lr=Parameters.decoder_lr)
encoder = Encoder()
encoder.fine_tune(Parameters.fine_tune_encoder)
encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),
                                    lr=Parameters.encoder_lr) if Parameters.fine_tune_encoder else None


# Move to GPU, if available
decoder = decoder.to(Parameters.device)
encoder = encoder.to(Parameters.device)

# Loss function
criterion = nn.CrossEntropyLoss().to(Parameters.device)

# Custom dataloaders
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])
train_loader = DiffusionDBDataLoader(images_dataset, 
                                    prompts_dataset, 
                                    (Parameters.max_img_width, Parameters.max_img_height),
                                    word_map_dict, 
                                    Parameters.batch_size,
                                    transform=normalize)

val_loader = DiffusionDBDataLoader(images_dataset, 
                                    prompts_dataset, 
                                    (Parameters.max_img_width, Parameters.max_img_height),
                                    word_map_dict, 
                                    Parameters.batch_size,
                                    transform=normalize)



In [8]:
# Epochs
for epoch in range(Parameters.start_epoch, Parameters.epochs):
    print("Epoch:", epoch+1)

    # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
    if Parameters.epochs_since_improvement == 20:
        break
    if Parameters.epochs_since_improvement > 0 and Parameters.epochs_since_improvement % 8 == 0:
        adjust_learning_rate(decoder_optimizer, 0.8)
        if Parameters.fine_tune_encoder:
            adjust_learning_rate(encoder_optimizer, 0.8)

    # One epoch's training
    train(train_loader=train_loader,
            encoder=encoder,
            decoder=decoder,
            criterion=criterion,
            encoder_optimizer=encoder_optimizer,
            decoder_optimizer=decoder_optimizer,
            epoch=epoch)
    
    # One epoch's validation
    recent_bleu4 = validate(val_loader=val_loader,
                            encoder=encoder,
                            decoder=decoder,
                            criterion=criterion)

    # Check if there was an improvement
    is_best = recent_bleu4 > Parameters.best_bleu4
    Parameters.best_bleu4 = max(recent_bleu4, Parameters.best_bleu4)
    if not is_best:
        Parameters.epochs_since_improvement += 1
        print("\nEpochs since last improvement: %d\n" % (Parameters.epochs_since_improvement,))
    else:
        Parameters.epochs_since_improvement = 0

    # Save checkpoint
    save_checkpoint("_DiffusionDB_prompt_capture", epoch, Parameters.epochs_since_improvement, encoder, decoder, encoder_optimizer,
                    decoder_optimizer, recent_bleu4, is_best)

Epoch: 1
32
53
Epoch: [0][0/1000]	Batch Time 610.030 (610.030)	Data Load Time 1.412 (1.412)	Loss 9.7822 (9.7822)	Top-5 Accuracy 0.096 (0.096)
32
48
32
48
32
51


RuntimeError: Length of all samples has to be greater than 0, but found an element in 'lengths' that is <= 0