In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_curve, auc, confusion_matrix, accuracy_score
import numpy as np
from PIL import Image
import os
from torch.optim import Adam
import copy

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class SignatureDataset(Dataset):
    def __init__(self, root_dir, subset='train', transform=None):
        self.root_dir = os.path.join(root_dir, subset)
        self.transform = transform
        
        # Verify root directory exists
        if not os.path.exists(self.root_dir):
            raise ValueError(f"Directory not found: {self.root_dir}")
            
        print(f"\nAnalyzing directory: {self.root_dir}")
        self.pairs = self._create_pairs()
        print(f"Total pairs created: {len(self.pairs)}")
        
        if len(self.pairs) == 0:
            print(f"\nDirectory structure for {self.root_dir}:")
            self._print_directory_structure(self.root_dir)
            raise ValueError(f"No valid pairs found in {self.root_dir}")
    
    def _print_directory_structure(self, startpath):
        for root, dirs, files in os.walk(startpath):
            level = root.replace(startpath, '').count(os.sep)
            indent = ' ' * 4 * level
            print(f'{indent}{os.path.basename(root)}/')
            subindent = ' ' * 4 * (level + 1)
            for f in files:
                if f.endswith(('.png', '.jpg', '.jpeg')):
                    print(f'{subindent}{f}')
    
    def _create_pairs(self):
        pairs = []
        genuine_folders = []
        
        # Get all subdirectories
        for item in os.listdir(self.root_dir):
            if os.path.isdir(os.path.join(self.root_dir, item)):
                if not item.endswith('_forg'):
                    genuine_folders.append(item)
        
        print(f"Found genuine folders: {genuine_folders}")
        
        for folder in genuine_folders:
            genuine_path = os.path.join(self.root_dir, folder)
            forged_path = os.path.join(self.root_dir, f"{folder}_forg")
            
            if not os.path.exists(forged_path):
                print(f"Warning: No forged folder found for {folder}")
                continue
            
            # Get genuine and forged images
            genuine_images = [f for f in os.listdir(genuine_path) 
                            if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
            forged_images = [f for f in os.listdir(forged_path) 
                           if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
            
            print(f"\nFolder {folder}:")
            print(f"  Genuine images: {len(genuine_images)}")
            print(f"  Forged images: {len(forged_images)}")
            
            # Create genuine pairs
            for i in range(len(genuine_images)):
                for j in range(i + 1, len(genuine_images)):
                    pairs.append((
                        os.path.join(genuine_path, genuine_images[i]),
                        os.path.join(genuine_path, genuine_images[j]),
                        1  # genuine pair
                    ))
            
            # Create forged pairs
            for genuine_img in genuine_images:
                for forged_img in forged_images:
                    pairs.append((
                        os.path.join(genuine_path, genuine_img),
                        os.path.join(forged_path, forged_img),
                        0  # forged pair
                    ))
        
        return pairs
    
    def __len__(self):
        return len(self.pairs)
    
    def __getitem__(self, idx):
        img1_path, img2_path, label = self.pairs[idx]
        
        try:
            img1 = Image.open(img1_path).convert('RGB')
            img2 = Image.open(img2_path).convert('RGB')
        except Exception as e:
            print(f"Error loading images:\nPath 1: {img1_path}\nPath 2: {img2_path}\nError: {str(e)}")
            raise
        
        if self.transform:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
        
        return img1, img2, torch.FloatTensor([label])

class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        self.cnn = models.resnet18(pretrained=True)
        self.cnn.fc = nn.Linear(self.cnn.fc.in_features, 256)
        
        self.fc1 = nn.Sequential(
            nn.Linear(256 * 2, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1),
            nn.Sigmoid()
        )
    
    def forward(self, img1, img2):
        feat1 = self.cnn(img1)
        feat2 = self.cnn(img2)
        combined = torch.cat((feat1, feat2), 1)
        output = self.fc1(combined)
        return output

def train_model(model, train_loader, test_loader, criterion, optimizer, num_epochs, patience):
    best_accuracy = 0.0
    best_model_wts = copy.deepcopy(model.state_dict())
    train_losses = []
    test_losses = []
    patience_counter = 0

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        for img1, img2, labels in train_loader:
            img1, img2, labels = img1.to(device), img2.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(img1, img2)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * img1.size(0)

        epoch_loss = running_loss / len(train_loader.dataset)
        train_losses.append(epoch_loss)

        # Validate
        model.eval()
        test_loss = 0.0
        corrects = 0
        all_labels = []
        all_preds = []

        with torch.no_grad():
            for img1, img2, labels in test_loader:
                img1, img2, labels = img1.to(device), img2.to(device), labels.to(device)
                outputs = model(img1, img2)
                loss = criterion(outputs, labels)
                test_loss += loss.item() * img1.size(0)
                preds = (outputs > 0.5).float()
                corrects += (preds == labels).sum().item()
                all_labels.extend(labels.cpu().numpy())
                all_preds.extend(preds.cpu().numpy())

        epoch_test_loss = test_loss / len(test_loader.dataset)
        test_losses.append(epoch_test_loss)
        accuracy = corrects / len(test_loader.dataset)
        print(f'Epoch {epoch}/{num_epochs - 1}, '
              f'Train Loss: {epoch_loss:.4f}, '
              f'Test Loss: {epoch_test_loss:.4f}, '
              f'Accuracy: {accuracy:.4f}')

        # Check if this is the best model so far
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            best_model_wts = copy.deepcopy(model.state_dict())
            torch.save(model.state_dict(), '/kaggle/working/best_model.pth')
            patience_counter = 0
        else:
            patience_counter += 1

        # Early stopping
        if patience_counter >= patience:
            print("Early stopping")
            break

        # Calculate confusion matrix for this epoch
        cm = confusion_matrix(all_labels, all_preds)
        print(f'Confusion Matrix (Epoch {epoch}):\n{cm}')

    # Load the best model weights
    model.load_state_dict(best_model_wts)

    return train_losses, test_losses, best_accuracy, cm

def main():
    # Data transforms
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    ])
    
    # Dataset paths for Kaggle
    dataset_root = '/kaggle/input/signature-verification-dataset/sign_data'
    
    print("\nInitializing training dataset...")
    train_dataset = SignatureDataset(dataset_root, subset='train', transform=transform)
    
    print("\nInitializing testing dataset...")
    test_dataset = SignatureDataset(dataset_root, subset='test', transform=transform)
    
    # Create data loaders with num_workers=0 for debugging
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0)
    
    # Initialize model
    model = SiameseNetwork().to(device)
    criterion = nn.BCELoss()
    optimizer = Adam(model.parameters(), lr=0.0001)
    
    # Training parameters
    num_epochs = 50
    patience = 7
    
    # Train model
    train_losses, test_losses, best_accuracy, cm = train_model(
        model, train_loader, test_loader, criterion, optimizer, num_epochs, patience
    )
    
    # Save final results
    print(f"\nBest accuracy achieved: {best_accuracy:.2f}%")
    
    # Save the final model
    torch.save(model.state_dict(), '/kaggle/working/final_model.pth')
    
    # Save the visualization plots to Kaggle working directory
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(test_losses, label='Test Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Test Loss')
    plt.legend()
    plt.savefig('/kaggle/working/training_test_loss_plot.png')
    plt.show()

    # Plot confusion matrix
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=['Forged', 'Genuine'], yticklabels=['Forged', 'Genuine'])
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.title('Confusion Matrix')
    plt.savefig('/kaggle/working/confusion_matrix.png')
    plt.show()

