<div style="width: 30%; float: right; margin: 10px; margin-right: 5%;">
    <img src="https://upload.wikimedia.org/wikipedia/commons/thumb/d/d3/FHNW_Logo.svg/2560px-FHNW_Logo.svg.png" width="500" style="float: left; filter: invert(50%);"/>
</div>

<h1 style="text-align: left; margin-top: 10px; float: left; width: 60%;">
    del Mini-Challenge 2 <br> 
</h1>

<p style="clear: both; text-align: left;">
    Bearbeitet durch Si Ben Tran im HS 2023.<br>Bachelor of Science FHNW in Data Science.
</p>


Ziel:  
Vertiefung in ein eher aktuelles Paper aus der Forschung und Umsetzung eines darin beschriebenen oder verwandten Tasks - gemäss Vereinbarung mit dem Fachcoach. 

Beispiel:  
Implementiere, trainiere und validiere ein Deep Learning Modell für Image Captioning wie beschrieben im Paper Show and Tell.

Zeitlicher Rahmen:  
Wird beim Schritt 1 verbindlich festgelegt.

Beurteilung:  
Beurteilt wird auf Basis des abgegebenen Notebooks:  
•	Vollständige und korrekte Umsetzung der vereinbarten Aufgabestellung.  
•	Klare, gut-strukturierte Umsetzung.   
•	Schlüssige Beschreibung und Interpretation der Ergebnisse. Gut gewählte und gut kommentierten Plots und Tabellen.  
•	Vernünftiger Umgang mit (Computing-)Ressourcen.  
•	Verständliche Präsentation der Ergebnisse.  

