In [1]:
from pipeline import run_pipeline

result = run_pipeline(
    urls=['https://histology.siu.edu/ssb/neuron.htm'],
    output_dir='./brain_dataset_test1',
    threshold=0.50
)

2025-12-03 11:18:33 - INFO - BRAIN IMAGE SCRAPER PIPELINE
2025-12-03 11:18:33 - INFO - URLs to process: 1
2025-12-03 11:18:33 - INFO - Output directory: brain_dataset_test1
2025-12-03 11:18:33 - INFO - AI threshold: 0.5
2025-12-03 11:18:33 - INFO - 
2025-12-03 11:18:33 - INFO - STAGE 1: EXTRACTION
2025-12-03 11:18:36 - INFO - Extracted 64 images from https://histology.siu.edu/ssb/neuron.htm
2025-12-03 11:18:36 - INFO - Total extracted: 64 images from 1 sources
2025-12-03 11:18:36 - INFO - Extracted 64 images from 1 sources
2025-12-03 11:18:36 - INFO - 
2025-12-03 11:18:36 - INFO - STAGE 2: FILTERING
2025-12-03 11:18:38 - INFO - Running rule-based filtering...
2025-12-03 11:18:38 - INFO - Passed rules: 61, Failed: 3
2025-12-03 11:18:38 - INFO - Downloading images...
2025-12-03 11:19:46 - INFO - Download progress: 50/61
2025-12-03 11:20:02 - INFO - Downloaded: 61, Failed: 0
2025-12-03 11:20:02 - INFO - Running AI classification...
  from .autonotebook import tqdm as notebook_tqdm
Disabli

In [4]:
import random
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [39]:
class SignatureTripletDataset(Dataset):
    def __init__(self, base_data_dir, triplets_per_user=100, transform=None, signature_map=None, user_ids=None):
        self.base_data_dir = base_data_dir
        self.triplets_per_user = triplets_per_user
        self.transform = transform
        
        if signature_map is None:
            self.signature_map = self._create_signature_map()
        else:
            self.signature_map = signature_map
            
        if user_ids is None:
            self.user_ids = sorted(list(self.signature_map.keys()))
        else:
            self.user_ids = list(user_ids)
    
    def __len__(self):
        return len(self.user_ids) * self.triplets_per_user
    
    def __getitem__(self, index):
        user_id = random.choice(self.user_ids)
        
        real_paths = self.signature_map[user_id]["real"]
        anchor_path, positive_path = random.sample(real_paths, 2)
        
        fake_paths = self.signature_map[user_id]["fake"]
        negative_path = random.choice(fake_paths)
        
        anchor_img = self._load_image(anchor_path)
        positive_img = self._load_image(positive_path)
        negative_img = self._load_image(negative_path)
        
        return anchor_img, positive_img, negative_img
    
    def _create_signature_map(self):
        signature_map = {}
        
        real_dir = os.path.join(self.base_data_dir, "Real")
        fake_dir = os.path.join(self.base_data_dir, "Fake")
        
        for user_id in os.listdir(real_dir):
            user_real_dir = os.path.join(real_dir, user_id)
            if not os.path.isdir(user_real_dir):
                continue
            
            real_paths = [
                os.path.join(user_real_dir, f)
                for f in os.listdir(user_real_dir)
                if f.lower().endswith((".png",".jpg",".jpeg"))
            ]
            
            user_fake_dir = os.path.join(fake_dir, user_id)
            fake_paths = []
            if os.path.isdir(user_fake_dir):
                fake_paths = [
                    os.path.join(user_fake_dir, f)
                    for f in os.listdir(user_fake_dir)
                    if f.lower().endswith((".png",".jpg","jpeg"))
                ]
            if len(real_paths) >= 2 and len(fake_paths) >= 1:
                signature_map[user_id] = {
                    "real" : real_paths,
                    "fake" : fake_paths
                }
        return signature_map
    
    def _load_image(self, path):
        with Image.open(path) as img:
            img = img.convert("RGB")
            if self.transform is not None:
                img = self.transform(img)
        return img

In [40]:
mean = [0.861, 0.861, 0.861]
std = [0.274, 0.274, 0.274]

