```
conda init
conda create -n train-models python=3.10 -y
conda activate train-models
pip install torch==2.6.0+cu118 torchaudio==2.6.0+cu118 --extra-index-url https://download.pytorch.org/whl/cu118
pip install pandas ipykernel tqdm
```

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import pandas as pd
from pathlib import Path
from tqdm import tqdm
import torch.nn.functional as F

In [2]:
# Paths
DATA_DIR = Path("../VoxCeleb2")

JPG_EMB_PATH = DATA_DIR / "train/jpg_train_embeddings.pt"
JPG_VAL_EMB_PATH = DATA_DIR / "val/jpg_val_embeddings.pt"
WAV_EMB_PATH = DATA_DIR / "train/wav_train_embeddings.pt"
WAV_VAL_EMB_PATH = DATA_DIR / "val/wav_val_embeddings.pt"

In [3]:
# Load embeddings
jpg_train_emb = torch.load(JPG_EMB_PATH)
jpg_val_emb   = torch.load(JPG_VAL_EMB_PATH)

wav_train_emb = torch.load(WAV_EMB_PATH)
wav_val_emb   = torch.load(WAV_VAL_EMB_PATH)

print("Train JPG:", jpg_train_emb.shape, "Train WAV:", wav_train_emb.shape)
print("Val JPG  :", jpg_val_emb.shape, "Val WAV  :", wav_val_emb.shape)

Train JPG: torch.Size([29498, 512]) Train WAV: torch.Size([29498, 512])
Val JPG  : torch.Size([3381, 512]) Val WAV  : torch.Size([3381, 512])


In [None]:
# Define MLP model
class FaceToSpeechMLP(nn.Module):
    def __init__(self, dropout_rate=0.3):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            
            nn.Linear(1024, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            
            nn.Linear(1024, 512)
        )
        
    def forward(self, x):
        return self.model(x)

In [11]:
# Dataset / DataLoader
BATCH_SIZE = 128

train_ds = TensorDataset(jpg_train_emb, wav_train_emb)
val_ds   = TensorDataset(jpg_val_emb,   wav_val_emb)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE)

In [12]:
# Training Setup
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [13]:
# Training parameters
model = FaceToSpeechMLP().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

NUM_EPOCHS = 50
best_val_loss = float('inf')
MODEL_PATH = DATA_DIR / "face2speech_best.pt"

In [14]:
# Define loss function
def cosine_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    pred_norm = F.normalize(pred, dim=1)
    target_norm = F.normalize(target, dim=1)
    loss = 1 - (pred_norm * target_norm).sum(dim=1).mean()
    return loss

In [15]:
# Training loop with validation
for epoch in range(NUM_EPOCHS):
    model.train()
    train_loss = 0
    for x, y in tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}"):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x)

        loss = cosine_loss(out, y)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * x.size(0)
    
    train_loss /= len(train_loader.dataset)
    
    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for x_val, y_val in val_loader:
            x_val, y_val = x_val.to(device), y_val.to(device)
            val_out = model(x_val)

            loss = cosine_loss(val_out, y_val)
            val_loss += loss.item() * x_val.size(0)
            
    val_loss /= len(val_loader.dataset)
    
    print(f"Epoch {epoch+1}/{NUM_EPOCHS} - Train Loss: {train_loss:.6f} - Val Loss: {val_loss:.6f}")
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), MODEL_PATH)
        print("Saved best model")


Epoch 1/50: 100%|██████████| 231/231 [00:00<00:00, 267.38it/s]


Epoch 1/50 - Train Loss: 0.338370 - Val Loss: 0.596779
Saved best model


Epoch 2/50: 100%|██████████| 231/231 [00:00<00:00, 377.75it/s]


Epoch 2/50 - Train Loss: 0.262117 - Val Loss: 0.594339
Saved best model


Epoch 3/50: 100%|██████████| 231/231 [00:00<00:00, 379.19it/s]


Epoch 3/50 - Train Loss: 0.247436 - Val Loss: 0.598952


Epoch 4/50: 100%|██████████| 231/231 [00:00<00:00, 382.29it/s]


Epoch 4/50 - Train Loss: 0.238951 - Val Loss: 0.613108


Epoch 5/50: 100%|██████████| 231/231 [00:00<00:00, 373.65it/s]