Referenzen, Key Words  
•	Word Embedding (z.B. word2vec, glove), um Wörter in numerische Vektoren in einem geeignet dimensionierten Raum zu mappen. Siehe z.B. Andrew Ng, Coursera: [Link](https://www.coursera.org/lecture/nlp-sequence-models/learning-word-embeddings-APM5s)      
•	Bild Embedding mittels vortrainierten (evt. retrained) Netzwerken wie beispielsweise ResNet, GoogLeNet, EfficientNet oder ähnlich Transfer-Learning.  
•	Seq2Seq Models bekannt für Sprach-Übersetzung. 

Daten:   
•	Gemäss Vereinbarung (für Captioning: [Flickr8k-Daten](https://www.kaggle.com/adityajn105/flickr8k/activity)).

•	Absprache/Beschluss mit Coach und Beschluss, was evaluiert werden soll.
 

# 1 Setup und Imports

In [None]:
# autoreload
%load_ext autoreload
%autoreload 2

import os
os.chdir('../')

In [None]:
import tqdm 
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt

# Torch
import torch
print(torch.__version__)
#import torch.nn as nn
#import torch.optim as optim
#from torch.utils.data import Dataset, DataLoader

from src.gpu_setup import DeviceSetup


In [None]:
device_setup = DeviceSetup(seed=42)
device_setup.setup()

---

# 2 Daten
Wir erkennen bei der Spalte *image*, das ein jpg. Bilddatei mehrere *catpion* hat.

Bei der Visualisierung der Bilder erkenne wir:
- Personen oder Tiere
- Unterschiedliche Grössen
- Unterschiedliche Auflösung


In [None]:
class DataExplorer:
    def __init__(self, image_path, captions_path):
        self.image_path = image_path
        self.data = pd.read_csv(captions_path)


    def _get_image_unique(self):
        """
        This method returns a list of unique image IDs.
        """
        image_unique = self.data['image'].unique()
        return image_unique
    
    def _get_word_counts(self):
        """
        This method returns a list of the number of words per caption.
        """
        word_counts = self.data['caption'].apply(str.split).apply(len)
        return word_counts
    
    def _read_image(self, image_id):
        """
        This method reads an image from a specific path and returns the image object.
        """
        image = Image.open(self.image_path + "/" + image_id)
        return image

    def _get_captions(self, image_id):
        """
        This method retrieves the captions associated with an image ID from the data dictionary.
        """
        captions = []
        for i in range(len(self.data)):
            if self.data['image'][i] == image_id:
                captions.append(self.data['caption'][i])
        captions = '\n'.join(captions)
        return captions

    def plot_n_m_image_caption(self, n, m):
        """
        This method plots a grid of n x m images along with their captions.
        """
        image_unique = self._get_image_unique()
        fig, ax = plt.subplots(n, m, figsize=(16, 20))
        for i in range(n):
            for j in range(m):
                index = np.random.randint(0, len(image_unique))
                image_id = image_unique[index]
                image = self._read_image(image_id)
                captions = self._get_captions(image_id)
                ax[i, j].imshow(np.asarray(image))
                ax[i, j].set_title(captions)
        plt.tight_layout()
        plt.show()

    def plot_image_size(self):
        """
        This method plots a grid of n x m images along with their captions.
        """
        image_unique = self._get_image_unique()
        fig, ax = plt.subplots(1, 1, figsize=(8, 8))
        # set range of x and y axis
        ax.set_xlabel('width of image')
        ax.set_ylabel('height of image')
        for i in range(len(image_unique)):
            image_id = image_unique[i]
            image = self._read_image(image_id)
            width, height = image.size
            ax.scatter(width, height)
        ax.set_title('Distribution size of images')
        plt.tight_layout()
        plt.show()

    # plot caption distribution word length
    def plot_caption_distribution(self):
        """
        This method plots the distribution of the number of words per caption.
        """
        word_counts = self._get_word_counts()
        plt.figure(figsize=(12, 8))
        plt.hist(word_counts, bins=25, color = 'limegreen', edgecolor='black', linewidth=1.2)
        plt.title("Distribution of Number of Words per Caption")
        plt.xlabel("Number of Words")
        plt.ylabel("Frequency")
        plt.show()

    # get statistical summary of caption distribution word length
    def get_caption_distribution(self):
        """
        This method prints the statistical summary of the number of words per caption.
        """
        word_counts = self._get_word_counts()
        print(word_counts.describe(percentiles=[0.25, 0.5, 0.75, 0.95]))

    def plot_caption_ecdf(self):
        """
        This method plots the ECDF of the number of words per caption.
        """
        word_counts = self._get_word_counts()
        word_counts_sorted = word_counts.sort_values()
        y = np.arange(1, len(word_counts_sorted) + 1) / len(word_counts_sorted)  
        plt.figure(figsize=(12, 8))
        plt.plot(word_counts_sorted, y, color='limegreen')
        plt.axhline(y=0.95, color='r', linestyle='-')
        plt.axvline(x=19, color='r', linestyle='-')
        plt.xticks(np.arange(np.min(word_counts_sorted), np.max(word_counts_sorted), 1.0))
        plt.title("ECDF of Number of Words per Caption")
        plt.xlabel("Number of Words")
        plt.ylabel("Proportion")
        plt.show()

    # plot most commen words
    def plot_most_common_words(self):
        """
        This method plots the most common words in the captions.
        """
        from collections import Counter
        word_counts = self.data['caption'].apply(str.split).apply(Counter).sum()
        word_counts = pd.DataFrame.from_dict(word_counts, orient='index').reset_index()
        word_counts.columns = ['word', 'count']
        word_counts = word_counts.sort_values(by='count', ascending=False)
        plt.figure(figsize=(12, 8))
        plt.bar(word_counts['word'][:20], word_counts['count'][:20], color = 'limegreen', edgecolor='black', linewidth=1.2)
        plt.title("Most Common Words in Captions")
        plt.xlabel("Words")
        plt.ylabel("Frequency")
        plt.xticks(rotation=45)
        plt.show()

In [None]:
image_path = "data/Flickr8K/images/"
captions_path = "data/Flickr8K/captions.txt"

flicker_data_explorer = DataExplorer(image_path, captions_path)
flicker_data = flicker_data_explorer.data

## 2.1 Dataframe

In [None]:
flicker_data

## 2.2 Visualisierungen der Bilder

In [None]:
#flicker_data_explorer.plot_n_m_image_caption(2, 2)

## 2.3 Grössen der Bilder

In [None]:
#flicker_data_explorer.plot_image_size()

## 2.4 Caption Länge

In [None]:
#flicker_data_explorer.plot_caption_distribution()
#flicker_data_explorer.plot_caption_ecdf()
#flicker_data_explorer.get_caption_distribution()

## 2.5 Häufigste Wörter

In [None]:
#flicker_data_explorer.plot_most_common_words()

---

# 3 Preprocessing der Bilder

Wir werden Die Bilder wie folgt vorbereiten, damit das Model die Bilder verarbeiten kann: 

`ToPILImage()`: Dieser Schritt konvertiert das Eingabebild in ein PIL (Python Imaging Library) Bildformat. Dies ist erforderlich, wenn das Eingabebild nicht bereits im PIL-Format vorliegt.

`CenterCrop((500, 500))`: Hier wird das Bild auf eine Größe von 500x500 Pixel zentriert zugeschnitten. Dies ist nützlich, um das Bild auf eine bestimmte Größe zu bringen und sicherzustellen, dass wichtige Merkmale in der Mitte erhalten bleiben.

`Resize((224, 224))`: Das Bild wird auf eine Größe von 224x224 Pixel skaliert. Dies ist eine häufig verwendete Größe für viele neuronale Netzwerke, insbesondere in der Bildklassifikation, wie z.B. Convolutional Neural Networks (CNNs).

`ToTensor()`: Hier wird das Bild in einen PyTorch-Tensor konvertiert. Die meisten neuronalen Netzwerke in PyTorch und anderen Frameworks arbeiten mit Tensoren als Eingabe.

`Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])`: Diese Transformation normalisiert die Pixelwerte des Bildes. Dies ist wichtig, um sicherzustellen, dass die Werte im Eingangsbild in einem bestimmten Bereich liegen. Die angegebenen Mittelwerte und Standardabweichungen sind typische Werte für die Normalisierung von Bildern, die auf dem ImageNet-Datensatz trainiert wurden.

`RandomHorizontalFlip`: Führt mit einer Wahrscheinlichkeit von `horizontal_flip_prob` eine zufällige horizontale Spiegelung des Bildes durch.

`RandomVerticalFlip`: Führt mit einer Wahrscheinlichkeit von `vertical_flip_prob` eine zufällige vertikale Spiegelung des Bildes durch.

`RandomRotation`: Führt eine zufällige Rotation des Bildes um den angegebenen Winkel (`rotation_degree` Grad) durch.

`ColorJitter`: Verändert die Helligkeit, den Kontrast, die Sättigung und den Farbton des Bildes zufällig, um die Farbvariationen zu erhöhen.

In [None]:
from torchvision.transforms import Compose, CenterCrop, Resize, ToTensor, ToPILImage, Normalize

target_size = (224, 224)
center_crpp = (500, 500)
mean_values = [0.485, 0.456, 0.406]
std_values = [0.229,0.224,0.225]

# Transformations for the image
image_transformations = Compose([
    CenterCrop(center_crpp),
    Resize(target_size),
    ToTensor(),
    Normalize(mean=mean_values,
                std=std_values)
])

from torchvision.transforms import RandomHorizontalFlip, RandomRotation, RandomVerticalFlip, ColorJitter

rotation_degree = 45
horizontal_flip_prob = 0.5
vertical_flip_prob = 0.5

image_transforms_augmented = Compose([
    ToPILImage(),
    RandomHorizontalFlip(p=horizontal_flip_prob),
    RandomVerticalFlip(p=vertical_flip_prob),
    RandomRotation(degrees=rotation_degree),
    ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
    CenterCrop(center_crpp),
    Resize(target_size),
    ToTensor(),
    Normalize(mean=mean_values,
                std=std_values)
])

---

# 4 PreProcessing von Captions

In [None]:
import nltk

class CaptionProcessor:
    def __init__(self, max_length: int):
        self.max_length = max_length
        self.vocabulary = []
        self.start_token = "<sos>"
        self.stop_token = "<eos>"
        self.unknown_token = "<unk>"
        self.padding_token = "<pad>"
        self.token_to_index = {}  
        self.index_to_token = {}

    def create_vocab(self, dataframe, caption_column: str):
        """
        Create a vocabulary from a DataFrame containing captions.
        Args:
            dataframe (pandas.DataFrame): A DataFrame containing caption data.
            caption_column (str): The name of the column in the DataFrame that contains captions.
        """
        # Join all the captions in the specified column of the DataFrame into a single string
        all_captions = " ".join(dataframe[caption_column].values)
        # make all captions lowercase 
        all_captions = all_captions.lower()
        # create a set of all unique words in the joined captions string
        all_captions = set(all_captions.split())
        # convert the set of unique words to a list  
        all_captions = list(all_captions)
        # join all the unique words into a single string separated by spaces
        all_captions = " ".join(all_captions) 
        # Tokenize the joined captions into words and convert them to lowercase
        self.vocabulary = nltk.tokenize.word_tokenize(all_captions)
        # Add special tokens to the vocabulary
        self.vocabulary = [self.start_token, self.stop_token, self.unknown_token, self.padding_token] + self.vocabulary
        self.token_to_index = {token: index for index, token in enumerate(self.vocabulary)}
        self.index_to_token = {index: token for index, token in enumerate(self.vocabulary)}


    def caption_to_tokens(self, dataframe, caption_column):
        """
        Preprocess captions in a DataFrame by tokenizing and adding start and stop tokens.
        Args:
            dataframe (pandas.DataFrame): A DataFrame containing caption data.
            caption_column (str): The name of the column in the DataFrame that contains captions.
        Returns:
            pandas.DataFrame: The input DataFrame with an additional "tokenized_caption" column 
                              containing tokenized captions with start and stop tokens added.
        """
        # Convert captions to lowercase
        dataframe["tokenized_caption"] = dataframe[caption_column].apply(lambda x: x.lower())
        # Split caption into tokens
        dataframe["tokenized_caption"] = dataframe[caption_column].apply(lambda x: x.split())
        # Add start token at the beginning
        dataframe["tokenized_caption"] = dataframe["tokenized_caption"].apply(lambda x: [self.start_token] + x)
        # Add stop token at the end
        dataframe["tokenized_caption"] = dataframe["tokenized_caption"].apply(lambda x: x + [self.stop_token])
        return dataframe


    def tokens_to_indices(self, tokens):
        """
        Converts a list of tokens into their corresponding indices in the vocabulary.
        Args:
            tokens (list): List of tokens representing a caption.
        Returns:
            list: List of indices representing the caption.
        """
        indices = [self.token_to_index[token] if token in self.vocabulary else self.token_to_index[self.unknown_token] for token in tokens]
        # add padding to max length
        indices_padding = indices + [self.token_to_index[self.padding_token]] * (self.max_length - len(indices))
        # truncate to max length
        indices_padding = indices_padding[:self.max_length]
        return indices_padding
    
    def indices_to_tokens(self, indices):
        """
        Converts a list of indices to their corresponding tokens.
        Args:
            indices (list): List of indices representing a caption.
        Returns:
            list: List of tokens corresponding to the input indices.
        """
        tokens = [self.index_to_token[index] if index in self.index_to_token else self.unknown_token for index in indices]
        return tokens

    def tokens_to_caption(self, tokens):
        """
        Converts a list of tokens into a human-readable caption, handling unknown words, padding, and maximum length.
        Args:
            tokens (list): List of tokens representing a caption.
        Returns:
            str: Human-readable caption.
        """
        # Lower-case every token in the list
        tokens = [token for token in tokens]
        # Remove start and stop tokens
        tokens = [token for token in tokens if token not in [self.start_token, self.stop_token, self.unknown_token, self.padding_token]]
        # Replace unknown tokens with "<unk>"
        #tokens = [token if token in self.vocabulary else self.unknown_token for token in tokens]
        # Exclude padding tokens
        #tokens = [token for token in tokens if token != self.padding_token]
        # Truncate the caption to the maximum length
        tokens = tokens[:self.max_length]
        # Join the tokens to create the caption
        caption = " ".join(tokens)
        return caption

In [None]:
caption_processor = CaptionProcessor(max_length=20)
caption_processor.create_vocab(flicker_data, "caption")
flicker_data_tokenized = caption_processor.caption_to_tokens(flicker_data, "caption")
display(flicker_data_tokenized.head())
print("Vocab size:", len(caption_processor.vocabulary))

# testing tokens_to_indices
test_tokens = ["<sos>", "a", "group", "of", "people", "are", "Ben", "standing", "around", "a", "table", "Ben", "<pad>", "<pad>", "<pad>", "<eos>"]
test_indices = caption_processor.tokens_to_indices(test_tokens)
print(test_indices)

# testing indices_to_tokens
test_indices = [0, 476674, 476492, 476493, 476512, 476605, 2, 476596, 476476, 476674, 476103, 2, 3, 3, 3, 1, 3, 3, 3, 3]
test_tokens = caption_processor.indices_to_tokens(test_indices)
print(test_tokens)

# testing tokens_to_caption
test_tokens = ['<sos>', 'a', 'group', 'of', 'people', 'are', 'standing', 'around', 'a', 'table', '<unk>', '<eos>']
test_caption = caption_processor.tokens_to_caption(test_tokens)
print(test_caption)


---

# 5 Word Embedding


In [None]:
# Embedding the tokenized_caption

---

# 6 Data Preparation

## 6.1 Train-Validation-Test Split

In [None]:
from sklearn.model_selection import train_test_split
# get all unique images
images = flicker_data_tokenized["image"].unique()

# Split the images into train, validation, and test sets Train/Val/Test split (60/20/20)
train_images, val_test_images = train_test_split(images, test_size=0.4, random_state=42)
val_images, test_images = train_test_split(val_test_images, test_size=0.5, random_state=42)

# Then, filter the data based on the split sets
train_df = flicker_data_tokenized[flicker_data_tokenized["image"].isin(train_images)]
val_df = flicker_data_tokenized[flicker_data_tokenized["image"].isin(val_images)]
test_df = flicker_data_tokenized[flicker_data_tokenized["image"].isin(test_images)]

# display the data
display("train_data", train_df)
display("val_data", val_df)
display("test_data", test_df)

## 6.2 Dataset

In [None]:
from torch.utils.data import DataLoader
from multiprocessing import cpu_count
from src.flickerdataset import Flicker8kDataset

# Create the datasets
train_data = Flicker8kDataset(train_df, image_path, image_transformations, caption_processor)
val_data = Flicker8kDataset(val_df, image_path, image_transformations, caption_processor)
test_data = Flicker8kDataset(test_df, image_path, image_transformations, caption_processor)

# get dimension of the data, and the first data
print("train_data", len(train_data))
print("val_data", len(val_data))
print("test_data", len(test_data))
image_train, caption_train = train_data[1]
print("image_train", image_train.shape)
print("caption_train", caption_train.shape)

# DataLoader


In [None]:
batch_size = 2

train_loader = DataLoader(train_data, batch_size=batch_size, pin_memory=False, num_workers=0, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, pin_memory=False, num_workers=0, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size, pin_memory=False, num_workers=0, shuffle=True)

In [None]:
print(f"Number of batches in train_loader: {len(train_loader)}")
print(f"Number of batches in val_loader: {len(val_loader)}")
print(f"Number of batches in test_loader: {len(test_loader)}")

In [None]:
# get dimension of the data, and the first data
image_train, caption_train = next(iter(val_loader))
print("image_train", image_train.shape)
print("caption_train", caption_train.shape)


--- 

# 7 CNN Encoder

In [None]:
# load Resnet18 model in this class 
import torch
import torch.nn as nn
from torchvision.models import resnet18, ResNet18_Weights

import torch
import torch.nn as nn
from torchvision import models
from torchvision.models import ResNet18_Weights

class EncoderCNN(nn.Module):
    def __init__(self, embedding_dim):
        super(EncoderCNN, self).__init__()
        self.embedding_dim = embedding_dim # define embedding dimension
        self.resnet = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)  # Use ResNet-18 with pre-trained weights

        # freeze all layers except the last linear layer
        for param in self.resnet.parameters():
            param.requires_grad = False
        

        # overwrite the last layer
        self.resnet.fc = nn.Sequential(nn.Linear(self.resnet.fc.in_features, self.embedding_dim), 
                                       nn.BatchNorm1d(self.embedding_dim, momentum=0.01))

    def forward(self, images):
        features = self.resnet(images)
        return features
    

In [None]:
embedding_dim = 128
encoder = EncoderCNN(embedding_dim)
encoder.to(device_setup.device)

In [None]:
# test the encoder
image_train, caption_train = next(iter(train_loader))
image_train = image_train.to(device_setup.device)
caption_train = caption_train.to(device_setup.device)
print("image_train", image_train.shape)
print("caption_train", caption_train.shape)
features = encoder(image_train)
print("features", features.shape)


---

# 8 LSTM Decoder

In [None]:
# load DecoderLSTM model 
class DecoderLSTM(nn.Module):
    def __init__(self, embedding_dim, hidden_size, vocab_size, num_layers=1):
        super(DecoderLSTM, self).__init__()
        # define the properties
        self.embedding_dim = embedding_dim
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        self.num_layers = num_layers
        # define the layers
        self.embed = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
        # initialize weights
        self.init_weights()

    def forward(self, features, captions):
        """
        Decode feature vectors and generates captions.
        Args:
            features (torch.Tensor): Tensor of extracted feature vectors from images.
            captions (torch.Tensor): Tensor of captions.
        Returns:
            torch.Tensor: Tensor of predicted captions.
        """
        # Remove end token from captions
        captions = captions[:, :-1]
        # Embed the captions
        embeddings = self.embed(captions)
        # Concatenate the feature vectors and embeddings
        embeddings = torch.cat((features.unsqueeze(1), embeddings), dim=1)
        # Pass the embeddings through the LSTM cells
        hiddens, _ = self.lstm(embeddings)
        # Reshape outputs to be (batch_size * sequence_length, hidden_size)
        outputs = hiddens.reshape(-1, hiddens.size(2))
        # Pass the outputs through the linear layer
        outputs = self.linear(outputs)
        return outputs

    def greedy_sample(self, features, states=None, max_length=20):
        batch_size = features.size(0)  # Get the batch size
        sampled_ids = [[] for _ in range(batch_size)]  # List to store the sampled captions for each image in the batch
        
        # Prepare the initial input for LSTM, which is the features tensor
        inputs = features.unsqueeze(1)
        
        for i in range(max_length):
            hiddens, states = self.lstm(inputs, states)  # Pass the input through LSTM
            outputs = self.linear(hiddens.squeeze(1))  # Pass the LSTM outputs through the linear layer

            predicted = outputs.argmax(1)  # Get the predicted word indices
            for j in range(batch_size):
                # break of eos token 
                if predicted[j].item() == 1:
                    break
                else:
                    sampled_ids[j].append(predicted[j].item())  # Append the predicted word index to the respective caption list

            inputs = self.embed(predicted)  # Prepare the input for the next time step
            inputs = inputs.unsqueeze(1)

        return sampled_ids

    def init_weights(self):
        """
        Initializes the weights of the linear layer.
        """
        self.linear.weight.data.uniform_(-0.5, 0.5)
        self.linear.bias.data.fill_(0)


In [None]:
embedding_dim = 256
hidden_size = 256
vocab_size = len(caption_processor.vocabulary)
num_layers = 1
decoder = DecoderLSTM(embedding_dim, hidden_size, vocab_size, num_layers)
decoder.to(device_setup.device)

In [None]:
# test the decoder
features = features.to(device_setup.device)
captions = caption_train.to(device_setup.device)
outputs = decoder(features, captions)
print("outputs", outputs.shape)
print("outputs", outputs)
# test the decoder sample
sampled_ids = decoder.greedy_sample(features)
print("sampled_ids", sampled_ids)

# Encoder-Decoder 2


In [None]:

class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder):
        super(EncoderDecoder, self).__init__()
        # define the properties
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, images, captions):
        features = self.encoder(images)
        outputs = self.decoder(features, captions)
        return outputs

    def greedy_sample(self, images):
        features = self.encoder(images)
        sampled_ids = self.decoder.greedy_sample(features)
        return sampled_ids
    
    # train model
    def train_step(self, images, captions, criterion, optimizer):
        # zero the parameter gradients
        optimizer.zero_grad()
        # forward + backward + optimize
        outputs = self.forward(images, captions)
        loss = criterion(outputs, captions.reshape(-1))
        loss.backward()
        optimizer.step()
        return loss.item()
    
    # validation model
    def val_step(self, images, captions, criterion):
        # forward
        outputs = self.forward(images, captions)
        loss = criterion(outputs, captions.reshape(-1))
        return loss.item()
    
    # test model
    def test_step(self, images):
        sampled_ids = self.greedy_sample(images)
        return sampled_ids
    
    # save model
    def save(self, path):
        torch.save(self.state_dict(), path)

    # load model
    def load(self, path):
        self.load_state_dict(torch.load(path))

