In [19]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
from torch.utils.data import DataLoader

### Generator (G) Structure

In [20]:
class CAdaIN(nn.Module):
    def __init__(self, input_dim, noise_dim=128):
        super(CAdaIN, self).__init__()
        self.fc_mean = nn.Linear(noise_dim, input_dim)
        self.fc_var = nn.Linear(noise_dim, input_dim)
        self.noise_layer = nn.Parameter(torch.zeros(1, input_dim, 1, 1))

    def forward(self, x, noise):
        mean = self.fc_mean(noise).view(-1, x.size(1), 1, 1)
        var = self.fc_var(noise).view(-1, x.size(1), 1, 1)
        x_normalized = (x - x.mean(dim=[2, 3], keepdim=True)) / (x.std(dim=[2, 3], keepdim=True) + 1e-5)
        return mean * x_normalized + var + self.noise_layer

class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels, noise_dim=128):
        super(DownBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.lrelu = nn.LeakyReLU(0.2, inplace=True)
        self.caadain = CAdaIN(out_channels, noise_dim)

    def forward(self, x, noise):
        x = self.conv(x)
        x = self.lrelu(x)
        x = self.bn(x)
        down = self.caadain(x, noise)
        return down

class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels, noise_dim=128):
        super(UpBlock, self).__init__()
        # ConvTranspose2d for upsampling (preserve size)
        self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, output_padding=0)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.caadain = CAdaIN(out_channels, noise_dim)
        
        # Additional convolution to reduce channels after concatenation with skip connection
        self.conv = nn.Conv2d(out_channels * 2, out_channels, kernel_size=3, padding=1)

    def forward(self, x, noise, skip_connection):
        # Apply deconvolution (up-sampling)
        x = self.deconv(x)          
        x = torch.cat([x, skip_connection], dim=1) 
        x = self.conv(x)            
        x = self.relu(x)
        x = self.bn(x)
        x = self.caadain(x, noise)
        
        return x

In [21]:
class Generator(nn.Module):
    def __init__(self, noise_dim=128):
        super(Generator, self).__init__()
        self.initial_conv = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        
        # Down blocks
        self.down1 = DownBlock(64, 128, noise_dim)
        self.down2 = DownBlock(128, 256, noise_dim)
        self.down3 = DownBlock(256, 256, noise_dim)
        self.down4 = DownBlock(256, 256, noise_dim)
        self.down5 = DownBlock(256, 256, noise_dim)
        self.down6 = DownBlock(256, 256, noise_dim)
        self.down7 = DownBlock(256, 256, noise_dim)
        
        # Up blocks
        self.up1 = UpBlock(256, 256, noise_dim)
        self.up2 = UpBlock(256, 256, noise_dim)
        self.up3 = UpBlock(256, 256, noise_dim)
        self.up4 = UpBlock(256, 256, noise_dim)
        self.up5 = UpBlock(256, 128, noise_dim)
        self.up6 = UpBlock(256, 64, noise_dim)
        self.final_conv = nn.Conv2d(64, 3, kernel_size=3, padding=1)
        
        # FC layers for noise encoding
        self.fc1 = nn.Linear(noise_dim, noise_dim)
        self.fc2 = nn.Linear(noise_dim, noise_dim)

    def forward(self, creases, noise):
        # Encode noise
        w = F.relu(self.fc1(noise))
        w = F.relu(self.fc2(w))
        
        # Initial convolution
        x = F.relu(self.initial_conv(creases))
        
        # Down-sampling
        d1 = self.down1(x, w)
        d2 = self.down2(d1, w)
        d3 = self.down3(d2, w)
        d4 = self.down4(d3, w)
        d5 = self.down5(d4, w)
        d6 = self.down6(d5, w)
        d7 = self.down7(d6, w)
        
        # Up-sampling with skip connections
        u1 = self.up1(d7, w, d6)
        u2 = self.up2(u1, w, d5)
        u3 = self.up3(u2, w, d4)
        u4 = self.up4(u3, w, d3)
        u5 = self.up5(u4, w, d2)
        u6_input = torch.cat((u5, d1), dim=1)
        u6 = self.up6(u6_input, w)
        
        # Final convolution
        gen = torch.tanh(self.final_conv(u6))
        return gen

### Encoder (E) Structure

In [22]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, negative_slope=0.2):
        super(ResidualBlock, self).__init__()
        
        # First sequence: BN -> LReLU -> Conv
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.lrelu1 = nn.LeakyReLU(negative_slope, inplace=True)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        
        # Second sequence: BN -> LReLU -> Conv
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.lrelu2 = nn.LeakyReLU(negative_slope, inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        
        # AvgPool layer
        self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2) if stride == 2 else nn.Identity()
        
        # Identity (skip connection) to match dimensions if stride > 1 or channel mismatch
        self.identity = nn.Identity()
        if stride != 1 or in_channels != out_channels:
            self.identity = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        identity = self.identity(x)  # Skip connection
        
        # Forward path
        out = self.bn1(x)
        out = self.lrelu1(out)
        out = self.conv1(out)
        
        out = self.bn2(out)
        out = self.lrelu2(out)
        out = self.conv2(out)
        
        # Apply AvgPool to the output
        out = self.avg_pool(out)
        
        # Add the skip connection
        out += identity
        return out

