In [1]:
!pip install datasets

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
import numpy as np 
from PIL import Image
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms
import torch.optim as optim
from torch.nn.utils.rnn import pack_padded_sequence
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import spacy
from random import seed
from random import random
import torchtext

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [4]:
diffusiondb = load_dataset('poloclub/diffusiondb', 'large_first_1k')

train_df = pd.DataFrame(diffusiondb["train"])
train_df = train_df[["image", "prompt"]]
del diffusiondb
train_df



  0%|          | 0/1 [00:00<?, ?it/s]

Unnamed: 0,image,prompt
0,<PIL.WebPImagePlugin.WebPImageFile image mode=...,"goddess portrait, ismail inceoglu"
1,<PIL.WebPImagePlugin.WebPImageFile image mode=...,"goddess portrait, ismail inceoglu"
2,<PIL.WebPImagePlugin.WebPImageFile image mode=...,portrait of king of candy mr harry haribo oil ...
3,<PIL.WebPImagePlugin.WebPImageFile image mode=...,super epic realistic nature photo trending on ...
4,<PIL.WebPImagePlugin.WebPImageFile image mode=...,super epic realistic nature photo trending on ...
...,...,...
995,<PIL.WebPImagePlugin.WebPImageFile image mode=...,"portrait of haribo bear in future city, color ..."
996,<PIL.WebPImagePlugin.WebPImageFile image mode=...,"photo of terrifying witch, hyper detailed, flo..."
997,<PIL.WebPImagePlugin.WebPImageFile image mode=...,"portrait of haribo bear in future city, color ..."
998,<PIL.WebPImagePlugin.WebPImageFile image mode=...,"portrait of haribo bear in future city, color ..."


In [5]:
nlp = spacy.load("en_core_web_sm")

def tokenize(text):
    return [tok.text for tok in nlp(text)]

tokenize("Hallo ich bin Lukas")

['Hallo', 'ich', 'bin', 'Lukas']

In [6]:
word_counts = Counter()
c = 0
for sentence in train_df["prompt"]:
    doc = nlp(sentence)
    
    # Iterate over each token in the processed sentence
    for token in doc:
        # Check if the token is a word (excluding punctuation and whitespace)
        if token.is_alpha:
            # Increment the count for the word
            word_counts[token.text] = word_counts.get(token.text, 0) + 1
    
    c += 1
    if c > 20: 
      break

# Print the word counts
for word, count in word_counts.items():
    print(f"{word}: {count}")

goddess: 2
portrait: 7
ismail: 2
inceoglu: 2
of: 10
king: 1
candy: 1
mr: 1
harry: 1
haribo: 1
oil: 1
painting: 1
bloody: 1
conquest: 1
tap: 1
e: 1
super: 8
epic: 12
realistic: 8
nature: 8
photo: 8
trending: 8
on: 8
instagram: 8
with: 8
lonely: 8
person: 8
in: 12
yellow: 8
raincoat: 8
standing: 8
at: 8
a: 12
distance: 8
beautiful: 4
princess: 4
wearing: 4
evil: 4
black: 4
oily: 4
tar: 4
by: 4
hr: 4
giger: 4
greg: 4
rutkowski: 4
luis: 4
royo: 4
and: 8
wayne: 4
barlowe: 4
k: 4
mountains: 4
lake: 4
walley: 4
liminal: 4
emperor: 4
palpatine: 4
the: 4
desert: 8
tatooine: 4
film: 4
still: 4
wide: 4
shot: 4
heat: 4
sci: 4
fi: 4
dramatic: 4
light: 4
young: 2
mark: 2
hamill: 2
as: 2
child: 2
star: 4
wars: 2


In [7]:
nlp = spacy.load("en_core_web_sm")

