In [None]:
!pip install datasets
!pip install huggingface_hub
!pip3 install torch torchvision
!pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git

In [None]:
from huggingface_hub import login
# from datasets import load_dataset
from huggingface_hub import hf_hub_download

import zipfile
import os
from os.path import isfile, join
import re

import pandas as pd
import numpy as np
import pickle as pkl

import torch
import clip
from PIL import Image
from torch.utils.data import Dataset
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

# Download and Prepare Data

In [None]:
login(token="")
# !huggingface-cli login

In [None]:
files_to_download = [
    "data/emo3d_data.parquet",
    "data/train_data.csv",
    "data/val_data.csv",
    "data/test_data.csv",
    "data/anger_images.zip",
    "data/contempt_images.zip",
    "data/disgust_images.zip",
    "data/fear_images.zip",
    "data/sadness_images.zip",
    # # "data/primitive_emotions.csv",
    # # "data/primitive_emotions.zip",
    "data/prompt1_images_0_1000.zip",
    "data/prompt1_images_1000_2000.zip",
    "data/prompt1_images_2000_3000.zip",
    "data/prompt1_images_3000_4000.zip",
    "data/prompt1_images_4000_5000.zip",
    "data/prompt1_images_5000_6000.zip",
    "data/prompt1_images_6000_7000.zip",
    "data/prompt1_images_7000_9000.zip",
    "data/prompt1_images_9000_10000.zip",
    "data/prompt2_images_0_1000.zip",
    "data/prompt2_images_1000_2000.zip",
    "data/prompt2_images_2000_3000.zip",
    "data/prompt2_images_3000_4000.zip",
    "data/prompt2_images_4000_5000.zip",
    "data/prompt2_images_5000_6000.zip",
]


repo_id = "llm-lab/Emo3D"
local_dir = "/content"
# Download all files
for file_path in files_to_download:
    print(f"Downloading {file_path}...")
    hf_hub_download(repo_type='dataset', repo_id=repo_id, filename=file_path, local_dir=local_dir)
    print(f"{file_path} downloaded successfully.")

In [None]:
for file_path in files_to_download:
    local_file_path = os.path.join(local_dir, file_path)  # Full path to the file
    print(local_file_path)
    if file_path.endswith(".zip"):  # Check if it's a zip file
        print(f"Unzipping {file_path}...")

        unzip_dir = file_path[:-4]
        with zipfile.ZipFile(local_file_path, 'r') as zip_ref:
            zip_ref.extractall(unzip_dir)  # Extract to the same directory
        os.remove(local_file_path)
        print(f"{file_path} extracted successfully.")

# Utilities

In [None]:
blendshapes_names = ['_neutral', 'browDownLeft', 'browDownRight', 'browInnerUp', 'browOuterUpLeft', 'browOuterUpRight', 'cheekPuff',
 'cheekSquintLeft', 'cheekSquintRight', 'eyeBlinkLeft', 'eyeBlinkRight', 'eyeLookDownLeft', 'eyeLookDownRight', 'eyeLookInLeft',
 'eyeLookInRight', 'eyeLookOutLeft', 'eyeLookOutRight', 'eyeLookUpLeft', 'eyeLookUpRight', 'eyeSquintLeft', 'eyeSquintRight',
 'eyeWideLeft', 'eyeWideRight', 'jawForward', 'jawLeft', 'jawOpen', 'jawRight', 'mouthClose', 'mouthDimpleLeft', 'mouthDimpleRight',
 'mouthFrownLeft', 'mouthFrownRight', 'mouthFunnel', 'mouthLeft', 'mouthLowerDownLeft', 'mouthLowerDownRight', 'mouthPressLeft',
 'mouthPressRight', 'mouthPucker', 'mouthRight', 'mouthRollLower', 'mouthRollUpper', 'mouthShrugLower', 'mouthShrugUpper', 'mouthSmileLeft',
 'mouthSmileRight', 'mouthStretchLeft', 'mouthStretchRight', 'mouthUpperUpLeft', 'mouthUpperUpRight', 'noseSneerLeft', 'noseSneerRight']