Epoch 5/50 - Train Loss: 0.232892 - Val Loss: 0.609696


Epoch 6/50: 100%|██████████| 231/231 [00:00<00:00, 346.90it/s]


Epoch 6/50 - Train Loss: 0.228027 - Val Loss: 0.614111


Epoch 7/50: 100%|██████████| 231/231 [00:00<00:00, 380.88it/s]


Epoch 7/50 - Train Loss: 0.224244 - Val Loss: 0.623288


Epoch 8/50: 100%|██████████| 231/231 [00:00<00:00, 366.51it/s]


Epoch 8/50 - Train Loss: 0.220476 - Val Loss: 0.622814


Epoch 9/50: 100%|██████████| 231/231 [00:00<00:00, 369.48it/s]


Epoch 9/50 - Train Loss: 0.217753 - Val Loss: 0.637574


Epoch 10/50: 100%|██████████| 231/231 [00:00<00:00, 366.45it/s]


Epoch 10/50 - Train Loss: 0.215040 - Val Loss: 0.633902


Epoch 11/50: 100%|██████████| 231/231 [00:00<00:00, 356.20it/s]


Epoch 11/50 - Train Loss: 0.212665 - Val Loss: 0.634957


Epoch 12/50: 100%|██████████| 231/231 [00:00<00:00, 357.72it/s]


Epoch 12/50 - Train Loss: 0.210650 - Val Loss: 0.637900


Epoch 13/50: 100%|██████████| 231/231 [00:00<00:00, 372.85it/s]


Epoch 13/50 - Train Loss: 0.208331 - Val Loss: 0.635914


Epoch 14/50: 100%|██████████| 231/231 [00:00<00:00, 371.22it/s]


Epoch 14/50 - Train Loss: 0.206749 - Val Loss: 0.643960


Epoch 15/50: 100%|██████████| 231/231 [00:00<00:00, 343.38it/s]


Epoch 15/50 - Train Loss: 0.205229 - Val Loss: 0.645503


Epoch 16/50: 100%|██████████| 231/231 [00:00<00:00, 374.67it/s]


Epoch 16/50 - Train Loss: 0.203562 - Val Loss: 0.648273


Epoch 17/50: 100%|██████████| 231/231 [00:00<00:00, 367.83it/s]


Epoch 17/50 - Train Loss: 0.202113 - Val Loss: 0.643527


Epoch 18/50: 100%|██████████| 231/231 [00:00<00:00, 362.48it/s]


Epoch 18/50 - Train Loss: 0.200840 - Val Loss: 0.646158


Epoch 19/50: 100%|██████████| 231/231 [00:00<00:00, 377.85it/s]


Epoch 19/50 - Train Loss: 0.199611 - Val Loss: 0.646544


Epoch 20/50: 100%|██████████| 231/231 [00:00<00:00, 363.88it/s]


Epoch 20/50 - Train Loss: 0.198392 - Val Loss: 0.647218


Epoch 21/50: 100%|██████████| 231/231 [00:00<00:00, 364.97it/s]


Epoch 21/50 - Train Loss: 0.197280 - Val Loss: 0.637606


Epoch 22/50: 100%|██████████| 231/231 [00:00<00:00, 371.88it/s]


Epoch 22/50 - Train Loss: 0.196076 - Val Loss: 0.648072


Epoch 23/50: 100%|██████████| 231/231 [00:00<00:00, 381.27it/s]


Epoch 23/50 - Train Loss: 0.195292 - Val Loss: 0.639217


Epoch 24/50: 100%|██████████| 231/231 [00:00<00:00, 384.29it/s]


Epoch 24/50 - Train Loss: 0.194501 - Val Loss: 0.648960


Epoch 25/50: 100%|██████████| 231/231 [00:00<00:00, 373.09it/s]


Epoch 25/50 - Train Loss: 0.193862 - Val Loss: 0.649954


Epoch 26/50: 100%|██████████| 231/231 [00:00<00:00, 383.33it/s]


Epoch 26/50 - Train Loss: 0.193023 - Val Loss: 0.650716


Epoch 27/50: 100%|██████████| 231/231 [00:00<00:00, 383.32it/s]


Epoch 27/50 - Train Loss: 0.192195 - Val Loss: 0.643922


Epoch 28/50: 100%|██████████| 231/231 [00:00<00:00, 373.48it/s]