class DiffusionDataset(Dataset):
    def __init__(self, dataframe):
        self.data = dataframe
        self.vocab_size, self.word2index = self.build_vocab()
        self.transformed_images = self.transform_images()
        self.tokenized_prompts = self.tokenize_and_index_prompts()
        self.eos_index = self.word2index["<EOS>"] 

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image = self.transformed_images[index]
        prompt = self.tokenized_prompts[index]

        #print(f"Padded Prompt Shape: {padded_prompt.shape}\nPadded Prompt: {padded_prompt}\n")
        return image, torch.tensor(prompt)

    def build_vocab(self):
        word_counts = Counter()

        # Das zählt nur die Buchstaben und deren Häufigkeit
        #for tokens in self.data["prompt"]:
        #    word_counts.update(tokens)

        for sentence in self.data["prompt"]:
          doc = nlp(sentence)
          for token in doc:
            if token.is_alpha:
                word_counts[token.text] = word_counts.get(token.text, 0) + 1

        vocab = [word for word, count in word_counts.most_common(5000)]
        vocab_size = len(vocab) + 3  # Increment vocab_size by 2 for <UNK>, <SOS> and <EOS> tags

        word2index = {word: i+3 for i, word in enumerate(vocab)}  # Shift indices by 3 for <UNK> and <EOS>
        word2index["<UNK>"] = 0
        word2index["<SOS>"] = 1
        word2index["<EOS>"] = 2

        return vocab_size, word2index

    def transform_images(self):
      transform = transforms.Compose([transforms.Resize((256,256)),
                                      transforms.ToTensor()])

      # Convert the PIL image to Torch tensor of size 512x512
      return self.data["image"].apply(transform).to_list()



    def tokenize_and_index_prompts(self):
        return self.data["prompt"].apply(self.tokenize).apply(self.tokens_to_indices).tolist()

    def tokenize(self, text):
      return [tok.text for tok in nlp(text)]

    def tokens_to_indices(self, tokens):
        return [self.word2index["<SOS>"]] + [self.word2index.get(word, 0) for word in tokens] + [self.word2index["<EOS>"]]

In [8]:
class EncoderCNN(nn.Module):
    def __init__(self, embed_size, n_layers, hid_dim):
        super().__init__()

        self.resnet = models.resnet50(pretrained=True)
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, hid_dim)

        self.hidden_size = hid_dim
        self.n_layers = n_layers

    def forward(self, images):
        """Extract feature vectors from input images."""
        with torch.no_grad():
            features = self.resnet(images)

        batch_size = features.size(0)
        hidden = features.unsqueeze(0).expand(self.n_layers, batch_size, self.hidden_size)
        # Initialize the cell state with zeros
        cell = torch.zeros(self.n_layers, batch_size, self.hidden_size).to(features.device)
        return hidden, cell



class DecoderRNN(nn.Module):
    def __init__(self, output_dim, emb_dim, n_layers, hid_dim, dropout):
        super().__init__()
        
        self.output_dim = output_dim
        self.hid_dim = hid_dim
        self.n_layers = n_layers
        
        self.embedding = nn.Embedding(output_dim, emb_dim)
        
        self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout = dropout)
        
        self.fc_out = nn.Linear(hid_dim, output_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, input, hidden, cell):
        
        #input = [batch size]
        #hidden = [n layers * n directions, batch size, hid dim]
        #cell = [n layers * n directions, batch size, hid dim]
        
        #n directions in the decoder will both always be 1, therefore:
        #hidden = [n layers, batch size, hid dim]
        #context = [n layers, batch size, hid dim]

        input = input.unsqueeze(0)
        #input = [1, batch size]
        
        embedded = self.dropout(self.embedding(input))
        
        #embedded = [1, batch size, emb dim]
        output, (hidden, cell) = self.rnn(embedded, (hidden.contiguous(), cell.contiguous()))
        
        #output = [seq len, batch size, hid dim * n directions]
        #hidden = [n layers * n directions, batch size, hid dim]
        #cell = [n layers * n directions, batch size, hid dim]
        
        #seq len and n directions will always be 1 in the decoder, therefore:
        #output = [1, batch size, hid dim]
        #hidden = [n layers, batch size, hid dim]
        #cell = [n layers, batch size, hid dim]
        
        prediction = self.fc_out(output.squeeze(0))
        
        #prediction = [batch size, output dim]
        
        return prediction, hidden, cell

In [9]:
from torch.jit import script_if_tracing
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
       
    def forward(self, src, trg, teacher_forcing_ratio = float(0.5)):

        #trg = [batch size, trg len]
        #teacher_forcing_ratio is probability to use teacher forcing
        #e.g. if teacher_forcing_ratio is 0.75 we use ground-truth inputs 75% of the time

        batch_size = trg.shape[0]
        trg_len = trg.shape[1]
        trg_vocab_size = self.decoder.output_dim
        
        #tensor to store decoder outputs
        outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)
        
        #last hidden state of the encoder is used as the initial hidden state of the decoder
        hidden, cell = self.encoder(src)
        
        #first input to the decoder is the <sos> tokens
        #input = trg[0,:]
        input = trg[:, 0]

        for t in range(1, trg_len):
            #insert input token embedding, previous hidden and previous cell states
            #receive output tensor (predictions) and new hidden and cell states
            output, hidden, cell = self.decoder(input, hidden, cell)
            #output, hidden, cell = self.decoder(input, hidden.unsqueeze(0), cell.unsqueeze(0))

            #place predictions in a tensor holding predictions for each token
            outputs[t] = output
            
            #get the highest predicted token from our predictions
            top1 = output.argmax(1) 

            #if teacher forcing, use actual next token as next input
            #if not, use predicted token)

            input = trg[:,t] if random() < teacher_forcing_ratio else top1
            #input = trg[t] if teacher_force else top1.squeeze(0)
        return outputs

