## Installing Dependencies

The following cell installs all the required dependencies for the code.  

The code is made on **Python 3.13.2**. On older versions, it might be deprecated.


In [None]:
!pip install pandas 
!pip install pyarrow
!pip install fastparquet
!pip install requests beautifulsoup4
!pip install fuzzywuzzy
!pip install cairosvg
!pip uninstall -y pillow
!pip install --upgrade -y pillow
!pip install torch torchvision
!pip install scikit-learn


Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.
Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.
Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.
Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.
Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.
Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.
Found existing installation: pillow 11.1.0
Uninstalling pillow-11.1.0:
  Successfully uninstalled pillow-11.1.0
Note: you may need to re

Update this for parallel processing.

For more powerful computers this number can be set higher.

In [None]:
MAX_WORKERS = 10

## Processing CSV File

1. **Extract the CSV file** and load its contents.  

2. **Add a secondary column** to keep track of the domains that have been parsed.  


In [28]:

import pandas as pd

df = pd.read_parquet('logos.snappy.parquet')

df.to_csv('data.csv', index=False)

df['extracted'] = False

print(df)

                                     domain  extracted
0                         stanbicbank.co.zw      False
1                            astrazeneca.ua      False
2               autosecuritas-ct-seysses.fr      False
3                                    ovb.ro      False
4     mazda-autohaus-hellwig-hoyerswerda.de      False
...                                     ...        ...
4379                              synlab.ec      False
4380                            ccusa.co.za      False
4381               aamcolawrencevillega.com      False
4382     mazda-autohaus-born-ludwigslust.de      False
4383                     savethechildren.ca      False

[4384 rows x 2 columns]


## First Step Logo Extraction

- **Used Clearbit** for the initial logo extraction.  

- **Achieved an 84.7% initial extraction success rate** for logos.  

- **Implemented `ThreadPoolExecutor`** for parallel processing of the domains.  


In [None]:
import requests
import os

def proc_domain(data):
    index, domain = data
    
    url = f"https://logo.clearbit.com/{domain}"
    try:
        response = requests.get(url, timeout=10)
        
        if response.status_code == 200:
            # Create images directory if it doesn't exist
            if not os.path.exists('images'):
                os.makedirs('images')
                
            with open(f"images/{index}.png", 'wb') as f:
                f.write(response.content)
            return (index, True)
        else:
            return (index, False)
    except Exception as e:
        print(f"Error processing {domain}: {str(e)}")
        return (index, False)
    

In [None]:
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm

# Prepare the data for processing
domains = list(enumerate(df['domain']))
total = len(domains)

# Initialize a progress bar
progress_bar = tqdm(total=total, desc="Downloading logos", unit="logo")

# Track successful downloads
successful = 0

# Using ThreadPoolExecutor for parallel requests
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
    # Submit all tasks and get futures
    futures = [executor.submit(proc_domain, data) for data in domains]
    
    # Process results as they complete
    for future in futures:
        try:
            index, success = future.result()
            if success:
                df.at[index, 'extracted'] = True
                successful += 1
            progress_bar.update(1)
            progress_bar.set_postfix({"success": f"{successful}/{total}"})
        except Exception as e:
            progress_bar.update(1)
            print(f"Error processing task: {str(e)}")

Downloading logos: 100%|██████████| 4384/4384 [02:49<00:00, 30.92logo/s, success=3713/4384]

In [31]:
copy_df = df.copy(deep=True)
copy_df

Unnamed: 0,domain,extracted
0,stanbicbank.co.zw,True
1,astrazeneca.ua,True
2,autosecuritas-ct-seysses.fr,True
3,ovb.ro,True
4,mazda-autohaus-hellwig-hoyerswerda.de,True
...,...,...
4379,synlab.ec,False
4380,ccusa.co.za,True
4381,aamcolawrencevillega.com,True
4382,mazda-autohaus-born-ludwigslust.de,False


## Second Step of the Logo Extraction

- **Used Gemini API** to find the logo's URL in the page's header.  

- **Also used `ThreadPoolExecutor`** for the Gemini API calls.  

- **Downloaded** `.jpg`, `.png`, and `.svg` files.  

- **Converted `.svg` files to `.png`** for later use in feature extraction.  

