In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import os
import random

In [2]:
# Define the Siamese Network
class SiameseNetwork(nn.Module):
    def __init__(self, base_network):
        super(SiameseNetwork, self).__init__()
        self.base_network = base_network

    def forward(self, img1, img2):
        output1 = self.base_network(img1)
        output2 = self.base_network(img2)
        return output1, output2

# Define the base CNN network
class BaseNetwork(nn.Module):
    def __init__(self):
        super(BaseNetwork, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.fc = nn.Sequential(
            nn.Linear(128 * 12 * 12, 512),
            nn.ReLU(),
            nn.Linear(512, 256)
        )

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x


In [3]:
# Prepare the data
transform = transforms.Compose([
    transforms.Resize((100, 100)),
    transforms.ToTensor()
])

In [5]:
class HandwritingTripletDataset(Dataset):
    def __init__(self, triplets, transform=None):
        self.triplets = triplets
        self.transform = transform

    def __len__(self):
        return len(self.triplets)

    def __getitem__(self, idx):
        anchor_path, positive_path, negative_path = self.triplets[idx]
        anchor_img = Image.open(anchor_path).convert('L')
        positive_img = Image.open(positive_path).convert('L')
        negative_img = Image.open(negative_path).convert('L')

        if self.transform:
            anchor_img = self.transform(anchor_img)
            positive_img = self.transform(positive_img)
            negative_img = self.transform(negative_img)

        return anchor_img, positive_img, negative_img


In [9]:
# generating dataset for training the network using triplet loss
def generate_triplets(labels_dict, num_triplets):
    triplets = []
    writer_indices = list(labels_dict.keys())
    
    while len(triplets) < num_triplets:
        # Randomly select an anchor writer
        anchor_writer = random.choice(writer_indices)
        anchor_images = labels_dict[anchor_writer]
        
        # Ensure anchor writer has at least 2 images
        if len(anchor_images) < 2:
            continue
        
        # Select anchor and positive image from the same writer
        anchor_img, positive_img = random.sample(anchor_images, 2)
        
        # Select a negative writer different from the anchor writer
        negative_writer = random.choice([w for w in writer_indices if w != anchor_writer])
        negative_img = random.choice(labels_dict[negative_writer])
        
        # Add the triplet to the list
        triplets.append((anchor_img, positive_img, negative_img))
    
    return triplets

In [10]:
class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        positive_distance = nn.functional.pairwise_distance(anchor, positive)
        negative_distance = nn.functional.pairwise_distance(anchor, negative)
        loss = torch.mean(torch.clamp(positive_distance - negative_distance + self.margin, min=0.0))
        return loss

In [21]:
dataset_dir = "dataset/dataset/train"
labels_dict = {}
count = 0
for writer_label, writer_folder in enumerate(os.listdir(dataset_dir)):
    count += 1
    if(count > 30):
        break
    writer_folder_path = os.path.join(dataset_dir, writer_folder)
    if not os.path.isdir(writer_folder_path):
        continue

    labels_dict[writer_label] = []

    for img_name in os.listdir(writer_folder_path):
        img_path = os.path.join(writer_folder_path, img_name)
        labels_dict[writer_label].append(img_path)

In [22]:
triplets = generate_triplets(labels_dict, 1000)

In [23]:
dataset = HandwritingTripletDataset(triplets, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)


In [24]:
# Initialize the network
base_network = BaseNetwork()
criterion = TripletLoss()
optimizer = optim.Adam(base_network.parameters(), lr=0.02)

In [25]:
num_epochs = 15
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
for epoch in range(num_epochs):
    base_network.train()
    running_loss = 0.0
    for anchor_img, positive_img, negative_img in dataloader:
        anchor_img, positive_img, negative_img = anchor_img.to(device), positive_img.to(device), negative_img.to(device)
        optimizer.zero_grad()
        anchor_output = base_network(anchor_img)
        positive_output = base_network(positive_img)
        negative_output = base_network(negative_img)
        loss = criterion(anchor_output, positive_output, negative_output)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(dataloader)}")

print("Training completed.")

Epoch [1/15], Loss: 1.9426118396222591
Epoch [2/15], Loss: 1.0


In [19]:
# Testing the network
dataset_dir = "dataset/dataset/train"
labels_dict_test = {}
count = 0
for writer_label, writer_folder in enumerate(os.listdir(dataset_dir)):
    count += 1
    if(count < 30):
        continue
    if(count > 40):
        break
    writer_folder_path = os.path.join(dataset_dir, writer_folder)
    if not os.path.isdir(writer_folder_path):
        continue

    labels_dict_test[writer_label] = []

    for img_name in os.listdir(writer_folder_path):
        img_path = os.path.join(writer_folder_path, img_name)
        labels_dict_test[writer_label].append(img_path)
    
test_triplets = generate_triplets(labels_dict_test, 1000)
test_dataset = HandwritingTripletDataset(test_triplets, transform=transform)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)

base_network.eval()

correct = 0
total = 0
with torch.no_grad():
    for anchor_img, positive_img, negative_img in test_dataloader:
        anchor_img, positive_img, negative_img = anchor_img.to(device), positive_img.to(device), negative_img.to(device)
        anchor_output = base_network(anchor_img)
        positive_output = base_network(positive_img)
        negative_output = base_network(negative_img)
        positive_distance = nn.functional.pairwise_distance(anchor_output, positive_output)
        negative_distance = nn.functional.pairwise_distance(anchor_output, negative_output)
        correct += (positive_distance < negative_distance).sum().item()
        total += anchor_img.size(0)
        print(f"Accuracy: {correct/total * 100:.2f}%")

Accuracy: 75.00%
Accuracy: 68.75%
Accuracy: 68.75%
Accuracy: 64.84%
Accuracy: 66.88%
Accuracy: 66.67%
Accuracy: 66.96%
Accuracy: 67.19%
Accuracy: 67.36%
Accuracy: 65.94%
Accuracy: 64.49%
Accuracy: 62.76%
Accuracy: 63.22%
Accuracy: 63.39%
Accuracy: 63.54%
Accuracy: 63.48%
Accuracy: 63.60%
Accuracy: 63.37%
Accuracy: 63.16%
Accuracy: 62.97%
Accuracy: 63.24%
Accuracy: 62.64%
Accuracy: 62.36%
Accuracy: 62.37%
Accuracy: 62.25%
Accuracy: 62.50%
Accuracy: 62.38%
Accuracy: 62.05%
Accuracy: 61.85%
Accuracy: 62.08%
Accuracy: 62.00%
Accuracy: 62.10%


In [None]:
# Save the model
torch.save(base_network.state_dict(), "siamese_network.pth")
print("Model saved.")