In [10]:
dataset = DiffusionDataset(train_df)
del train_df

In [11]:
embed_size = 512
hidden_size = 256
output_size = dataset.vocab_size
n_layers = 2
dec_dropout = 0.5

batch_size = 2
num_epochs = 10
clip = 1

# seed random number generator
seed(1)

In [12]:
dataset.vocab_size

1484

In [13]:
c = 0
for img in dataset.transformed_images:
  print(img.size())
  c += 1
  if c > 5:
    break


torch.Size([3, 256, 256])
torch.Size([3, 256, 256])
torch.Size([3, 256, 256])
torch.Size([3, 256, 256])
torch.Size([3, 256, 256])
torch.Size([3, 256, 256])


In [14]:
encoder = EncoderCNN(embed_size, n_layers, hidden_size).to(device)
decoder = DecoderRNN(output_size, embed_size, n_layers, hidden_size, dec_dropout).to(device)
model = Seq2Seq(encoder, decoder, device).to(device)

criterion = nn.CrossEntropyLoss()
#params = list(decoder.parameters()) + list(encoder.linear.parameters()) + list(encoder.bn.parameters())
#optimizer = optim.Adam(params, lr=0.001)
optimizer = optim.Adam(model.parameters())



In [15]:
def collate_fn(data):
    # Sort a data list by caption length (descending order).
    data.sort(key=lambda x: len(x[1]), reverse=True)
    images, prompts = zip(*data)

    # Merge images (from tuple of 3D tensor to 4D tensor).
    images = torch.stack(images, 0)

    # Merge prompts (from tuple of 1D tensor to 2D tensor).
    lengths = [len(prompt) for prompt in prompts]
    padded_prompts = torch.zeros(len(prompts), max(lengths)).long()
    for i, cap in enumerate(prompts):
        end = lengths[i]
        padded_prompts[i, :end] = cap[:end]

    return images, padded_prompts, lengths

In [16]:
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

In [17]:
x = [ x[1] for x in next(iter(data_loader)) ]
x[1]

tensor([  1,   6,  14,  64,   0, 386, 169, 116,   5,   6, 537, 389, 387, 388,
        162,  17,   6, 538, 146, 191,   0, 158,  13,   2,   0,   0,   0,   0,
          0,   0,   0,   0,   0])

In [18]:
def init_weights(m):
    for name, param in m.named_parameters():
        nn.init.uniform_(param.data, -0.08, 0.08)
        
model.apply(init_weights)