- **Managed to achieve an additional 9% extraction rate** with this approach.  


In [None]:
from bs4 import BeautifulSoup, Comment

def clean_html(html: str) -> str:
    """
    Extracts and cleans only the content inside the <header> tag by removing
    <script> comment tags.

    Args:
        html (str): Raw HTML content to clean.

    Returns:
        str: Cleaned content inside <header>, or an empty string if <header> is missing.
    """
    try:
        # Parse HTML
        soup = BeautifulSoup(html, 'html.parser')

        # Extract the header tag
        header = soup.header
        if not header:
            return ""  # Return empty string if no <header> exists

        # Remove <script> and <style> tags inside the header
        for tag in header(["script"]):
            tag.decompose()  # Removes tag and its content

        # Remove comments inside the header
        for comment in header.find_all(string=lambda text: isinstance(text, Comment)):
            comment.extract()

        # Return cleaned header HTML
        return header.prettify()

    except Exception:
        return ""  # Return empty string if cleaning fails


In [None]:
import requests

def download_png(url, save_path):
    # Send GET request to the URL
    response = requests.get(url, stream=True)
    
    # Check if the request was successful
    response.raise_for_status()
    
    # Save the file
    with open(save_path, 'wb') as file:
        for chunk in response.iter_content(chunk_size=8192):
            file.write(chunk)
            
    return True

def proc_data(url, idx):
    if '.svg' in url:
        save_path = f'./images/{idx}.svg'
    elif '.png' in url:
        save_path = f'./images/{idx}.png'
    elif '.jpg' in url:
        save_path = f'./images/{idx}.jpg'
        
    download_png(url, save_path)

In [None]:
from api_clients import call_gemini

def process_nth_domain(n):
    if copy_df['extracted'][n] == False:
        try:
            url = f"https://www.{copy_df['domain'][n]}"
            response = requests.get(url, timeout=10)
            if response.status_code == 200:
                html = response.text
                header = clean_html(html)
                
                response = call_gemini(f""" Tell me what is the logo url for this header: {header}. 
                                          The domain is {copy_df['domain'][n]}. 
                                          I want you to find only the url, no aditional information.
                                          The url should look like this: https://www.example.com/logo.png.
                                          If you can't find 'https' add it.
                                          If you can't find the logo url, just type 'NotFound'.
                                          
                                          It is MANDATORY that if a url is found it is returned with 'https://' in the beginning
                                          It is MANDATORY that if a url is not found the result is 'NotFound' exactly like this""")[:-1]
                
                if response.lower() == 'notfound':
                    return
                
                proc_data(response, n)
                
                copy_df.loc[n, 'extracted'] = True
                
        except Exception as e:
            print(f"Error processing {url} - {str(e)}")
            return


In [None]:
# Using ThreadPoolExecutor for parallel requests
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
    executor.map(process_nth_domain, range(len(copy_df['domain'])))

In [37]:
import os
import re
import html
import cairosvg

def clean_svg_for_parsing(svg_file_path):
    """
    Clean SVG file by replacing HTML entities with proper XML entities
    
    Parameters:
    svg_file_path (str): Path to the input SVG file
    output_file_path (str): Path where the cleaned SVG will be saved (optional)
    
    Returns:
    str: Path to the cleaned SVG file
    """
    base, ext = os.path.splitext(svg_file_path)
    output_file_path = f"{base}_cleaned{ext}"
    
    try:
        # Read the SVG file
        with open(svg_file_path, 'r', encoding='utf-8') as file:
            svg_content = file.read()
        
        # Find and replace HTML entities
        def replace_entity(match):
            entity = match.group(1)
            # Convert HTML entity to its Unicode character
            return html.unescape(f"&{entity};")
        
        # Replace entities like &aacute; with their Unicode equivalents
        cleaned_content = re.sub(r'&([a-zA-Z]+);', replace_entity, svg_content)
        
        # Write the cleaned content to a new file
        with open(output_file_path, 'w', encoding='utf-8') as file:
            file.write(cleaned_content)
        
        return output_file_path
        
    except Exception as e:
        print(f"Error cleaning SVG file {svg_file_path}: {e}")
        return None

# Get all files in the images directory
filenames = os.listdir('./images')