# define the properties
hidden_size = 512
vocab_size = len(caption_processor.vocabulary)
num_layers = 1
# define the model
decoder = DecoderLSTM(embedding_dim, hidden_size, vocab_size, num_layers)
decoder.to(device_setup.device)
# define the model
model = EncoderDecoder(encoder, decoder)
model.to(device_setup.device)
print(model)

# test the model
image_train, caption_train = next(iter(train_loader))
image_train = image_train.to(device_setup.device)
caption_train = caption_train.to(device_setup.device)
print("image_train", image_train.shape)
print("caption_train", caption_train.shape)
outputs = model(image_train, caption_train)
print("outputs", outputs.shape)

# define the loss function and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=caption_processor.token_to_index[caption_processor.padding_token])
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# define the number of epochs

num_epochs = 1
# define the path to save the model
model_path = "models/encoder_decoder.pt"
# train the model
for epoch in range(num_epochs):
    # train the model
    model.train()
    train_loss = 0.0
    for i, (images, captions) in enumerate(train_loader):
        # move images and captions to gpu if available
        images = images.to(device_setup.device)
        captions = captions.to(device_setup.device)
        # train step
        loss = model.train_step(images, captions, criterion, optimizer)
        train_loss += loss
        # print statistics
        if i % 100 == 0:
            print(f"Epoch: {epoch}, Batch: {i}, Loss: {loss}")
    # validation the model
    model.eval()
    val_loss = 0.0
    for i, (images, captions) in enumerate(val_loader):
        # move images and captions to gpu if available
        images = images.to(device_setup.device)
        captions = captions.to(device_setup.device)
        # validation step
        loss = model.val_step(images, captions, criterion)
        val_loss += loss
        # print statistics
        if i % 100 == 0:
            print(f"Epoch: {epoch}, Batch: {i}, Loss: {loss}")
    # print statistics
    print(f"Epoch: {epoch}, Train Loss: {train_loss/len(train_loader)}, Val Loss: {val_loss/len(val_loader)}")
    # save the model
    model.save(model_path)

