In [None]:
!pip install datasets
!pip install huggingface_hub

In [None]:
!pip3 install torch torchvision
!pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git

In [None]:
from huggingface_hub import login
from huggingface_hub import hf_hub_download

import zipfile
import os

import pandas as pd
import numpy as np
import pickle as pkl
import torch
import clip
from os.path import isfile, join
from PIL import Image

import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm

# Download and Prepare Data

In [None]:
login(token="")
# !huggingface-cli login

In [None]:
files_to_download = [
    "data/emo3d_data.parquet",
    "data/train_data.csv",
    "data/val_data.csv",
    "data/test_data.csv",
    "data/anger_images.zip",
    "data/contempt_images.zip",
    "data/disgust_images.zip",
    "data/fear_images.zip",
    "data/sadness_images.zip",
    # "data/primitive_emotions.csv",
    # "data/primitive_emotions.zip",
    "data/prompt1_images_0_1000.zip",
    "data/prompt1_images_1000_2000.zip",
    "data/prompt1_images_2000_3000.zip",
    "data/prompt1_images_3000_4000.zip",
    "data/prompt1_images_4000_5000.zip",
    "data/prompt1_images_5000_6000.zip",
    "data/prompt1_images_6000_7000.zip",
    "data/prompt1_images_7000_9000.zip",
    "data/prompt1_images_9000_10000.zip",
    "data/prompt2_images_0_1000.zip",
    "data/prompt2_images_1000_2000.zip",
    "data/prompt2_images_2000_3000.zip",
    "data/prompt2_images_3000_4000.zip",
    "data/prompt2_images_4000_5000.zip",
    "data/prompt2_images_5000_6000.zip",
]

repo_id = "llm-lab/Emo3D"
local_dir = "/content"
# Download all files
for file_path in files_to_download:
    print(f"Downloading {file_path}...")
    hf_hub_download(repo_type='dataset', repo_id=repo_id, filename=file_path, local_dir=local_dir)
    print(f"{file_path} downloaded successfully.")

In [None]:
for file_path in files_to_download:
    local_file_path = os.path.join(local_dir, file_path)  # Full path to the file
    print(local_file_path)
    if file_path.endswith(".zip"):  # Check if it's a zip file
        print(f"Unzipping {file_path}...")

        unzip_dir = file_path[:-4]
        with zipfile.ZipFile(local_file_path, 'r') as zip_ref:
            zip_ref.extractall(unzip_dir)  # Extract to the same directory
        os.remove(local_file_path)
        print(f"{file_path} extracted successfully.")

# Utilities

In [None]:
blendshapes_names = ['_neutral', 'browDownLeft', 'browDownRight', 'browInnerUp', 'browOuterUpLeft', 'browOuterUpRight', 'cheekPuff',
 'cheekSquintLeft', 'cheekSquintRight', 'eyeBlinkLeft', 'eyeBlinkRight', 'eyeLookDownLeft', 'eyeLookDownRight', 'eyeLookInLeft',
 'eyeLookInRight', 'eyeLookOutLeft', 'eyeLookOutRight', 'eyeLookUpLeft', 'eyeLookUpRight', 'eyeSquintLeft', 'eyeSquintRight',
 'eyeWideLeft', 'eyeWideRight', 'jawForward', 'jawLeft', 'jawOpen', 'jawRight', 'mouthClose', 'mouthDimpleLeft', 'mouthDimpleRight',
 'mouthFrownLeft', 'mouthFrownRight', 'mouthFunnel', 'mouthLeft', 'mouthLowerDownLeft', 'mouthLowerDownRight', 'mouthPressLeft',
 'mouthPressRight', 'mouthPucker', 'mouthRight', 'mouthRollLower', 'mouthRollUpper', 'mouthShrugLower', 'mouthShrugUpper', 'mouthSmileLeft',
 'mouthSmileRight', 'mouthStretchLeft', 'mouthStretchRight', 'mouthUpperUpLeft', 'mouthUpperUpRight', 'noseSneerLeft', 'noseSneerRight']


bs_mirror_idx = []
for i, name in enumerate(blendshapes_names):
    if "Left" in name:
        right_name = name.replace("Left", "Right")
        right_index = blendshapes_names.index(right_name)
        bs_mirror_idx.append(right_index)
    elif "Right" in name:
        left_name = name.replace("Right", "Left")
        left_index = blendshapes_names.index(left_name)
        bs_mirror_idx.append(left_index)
    else:
        bs_mirror_idx.append(i)