if __name__ == "__main__":
    main()



Initializing training dataset...

Analyzing directory: /kaggle/input/signature-verification-dataset/sign_data/train
Found genuine folders: ['057', '061', '048', '053', '051', '018', '044', '016', '009', '012', '029', '025', '001', '056', '006', '042', '055', '027', '041', '036', '035', '026', '065', '062', '034', '058', '060', '068', '033', '049', '023', '020', '013', '050', '052', '066', '002', '067', '022', '043', '054', '047', '004', '021', '015', '059', '014', '039', '040', '064', '063', '031', '017', '003', '019', '024', '069', '037', '046', '045', '028', '038', '032', '030']

Folder 057:
  Genuine images: 12
  Forged images: 12

Folder 061:
  Genuine images: 12
  Forged images: 12

Folder 048:
  Genuine images: 12
  Forged images: 8

Folder 053:
  Genuine images: 12
  Forged images: 16

Folder 051:
  Genuine images: 12
  Forged images: 8

Folder 018:
  Genuine images: 12
  Forged images: 12

Folder 044:
  Genuine images: 12
  Forged images: 12

Folder 016:
  Genuine images: 23
 

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 203MB/s]


Epoch 0/49, Train Loss: 0.0290, Test Loss: 0.0417, Accuracy: 0.9812
Confusion Matrix (Epoch 0):
[[2976    0]
 [  82 1304]]
Epoch 1/49, Train Loss: 0.0034, Test Loss: 0.0467, Accuracy: 0.9876
Confusion Matrix (Epoch 1):
[[2976    0]
 [  54 1332]]
Epoch 2/49, Train Loss: 0.0001, Test Loss: 0.0158, Accuracy: 0.9966
Confusion Matrix (Epoch 2):
[[2976    0]
 [  15 1371]]
Epoch 3/49, Train Loss: 0.0044, Test Loss: 0.2082, Accuracy: 0.9541
Confusion Matrix (Epoch 3):
[[2976    0]
 [ 200 1186]]
