In [1]:
import os
import glob
import time
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.models as models
import torchvision.utils
from tqdm import tqdm  # progress bar

# ---------------------------
# Parameters (hard-coded)
# ---------------------------
INPUT_DIR = "/kaggle/input/processed-faces/processed_faces"

BATCH_SIZE = 22
EPOCHS = 26
LR = 0.0002
CHECKPOINT_DIR = "checkpoints"
# Optional inference parameters:
TEST_IMAGE = ""  # Set to a valid image path for testing (e.g., "sample.jpg"), or leave as empty string
TIME_PARAM = 0.75
OUTPUT_IMAGE = "123.jpg"

# =============================================================================
# CUSTOM DATASET
# =============================================================================
class TimeVariableDeepfakesDataset(Dataset):
    """
    Custom Dataset for Time-Variable Deepfakes.
    
    Assumes that image filenames contain an age label (e.g., "23_image.jpg").
    The age is normalized to a continuous time parameter between 0 and 1.
    """
    def __init__(self, root_dir, transform=None, min_age=0, max_age=100):
        self.root_dir = root_dir
        self.image_paths = glob.glob(os.path.join(root_dir, '*.jpg'))
        self.transform = transform
        self.min_age = min_age
        self.max_age = max_age

    def __len__(self):
        return len(self.image_paths)

    def _extract_age(self, filename):
        base = os.path.basename(filename)
        age_str = base.split('_')[0]
        try:
            age = int(age_str)
        except ValueError:
            age = self.min_age
        return age

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        age = self._extract_age(img_path)
        # Normalize age to [0, 1]
        time_param = (age - self.min_age) / (self.max_age - self.min_age)
        time_param = torch.tensor([time_param], dtype=torch.float32)
        return image, time_param

# =============================================================================
# MODEL DEFINITIONS
# =============================================================================
class SelfAttention(nn.Module):
    """
    Self-Attention Block as used in SAGAN.
    """
    def __init__(self, in_dim):
        super(SelfAttention, self).__init__()
        self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
        self.key_conv   = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax  = nn.Softmax(dim=-1)
    
    def forward(self, x):
        B, C, W, H = x.size()
        proj_query = self.query_conv(x).view(B, -1, W * H).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(B, -1, W * H)
        energy = torch.bmm(proj_query, proj_key)
        attention = self.softmax(energy)
        proj_value = self.value_conv(x).view(B, -1, W * H)
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(B, C, W, H)
        out = self.gamma * out + x
        return out

class UNetGenerator(nn.Module):
    """
    U-Net style Generator with attention blocks.
    Conditions generation on a continuous time variable.
    
    The time parameter is embedded and concatenated to the input image.
    """
    def __init__(self, in_channels=3, out_channels=3, time_dim=1, feature_channels=64):
        super(UNetGenerator, self).__init__()
        self.feature_channels = feature_channels
        
        # Time embedding: embeds the time parameter into a vector of size feature_channels.
        self.time_embedding = nn.Sequential(
            nn.Linear(time_dim, feature_channels),
            nn.ReLU(True),
            nn.Linear(feature_channels, feature_channels)
        )
        # Update encoder to accept extra channels from time embedding.
        self.enc1 = self.conv_block(in_channels + feature_channels, feature_channels)
        self.enc2 = self.conv_block(feature_channels, feature_channels * 2)
        self.enc3 = self.conv_block(feature_channels * 2, feature_channels * 4)
        self.enc4 = self.conv_block(feature_channels * 4, feature_channels * 8)
        
        # Bottleneck with attention
        self.bottleneck = self.conv_block(feature_channels * 8, feature_channels * 16)
        self.attention = SelfAttention(feature_channels * 16)
        
        # Decoder with skip connections
        self.dec4 = self.deconv_block(feature_channels * 16 + feature_channels * 8, feature_channels * 8)
        self.dec3 = self.deconv_block(feature_channels * 8 + feature_channels * 4, feature_channels * 4)
        self.dec2 = self.deconv_block(feature_channels * 4 + feature_channels * 2, feature_channels * 2)
        self.dec1 = self.deconv_block(feature_channels * 2 + feature_channels, feature_channels)
        
        self.final_conv = nn.Conv2d(feature_channels, out_channels, kernel_size=1)
        self.tanh = nn.Tanh()

    def conv_block(self, in_channels, out_channels):
        """Convolutional block: Conv2d -> BatchNorm -> LeakyReLU."""
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True)
        )

    def deconv_block(self, in_channels, out_channels):
        """Deconvolutional block: ConvTranspose2d -> BatchNorm -> ReLU."""
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x, time_param):
        # Embed and spatially expand the time parameter.
        t_emb = self.time_embedding(time_param)  # Shape: (B, feature_channels)
        B, _, H, W = x.size()
        t_emb_expanded = t_emb.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, H, W)
        x = torch.cat([x, t_emb_expanded], dim=1)
        
        # Encoder
        e1 = self.enc1(x)    # (B, feature_channels, H/2, W/2)
        e2 = self.enc2(e1)   # (B, 2*feature_channels, H/4, W/4)
        e3 = self.enc3(e2)   # (B, 4*feature_channels, H/8, W/8)
        e4 = self.enc4(e3)   # (B, 8*feature_channels, H/16, W/16)
        
        # Bottleneck + Attention
        b = self.bottleneck(e4)  # (B, 16*feature_channels, H/32, W/32)
        b = self.attention(b)
        
        # Upsample bottleneck output to match e4 spatial size.
        b = nn.functional.interpolate(b, scale_factor=2, mode='bilinear', align_corners=True)
        
        # Decoder with skip connections
        d4 = self.dec4(torch.cat([b, e4], dim=1))
        d3 = self.dec3(torch.cat([d4, e3], dim=1))
        d2 = self.dec2(torch.cat([d3, e2], dim=1))
        d1 = self.dec1(torch.cat([d2, e1], dim=1))
        out = self.tanh(self.final_conv(d1))
        return out

