In [None]:
import os
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet50

from sklearn.metrics import accuracy_score, precision_score, recall_score

In [None]:
ANC_PATH = '/kaggle/input/lfw-positive-negative/content/data/anchor/face'
POS_PATH = '/kaggle/input/lfw-positive-negative/content/data/positive/face'
NEG_PATH = '/kaggle/input/lfw-positive-negative/content/data/negative/face'

In [None]:
anchor_files = [os.path.join(ANC_PATH, f) for f in os.listdir(ANC_PATH) if f.endswith('.jpg')]
positive_files = [os.path.join(POS_PATH, f) for f in os.listdir(POS_PATH) if f.endswith('.jpg')]
negative_files = [os.path.join(NEG_PATH, f) for f in os.listdir(NEG_PATH) if f.endswith('.jpg')]

In [None]:
anchor_files = anchor_files[:3000]
positive_files = positive_files[:3000]
negative_files = negative_files[:3000]

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

In [None]:
def preprocess(file_path):
    img = Image.open(file_path).convert("RGB")
    return transform(img)

In [None]:
img = preprocess('/kaggle/input/lfw-positive-negative/content/data/anchor/face/13eaa162-ba39-11ef-b1cd-0242ac1c000c.jpg')
plt.imshow(img.permute(1, 2, 0))  # Convert from (C, H, W) to (H, W, C) for visualization
plt.axis('off')
plt.show()

In [None]:
class SiameseDataset(Dataset):
    def __init__(self, anchor_files, positive_files, negative_files):
        self.anchor_files = anchor_files
        self.positive_files = positive_files
        self.negative_files = negative_files
        self.labels = torch.cat([
            torch.ones(len(anchor_files)),
            torch.zeros(len(anchor_files))
        ])

    def __len__(self):
        return len(self.anchor_files) * 2

    def __getitem__(self, idx):
        if idx < len(self.anchor_files):
            anchor_img = preprocess(self.anchor_files[idx])
            pos_img = preprocess(self.positive_files[idx])
            label = 1  # Positive pair
        else:
            idx = idx - len(self.anchor_files)
            anchor_img = preprocess(self.anchor_files[idx])
            neg_img = preprocess(self.negative_files[idx])
            label = 0  # Negative pair
        
        return anchor_img, pos_img if label == 1 else neg_img, label

In [None]:
dataset = SiameseDataset(anchor_files, positive_files, negative_files)
dataloader = DataLoader(dataset, shuffle=True, batch_size=32)

In [None]:
for anchor, pair_img, label in dataloader:
    print(anchor.shape, pair_img.shape, label.shape)
    break

In [None]:
sample_anchor, sample_pair_img, sample_label = dataset[0]
plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plt.imshow(sample_anchor.permute(1, 2, 0))
plt.title("Anchor")
plt.axis('off')

In [None]:
plt.subplot(1, 2, 2)
plt.imshow(sample_pair_img.permute(1, 2, 0))
plt.title("Positive" if sample_label == 1 else "Negative")
plt.axis('off')
plt.show()

In [None]:
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size  

In [None]:
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

In [None]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
class VGGFaceEmbedding(nn.Module):
    def __init__(self):
        super(VGGFaceEmbedding, self).__init__()
        self.meta = {
            'mean': [129.186279296875, 104.76238250732422, 93.59396362304688],
            'std': [1, 1, 1],
            'imageSize': [224, 224, 3]
        }
        # Load pre-trained ResNet-50
        self.base_model = resnet50(pretrained=True)
        # Remove the final fully connected layer
        self.base_model = nn.Sequential(*list(self.base_model.children())[:-2])
        self.pooling = nn.AdaptiveAvgPool2d((1, 1))
        self.flatten = nn.Flatten()

    def forward(self, x):
        x = self.base_model(x)
        x = self.pooling(x)
        x = self.flatten(x)
        return x

In [None]:
class L1Dist(nn.Module):
    def __init__(self):
        super(L1Dist, self).__init__()

    def forward(self, input_embedding, validation_embedding):
        return torch.abs(input_embedding - validation_embedding)

In [None]:
def make_siamese_model():
    class SiameseNetwork(nn.Module):
        def __init__(self):
            super(SiameseNetwork, self).__init__()
            self.embedding = VGGFaceEmbedding()
            self.distance = L1Dist()
            self.fc1 = nn.Linear(2048, 512)  # Assuming embedding output size is 2048
            self.fc2 = nn.Linear(512, 1)
            self.sigmoid = nn.Sigmoid()

        def forward(self, input_image, validation_image):
            input_embedding = self.embedding(input_image)
            validation_embedding = self.embedding(validation_image)
            distances = self.distance(input_embedding, validation_embedding)
            x = self.fc1(distances)
            x = self.fc2(x)
            x = self.sigmoid(x)
            return x

    return SiameseNetwork()

In [None]:
siamese_model = make_siamese_model()
print(siamese_model)

In [None]:
binary_cross_entropy_loss = nn.BCELoss()
optimizer = optim.Adam(siamese_model.parameters(), lr=1e-4)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train_step(batch, model, loss_fn, optimizer):
    model.train()
    input_image, validation_image, labels = batch
    # Move to device and convert labels to float
    input_image = input_image.to(device)
    validation_image = validation_image.to(device)
    labels = labels.float().to(device)
    
    # Forward pass
    predictions = model(input_image, validation_image).squeeze()
    loss = loss_fn(predictions, labels)
    
    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item()