for filename in filenames:
    if '.svg' in filename:
        prefix = filename.split('.')[0]
        svg_path = f'./images/{prefix}.svg'
        png_path = f'./images/{prefix}.png'
        
        try:
            # Clean the SVG file first
            cleaned_svg = clean_svg_for_parsing(svg_path)
            
            if cleaned_svg:
                # Convert the cleaned SVG to PNG
                cairosvg.svg2png(url=cleaned_svg, write_to=png_path)
                
                # Remove the temporary cleaned file
                os.remove(cleaned_svg)
                
                # Remove the original SVG files
                os.remove(svg_path)
            else:
                print(f"Failed to clean {svg_path}")
                
        except Exception as e:
            os.remove(svg_path)
            os.remove(cleaned_svg)
            print(f"Error processing {svg_path}: {e}")

Error processing ./images/3866.svg: The SVG size is undefined
Error processing ./images/934.svg: not well-formed (invalid token): line 507, column 54
Error processing ./images/495.svg: mismatched tag: line 7, column 9


## Feature Extraction

- **Used EfficientNetB5** for feature extraction.  

- **Removed the last layer** to take only the feature vectors, without classification.  

- **Used `ThreadPoolExecutor`** for faster feature processing.  


In [52]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from PIL import Image
import os
from torchvision import transforms