Seq2Seq(
  (encoder): EncoderCNN(
    (resnet): ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplac

In [19]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 26,488,588 trainable parameters


In [20]:
#TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]
#TODO
#criterion = nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX)

In [21]:
def translate_output(output, word2index):
    index2word = {index: word for word, index in word2index.items()}  # Create index-to-word dictionary
    translated_sentences = []
    for seq in output:
        sentence = []
        for idx in seq:
            word = index2word.get(idx.item(), "<UNK>")
            if word == "<EOS>":
                break
            sentence.append(word)
        translated_sentence = " ".join(sentence)
        translated_sentences.append(translated_sentence)
    return translated_sentences

In [22]:
for i, (images, prompts, trg_lengths) in enumerate(data_loader):
    example_images = images
    example_prompts = prompts
    break

# Translate the example prompt
translated_example_prompts = translate_output(prompts, dataset.word2index)


In [23]:
def get_translations(images, prompts): 
    model.eval()
    with torch.no_grad():
        # Move images and prompts to the device
        images = images.to(device)
        prompts = prompts.to(device)

        # Perform forward pass for the images and prompts
        outputs = model(images, prompts)

        # Get the predicted words with the highest probability
        top1 = outputs.argmax(2).transpose(0, 1)

        # Translate the predicted output to words
        translated_output = translate_output(top1, dataset.word2index)

        return translated_output
    

In [26]:
# Training loop
translations_list = []  # List to store translated sentences

for epoch in range(num_epochs):
    model.train()
    for i, (images, prompts, trg_lengths) in enumerate(data_loader):
        images = images.to(device)
        prompts = prompts.to(device)

        # TODO add packing?
        targets = pack_padded_sequence(prompts, trg_lengths, batch_first=True)[0]

        optimizer.zero_grad()
        output = model(images, prompts)

        # Remove the <sos> token and reshape the output and target tensors
        output_dim = output.shape[-1]
        output = output[1:].view(-1, output_dim).contiguous()

        trg = prompts.transpose(0, 1)[1:].contiguous().view(-1)

        loss = criterion(output, trg)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()

        if i % 50 == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'
                  .format(epoch + 1, num_epochs, i, len(data_loader), loss.item(), np.exp(loss.item())))

    # Get translations after each epoch and append to the list
    translations_list.append(get_translations(example_images, example_prompts))

Epoch [1/10], Step [0/500], Loss: 7.3156, Perplexity: 1503.5406
Epoch [1/10], Step [50/500], Loss: 5.2023, Perplexity: 181.6851
Epoch [1/10], Step [100/500], Loss: 4.2118, Perplexity: 67.4762
Epoch [1/10], Step [150/500], Loss: 3.3689, Perplexity: 29.0469
Epoch [1/10], Step [200/500], Loss: 4.7000, Perplexity: 109.9488
Epoch [1/10], Step [250/500], Loss: 3.6381, Perplexity: 38.0210
Epoch [1/10], Step [300/500], Loss: 3.5041, Perplexity: 33.2519
Epoch [1/10], Step [350/500], Loss: 3.6949, Perplexity: 40.2407
Epoch [1/10], Step [400/500], Loss: 4.1044, Perplexity: 60.6058
Epoch [1/10], Step [450/500], Loss: 3.1979, Perplexity: 24.4803
Epoch [2/10], Step [0/500], Loss: 2.9493, Perplexity: 19.0924
Epoch [2/10], Step [50/500], Loss: 2.3402, Perplexity: 10.3836
Epoch [2/10], Step [100/500], Loss: 4.4046, Perplexity: 81.8290
Epoch [2/10], Step [150/500], Loss: 3.7858, Perplexity: 44.0709
Epoch [2/10], Step [200/500], Loss: 3.6382, Perplexity: 38.0223
Epoch [2/10], Step [250/500], Loss: 3.7721

In [27]:
translations_first_example, translations_second_example = zip(*translations_list)

print("Translated first Prompt:")
print(translated_example_prompts[0])


print("Translated outputs over epochs:")
print()
# Print the translated prompts over epochs
for i, sentence in enumerate(translations_first_example):
    print(f"Epoch {i+1} predicted sentence {sentence}")

#print()

#print("Translated second Prompt:")
#print(translated_example_prompts[0])

# Print the translated prompts over epochs
#for i, sentence in enumerate(translations_first_example):
#    print(f"Epoch {i+1} predicted sentence {sentence}")


Translated first Prompt:
<SOS> full body of beautiful baroque girl <UNK> soft lighting <UNK> realistic wide angle <UNK> glamour pose <UNK> sharp focus <UNK> <UNK> k high <UNK> <UNK> <UNK> megapixel <UNK> insanely detailed <UNK> intricate <UNK> elegant <UNK> art by artgerm and wlop
Translated outputs over epochs:

Epoch 1 predicted sentence <UNK> a of of <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK>
Epoch 2 predicted sentence <UNK> a epic of a <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> artgerm and <UNK> <UNK>
Epoch 3 predicted sentence <UNK> a of of <UNK> <UNK> <UNK> <UNK> <UNK> lighting <UNK> <UNK> <UNK> <UNK> <UNK> sharp focus <UNK> <UNK> <UNK> <UNK> <UN

In [None]:
testset = load_dataset('poloclub/diffusiondb', 'large_random_1k')
test_df = pd.DataFrame(testset["train"])
test_df = train_df[["image", "prompt"]]
del testset
test_df

In [None]:
test_loader = DiffusionDataset(test_df, batch_size = batch_size, shuffle=False)

In [None]:
model.eval()
with torch.no_grad():
  losses = []
  for (images, prompts) in test_loader:
    images = images.to(device)
    prompts = prompts.to(device)
    output = model(images, prompts)

    output_dim = output.shape[-1]
    output = output[1:].view(-1, output_dim).contiguous() 

    trg = prompts.transpose(0, 1)  # Transpose dimensions 0 and 1
    trg = trg[1:].contiguous()  # Remove the first token <sos>
    trg = trg.view(-1)  # Flatten the tensor

    loss = criterion(output, targets)
    losses.append(loss.item())
    
avg_loss = np.sum(losses) / len(test_loader)
print(f"Avg Loss: {avg_loss}, Avg Perplexity: {np.exp(avg_loss)}")