def validate(data_loader, model, loss_fn):
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in data_loader:
            input_image, validation_image, labels = batch
            # Move to device and convert labels to float
            input_image = input_image.to(device)
            validation_image = validation_image.to(device)
            labels = labels.float().to(device)
            
            predictions = model(input_image, validation_image).squeeze()
            loss = loss_fn(predictions, labels)
            val_loss += loss.item()
            predicted = (predictions > 0.5).float()
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
    accuracy = correct / total
    return val_loss / len(data_loader), accuracy

def train(data_loader, val_loader, model, loss_fn, optimizer, epochs, patience=5):
    try:
        print(f"Training on {device}")
        model = model.to(device)
        
        train_losses = []
        val_losses = []
        val_accuracies = []
        
        best_val_loss = float('inf')
        epochs_no_improve = 0
        early_stop = False
        
        for epoch in range(epochs):
            if early_stop:
                print("Early stopping triggered.")
                break
                
            print(f"\nEpoch {epoch + 1}/{epochs}")
            epoch_loss = 0.0
            model.train()
            
            for batch_idx, batch in enumerate(data_loader):
                loss = train_step(batch, model, loss_fn, optimizer)
                epoch_loss += loss
                if (batch_idx + 1) % 10 == 0:
                    print(f"Batch {batch_idx + 1}/{len(data_loader)}, Loss: {loss:.4f}")
            
            train_loss = epoch_loss / len(data_loader)
            val_loss, val_accuracy = validate(val_loader, model, loss_fn)
            
            train_losses.append(train_loss)
            val_losses.append(val_loss)
            val_accuracies.append(val_accuracy)
            
            print(f"Epoch {epoch + 1} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")
            
            # Early stopping logic
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                epochs_no_improve = 0
                # Optionally, save the best model
                torch.save(model.state_dict(), 'best_model.pth')
            else:
                epochs_no_improve += 1
                if epochs_no_improve >= patience:
                    early_stop = True
        
        # Plot metrics
        plt.figure(figsize=(12, 4))
        
        plt.subplot(1, 2, 1)
        plt.plot(range(1, len(train_losses) + 1), train_losses, label="Training Loss")
        plt.plot(range(1, len(val_losses) + 1), val_losses, label="Validation Loss")
        plt.xlabel("Epochs")
        plt.ylabel("Loss")
        plt.legend()
        plt.title("Loss vs Epochs")
        
        plt.subplot(1, 2, 2)
        plt.plot(range(1, len(val_accuracies) + 1), val_accuracies, label="Validation Accuracy")
        plt.xlabel("Epochs")
        plt.ylabel("Accuracy")
        plt.legend()
        plt.title("Validation Accuracy vs Epochs")
        plt.show()
        
        return train_losses, val_losses, val_accuracies
        
    except Exception as e:
        print(f"Error during training: {str(e)}")
        # Ensure we still return the lists even if training fails partway through
        return train_losses, val_losses, val_accuracies

In [None]:
EPOCHS = 20
train_losses, val_losses, val_accuracies = train(
    train_loader, 
    val_loader, 
    siamese_model, 
    binary_cross_entropy_loss, 
    optimizer, 
    EPOCHS,
)

In [None]:
model_cpu = siamese_model.to("cpu")

In [None]:
torch.save(model_cpu.state_dict(), "model_cpu.pth")

In [None]:
torch.save(siamese_model.state_dict(), 'siamese_model_weights.pth')
print("Model state_dict saved to siamese_model_weights.pth")

In [None]:
siamese_model.load_state_dict(torch.load('siamese_model_weights.pth'))
siamese_model.eval()

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

all_y_true = []
all_y_pred = []
all_y_prob = []  # Store probabilities for ROC-AUC
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
siamese_model.to(device)

siamese_model.eval()  # Set model to evaluation mode
with torch.no_grad():
    for anchor_img, paired_img, y_true in test_loader:
        anchor_img = anchor_img.to(device)
        paired_img = paired_img.to(device)
        y_true = y_true.to(device)

        y_hat = siamese_model(anchor_img, paired_img)
        y_prob = torch.sigmoid(y_hat).squeeze().cpu().numpy()  # Probabilities
        y_pred = [1.0 if prob > 0.5 else 0.0 for prob in y_prob]  # Binary predictions

        all_y_true.extend(y_true.cpu().numpy())
        all_y_pred.extend(y_pred)
        all_y_prob.extend(y_prob)

# Print predictions and true labels
print("Predictions:", all_y_pred)
print("True Labels:", all_y_true)

# Calculate metrics
accuracy = accuracy_score(all_y_true, all_y_pred)
precision = precision_score(all_y_true, all_y_pred)
recall = recall_score(all_y_true, all_y_pred)
f1 = f1_score(all_y_true, all_y_pred)
roc_auc = roc_auc_score(all_y_true, all_y_prob)

print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")
print(f"ROC-AUC: {roc_auc:.4f}")

# Check class distribution
import numpy as np
print("Class Distribution:", np.unique(all_y_true, return_counts=True))