In [23]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        
        # Initial convolutional layer
        self.initial_conv = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),  # Output size: 128x128x64
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        # Residual blocks with downsampling as specified
        self.residual_block1 = ResidualBlock(64, 128, stride=2)   # Output size: 64x64x128
        self.residual_block2 = ResidualBlock(128, 192, stride=2)  # Output size: 32x32x192
        self.residual_block3 = ResidualBlock(192, 256, stride=2)  # Output size: 16x16x256
        self.residual_block4 = ResidualBlock(256, 256, stride=2)  # Output size: 8x8x256
        
        # Fully connected layer to produce 1x8 output
        self.fc = nn.Linear(256 * 8 * 8, 8)

    def forward(self, x):
        x = self.initial_conv(x)              # Initial convolution
        x = self.residual_block1(x)           # Residual block 1
        x = self.residual_block2(x)           # Residual block 2
        x = self.residual_block3(x)           # Residual block 3
        x = self.residual_block4(x)           # Residual block 4
        
        x = x.view(x.size(0), -1)             # Flatten the output
        enc = self.fc(x)                        # Fully connected layer to get 1x8 output
        return enc

### Discriminator (D) Structure

In [24]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.utils import spectral_norm

# Attention mechanism (from RLOC Discriminator)
class AttentionModule(nn.Module):
    def __init__(self, in_channels):
        super(AttentionModule, self).__init__()
        self.attention = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0),
            nn.Sigmoid()
        )

    def forward(self, x):
        return x * self.attention(x)

# RLOC Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        # Convolutional layers with spectral normalization
        self.conv1 = spectral_norm(nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1))  # Input: 256x256x3 -> Output: 128x128x64
        self.conv2 = spectral_norm(nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1))  # Output: 64x64x128
        self.conv3 = spectral_norm(nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1))  # Output: 32x32x256
        self.attention = AttentionModule(256)  # Attention applied to 256 feature maps
        self.fc = nn.Linear(256 * 32 * 32, 1)  # Flattened size from 32x32x256 -> 1 output

    def forward(self, x):
        x = F.leaky_relu(self.conv1(x), 0.2)  # Leaky ReLU after conv1
        x = F.leaky_relu(self.conv2(x), 0.2)  # Leaky ReLU after conv2
        x = F.leaky_relu(self.conv3(x), 0.2)  # Leaky ReLU after conv3
        
        # Apply attention mechanism
        x = self.attention(x)  # Attention-modulated features

        # Flatten the output for the fully connected layer
        x = x.view(x.size(0), -1)  # Flatten to (batch_size, 256*32*32)

        return torch.sigmoid(self.fc(x))  # Output: 1 value (real or fake)

### ID-Aware Loss

In [25]:
def id_aware_loss(DMB, B0, B00):
    B0_features = DMB(B0)
    B00_features = DMB(B00)
    cosine_similarity = F.cosine_similarity(B0_features, B00_features, dim=1)
    return 1 - cosine_similarity.mean()

### Total Loss

In [26]:
def total_loss(B, B0, mean, log_var, DMB, adversarial_loss):
    lambda_1, lambda_2, lambda_kl, lambda_id = 10.0, 1.0, 0.01, 5.0
    # L1 loss to ensure numerical similarity between B and B0
    l1_loss = F.l1_loss(B, B0)
    # KL divergence loss for regularizing the encoder to Gaussian noise
    kl_div = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
    # Adversarial loss using GAN for realistic image generation
    adv_loss = adversarial_loss(B, B0)
    # ID-aware loss to enforce identity consistency in generated images
    id_loss_value = id_aware_loss(DMB, B0)

    # Total loss 
    return lambda_1 * l1_loss + lambda_2 * adv_loss + lambda_kl * kl_div + lambda_id * id_loss_value

### Training the Model

In [None]:
import os
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torch.nn.functional as F

# Define the custom dataset class
class IITDPalmprintDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = os.path.join(root_dir, 'segmented')  # Use only 'segmented' folder
        self.transform = transform
        self.image_paths = []

        # Collect image paths from 'segmented' folder (left and right subfolders)
        for subfolder in ['left', 'right']:
            folder_path = os.path.join(self.root_dir, subfolder)
            if os.path.isdir(folder_path):
                for img_file in os.listdir(folder_path):
                    if img_file.endswith('.bmp'):  # IITD segmented images are in .bmp format
                        self.image_paths.append(os.path.join(folder_path, img_file))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        # Return the same image twice for training purposes
        return image, image  # (real_palmprint, palm_creases)

# Define transformations
transform = transforms.Compose([
    transforms.Resize((256,256)),  # Resize to match model input size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),  # Normalize to [-1, 1] range
])

# Load dataset
dataset_path = "IITD Palmprint V1"  # Replace with actual path to dataset
dataset = IITDPalmprintDataset(root_dir=dataset_path, transform=transform)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