train_transform = transforms.Compose([
    transforms.RandomAffine(
        degrees=0,
        shear=10,
        translate=(0.1,0.1)
    ),
    transforms.RandomPerspective(distortion_scale=0.1, p=0.5),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

In [41]:
signature_data_dir = "/home/siddharth/workspace/src/brain_scraper/Signature_Verification_v5_v11/"
full_dataset = SignatureTripletDataset(
    base_data_dir=signature_data_dir,
    triplets_per_user=100,
    transform=None
)

In [42]:
def create_signature_datasets_splits(full_dataset, train_split=0.8, train_transform=None, val_transform=None):
    all_users = full_dataset.user_ids.copy()
    random.shuffle(all_users)
    
    n_train = int(len(all_users) * train_split)
    train_users = all_users[:n_train]
    val_users = all_users[n_train:]
    
    train_dataset = SignatureTripletDataset(
        base_data_dir=full_dataset.base_data_dir,
        triplets_per_user=full_dataset.triplets_per_user,
        transform=train_transform,
        signature_map=full_dataset.signature_map,
        user_ids=train_users
    )
    
    val_dataset = SignatureTripletDataset(
        base_data_dir=full_dataset.base_data_dir,
        triplets_per_user=full_dataset.triplets_per_user,
        transform=val_transform,
        signature_map=full_dataset.signature_map,
        user_ids=val_users
    )
    
    return train_dataset, val_dataset

In [43]:
train_dataset, val_dataset = create_signature_datasets_splits(
    full_dataset=full_dataset,
    train_split=0.8,
    train_transform=train_transform,
    val_transform=val_transform
)

In [44]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

print("Train Size:", len(train_dataset))
print("Val Size:", len(val_dataset))

Train Size: 4000
Val Size: 1100


In [45]:
class SimpleEmbeddingNetwork(nn.Module):
    def __init__(self, embedding_dim=128):
        super(SimpleEmbeddingNetwork, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(0.3),
            
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(0.4),
            
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        
        self.fc = nn.Sequential(
            nn.Linear(in_features=128 * 25 * 25, out_features=256),
            nn.ReLU(),
            nn.Dropout(0.6),
            nn.Linear(in_features=256, out_features=embedding_dim)
        )
    
    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

In [46]:
class SiameseNetwork(nn.Module):
    def __init__(self, embedding_network):
        super().__init__()
        self.embedding_network = embedding_network
        
    def forward(self, *inputs, triplet_bool=True):
        if triplet_bool:
            anchor, positive, negative = inputs
            z_a = self.embedding_network(anchor)
            z_p = self.embedding_network(positive)
            z_n = self.embedding_network(negative)
            return z_a, z_p, z_n
        else:
            img1, img2 = inputs
            z1 = self.embedding_network(img1)
            z2 = self.embedding_network(img2)
            return z1, z2
    
    def get_embedding(self, image):
        return self.embedding_network(image)

In [47]:
embedding_dim=128
embedding_net = SimpleEmbeddingNetwork(embedding_dim=embedding_dim)
siamese_model = SiameseNetwork(embedding_network=embedding_net).to(device)

In [48]:
triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)
optimizer = optim.Adam(params=siamese_model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.2)

In [51]:
def training_loop_signature(model, train_loader, val_loader, loss_fcn, optimizer, scheduler, num_epochs, threshold, device):
    best_val_acc = 0.0
    best_state_dict = None
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        
        for batch in train_loader:
            anchor, positive, negative = batch
            anchor = anchor.to(device)
            positive = positive.to(device)
            negative = negative.to(device)
            
            optimizer.zero_grad()
            
            z_a, z_p, z_n = model(anchor, positive, negative, triplet_bool=True)
            loss = loss_fcn(z_a, z_p, z_n)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * anchor.size(0)
        scheduler.step()
        avg_train_loss = running_loss / len(train_loader.dataset)
        
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for batch in val_loader:
                anchor, positive, negative = batch
                anchor = anchor.to(device)
                positive = positive.to(device)
                negative = negative.to(device)
                
                z_a, z_p, z_n = model(anchor, positive, negative, triplet_bool=True)
                loss = loss_fcn(z_a, z_p, z_n)
                val_loss += loss.item() * anchor.size(0)
                
                d_ap = F.pairwise_distance(z_a, z_p)
                d_an = F.pairwise_distance(z_a, z_n)
                
                genuine_correct = (d_ap < threshold)
                fake_correct = (d_an >= threshold)
                batch_correct = (genuine_correct & fake_correct).sum().item()
                
                correct += batch_correct
                total += anchor.size(0)
                
        avg_val_loss = val_loss / len(val_loader.dataset)
        val_acc = correct / total if total > 0 else 0.0
        
        print(f"Epoch [{epoch+1}/{num_epochs}] "
              f"Train Loss: {avg_train_loss:.4f}  "
              f"Val Loss: {avg_val_loss:.4f}  Val Acc: {val_acc:.4f}")
        
        # save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_state_dict = model.state_dict()
    
    if best_state_dict is not None:
        model.load_state_dict(best_state_dict)
    
    print(f"Best validation accuracy: {best_val_acc:.4f}")
    return model

In [52]:
num_epochs = 5
threshold_dist = 0.8  # you can tune this later

trained_siamese = training_loop_signature(
    model=siamese_model,
    train_loader=train_loader,
    val_loader=val_loader,
    loss_fcn=triplet_loss,
    optimizer=optimizer,
    scheduler=scheduler,
    num_epochs=num_epochs,
    threshold=threshold_dist,
    device=device
)

Epoch [1/5] Train Loss: 0.5894  Val Loss: 0.5630  Val Acc: 0.3791
Epoch [2/5] Train Loss: 0.4095  Val Loss: 0.5685  Val Acc: 0.4209
Epoch [3/5] Train Loss: 0.3084  Val Loss: 0.4957  Val Acc: 0.3073
Epoch [4/5] Train Loss: 0.2971  Val Loss: 0.5490  Val Acc: 0.4191
Epoch [5/5] Train Loss: 0.2734  Val Loss: 0.5183  Val Acc: 0.3909
Best validation accuracy: 0.4209


In [54]:
def show_random_triplet(data_loader):
    model_batch = next(iter(data_loader))
    anchor, positive, negative = model_batch
    idx = random.randint(0, anchor.size(0), -1)
    
    imgs = [anchor[idx], positive[idx], negative[idx]]
    titles = ["Anchor (Real)", "Positive (Real)", "Negative (Fake)"]
    plt.figure(figsize=(8, 3))
    for i, (img, title) in enumerate(zip(imgs, titles), 1):
        plt.subplot(1, 3, i)
        img_np = img.permute(1, 2, 0).cpu.numpy()
        img_np = img_np * std[0] + mean[0]
        img_np = img_np.clip(0, 1)
        plt.imshow(img_np, cmap="gray")
        plt.title(title)
        plt.axis("off")
    plt.show()

In [55]:
show_random_triplet(train_loader)

TypeError: Random.randint() takes 3 positional arguments but 4 were given