class LogoNet(nn.Module):
    """
    A specialized neural network for logo recognition with:
    1. Enhanced local feature extraction
    2. Multi-scale processing
    3. Shape-aware attention
    4. Rotation and scale invariance
    """
    def __init__(self, embedding_dim=512):
        super(LogoNet, self).__init__()
        
        # Base feature extractor (EfficientNet is good for logos due to better edge detection)
        # Could use ResNet50 or other backbones too
        self.backbone = models.efficientnet_b5(weights=models.EfficientNet_B5_Weights.DEFAULT)
        backbone_out_features = 2048  # EfficientNet-B5 output features
        
        # Remove classifier head
        self.backbone = nn.Sequential(*list(self.backbone.children())[:-1])
        
        # Multi-scale feature extraction (helps with logos of different scales)
        self.conv_1x1 = nn.Conv2d(backbone_out_features, 256, kernel_size=1)
        self.conv_3x3 = nn.Conv2d(backbone_out_features, 256, kernel_size=3, padding=1)
        self.conv_5x5 = nn.Conv2d(backbone_out_features, 256, kernel_size=5, padding=2)
        
        # Shape-aware attention module
        self.attention = nn.Sequential(
            nn.Conv2d(768, 128, kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(128, 768, kernel_size=1),
            nn.Sigmoid()
        )
        
        # Global context module
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        
        # Feature embedding layer
        self.embedding = nn.Sequential(
            nn.Linear(768, embedding_dim),
            nn.BatchNorm1d(embedding_dim),
            nn.ReLU()
        )
    
    def forward(self, x):
        # Extract base features
        x = self.backbone(x)
        
        # Multi-scale feature extraction
        feat_1x1 = self.conv_1x1(x)
        feat_3x3 = self.conv_3x3(x)
        feat_5x5 = self.conv_5x5(x)
        
        # Concatenate multi-scale features
        multi_scale_features = torch.cat([feat_1x1, feat_3x3, feat_5x5], dim=1)
        
        # Apply attention
        attention_weights = self.attention(multi_scale_features)
        attended_features = multi_scale_features * attention_weights
        
        # Global pooling and flatten
        x = self.global_pool(attended_features)
        x = x.view(x.size(0), -1)
        
        # Get embedding
        embedding = self.embedding(x)
        
        return embedding


In [53]:
# Initialize model
model = LogoNet()
model.eval()

# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Example transform for inference
inference_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [54]:
# Function to extract logo features using pretrained model
def extract_logo_features(image_path, model=model, transform=inference_transform):
    """Extract features from a logo image using a specialized logo model"""
    try:
        # Open image
        image = Image.open(f"./images/{image_path}").convert('RGB')
        
        # Apply transformations
        image = transform(image).unsqueeze(0)
        
        # Move to the same device as model
        device = next(model.parameters()).device
        image = image.to(device)
        
        # Extract features
        with torch.no_grad():
            features = model(image)
        
        # Return normalized features
        return F.normalize(features, p=2, dim=1).view(1, -1), image_path
    
    except Exception as e:
        print(f"Error processing {image_path}: {str(e)}")
        return None

In [None]:
from concurrent.futures import ThreadPoolExecutor

filenames = os.listdir('./images')
features = []

with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
    parallel_features = executor.map(extract_logo_features, filenames)
    
features = [f for f in parallel_features if f is not None]

Error processing 3111.png: cannot identify image file './images/3111.png'
Error processing 1865.png: cannot identify image file './images/1865.png'
Error processing 4375.png: cannot identify image file './images/4375.png'
Error processing 2006.png: cannot identify image file './images/2006.png'
Error processing 253.png: cannot identify image file './images/253.png'
Error processing 2981.png: cannot identify image file './images/2981.png'
Error processing 3382.png: cannot identify image file './images/3382.png'
Error processing 2803.png: cannot identify image file './images/2803.png'
Error processing 517.png: cannot identify image file './images/517.png'


## Brute-Force Clustering

- **Defined a `SIMILARITY_THRESHOLD`** that can be adjusted based on preference.  

- **A higher `THRESHOLD` results in more similar logos per cluster but also increases the number of clusters.**  

- **Brute-force approach:** Each logo is compared against all logos in existing clusters to ensure optimal placement.  


In [56]:
SIMILARITY_THRESHOLD = 0.7  # Adjust this threshold based on your needs

In [57]:
import torch.nn.functional as F
from tqdm import tqdm

clusters = []
image_clusters = []

# Total number of features to process
total_features = len(features)

# Iterate through features with a progress bar
for feature in tqdm(features, desc="Clustering Images", unit="image", total=total_features):
    if feature is None:
        continue
    feature, filename = feature
    
    best_cluster = -1
    best_cluster_raport = 0

    for i in range(len(clusters)):
        cluster = clusters[i]
        close = 0
        far = 0

        for image_feature in cluster:
            similarity = F.cosine_similarity(feature, image_feature, dim=1).item()
            
            if similarity < SIMILARITY_THRESHOLD:
                far += 1
            else:
                close += 1

        if close / (close + far) > 0.8 and close / (close + far) > best_cluster_raport:
            best_cluster = i
            best_cluster_raport = close / (close + far)

    if best_cluster == -1:
        clusters.append([feature])
        image_clusters.append([filename])
    else:
        clusters[best_cluster].append(feature)
        image_clusters[best_cluster].append(filename)
        
print(f"Number of clusters: {len(clusters)}")

Clustering Images: 100%|██████████| 3949/3949 [00:47<00:00, 84.01image/s]

Number of clusters: 204





## Grouping and JSON Generation

- **Grouped all images** based on the clusters formed in the previous step.  

- **Generated a JSON file** mapping domains to their respective clusters.  


In [58]:
import pandas as pd
import json

for i in range(len(image_clusters)):
    if not os.path.exists(f"./clusters/cluster{i}"):
        os.makedirs(f"./clusters/cluster{i}")
    for image_path in image_clusters[i]:
        os.rename(f"./images/{image_path}", f"./clusters/cluster{i}/{image_path}")
        
df = pd.read_csv('data.csv')
domain_clusters = [[df['domain'][int(y.split('.')[0])] for y in x]for x in image_clusters]
json.dump(domain_clusters, open('clusters.json', 'w'), indent=4)

In [59]:
# import os
# import shutil

# # Define source and destination folders
# source_folder = "./clusters"  # Folder containing cluster subfolders
# destination_folder = "./images"  # Target folder

# # Ensure destination folder exists
# os.makedirs(destination_folder, exist_ok=True)

# # Loop through all subfolders in 'clusters'
# for subfolder in os.listdir(source_folder):
#     subfolder_path = os.path.join(source_folder, subfolder)

#     # Check if it's a directory
#     if os.path.isdir(subfolder_path):
#         # Move each file inside the subfolder
#         for filename in os.listdir(subfolder_path):
#             file_path = os.path.join(subfolder_path, filename)
            
#             if os.path.isfile(file_path):  # Ensure it's a file
#                 shutil.move(file_path, os.path.join(destination_folder, filename))
                
#         os.rmdir(subfolder_path)

# print("✅ All files moved successfully to ./images!")