class MultiScaleDiscriminator(nn.Module):
    """
    Multi-Scale Discriminator to enforce realism at different resolutions.
    """
    def __init__(self, in_channels=3, feature_channels=64, num_scales=3):
        super(MultiScaleDiscriminator, self).__init__()
        self.num_scales = num_scales
        self.discriminators = nn.ModuleList()
        for _ in range(num_scales):
            self.discriminators.append(self.make_discriminator(in_channels, feature_channels))
        self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)

    def make_discriminator(self, in_channels, feature_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, feature_channels, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(feature_channels, feature_channels * 2, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(feature_channels * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(feature_channels * 2, feature_channels * 4, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(feature_channels * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(feature_channels * 4, 1, kernel_size=4, padding=1)
        )

    def forward(self, x):
        outputs = []
        for disc in self.discriminators:
            out = disc(x)
            outputs.append(out)
            x = self.downsample(x)
        return outputs

class VGGFeatureExtractor(nn.Module):
    """
    VGG16 Feature Extractor for computing Perceptual Loss.
    """
    def __init__(self):
        super(VGGFeatureExtractor, self).__init__()
        vgg16 = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)
        self.features = vgg16.features.eval()
        for param in self.features.parameters():
            param.requires_grad = False

    def forward(self, x):
        return self.features(x)

# =============================================================================
# LOSS FUNCTIONS
# =============================================================================
def adversarial_loss(outputs, target_is_real):
    """
    Compute adversarial loss using Mean Squared Error.
    """
    loss = 0
    target_tensor = 1.0 if target_is_real else 0.0
    criterion = nn.MSELoss()
    for output in outputs:
        loss += criterion(output, torch.full_like(output, target_tensor))
    return loss

def identity_loss(input_image, generated_image):
    """
    L1 loss to preserve the subject's identity.
    """
    return nn.L1Loss()(generated_image, input_image)

def perceptual_loss(vgg, input_image, generated_image, layers=('8', '15')):
    """
    Compute perceptual loss based on intermediate VGG features.
    
    The layer numbers (as strings) indicate which features to use.
    """
    loss = 0.0
    x = input_image
    y = generated_image
    for name, layer in vgg.features._modules.items():
        x = layer(x)
        y = layer(y)
        if name in layers:
            loss += nn.L1Loss()(x, y)
    return loss

def time_consistency_loss(generator, input_image, time1, time2):
    """
    Enforce smooth transitions: small changes in time should yield similar outputs.
    """
    gen1 = generator(input_image, time1)
    gen2 = generator(input_image, time2)
    return nn.L1Loss()(gen1, gen2)

# =============================================================================
# TRAINING & TESTING FUNCTIONS
# =============================================================================
def train(model_G, model_D, dataloader, optimizer_G, optimizer_D, vgg, device, num_epochs=20, checkpoint_dir='checkpoints'):
    """
    Train the conditional GAN.
    """
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    
    for epoch in range(num_epochs):
        epoch_start = time.time()
        # Wrap dataloader with tqdm for live progress.
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)
        for i, (real_images, time_params) in enumerate(progress_bar):
            real_images = real_images.to(device)
            time_params = time_params.to(device)
            
            # ---------------------
            # Train Generator
            # ---------------------
            optimizer_G.zero_grad()
            fake_images = model_G(real_images, time_params)
            disc_outputs = model_D(fake_images)
            loss_adv = adversarial_loss(disc_outputs, True)
            loss_id = identity_loss(real_images, fake_images)
            loss_perc = perceptual_loss(vgg, real_images, fake_images)
            # Sample a nearby time parameter (delta=0.05, clamped to [0,1])
            delta = 0.05
            time_plus = torch.clamp(time_params + delta, 0, 1)
            loss_time = time_consistency_loss(model_G, real_images, time_params, time_plus)
            
            loss_G = loss_adv + 10.0 * loss_id + 5.0 * loss_perc + 2.0 * loss_time
            loss_G.backward()
            optimizer_G.step()
            
            # ---------------------
            # Train Discriminator
            # ---------------------
            optimizer_D.zero_grad()
            real_outputs = model_D(real_images)
            loss_real = adversarial_loss(real_outputs, True)
            fake_outputs = model_D(fake_images.detach())
            loss_fake = adversarial_loss(fake_outputs, False)
            loss_D = (loss_real + loss_fake) * 0.5
            loss_D.backward()
            optimizer_D.step()
            
            # Update progress bar with current losses.
            progress_bar.set_postfix({
                "Loss_G": loss_G.item(),
                "Loss_D": loss_D.item()
            })
                
        # Save a checkpoint after each epoch.
        checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth')
        torch.save({
            'epoch': epoch,
            'generator_state_dict': model_G.state_dict(),
            'discriminator_state_dict': model_D.state_dict(),
            'optimizer_G_state_dict': optimizer_G.state_dict(),
            'optimizer_D_state_dict': optimizer_D.state_dict()
        }, checkpoint_path)
        print(f"Epoch [{epoch+1}/{num_epochs}] completed in {time.time()-epoch_start:.2f} seconds. Checkpoint saved to {checkpoint_path}.")

def test(generator, input_image, time_param, device, output_path='output.jpg'):
    """
    Run inference: generate a deepfake image for a given time parameter.
    """
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
    ])
    if isinstance(input_image, Image.Image):
        input_tensor = transform(input_image).unsqueeze(0).to(device)
    else:
        input_tensor = input_image.to(device)
    
    time_tensor = torch.tensor([[time_param]], dtype=torch.float32).to(device)
    generator.eval()
    with torch.no_grad():
        fake = generator(input_tensor, time_tensor)
    fake = (fake + 1) / 2  # Denormalize from [-1,1] to [0,1]
    torchvision.utils.save_image(fake, output_path)
    print(f"Generated image saved to {output_path}")