---

# 9 Encoder-Decoder Architektur

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pack_padded_sequence
from tqdm import tqdm

class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, images, captions):
        features = self.encoder(images)
        outputs = self.decoder(features, captions)
        return outputs

    def train_one_epoch(self, train_loader, criterion, optimizer, device):
        self.train()
        total_loss = 0
        for images, captions in tqdm(train_loader, desc="Training"):
            images, captions = images.to(device), captions.to(device)
            optimizer.zero_grad()
            outputs = self(images, captions)
            targets = pack_padded_sequence(captions, [len(caption) for caption in captions], batch_first=True, enforce_sorted=False)[0]
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        return total_loss / len(train_loader)

    def validate(self, val_loader, criterion, device):
        self.eval()
        total_loss = 0
        with torch.no_grad():
            for images, captions in tqdm(val_loader, desc="Validation"):
                images, captions = images.to(device), captions.to(device)
                outputs = self(images, captions)

                targets = pack_padded_sequence(captions, [len(caption) for caption in captions], batch_first=True, enforce_sorted=False)[0]
                loss = criterion(outputs, targets)
                total_loss += loss.item()

        return total_loss / len(val_loader)


In [None]:
# Instantiate the model
encoder_decoder = EncoderDecoder(encoder, decoder)

# Move the model to a device
device = device_setup.device
encoder_decoder.to(device)

