In [1]:
import torch
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
import os
import numpy as np
from PIL import Image
import torchvision.transforms as transforms

class TripletDataset(Dataset):
    def __init__(self, bonafide_dir, morphed_dir, transform=None):
        self.bonafide_dir = bonafide_dir
        self.morphed_dir = morphed_dir
        self.transform = transform

        # Create lists of file paths
        self.bonafide_files = [os.path.join(bonafide_dir, f) for f in os.listdir(bonafide_dir)]
        self.morphed_files = [os.path.join(morphed_dir, f) for f in os.listdir(morphed_dir)]

        # Prepare file indices
        self.bonafide_indices = list(range(len(self.bonafide_files)))
        self.morphed_indices = list(range(len(self.morphed_files)))

        # Split data into train and test
        self.train_bonafide_files, self.test_bonafide_files = train_test_split(
            self.bonafide_files, test_size=0.2, random_state=42
        )
        self.train_morphed_files, self.test_morphed_files = train_test_split(
            self.morphed_files, test_size=0.2, random_state=42
        )

    def __len__(self):
        return len(self.train_bonafide_files) * 2  # Increase size for more triplets

    def __getitem__(self, idx):
      if idx % 2 == 0:
          # Anchor and Positive: both from bonafide
          anchor_idx = idx % len(self.train_bonafide_files)
          positive_idx = (anchor_idx + 1) % len(self.train_bonafide_files)
          anchor_path = self.train_bonafide_files[anchor_idx]
          positive_path = self.train_bonafide_files[positive_idx]
          negative_path = np.random.choice(self.train_morphed_files)
      else:
          # Anchor and Positive: both from morphed
          anchor_idx = idx % len(self.train_morphed_files)
          positive_idx = (anchor_idx + 1) % len(self.train_morphed_files)
          anchor_path = self.train_morphed_files[anchor_idx]
          positive_path = self.train_morphed_files[positive_idx]
          negative_path = np.random.choice(self.train_bonafide_files)

      # Load images
      anchor = self.load_image(anchor_path)
      anchor_name = os.path.basename(anchor_path)
      positive = self.load_image(positive_path)
      negative = self.load_image(negative_path)

      return anchor, positive, negative, anchor_name


    def load_image(self, path):
        image = Image.open(path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        self.inception = nn.Sequential(
            nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Flatten()
        )
        self.fc = nn.Sequential(
            nn.Linear(256*3*3, 1024),
            nn.ReLU(),
            nn.Linear(1024, 128)
        )

    def forward_one(self, x):
        x = self.inception(x)
        x = self.fc(x)
        return x

    def forward(self, input1, input2, input3):
        output1 = self.forward_one(input1)
        output2 = self.forward_one(input2)
        output3 = self.forward_one(input3)
        return output1, output2, output3

class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        pos_dist = F.pairwise_distance(anchor, positive, 2)
        neg_dist = F.pairwise_distance(anchor, negative, 2)
        loss = F.relu(pos_dist - neg_dist + self.margin)
        return loss.mean()


In [4]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to a suitable size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
import torch.optim as optim
import json
from torch.utils.data import DataLoader
loss_dict = {}

def train_model():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = SiameseNetwork().to(device)
    criterion = TripletLoss(margin=1.0)
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    dataset = TripletDataset(
        bonafide_dir='path_to_CASIA_Webface',
        morphed_dir='path_to_FRLL',
    )
    dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

    num_epochs = 10
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for anchor, positive, negative, anchor_name in dataloader:
            anchor, positive, negative = anchor.to(device), positive.to(device), negative.to(device)

            optimizer.zero_grad()
            output1, output2, output3 = model(anchor, positive, negative)
            loss = criterion(output1, output2, output3)
            loss_dict[anchor_name] = loss.item()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(dataloader)}")

if __name__ == "__main__":
    train_model()

with open('loss_data.json', 'r') as f:
    loss_dict = json.load(f)

# Testing

In [None]:
import torch
from torch.utils.data import DataLoader
from torch import device

# Load the trained model
model = SiameseNetwork().to(device)  # Ensure this matches your trained model structure
model.load_state_dict(torch.load('path_to_trained_model.pth'))  # Load the model weights
model.eval()  # Set the model to evaluation mode
loss_dict = {}

# Define your loss function
criterion = TripletLoss(margin=1.0)

# Use the existing TripletDataset class for testing
test_dataset = TripletDataset(
    bonafide_dir='path_to_CASIA_Webface_features',
    morphed_dir='path_to_FRLL_features'
)

test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

# Testing loop
def test_model(model, data_loader, criterion, device):
    model.eval()  # Ensure model is in evaluation mode
    total_loss = 0.0

    with torch.no_grad():
        for anchor, positive, negative,anchor_name in data_loader:
            anchor, positive, negative = anchor.to(device), positive.to(device), negative.to(device)

            # Forward pass
            output1, output2, output3 = model(anchor, positive, negative)
            loss = criterion(output1, output2, output3)
            loss_dict[anchor_name] = loss.item()
            total_loss += loss.item()

    average_loss = total_loss / len(data_loader)
    print(f'Average Test Loss: {average_loss}')

# Run the test
test_model(model, test_loader, criterion, device)

with open('testing_loss_data.json', 'r') as f:
    loss_dict = json.load(f)
