In [1]:
import torch
from torch.utils.data import Dataset

from embeddings.autoencoder_embeddings.recipe_autoencoder import RecipeAutoencoder


class RecipeDataset(Dataset):
    def __init__(
            self,
            data,
            tokenizer,
            tag_vocab,
            category_vocab,
            ingredient_vocab,
            max_length=128,
    ):
        self.data = data
        self.tokenizer = tokenizer
        self.tag_vocab = tag_vocab
        self.category_vocab = category_vocab
        self.ingredient_vocab = ingredient_vocab
        self.max_length = max_length

    def encode_strings(self, strings, vocab):
        """Convert a list of strings to indices based on the provided vocabulary."""
        return [
            vocab.get(s, 0) for s in strings
        ]  # Map to index, default to 0 if not found

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        recipe = self.data[idx]

        # Tokenize name, description, and steps
        name = self.tokenizer(
            recipe["name"],
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )
        description = self.tokenizer(
            recipe["description"],
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )
        steps = [
            self.tokenizer(
                step,
                padding="max_length",
                truncation=True,
                max_length=self.max_length,
                return_tensors="pt",
            )
            for step in recipe["steps"]
        ]

        # Encode tags and ingredients as indices
        tags = torch.tensor(
            self.encode_strings(recipe["tags"], self.tag_vocab), dtype=torch.long
        )
        ingredients = torch.tensor(
            self.encode_strings(recipe["ingredients"], self.ingredient_vocab),
            dtype=torch.long,
        )

        # Encode category as a single index
        category = torch.tensor(
            self.category_vocab.get(recipe["category"], 0), dtype=torch.long
        )

        # Nutriments as a tensor
        nutriments = torch.tensor(recipe["nutriments"], dtype=torch.float32)

        return name, description, steps, tags, category, ingredients, nutriments


from collections import Counter


def build_vocab(items, min_freq=1):
    """
    Build a vocabulary from a list of strings or list of lists of strings.
    Args:
        items (list): The input data (e.g., tags, categories, ingredients).
        min_freq (int): Minimum frequency for a string to be included in the vocab.
    Returns:
        dict: A dictionary mapping each string to a unique index.
    """
    if isinstance(items[0], list):  # Handle lists of lists (e.g., tags or ingredients)
        items = [item for sublist in items for item in sublist]
    counter = Counter(items)
    vocab = {
        word: idx + 1
        for idx, (word, count) in enumerate(counter.items())
        if count >= min_freq
    }  # Start indices from 1
    vocab["<UNK>"] = 0  # Add a token for unknown strings
    return vocab


import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import AutoTokenizer