In [None]:
import re

replacements = {
        "left": "right",
        "right": "left",
        "Left": "Right",
        "Right": "Left"
    }

regex = re.compile("(%s)" % "|".join(map(re.escape, replacements.keys())))
def mirror_text(text):
    return regex.sub(lambda mo: replacements[mo.group()], text)

# Create Dataset

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
clip_model, preprocess = clip.load("ViT-B/32", device=device)

In [None]:
import torch
from torch.utils.data import Dataset
from PIL import Image
import pandas as pd
from tqdm import tqdm

image_dir = "/content/data"

class CustomDataset(Dataset):
    def __init__(self, dataframe, augmentation=True):
        self.dataframe = dataframe
        self.data = []

        # Loop over rows to create the pairs
        text_cols = ['text_1', 'text_2', 'text_3']
        image_cols = ['img_1', 'img_2', 'img_3', 'img_4']
        bs_cols = ['blenshape_score_1', 'blenshape_score_2', 'blenshape_score_3', 'blenshape_score_4']

        for i in tqdm(range(len(self.dataframe))):
            for j, bs_col in enumerate(bs_cols):
                if not pd.isna(dataframe[bs_col].iloc[i]):
                    blendshape_score = dataframe[bs_col].iloc[i]
                    bs_arr = np.array(blendshape_score.strip("[]").split(), dtype=float)
                    mirror_bs_arr = self.mirror_blenshape_score(bs_arr)

                    img_col = image_cols[j]
                    if not pd.isna(dataframe[img_col].iloc[i]):
                        image_path = os.path.join(image_dir, dataframe[img_col].iloc[i])
                        image_embedding = self.get_image_embedding(image_path)
                        for text_col in text_cols:
                            text = dataframe[text_col].iloc[i]
                            text_embedding = self.get_text_embedding(text)
                            self.data.append((text_embedding, image_embedding, torch.from_numpy(bs_arr[1:]).to(device)))

                        if augmentation == True:
                            mirror_image_embedding = self.gent_mirror_image_embedding(image_path)
                            for text_col in text_cols:
                                text = dataframe[text_col].iloc[i]
                                mirror_text_embedding = self.get_text_embedding(mirror_text(text))
                                self.data.append((mirror_text_embedding, mirror_image_embedding, torch.from_numpy(mirror_bs_arr[1:]).to(device)))




    def get_text_embedding(self, text):
        inputs = clip.tokenize(text).to(device)
        with torch.no_grad():
            outputs = clip_model.encode_text(inputs)
        return outputs

    def get_image_embedding(self, image_path):
        inputs = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
        with torch.no_grad():
            outputs = clip_model.encode_image(inputs)
        return outputs

    def gent_mirror_image_embedding(self, image_path):
        inputs = preprocess(Image.open(image_path).transpose(Image.FLIP_LEFT_RIGHT)).unsqueeze(0).to(device)
        with torch.no_grad():
            outputs = clip_model.encode_image(inputs)
        return outputs

    def mirror_blenshape_score(self, bs):
        """Mirrors the blendshape array based on predefined symmetry indices."""
        return bs[bs_mirror_idx]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text_embedding, image_embedding, blendshape_score = self.data[idx]
        return text_embedding, image_embedding, blendshape_score



In [None]:
import pandas as pd
train_df = pd.read_csv("/content/data/train_data.csv")
val_df = pd.read_csv("/content/data/val_data.csv")
test_df = pd.read_csv("/content/data/test_data.csv")

In [None]:
train_dataset = CustomDataset(train_df)
val_dataset = CustomDataset(val_df, augmentation=False)
test_dataset = CustomDataset(test_df, augmentation=False)

# Model

