In [1]:
import timm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from PIL import Image

from dataset import FacialLandmarkDataset

In [2]:
class FLDModel(nn.Module):
    def __init__(self, vit_model_name='vit_base_patch16_224'):
        super(FLDModel, self).__init__()
        # Load pre-trained ViT model
        self.vit = timm.create_model(vit_model_name, pretrained=True)
        
        # Remove the classification head
        self.vit.head = nn.Identity()
        
        # Custom head for facial landmark detection
        self.fc1 = nn.Linear(self.vit.num_features, 256)  # Example dimensions
        self.fc2 = nn.Linear(256, 68 * 2)  # 68 landmarks with x and y coordinates
    
    def forward(self, x):
        # Extract features from ViT
        features = self.vit(x)
        
        # Pass through the custom head
        x = torch.relu(self.fc1(features))
        landmarks = self.fc2(x)
        
        return landmarks

In [3]:
# Define the transformations for ViT input
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to 224x224
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Example function to preprocess an image
def preprocess_image(image):
    image = Image.fromarray(image)
    image = preprocess(image)
    return image.unsqueeze(0)  # Add batch dimension

In [4]:
dataset = FacialLandmarkDataset(root_dir='archive/ibug_300W_large_face_landmark_dataset/afw', transform=preprocess)

In [5]:
# Split dataset into training and validation sets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Define DataLoaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4)

In [6]:
# Initialize the model
model = FLDModel()

# Define loss function and optimizer
criterion = nn.MSELoss()  # For regression tasks
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Define learning rate scheduler
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

# Move model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

FLDModel(
  (vit): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (patch_drop): Identity()
    (norm_pre): Identity()
    (blocks): Sequential(
      (0): Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (q_norm): Identity()
          (k_norm): Identity()
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=768, out_features=768, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (ls1): Identity()
        (drop_path1): Identity()
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (act): GELU(approximate='none')
     

In [7]:
def mean_absolute_error(predictions, targets):
    return torch.mean(torch.abs(predictions - targets))

In [8]:
# Training function for one epoch
def train_one_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for images, landmarks in dataloader:
        images = images.to(device)
        landmarks = landmarks.to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(images)
        
        # Compute loss
        loss = criterion(outputs, landmarks)
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)
    
    epoch_loss = running_loss / len(dataloader.dataset)
    return epoch_loss

# Evaluation function
def evaluate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    with torch.no_grad():
        for images, landmarks in dataloader:
            images = images.to(device)
            landmarks = landmarks.to(device)
            
            # Forward pass
            outputs = model(images)
            
            # Compute loss
            loss = criterion(outputs, landmarks)
            
            running_loss += loss.item() * images.size(0)
    
    epoch_loss = running_loss / len(dataloader.dataset)
    return epoch_loss

# Training loop with early stopping
num_epochs = 10
best_val_loss = float('inf')
patience = 5
counter = 0

for epoch in range(num_epochs):
    train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
    val_loss = evaluate(model, val_loader, criterion, device)
    
    scheduler.step()
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        counter = 0
        # Save the model checkpoint
        torch.save(model.state_dict(), 'best_model.pth')
    else:
        counter += 1
    
    if counter >= patience:
        print(f'Early stopping at epoch {epoch+1}')
        break
    
    print(f'Epoch {epoch+1}/{num_epochs}')
    print(f'Training Loss: {train_loss:.4f}')
    print(f'Validation Loss: {val_loss:.4f}')

Epoch 1/20
Training Loss: 11618.5310
Validation Loss: 11772.4215
Epoch 2/20
Training Loss: 10943.7318
Validation Loss: 10898.9764
Epoch 3/20
Training Loss: 9840.6596
Validation Loss: 9482.5204
Epoch 4/20
Training Loss: 8219.1955
Validation Loss: 7618.0295
Epoch 5/20
Training Loss: 6320.1807
Validation Loss: 5602.7119
Epoch 6/20
Training Loss: 4855.2679
Validation Loss: 4691.2571
Epoch 7/20
Training Loss: 4034.8001
Validation Loss: 3902.5067
Epoch 8/20
Training Loss: 3337.5858
Validation Loss: 3259.9332
Epoch 9/20
Training Loss: 2796.0760
Validation Loss: 2747.3434
Epoch 10/20
Training Loss: 2369.1809
Validation Loss: 2376.8478
Epoch 11/20
Training Loss: 2133.7253
Validation Loss: 2230.2627
Epoch 12/20
Training Loss: 2008.4122
Validation Loss: 2115.0342
Epoch 13/20
Training Loss: 1908.8226
Validation Loss: 2016.6244
Epoch 14/20
Training Loss: 1825.6532
Validation Loss: 1934.1001
Epoch 15/20
Training Loss: 1757.0931
Validation Loss: 1866.0919
Epoch 16/20
Training Loss: 1716.3059
Validati

In [10]:
# Save the best model
torch.save(model.state_dict(), 'best_model.pth')

# Load the model later
model = FLDModel()
model.load_state_dict(torch.load('best_model.pth'))
model.to(device)

  model.load_state_dict(torch.load('best_model.pth'))


FLDModel(
  (vit): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (patch_drop): Identity()
    (norm_pre): Identity()
    (blocks): Sequential(
      (0): Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (q_norm): Identity()
          (k_norm): Identity()
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=768, out_features=768, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (ls1): Identity()
        (drop_path1): Identity()
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (act): GELU(approximate='none')
     