def train_minilm_autoencoder(
        train_data,
        val_data,
        epochs=10,
        batch_size=64,
        max_length=128,
        device="cuda" if torch.cuda.is_available() else "cpu",
):
    # Step 1: Build vocabularies
    tag_vocab = build_vocab([recipe["tags"] for recipe in train_data], min_freq=1)
    category_vocab = build_vocab(
        [recipe["category"] for recipe in train_data], min_freq=1
    )
    ingredient_vocab = build_vocab(
        [recipe["ingredients"] for recipe in train_data], min_freq=1
    )

    # Step 2: Initialize tokenizer, dataset, and dataloaders
    tokenizer = AutoTokenizer.from_pretrained("microsoft/MiniLM-L12-H384-uncased")
    train_dataset = RecipeDataset(
        train_data,
        tokenizer,
        tag_vocab,
        category_vocab,
        ingredient_vocab,
        max_length=max_length,
    )
    val_dataset = RecipeDataset(
        val_data,
        tokenizer,
        tag_vocab,
        category_vocab,
        ingredient_vocab,
        max_length=max_length,
    )
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    # Step 3: Initialize the model, loss functions, and optimizer
    model = RecipeAutoencoder(
        latent_dim=256, ingredient_vocab_size=len(ingredient_vocab)
    ).to(device)
    criterion_text = nn.MSELoss()
    criterion_tags = nn.BCEWithLogitsLoss()
    criterion_category = nn.CrossEntropyLoss()
    criterion_ingredients = nn.BCEWithLogitsLoss()
    criterion_nutriments = nn.MSELoss()

    optimizer = torch.optim.Adam(
        [
            {
                "params": model.text_encoder.transformer.parameters(),
                "lr": 1e-5,
            },  # Fine-tune MiniLM
            {"params": model.parameters(), "lr": 1e-4},  # Other parts of the model
        ]
    )

    # Step 4: Training loop
    for epoch in range(epochs):
        model.train()
        total_train_loss = 0
        for batch in train_loader:
            optimizer.zero_grad()

            # Move data to the training device
            name, description, steps, tags, category, ingredients, nutriments = batch
            name = {key: val.to(device) for key, val in name.items()}
            description = {key: val.to(device) for key, val in description.items()}
            steps = [
                {key: val.to(device) for key, val in step.items()} for step in steps
            ]
            tags, category, ingredients, nutriments = (
                tags.to(device),
                category.to(device),
                ingredients.to(device),
                nutriments.to(device),
            )

            # Forward pass
            (
                latent_embedding,
                text_decoded,
                tags_decoded,
                category_decoded,
                ingredients_decoded,
                nutriments_decoded,
            ) = model(name, description, steps, tags, category, ingredients, nutriments)

            # Compute losses
            loss_text = criterion_text(text_decoded, name["input_ids"].float())
            loss_tags = criterion_tags(
                tags_decoded,
                nn.functional.one_hot(tags, num_classes=len(tag_vocab)).float(),
            )
            loss_category = criterion_category(category_decoded, category)
            loss_ingredients = criterion_ingredients(
                ingredients_decoded,
                nn.functional.one_hot(
                    ingredients, num_classes=len(ingredient_vocab)
                ).float(),
            )
            loss_nutriments = criterion_nutriments(nutriments_decoded, nutriments)

            # Combine losses
            loss = (
                    loss_text
                    + loss_tags
                    + loss_category
                    + loss_ingredients
                    + loss_nutriments
            )
            total_train_loss += loss.item()

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

        # Step 5: Validation loop
        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                name, description, steps, tags, category, ingredients, nutriments = (
                    batch
                )
                name = {key: val.to(device) for key, val in name.items()}
                description = {key: val.to(device) for key, val in description.items()}
                steps = [
                    {key: val.to(device) for key, val in step.items()} for step in steps
                ]
                tags, category, ingredients, nutriments = (
                    tags.to(device),
                    category.to(device),
                    ingredients.to(device),
                    nutriments.to(device),
                )

                (
                    latent_embedding,
                    text_decoded,
                    tags_decoded,
                    category_decoded,
                    ingredients_decoded,
                    nutriments_decoded,
                ) = model(
                    name, description, steps, tags, category, ingredients, nutriments
                )

                loss_text = criterion_text(text_decoded, name["input_ids"].float())
                loss_tags = criterion_tags(
                    tags_decoded,
                    nn.functional.one_hot(tags, num_classes=len(tag_vocab)).float(),
                )
                loss_category = criterion_category(category_decoded, category)
                loss_ingredients = criterion_ingredients(
                    ingredients_decoded,
                    nn.functional.one_hot(
                        ingredients, num_classes=len(ingredient_vocab)
                    ).float(),
                )
                loss_nutriments = criterion_nutriments(nutriments_decoded, nutriments)

                loss = (
                        loss_text
                        + loss_tags
                        + loss_category
                        + loss_ingredients
                        + loss_nutriments
                )
                total_val_loss += loss.item()

        print(
            f"Epoch [{epoch + 1}/{epochs}] - Train Loss: {total_train_loss:.4f}, Val Loss: {total_val_loss:.4f}"
        )

    # Save the final model
    torch.save(model.state_dict(), "minilm_recipe_autoencoder.pth")
    print("Model training complete and saved as 'minilm_recipe_autoencoder.pth'.")


import pandas as pd
from sklearn.model_selection import train_test_split

df = pd.read_pickle('embeddings_train_data.pkl')

# Split the DataFrame into training and validation datasets
train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)

# Convert DataFrame rows into a list of dictionaries
train_data = train_df.to_dict(orient="records")
val_data = val_df.to_dict(orient="records")

# Call the training function
train_minilm_autoencoder(
    train_data=train_data,
    val_data=val_data,
    epochs=10,
    batch_size=64,
    max_length=128,
    device="cuda" if torch.cuda.is_available() else "cpu"
)

  from .autonotebook import tqdm as notebook_tqdm


KeyboardInterrupt: 