In [None]:
from facenet_pytorch import InceptionResnetV1
import torch
import torch.nn as nn
import torch.optim as optim

# Load FaceNet pre-trained on VGGFace2 or CASIA-WebFace
model = InceptionResnetV1(pretrained='vggface2').eval()

# Move model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Freeze all layers initially
for param in model.parameters():
    param.requires_grad = False


In [None]:
# Unfreeze the last few layers to fine-tune on game photos
for name, param in model.named_parameters():
    if 'conv2d_7b' in name or 'last_linear' in name:
        param.requires_grad = True


In [None]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.Resize((160, 160)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

game_dataset = datasets.ImageFolder('path/to/game_photos', transform=transform)
game_loader = DataLoader(game_dataset, batch_size=32, shuffle=True)


In [None]:
class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        pos_dist = (anchor - positive).pow(2).sum(1)  # L2 distance between anchor and positive
        neg_dist = (anchor - negative).pow(2).sum(1)  # L2 distance between anchor and negative
        loss = torch.relu(pos_dist - neg_dist + self.margin)
        return loss.mean()


In [None]:
# Initialize loss and optimizer
triplet_loss = TripletLoss(margin=1.0).to(device)
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-5)

# Training loop
num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    
    for i, (images, _) in enumerate(game_loader):
        images = images.to(device)
        
        # Split images into anchor, positive, and negative samples 
        anchors = images[::3]   # Every 3rd image as anchor
        positives = images[1::3] # Shifted image as positive
        negatives = images[2::3] # Next shifted image as negative

        if len(anchors) != len(positives) or len(anchors) != len(negatives):
            continue  # Skip batch if it doesn't fit

        # Generate embeddings
        anchor_embeds = model(anchors)
        positive_embeds = model(positives)
        negative_embeds = model(negatives)

        # Calculate loss
        loss = triplet_loss(anchor_embeds, positive_embeds, negative_embeds)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss / len(game_loader)}')


In [None]:
torch.save(model.state_dict(), 'fine_tuned_facenet.pth')
    

In [None]:
def get_embedding(model, image_path):
    image = Image.open(image_path)
    image = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        embedding = model(image).cpu().numpy()
    return embedding

# Example usage
studio_embedding = get_embedding(model, 'path/to/studio_photo.jpg')
game_embedding = get_embedding(model, 'path/to/game_photo.jpg')

# Calculate similarity (e.g., cosine similarity)
from numpy.linalg import norm
similarity = np.dot(studio_embedding, game_embedding.T) / (norm(studio_embedding) * norm(game_embedding))
print("Similarity:", similarity)