# =============================================================================
# MAIN FUNCTION
# =============================================================================
def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Data transforms for training.
    data_transforms = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
    ])
    dataset = TimeVariableDeepfakesDataset(root_dir=INPUT_DIR, transform=data_transforms)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
    
    # Initialize models.
    generator = UNetGenerator(in_channels=3, out_channels=3, time_dim=1, feature_channels=64).to(device)
    discriminator = MultiScaleDiscriminator(in_channels=3, feature_channels=64).to(device)
    vgg = VGGFeatureExtractor().to(device)
    
    optimizer_G = optim.Adam(generator.parameters(), lr=LR, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=LR, betas=(0.5, 0.999))
    
    # Train the models.
    train(generator, discriminator, dataloader, optimizer_G, optimizer_D, vgg, device,
          num_epochs=EPOCHS, checkpoint_dir=CHECKPOINT_DIR)
    
    # Optional inference test.
    if TEST_IMAGE:
        sample_image = Image.open(TEST_IMAGE).convert('RGB')
        test(generator, sample_image, time_param=TIME_PARAM, device=device, output_path=OUTPUT_IMAGE)

if __name__ == '__main__':
    main()


Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:02<00:00, 204MB/s] 
                                                                                          

Epoch [1/26] completed in 1619.27 seconds. Checkpoint saved to checkpoints/checkpoint_epoch_1.pth.


                                                                                          