In [4]:
class VAE(nn.Module):
    def __init__(self, input_size, hidden_size, clip_embedding_size):
        super(VAE, self).__init__()

        # Encoder layers
        self.encoder = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU()
        )

        # Latent space layers
        self.fc_mu = nn.Linear(hidden_size, clip_embedding_size)
        self.fc_logvar = nn.Linear(hidden_size, clip_embedding_size)

        # Decoder layers
        self.decoder = nn.Sequential(
            nn.Linear(clip_embedding_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout_prob),
            nn.Linear(hidden_size, input_size),
            nn.Sigmoid()
        )

    def encode(self, x):
        hidden = self.encoder(x)
        mu = self.fc_mu(hidden)
        logvar = self.fc_logvar(hidden)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        epsilon = torch.randn_like(std)
        return mu + epsilon * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x, text_embeddings):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_reconstructed = self.decode(z)
        return x_reconstructed, mu, logvar, z

    def generate(self, text):
        inputs = clip.tokenize(text).to(device)
        with torch.no_grad():
            text_embedding = clip_model.encode_text(inputs).squeeze(0).float().to(device)
            bs_score_pred = self.decode(text_embedding).cpu().numpy()
        return np.around(bs_score_pred, decimals=3)

# Train

In [5]:
def train_vae(model, train_loader, val_loader, optimizer, device, num_epochs=30,
              lambda_text=0.004, lambda_image=0.002, verbose=True):
    model.to(device)
    best_loss = float('inf')
    best_model_state = None

    for epoch in range(num_epochs):
        model.train()
        train_total_loss, train_recon_loss, train_text_loss, train_image_loss = 0, 0, 0, 0

        for text_emb, img_emb, blendshape in tqdm(train_loader, disable=not verbose):
            text_emb = text_emb.squeeze(0).to(device).float()
            img_emb = img_emb.squeeze(0).to(device).float()
            blendshape = blendshape.to(device).float()

            recon, mu, logvar, z = model(blendshape, text_emb)

            # Losses
            mse = F.mse_loss(recon, blendshape, reduction='mean')

            z_norm = F.normalize(z, dim=-1)
            text_emb_norm = F.normalize(text_emb, dim=-1)
            img_emb_norm = F.normalize(img_emb, dim=-1)

            text_sim_loss = (1 - F.cosine_similarity(text_emb_norm, z_norm).mean())
            img_sim_loss = (1 - F.cosine_similarity(img_emb_norm, z_norm).mean())

            loss = mse + lambda_text * text_sim_loss + lambda_image * img_sim_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_total_loss += loss.item()
            train_recon_loss += mse.item()
            train_text_loss += text_sim_loss.item()
            train_image_loss += img_sim_loss.item()

        model.eval()
        val_total_loss, val_recon_loss, val_text_loss, val_image_loss = 0, 0, 0, 0

        for text_emb, img_emb, blendshape in tqdm(val_loader, disable=not verbose):
            text_emb = text_emb.squeeze(0).to(device).float()
            img_emb = img_emb.squeeze(0).to(device).float()
            blendshape = blendshape.to(device).float()

            recon, mu, logvar, z = model(blendshape, text_emb)

            # Losses
            mse = F.mse_loss(recon, blendshape, reduction='mean')

            z_norm = F.normalize(z, dim=-1)
            text_emb_norm = F.normalize(text_emb, dim=-1)
            img_emb_norm = F.normalize(img_emb, dim=-1)

            text_sim_loss = (1 - F.cosine_similarity(text_emb_norm, z_norm).mean())
            img_sim_loss = (1 - F.cosine_similarity(img_emb_norm, z_norm).mean())

            loss = mse + lambda_text * text_sim_loss + lambda_image * img_sim_loss

            val_total_loss += loss.item()
            val_recon_loss += mse.item()
            val_text_loss += text_sim_loss.item()
            val_image_loss += img_sim_loss.item()

        if verbose:
            print(f"Epoch {epoch+1}/{num_epochs}")
            print(f"Train - Total: {train_total_loss:.2f}, Recon: {train_recon_loss:.2f}, Text: {train_text_loss:.2f}, Image: {train_image_loss:.2f}")
            print(f"Val   - Total: {val_total_loss:.2f}, Recon: {val_recon_loss:.2f}, Text: {val_text_loss:.2f}, Image: {val_image_loss:.2f}")
            print("-" * 60)

        # Save best model
        if val_recon_loss <= best_loss:
            best_loss = val_recon_loss
            best_model_state = model.state_dict()

    model.load_state_dict(best_model_state)
    return best_loss, best_model_state

In [None]:
# Find Best Lambda Values

In [None]:
lambda_grid = [0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1.0, 2.0]
results = []

input_size = 51
hidden_size = 256
clip_embedding_size = 512

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