## 9.1 Training

In [None]:
# Loss Function and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(encoder_decoder.parameters(), lr=1e-3)

# Training Loop
num_epochs = 1
for epoch in range(num_epochs):
    train_loss = encoder_decoder.train_one_epoch(train_loader, criterion, optimizer, device)
    val_loss = encoder_decoder.validate(val_loader, criterion, device)
    print(f"Epoch: {epoch+1}/{num_epochs}, Training Loss: {train_loss}, Validation Loss: {val_loss}")


---
# 10 Export Model

In [None]:
# save encoder_decoder model
torch.save(encoder_decoder.state_dict(), "models/encoder_decoder_model_5_epoch_train_val_set.pth")

In [None]:
encoder_decoder.load_state_dict(torch.load("models/encoder_decoder_model_5_epoch_train_val_set.pth"))
encoder_decoder.to(device_setup.device)

---

# 11 Evalierung

In [None]:
import matplotlib.pyplot as plt
from PIL import Image
import os

def visualize_results(data_loader, encoder_decoder, caption_processor, device, image_path, num_images=8):
    encoder_decoder.eval()

    # Get a batch of data
    for images, _ in iter(data_loader):
        break
    images = images.to(device)

    # Generate captions
    predicted_captions = []
    for i in range(0, num_images * 5, 5):  # Iterating over every 5th image
        if i >= len(images):
            break
        image = images[i]
        features = encoder_decoder.encoder(image.unsqueeze(0))
        sampled_ids = encoder_decoder.decoder.greedy_sample(features)
        tokens = [caption_processor.indices_to_tokens(ids) for ids in sampled_ids]
        captions = [caption_processor.tokens_to_caption(token_list) for token_list in tokens]
        predicted_captions.append(captions)

    # Ensure we have enough rows in the subplot
    num_rows = min(num_images, len(predicted_captions))

    fig, axs = plt.subplots(num_rows, 1, figsize=(10, 20))
    if num_rows == 1:
        axs = [axs]  # Make axs iterable even if it's a single Axes object

    for i, ax in enumerate(axs):
        # Adjust the index to account for the 5-image groups
        adjusted_index = i * 5
        image_file = os.path.join(image_path, data_loader.dataset.dataframe.iloc[adjusted_index]["image"])
        if not os.path.exists(image_file):
            continue

        img = Image.open(image_file)
        ax.imshow(img)
        ax.axis('off')

        true_captions = data_loader.dataset.dataframe[data_loader.dataset.dataframe["image"] == data_loader.dataset.dataframe.iloc[adjusted_index]["image"]]["caption"].tolist()
        true_captions_text = "\n".join(true_captions)

        ax.set_title(f'Predicted: {predicted_captions[i][0]}\nTrue: {true_captions_text}')

    plt.show()

In [None]:
# Visualize results
visualize_results(test_loader, encoder_decoder, caption_processor, device, image_path, num_images=8)