bs_mirror_idx = []
for i, name in enumerate(blendshapes_names):
    if "Left" in name:
        right_name = name.replace("Left", "Right")
        right_index = blendshapes_names.index(right_name)
        bs_mirror_idx.append(right_index)
    elif "Right" in name:
        left_name = name.replace("Right", "Left")
        left_index = blendshapes_names.index(left_name)
        bs_mirror_idx.append(left_index)
    else:
        bs_mirror_idx.append(i)


In [None]:
replacements = {
        "left": "right",
        "right": "left",
        "Left": "Right",
        "Right": "Left"
    }

regex = re.compile("(%s)" % "|".join(map(re.escape, replacements.keys())))
def mirror_text(text):
    return regex.sub(lambda mo: replacements[mo.group()], text)

# Custom Dataset

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
clip_model, preprocess = clip.load("ViT-B/32", device=device)

In [None]:
image_dir = "/content/data"

class CustomDataset(Dataset):
    def __init__(self, dataframe, augmentation=True):
        self.dataframe = dataframe
        self.data = []

        # Loop over rows to create the pairs
        text_cols = ['text_1', 'text_2', 'text_3']
        image_cols = ['img_1', 'img_2', 'img_3', 'img_4']
        bs_cols = ['blenshape_score_1', 'blenshape_score_2', 'blenshape_score_3', 'blenshape_score_4']

        for i in tqdm(range(len(self.dataframe))):
            for j, bs_col in enumerate(bs_cols):
                if not pd.isna(dataframe[bs_col].iloc[i]):
                    blendshape_score = dataframe[bs_col].iloc[i]
                    bs_arr = np.array(blendshape_score.strip("[]").split(), dtype=float)
                    mirror_bs_arr = self.mirror_blenshape_score(bs_arr)
                    # Text embeddings
                    for text_col in text_cols:
                        text = dataframe[text_col].iloc[i]
                        text_embedding = self.get_text_embedding(text)
                        self.data.append((text_embedding, torch.from_numpy(bs_arr[1:]).to(device)))

                        if augmentation == True:
                            # Mirror blendshape score
                            mirror_text_embedding = self.get_text_embedding(mirror_text(text))
                            self.data.append((mirror_text_embedding, torch.from_numpy(mirror_bs_arr[1:]).to(device)))

                    # Image embeddings
                    img_col = image_cols[j]
                    if not pd.isna(dataframe[img_col].iloc[i]):
                        image_path = os.path.join(image_dir, dataframe[img_col].iloc[i])
                        image_embedding = self.get_image_embedding(image_path)
                        self.data.append((image_embedding, torch.from_numpy(bs_arr[1:]).to(device)))

                        if augmentation == True:
                            # Mirror blendshape score
                            mirror_image_embedding = self.gent_mirror_image_embedding(image_path)
                            self.data.append((mirror_image_embedding, torch.from_numpy(mirror_bs_arr[1:]).to(device)))


    def get_text_embedding(self, text):
        inputs = clip.tokenize(text).to(device)
        with torch.no_grad():
            outputs = clip_model.encode_text(inputs)
        return outputs

    def get_image_embedding(self, image_path):
        inputs = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
        with torch.no_grad():
            outputs = clip_model.encode_image(inputs)
        return outputs

    def gent_mirror_image_embedding(self, image_path):
        inputs = preprocess(Image.open(image_path).transpose(Image.FLIP_LEFT_RIGHT)).unsqueeze(0).to(device)
        with torch.no_grad():
            outputs = clip_model.encode_image(inputs)
        return outputs

    def mirror_blenshape_score(self, bs):
        """Mirrors the blendshape array based on predefined symmetry indices."""
        return bs[bs_mirror_idx]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        embedding, blendshape_score = self.data[idx]
        return embedding, blendshape_score