Epoch 28/50 - Train Loss: 0.191536 - Val Loss: 0.645145


Epoch 29/50: 100%|██████████| 231/231 [00:00<00:00, 384.04it/s]


Epoch 29/50 - Train Loss: 0.190773 - Val Loss: 0.648673


Epoch 30/50: 100%|██████████| 231/231 [00:00<00:00, 373.21it/s]


Epoch 30/50 - Train Loss: 0.190139 - Val Loss: 0.644216


Epoch 31/50: 100%|██████████| 231/231 [00:00<00:00, 377.13it/s]


Epoch 31/50 - Train Loss: 0.189412 - Val Loss: 0.651130


Epoch 32/50: 100%|██████████| 231/231 [00:00<00:00, 375.31it/s]


Epoch 32/50 - Train Loss: 0.188632 - Val Loss: 0.651763


Epoch 33/50: 100%|██████████| 231/231 [00:00<00:00, 386.78it/s]


Epoch 33/50 - Train Loss: 0.188262 - Val Loss: 0.650233


Epoch 34/50: 100%|██████████| 231/231 [00:00<00:00, 382.09it/s]


Epoch 34/50 - Train Loss: 0.187416 - Val Loss: 0.652798


Epoch 35/50: 100%|██████████| 231/231 [00:00<00:00, 371.96it/s]


Epoch 35/50 - Train Loss: 0.186981 - Val Loss: 0.651688


Epoch 36/50: 100%|██████████| 231/231 [00:00<00:00, 387.99it/s]


Epoch 36/50 - Train Loss: 0.186615 - Val Loss: 0.650304


Epoch 37/50: 100%|██████████| 231/231 [00:00<00:00, 339.09it/s]


Epoch 37/50 - Train Loss: 0.186123 - Val Loss: 0.651840


Epoch 38/50: 100%|██████████| 231/231 [00:00<00:00, 348.03it/s]


Epoch 38/50 - Train Loss: 0.185538 - Val Loss: 0.650379


Epoch 39/50: 100%|██████████| 231/231 [00:00<00:00, 382.47it/s]


Epoch 39/50 - Train Loss: 0.185245 - Val Loss: 0.648449


Epoch 40/50: 100%|██████████| 231/231 [00:00<00:00, 373.59it/s]


Epoch 40/50 - Train Loss: 0.184701 - Val Loss: 0.652681


Epoch 41/50: 100%|██████████| 231/231 [00:00<00:00, 366.66it/s]


Epoch 41/50 - Train Loss: 0.184305 - Val Loss: 0.646647


Epoch 42/50: 100%|██████████| 231/231 [00:00<00:00, 378.26it/s]


Epoch 42/50 - Train Loss: 0.183853 - Val Loss: 0.650185


Epoch 43/50: 100%|██████████| 231/231 [00:00<00:00, 378.19it/s]


Epoch 43/50 - Train Loss: 0.183319 - Val Loss: 0.650321


Epoch 44/50: 100%|██████████| 231/231 [00:00<00:00, 369.85it/s]


Epoch 44/50 - Train Loss: 0.182930 - Val Loss: 0.653945


Epoch 45/50: 100%|██████████| 231/231 [00:00<00:00, 373.99it/s]


Epoch 45/50 - Train Loss: 0.182627 - Val Loss: 0.654084


Epoch 46/50: 100%|██████████| 231/231 [00:00<00:00, 378.92it/s]


Epoch 46/50 - Train Loss: 0.182194 - Val Loss: 0.655859


Epoch 47/50: 100%|██████████| 231/231 [00:00<00:00, 359.94it/s]


Epoch 47/50 - Train Loss: 0.182099 - Val Loss: 0.658717


Epoch 48/50: 100%|██████████| 231/231 [00:00<00:00, 377.69it/s]


Epoch 48/50 - Train Loss: 0.181619 - Val Loss: 0.655232


Epoch 49/50: 100%|██████████| 231/231 [00:00<00:00, 376.48it/s]


Epoch 49/50 - Train Loss: 0.181165 - Val Loss: 0.653200


Epoch 50/50: 100%|██████████| 231/231 [00:00<00:00, 366.39it/s]

Epoch 50/50 - Train Loss: 0.181008 - Val Loss: 0.655941



