# I-LoRA Implementation for Face Intrinsics
Using FFHQ dataset and generating pseudo-ground truth using pretrained models

In [None]:
# Install required packages
!pip install torch torchvision diffusers transformers timm omegaconf zoedepth

In [None]:
import torch
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from zoedepth.models.builder import build_model
from zoedepth.utils.config import get_config
import requests
from omegaconf import OmegaConf

## 1. Setup Pseudo Ground Truth Models

In [None]:
class IntrinsicEstimator:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Initialize depth estimator (ZoeDepth)
        PRETRAINED_ZOEDEPTH = "https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_N.pt"
        conf = OmegaConf.load("zoedepth/configs/zoedepth_nk.yaml")
        self.depth_model = build_model(conf)
        self.depth_model.load_state_dict(torch.load(PRETRAINED_ZOEDEPTH))
        self.depth_model.to(self.device).eval()
        
        # Download and setup Omnidata model for normals (simplified for example)
        PRETRAINED_NORMALS = "path_to_omnidata_normal_model"
        # Initialize normal estimator (you would need to implement this based on Omnidata)
        
        # Standard image transforms
        self.transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
        ])
    
    @torch.no_grad()
    def estimate_depth(self, image):
        """Estimate depth using ZoeDepth"""
        image_tensor = self.transform(image).unsqueeze(0).to(self.device)
        depth = self.depth_model(image_tensor)
        return depth.cpu().numpy()[0]
    
    @torch.no_grad()
    def estimate_normals(self, image):
        """Estimate surface normals using Omnidata (placeholder)"""
        # Implement normal estimation using Omnidata
        pass
    
    def visualize_estimates(self, image):
        """Visualize original image and estimated intrinsics"""
        depth = self.estimate_depth(image)
        
        plt.figure(figsize=(15, 5))
        plt.subplot(131)
        plt.imshow(image)
        plt.title('Original Image')
        
        plt.subplot(132)
        plt.imshow(depth, cmap='magma')
        plt.title('Estimated Depth')
        
        plt.tight_layout()
        plt.show()

## 2. Create Face Dataset with Pseudo Ground Truth

In [None]:
class FaceIntrinsicDataset(Dataset):
    def __init__(self, root_dir, transform=None, num_samples=250):
        """Dataset for faces with estimated intrinsics
        
        Args:
            root_dir: Directory containing FFHQ images
            transform: Optional transforms
            num_samples: Number of samples to use (paper uses as few as 250)
        """
        self.root_dir = root_dir
        self.transform = transform
        self.estimator = IntrinsicEstimator()
        
        # Get all image paths and limit to num_samples
        self.image_paths = sorted(glob.glob(f"{root_dir}/*.png"))[:num_samples]
        
        # Pre-compute pseudo ground truth
        print("Generating pseudo ground truth...")
        self.cached_intrinsics = {}
        for idx, path in enumerate(self.image_paths):
            image = Image.open(path)
            self.cached_intrinsics[path] = {
                'depth': self.estimator.estimate_depth(image),
                # Add other intrinsics as needed
            }
            if idx % 10 == 0:
                print(f"Processed {idx+1}/{len(self.image_paths)} images")
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path)
        
        if self.transform:
            image = self.transform(image)
        
        # Get cached intrinsics
        intrinsics = self.cached_intrinsics[image_path]
        
        return {
            'image': image,
            'depth': torch.from_numpy(intrinsics['depth']),
            # Add other intrinsics as needed
        }

## 3. Example Usage with FFHQ

In [None]:
# Setup dataset
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(256),
    transforms.ToTensor(),
])

# Create dataset with 250 samples as used in paper
dataset = FaceIntrinsicDataset(
    root_dir="path_to_ffhq_dataset",  # Replace with your FFHQ path
    transform=transform,
    num_samples=250
)

# Create dataloader
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

# Visualize some examples
def show_batch(batch):
    images = batch['image']
    depths = batch['depth']
    
    fig, axes = plt.subplots(2, 4, figsize=(16, 8))
    for i in range(4):
        axes[0, i].imshow(images[i].permute(1, 2, 0))
        axes[0, i].set_title('Original')
        axes[1, i].imshow(depths[i], cmap='magma')
        axes[1, i].set_title('Depth')
    plt.tight_layout()
    plt.show()

# Show a batch
batch = next(iter(dataloader))
show_batch(batch)