Epoch [2/26] completed in 1618.15 seconds. Checkpoint saved to checkpoints/checkpoint_epoch_2.pth.


                                                                                          

Epoch [3/26] completed in 1618.17 seconds. Checkpoint saved to checkpoints/checkpoint_epoch_3.pth.


                                                                                          

Epoch [4/26] completed in 1618.16 seconds. Checkpoint saved to checkpoints/checkpoint_epoch_4.pth.


                                                                                          

Epoch [5/26] completed in 1618.11 seconds. Checkpoint saved to checkpoints/checkpoint_epoch_5.pth.


                                                                                          

Epoch [6/26] completed in 1618.23 seconds. Checkpoint saved to checkpoints/checkpoint_epoch_6.pth.


                                                                                          

Epoch [7/26] completed in 1618.19 seconds. Checkpoint saved to checkpoints/checkpoint_epoch_7.pth.


                                                                                          

Epoch [8/26] completed in 1618.12 seconds. Checkpoint saved to checkpoints/checkpoint_epoch_8.pth.


                                                                                          

Epoch [9/26] completed in 1618.10 seconds. Checkpoint saved to checkpoints/checkpoint_epoch_9.pth.


                                                                                           

Epoch [10/26] completed in 1618.24 seconds. Checkpoint saved to checkpoints/checkpoint_epoch_10.pth.


                                                                                           

Epoch [11/26] completed in 1618.26 seconds. Checkpoint saved to checkpoints/checkpoint_epoch_11.pth.


                                                                                           

Epoch [12/26] completed in 1618.07 seconds. Checkpoint saved to checkpoints/checkpoint_epoch_12.pth.


                                                                                           

Epoch [13/26] completed in 1617.87 seconds. Checkpoint saved to checkpoints/checkpoint_epoch_13.pth.


                                                                                           

Epoch [14/26] completed in 1618.13 seconds. Checkpoint saved to checkpoints/checkpoint_epoch_14.pth.


                                                                                           

Epoch [15/26] completed in 1618.11 seconds. Checkpoint saved to checkpoints/checkpoint_epoch_15.pth.


                                                                                           

Epoch [16/26] completed in 1618.21 seconds. Checkpoint saved to checkpoints/checkpoint_epoch_16.pth.


                                                                                           

Epoch [17/26] completed in 1618.20 seconds. Checkpoint saved to checkpoints/checkpoint_epoch_17.pth.


                                                                                           

Epoch [18/26] completed in 1618.08 seconds. Checkpoint saved to checkpoints/checkpoint_epoch_18.pth.


                                                                                           

Epoch [19/26] completed in 1618.22 seconds. Checkpoint saved to checkpoints/checkpoint_epoch_19.pth.


                                                                                           

Epoch [20/26] completed in 1618.06 seconds. Checkpoint saved to checkpoints/checkpoint_epoch_20.pth.


                                                                                           

Epoch [21/26] completed in 1618.05 seconds. Checkpoint saved to checkpoints/checkpoint_epoch_21.pth.


                                                                                           

Epoch [22/26] completed in 1618.00 seconds. Checkpoint saved to checkpoints/checkpoint_epoch_22.pth.


                                                                                           

Epoch [23/26] completed in 1618.00 seconds. Checkpoint saved to checkpoints/checkpoint_epoch_23.pth.


                                                                                           

Epoch [24/26] completed in 1618.13 seconds. Checkpoint saved to checkpoints/checkpoint_epoch_24.pth.


                                                                                           

Epoch [25/26] completed in 1617.96 seconds. Checkpoint saved to checkpoints/checkpoint_epoch_25.pth.


                                                                                            

Epoch [26/26] completed in 1618.11 seconds. Checkpoint saved to checkpoints/checkpoint_epoch_26.pth.