# Model definitions (assuming gen, enc, D are defined and initialized)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gen = Generator().to(device)
enc = Encoder().to(device)
D = Discriminator().to(device)

# Hyperparameters
num_epochs = 60
batch_size = 16
latent_dim = 100
learning_rate_G = 0.0002
learning_rate_D = 0.0001
L1_weight = 10.0
LD_weight = 0.1
LKL_weight = 0.01
LID_weight = 5.0

# Optimizers with independent learning rates for G and D
optimizer_G = optim.Adam(gen.parameters(), lr=learning_rate_G, betas=(0.5, 0.99))
optimizer_D = optim.Adam(D.parameters(), lr=learning_rate_D, betas=(0.5, 0.99))

# Define linear decay function for learning rate
def linear_decay(epoch, start_epoch=30, end_epoch=60, start_lr=learning_rate_G, end_lr=1e-8):
    if epoch < start_epoch:
        return 1.0
    decay_factor = (epoch - start_epoch) / (end_epoch - start_epoch)
    return max(1 - decay_factor, end_lr / start_lr)

# Assign the linear decay scheduler to both optimizers
scheduler_G = optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=linear_decay)
scheduler_D = optim.lr_scheduler.LambdaLR(optimizer_D, lr_lambda=linear_decay)

# Loss functions
criterion_L1 = nn.L1Loss()
criterion_BCE = nn.BCEWithLogitsLoss()
criterion_KL = lambda mean, log_var: -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
criterion_ID = nn.L1Loss()  # Placeholder for ID-aware loss

# Training loop
for epoch in range(num_epochs):
    for i, (real_palmprint, palm_creases) in enumerate(dataloader):
        real_palmprint = real_palmprint.to(device)
        palm_creases = palm_creases.to(device)

        # =================== Train Discriminator ===================
        optimizer_D.zero_grad()
        
        # Real images
        real_labels = torch.ones((real_palmprint.size(0), 1), device=device)
        fake_labels = torch.zeros((real_palmprint.size(0), 1), device=device)
        
        real_output = D(real_palmprint)
        d_real_loss = criterion_BCE(real_output, real_labels)

        # Fake images
        latent_noise = torch.randn(real_palmprint.size(0), latent_dim, device=device)
        fake_palmprint = gen(latent_noise, palm_creases).detach()  # Detach to avoid training gen during D training
        fake_output = D(fake_palmprint)
        d_fake_loss = criterion_BCE(fake_output, fake_labels)

        # Total Discriminator loss
        d_loss = (d_real_loss + d_fake_loss) * LD_weight
        d_loss.backward()
        optimizer_D.step()

        # =================== Train Generator and Encoder ===================
        optimizer_G.zero_grad()
        
        # Generate fake palmprint
        latent_noise = torch.randn(real_palmprint.size(0), latent_dim, device=device)
        encoded_noise = enc(real_palmprint)  # Encode real palmprint to latent noise
        fake_palmprint = gen(latent_noise, palm_creases)  # Generate palmprint from noise

        # Compute various losses
        L1_loss = L1_weight * criterion_L1(fake_palmprint, real_palmprint)
        adv_loss = LD_weight * criterion_BCE(D(fake_palmprint), real_labels)  # Encourage gen to fool D
        KL_loss = LKL_weight * criterion_KL(encoded_noise.mean, encoded_noise.log_var)

        # ID-aware loss: Ensure ID consistency for generated samples
        fake_palmprint_2 = gen(torch.randn(real_palmprint.size(0), latent_dim, device=device), palm_creases)
        ID_loss = LID_weight * criterion_ID(fake_palmprint, fake_palmprint_2)

        # Total generator loss
        total_loss_G = L1_loss + adv_loss + KL_loss + ID_loss
        total_loss_G.backward()
        optimizer_G.step()

        if i % 10 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{i}], D Loss: {d_loss.item():.4f}, G Loss: {total_loss_G.item():.4f}, "
                  f"L1: {L1_loss.item():.4f}, Adv: {adv_loss.item():.4f}, KL: {KL_loss.item():.4f}, ID: {ID_loss.item():.4f}")

    # Update learning rates
    scheduler_G.step()
    scheduler_D.step()

print("Training complete.")

In [None]:
def generate_pseudo_palmprint(synthetic_palm_creases, batch_size=16):
    gen.eval()  # Set the generator to evaluation mode
    with torch.no_grad():
        latent_noise = torch.randn(batch_size, latent_dim, device=device)
        generated_palmprint = gen(latent_noise, synthetic_palm_creases)
    return generated_palmprint

# Example usage
# Assume `synthetic_creases_loader` is a DataLoader for synthetic creases
for synthetic_palm_creases in synthetic_creases_loader:
    synthetic_palm_creases = synthetic_palm_creases.to(device)
    pseudo_palmprints = generate_pseudo_palmprint(synthetic_palm_creases, batch_size=synthetic_palm_creases.size(0))
    
    # Save or further process `pseudo_palmprints` as needed