for i in range(len(lambda_grid)):
        lambda_t = lambda_grid[i]
        lambda_i = lambda_grid[i]
        print(f"Training with lambda_text={lambda_t}, lambda_image={lambda_i}")

        model = VAE(input_size, hidden_size, clip_embedding_size).to(device)
        optimizer = optim.Adam(model.parameters(), lr=1e-3)

        val_loss, _ = train_vae(model, train_loader, val_loader, optimizer, device,
                               num_epochs=10,
                               lambda_text=lambda_t,
                               lambda_image=lambda_i,
                               verbose=False)

        results.append((lambda_t, lambda_i, val_loss))
        print(f"Loss: {val_loss:.2f}")

# Sort results by loss
results.sort(key=lambda x: x[2])
print("\nBest configuration:")
print(f"lambda_text = {results[0][0]}, lambda_image = {results[0][1]}, loss = {results[0][2]:.2f}")

In [None]:
lambda_grid = [0.001, 0.002, 0.003, 0.004, 0.005]
results = []

input_size = 51
hidden_size = 256
clip_embedding_size = 512

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

for lambda_t in lambda_grid:
    for lambda_i in lambda_grid:
        print(f"Training with lambda_text={lambda_t}, lambda_image={lambda_i}")

        model = VAE(input_size, hidden_size, clip_embedding_size).to(device)
        optimizer = optim.Adam(model.parameters(), lr=1e-3)

        val_loss, _ = train_vae(model, train_loader, val_loader, optimizer, device,
                               num_epochs=10,
                               lambda_text=lambda_t,
                               lambda_image=lambda_i,
                               verbose=False)

        results.append((lambda_t, lambda_i, val_loss))
        print(f"Loss: {val_loss:.2f}")

# Sort results by loss
results.sort(key=lambda x: x[2])
print("\nBest configuration:")
print(f"lambda_text = {results[0][0]}, lambda_image = {results[0][1]}, loss = {results[0][2]:.2f}")

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

input_size = 51
hidden_size = 256
clip_embedding_size = 512

l_text=0.004
l_image=0.002

model = VAE(input_size, hidden_size, clip_embedding_size).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

best_loss, best_model_state = train_vae(model, train_loader, val_loader, optimizer, device,
                               num_epochs=50,
                               lambda_text=l_text,
                               lambda_image=l_image,
                               verbose=True)

# Evaluate

In [12]:
def evaluate(test_loader, lambda_text=0.004, lambda_image=0.002):
    model.eval()
    test_total_loss, test_recon_loss, test_text_loss, test_image_loss = 0, 0, 0, 0
    for text_emb, img_emb, blendshape in tqdm(test_loader):
        text_emb = text_emb.squeeze(0).to(device).float()
        img_emb = img_emb.squeeze(0).to(device).float()
        blendshape = blendshape.to(device).float()

        recon, mu, logvar, z = model(blendshape, text_emb)

        # Losses
        mse = F.mse_loss(recon, blendshape, reduction='mean')

        z_norm = F.normalize(z, dim=-1)
        text_emb_norm = F.normalize(text_emb, dim=-1)
        img_emb_norm = F.normalize(img_emb, dim=-1)

        text_sim_loss = (1 - F.cosine_similarity(text_emb_norm, z_norm).mean())
        img_sim_loss = (1 - F.cosine_similarity(img_emb_norm, z_norm).mean())

        loss = mse + lambda_text * text_sim_loss + lambda_image * img_sim_loss

        test_total_loss += loss.item()
        test_recon_loss += mse.item()
        test_text_loss += text_sim_loss.item()
        test_image_loss += img_sim_loss.item()
    return test_total_loss, test_recon_loss, test_text_loss, test_image_loss

In [None]:
test_total_loss, test_recon_loss, test_text_loss, test_image_loss = evaluate(test_loader)

In [None]:
test_recon_loss

In [None]:
torch.save(best_model_state, "/content/drive/MyDrive/best_vae_model.pt")

# Inference

In [None]:
input_size = 51
hidden_size = 256
clip_embedding_size = 512

model = VAE(input_size, hidden_size, clip_embedding_size)
model.load_state_dict(torch.load("best_vae_model.pt"))
model.to(device)
model.eval()

In [None]:
model.generate("The face is expressing a sense of overwhelming joy, radiating happiness and contentment.")