In [None]:
train_df = pd.read_csv("/content/data/train_data.csv")
val_df = pd.read_csv("/content/data/val_data.csv")
test_df = pd.read_csv("/content/data/test_data.csv")

In [None]:
train_dataset = CustomDataset(train_df)
val_dataset = CustomDataset(val_df, augmentation=False)
test_dataset = CustomDataset(test_df, augmentation=False)

# Model

In [None]:
class ClipMLP(nn.Module):
    def __init__(self, input_dim=512, hidden_dim1=256, hidden_dim2=128, output_dim=51):
        super(ClipMLP, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim1),
            nn.ReLU(),
            nn.Linear(hidden_dim1, hidden_dim2),
            nn.ReLU(),
            nn.Linear(hidden_dim2, output_dim),
            nn.Sigmoid()  # Using sigmoid to ensure outputs between 0 and 1
        )

    def forward(self, x):
        return self.network(x)

# Train

In [None]:
def train_model(model, train_loader, val_loader, criterion, optimizer, device, num_epochs=50):
    best_val_loss = float('inf')

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        for embeddings, blendshape_scores in train_loader:
            # Move data to device
            embeddings = embeddings.to(device).float()
            blendshape_scores = blendshape_scores.to(device).float()

            optimizer.zero_grad()

            # Forward pass
            outputs = model(embeddings).squeeze()
            loss = criterion(outputs, blendshape_scores)

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for embeddings, blendshape_scores in val_loader:
                # Move data to device
                embeddings = embeddings.to(device).float()
                blendshape_scores = blendshape_scores.to(device).float()

                # Forward pass
                outputs = model(embeddings).squeeze()
                loss = criterion(outputs, blendshape_scores)

                val_loss += loss.item()

        # Print epoch statistics
        train_loss /= len(train_loader)
        val_loss /= len(val_loader)
        print(f'Epoch [{epoch+1}/{num_epochs}], '
              f'Train Loss: {train_loss:.4f}, '
              f'Val Loss: {val_loss:.4f}')

        # Save best model
        if val_loss <= best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_model.pth')

    return model

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create model
model = ClipMLP().to(device)

# Define loss function and optimizer
criterion = nn.MSELoss()  # Mean Squared Error Loss
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Train the model
trained_model = train_model(
        model,
        train_loader,
        val_loader,
        criterion,
        optimizer,
        device
    )


# Evaluate

In [None]:
# Evaluate on test set
trained_model.eval()
test_loss = 0.0
with torch.no_grad():
    for embeddings, blendshape_scores in test_loader:
        embeddings = embeddings.to(device).float()
        blendshape_scores = blendshape_scores.to(device).float()

        outputs = trained_model(embeddings).squeeze()
        loss = criterion(outputs, blendshape_scores)
        test_loss += loss.item()

test_loss /= len(test_loader)
print(f'Test Loss: {test_loss:.4f}')

torch.save(trained_model.state_dict(), 'final_model.pth')

In [None]:
best_model = ClipMLP().to(device)
best_model.load_state_dict(torch.load('best_model.pth', weights_only=True))
best_model.eval()

test_loss = 0.0
with torch.no_grad():
    for embeddings, blendshape_scores in test_loader:
        embeddings = embeddings.to(device).float()
        blendshape_scores = blendshape_scores.to(device).float()

        outputs = best_model(embeddings).squeeze()
        loss = criterion(outputs, blendshape_scores)
        test_loss += loss.item()

test_loss /= len(test_loader)
print(f'Test Loss: {test_loss:.4f}')

# Inference

In [None]:
text = "The face is expressing a sense of overwhelming joy, radiating happiness and contentment."
inputs = clip.tokenize(text).to(device)
with torch.no_grad():
    embedding = clip_model.encode_text(inputs).float()
bs_score_pred = best_model(embedding).cpu().detach().numpy()
print(np.around(bs_score_pred, decimals=2, out=None))