<a href="https://colab.research.google.com/github/amien1410/amien-scrapers/blob/main/Business_Classifier_with_Bart.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import pandas as pd
import aiohttp
import asyncio
import logging
from typing import List, Dict, Optional
from transformers import pipeline
from dataclasses import dataclass
import os
import time
from tqdm import tqdm

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

@dataclass
class BraveSearchConfig:
    """Configuration for Brave Search API"""
    api_key: str
    base_url: str = "https://api.search.brave.com/res/v1/web/search"
    rate_limit: int = 1  # requests per second
    max_retries: int = 3

class BusinessDescriptionFetcher:
    """Fetches business descriptions using the Brave Search API"""

    def __init__(self, config: BraveSearchConfig):
        """
        Initialize the fetcher with configuration

        Args:
            config (BraveSearchConfig): Configuration for Brave Search API
        """
        self.config = config
        self.session = None
        self._last_request_time = 0

    async def _init_session(self):
        """Initialize aiohttp session if not already created"""
        if self.session is None:
            self.session = aiohttp.ClientSession(
                headers={"X-Brave-API-Key": self.config.api_key}
            )

    async def _close_session(self):
        """Close aiohttp session"""
        if self.session:
            await self.session.close()
            self.session = None

    async def _rate_limit_delay(self):
        """Implement rate limiting"""
        current_time = time.time()
        time_since_last_request = current_time - self._last_request_time
        if time_since_last_request < 1.0 / self.config.rate_limit:
            await asyncio.sleep(1.0 / self.config.rate_limit - time_since_last_request)
        self._last_request_time = time.time()

    async def _fetch_description(self, business_name: str, country: str) -> Optional[str]:
        """
        Fetch description for a single business

        Args:
            business_name (str): Name of the business
            country (str): Country where the business operates

        Returns:
            Optional[str]: Business description if found, None otherwise
        """
        await self._rate_limit_delay()

        query = f"{business_name} {country} company description"
        params = {
            "q": query,
            "count": 1
        }

        for attempt in range(self.config.max_retries):
            try:
                async with self.session.get(self.config.base_url, params=params) as response:
                    if response.status == 200:
                        data = await response.json()
                        if data.get("web") and data["web"]["results"]:
                            return data["web"]["results"][0].get("description")
                    elif response.status == 429:  # Rate limit exceeded
                        await asyncio.sleep(2 ** attempt)  # Exponential backoff
                        continue
                    else:
                        logger.error(f"Error fetching description: {response.status}")
                        return None
            except Exception as e:
                logger.error(f"Error during API request: {str(e)}")
                if attempt == self.config.max_retries - 1:
                    return None
                await asyncio.sleep(2 ** attempt)

        return None

    async def fetch_descriptions(self, df: pd.DataFrame) -> pd.DataFrame:
        """
        Fetch descriptions for all businesses in the DataFrame

        Args:
            df (pd.DataFrame): Input DataFrame with business information

        Returns:
            pd.DataFrame: Updated DataFrame with business descriptions
        """
        await self._init_session()

        descriptions = []
        try:
            for _, row in tqdm(df.iterrows(), total=len(df)):
                description = await self._fetch_description(
                    row['business_name'],
                    row['country']
                )
                descriptions.append(description)
        finally:
            await self._close_session()

        result_df = df.copy()
        result_df['business_description'] = descriptions
        return result_df

In [None]:
# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

class BusinessClassifier:
    """Classifies businesses into subcategories using BART-large-MNLI with enhanced accuracy"""

    CATEGORY_DEFINITIONS = {
        'Consumer Goods': {
            'description': 'Companies that distribute consumer products, retail goods, and personal items',
            'keywords': ['retail', 'consumer', 'products', 'goods', 'personal', 'clothing', 'fashion', 'accessories']
        },
        'Food and Beverage': {
            'description': 'Companies involved in food products, beverages, and related distribution',
            'keywords': ['food', 'beverage', 'drink', 'restaurant', 'catering', 'grocery', 'meal']
        },
        'Healthcare and Pharmaceuticals': {
            'description': 'Companies in medical supplies, pharmaceuticals, and healthcare product distribution',
            'keywords': ['health', 'medical', 'pharmaceutical', 'medicine', 'drug', 'healthcare', 'clinical']
        },
        'Automotive': {
            'description': 'Companies in vehicle parts, automotive supplies, and related distribution',
            'keywords': ['auto', 'car', 'vehicle', 'automotive', 'parts', 'motor', 'transportation']
        },
        'Logistics and Transportation': {
            'description': 'Companies focused on shipping, freight, and transportation services',
            'keywords': ['logistics', 'shipping', 'freight', 'transport', 'delivery', 'supply chain']
        },
        'Technology and Electronics': {
            'description': 'Companies in software, hardware, IT solutions, and electronic equipment',
            'keywords': ['technology', 'software', 'hardware', 'IT', 'computer', 'digital', 'electronic', 'tech', 'solutions']
        },
        'Industrial and Manufacturing': {
            'description': 'Companies in industrial equipment, manufacturing supplies, and heavy machinery',
            'keywords': ['industrial', 'manufacturing', 'machinery', 'equipment', 'production', 'factory']
        },
        'Energy and Utilities': {
            'description': 'Companies in power generation, energy distribution, and utility services',
            'keywords': ['energy', 'power', 'utility', 'electricity', 'renewable', 'solar', 'gas']
        },
        'Hospitality and Services': {
            'description': 'Companies in hospitality, service industry, and related distribution',
            'keywords': ['hospitality', 'service', 'hotel', 'tourism', 'entertainment', 'leisure']
        }
    }

    CATEGORY_MAPPING = {
        'distribution_centre': list(CATEGORY_DEFINITIONS.keys())
    }

    def __init__(self):
        """Initialize the classifier with the BART-large-MNLI model"""
        self.classifier = pipeline(
            "zero-shot-classification",
            model="facebook/bart-large-mnli",
            device=-1  # CPU
        )

    def _get_subcategories(self, parent_category: str) -> List[str]:
        """Get subcategories for a given parent category"""
        return self.CATEGORY_MAPPING.get(parent_category, [])

    def _calculate_keyword_score(self, description: str, category: str) -> float:
        """
        Calculate a score based on keyword matches

        Args:
            description (str): Business description
            category (str): Category to check against

        Returns:
            float: Score between 0 and 1
        """
        description_lower = description.lower()
        keywords = self.CATEGORY_DEFINITIONS[category]['keywords']
        matches = sum(1 for keyword in keywords if keyword in description_lower)
        return matches / len(keywords) if keywords else 0

    def classify_business(self, description: str, parent_category: str) -> Optional[str]:
        """
        Classify a single business description with improved accuracy

        Args:
            description (str): Business description
            parent_category (str): Parent category of the business

        Returns:
            Optional[str]: Predicted subcategory
        """
        if not description:
            return None

        subcategories = self._get_subcategories(parent_category)
        if not subcategories:
            logger.warning(f"No subcategories found for parent category: {parent_category}")
            return None

        try:
            # Get base zero-shot classification
            result = self.classifier(
                description,
                candidate_labels=subcategories,
                hypothesis_template="This business belongs to the {} sector.",
                multi_label=False
            )

            # Get classification scores
            classifier_scores = {label: score for label, score in zip(result['labels'], result['scores'])}

            # Calculate keyword scores
            keyword_scores = {category: self._calculate_keyword_score(description, category)
                            for category in subcategories}

            # Combine scores (70% classifier, 30% keywords)
            final_scores = {}
            for category in subcategories:
                classifier_score = classifier_scores.get(category, 0)
                keyword_score = keyword_scores.get(category, 0)
                final_scores[category] = (classifier_score * 0.7) + (keyword_score * 0.3)

            # Return category with highest combined score
            return max(final_scores.items(), key=lambda x: x[1])[0]

        except Exception as e:
            logger.error(f"Error during classification: {str(e)}")
            return None

    def classify_businesses(self, df: pd.DataFrame) -> pd.DataFrame:
        """
        Classify all businesses in the DataFrame

        Args:
            df (pd.DataFrame): Input DataFrame with business descriptions

        Returns:
            pd.DataFrame: Updated DataFrame with business subcategories
        """
        result_df = df.copy()
        subcategories = []

        for _, row in tqdm(df.iterrows(), total=len(df)):
            subcategory = self.classify_business(
                row['business_description'],
                row['business_parent_category']
            )
            subcategories.append(subcategory)

        result_df['business_subcategory'] = subcategories
        return result_df

In [None]:
# load modules
import os
import pandas as pd
from tqdm import tqdm
from typing import List, Dict, Optional
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer

class BusinessClassifier:
    """Classifies businesses into subcategories using BART-large-MNLI with enhanced accuracy"""

    CATEGORY_DEFINITIONS = {
        'Consumer Goods': {
            'description': 'Companies that distribute consumer products, retail goods, and personal items',
            'keywords': ['retail', 'consumer', 'products', 'goods', 'personal', 'clothing', 'fashion', 'accessories']
        },
        'Food and Beverage': {
            'description': 'Companies involved in food products, beverages, and related distribution',
            'keywords': ['food', 'beverage', 'drink', 'restaurant', 'catering', 'grocery', 'meal']
        },
        'Healthcare and Pharmaceuticals': {
            'description': 'Companies in medical supplies, pharmaceuticals, and healthcare product distribution',
            'keywords': ['health', 'medical', 'pharmaceutical', 'medicine', 'drug', 'healthcare', 'clinical']
        },
        'Automotive': {
            'description': 'Companies in vehicle parts, automotive supplies, and related distribution',
            'keywords': ['auto', 'car', 'vehicle', 'automotive', 'parts', 'motor', 'transportation']
        },
        'Logistics and Transportation': {
            'description': 'Companies focused on shipping, freight, and transportation services',
            'keywords': ['logistics', 'shipping', 'freight', 'transport', 'delivery', 'supply chain']
        },
        'Technology and Electronics': {
            'description': 'Companies in software, hardware, IT solutions, and electronic equipment',
            'keywords': ['technology', 'software', 'hardware', 'IT', 'computer', 'digital', 'electronic', 'tech', 'solutions']
        },
        'Industrial and Manufacturing': {
            'description': 'Companies in industrial equipment, manufacturing supplies, and heavy machinery',
            'keywords': ['industrial', 'manufacturing', 'machinery', 'equipment', 'production', 'factory']
        },
        'Energy and Utilities': {
            'description': 'Companies in power generation, energy distribution, and utility services',
            'keywords': ['energy', 'power', 'utility', 'electricity', 'renewable', 'solar', 'gas']
        },
        'Hospitality and Services': {
            'description': 'Companies in hospitality, service industry, and related distribution',
            'keywords': ['hospitality', 'service', 'hotel', 'tourism', 'entertainment', 'leisure']
        }
    }

    CATEGORY_MAPPING = {
        'distribution_centre': list(CATEGORY_DEFINITIONS.keys())
    }

    def __init__(self, model_dir="model"):
        """Initialize the classifier with the BART-large-MNLI model, checking if the model is stored locally."""
        self.model_name = "facebook/bart-large-mnli"
        self.model_dir = model_dir
        self.classifier = self.load_model()

    def load_model(self):
        """Load the BART-large-MNLI model from local storage if available, else download and save it."""
        model_path = os.path.join(self.model_dir, self.model_name)

        # Check if model and tokenizer already exist in the designated folder
        if not os.path.isdir(model_path):
            print(f"Model not found locally in {model_path}. Downloading and saving model...")
            os.makedirs(model_path, exist_ok=True)
            model = AutoModelForSequenceClassification.from_pretrained(self.model_name)
            tokenizer = AutoTokenizer.from_pretrained(self.model_name)
            model.save_pretrained(model_path)
            tokenizer.save_pretrained(model_path)
        else:
            print(f"Model found locally in {model_path}. Loading model...")

        return pipeline("zero-shot-classification", model=model_path, tokenizer=model_path, device=-1)

    def _get_subcategories(self, parent_category: str) -> List[str]:
        """Get subcategories for a given parent category"""
        return self.CATEGORY_MAPPING.get(parent_category, [])

    def _calculate_keyword_score(self, description: str, category: str) -> float:
        """
        Calculate a score based on keyword matches

        Args:
            description (str): Business description
            category (str): Category to check against

        Returns:
            float: Score between 0 and 1
        """
        description_lower = description.lower()
        keywords = self.CATEGORY_DEFINITIONS[category]['keywords']
        matches = sum(1 for keyword in keywords if keyword in description_lower)
        return matches / len(keywords) if keywords else 0

    def classify_business(self, description: str, parent_category: str) -> Optional[str]:
        """
        Classify a single business description with improved accuracy

        Args:
            description (str): Business description
            parent_category (str): Parent category of the business

        Returns:
            Optional[str]: Predicted subcategory
        """
        if not description:
            return None

        subcategories = self._get_subcategories(parent_category)
        if not subcategories:
            print(f"Warning: No subcategories found for parent category: {parent_category}")
            return None

        try:
            # Get base zero-shot classification
            result = self.classifier(
                description,
                candidate_labels=subcategories,
                hypothesis_template="This business belongs to the {} sector.",
                multi_label=False
            )

            # Get classification scores
            classifier_scores = {label: score for label, score in zip(result['labels'], result['scores'])}

            # Calculate keyword scores
            keyword_scores = {category: self._calculate_keyword_score(description, category)
                            for category in subcategories}

            # Combine scores (70% classifier, 30% keywords)
            final_scores = {}
            for category in subcategories:
                classifier_score = classifier_scores.get(category, 0)
                keyword_score = keyword_scores.get(category, 0)
                final_scores[category] = (classifier_score * 0.7) + (keyword_score * 0.3)

            # Return category with highest combined score
            return max(final_scores.items(), key=lambda x: x[1])[0]

        except Exception as e:
            print(f"Error during classification: {str(e)}")
            return None

    def classify_businesses(self, df: pd.DataFrame) -> pd.DataFrame:
        """
        Classify all businesses in the DataFrame

        Args:
            df (pd.DataFrame): Input DataFrame with business descriptions

        Returns:
            pd.DataFrame: Updated DataFrame with business subcategories
        """
        result_df = df.copy()
        subcategories = []

        for _, row in tqdm(df.iterrows(), total=len(df)):
            subcategory = self.classify_business(
                row['business_description'],
                row['business_parent_category']
            )
            subcategories.append(subcategory)

        result_df['business_subcategory'] = subcategories
        return result_df

In [None]:
input_data = {
    'uuid': list(range(1, 21)),
    'business_name': [
        'Tech Innovators Inc.', 'Green Energy Solutions', 'Smart Home Devices Co.', 'Eco-Friendly Packaging',
        'HealthCare United', 'Auto Parts Express', 'Fast Logistics', 'Smart Grid Solutions',
        'Clean Energy Hub', 'Hospitality Network', 'Consumer Tech Lab', 'Industrial Machines Corp.',
        'Fresh Foods Ltd.', 'Pharma Health Inc.', 'Transport Connect', 'Digital Ware Inc.',
        'Green Manufacturing', 'Luxury Hotel Group', 'Smart Electronics Ltd.', 'Industrial Power Systems'
    ],
    'business_description': [
        'A company that develops cutting-edge software solutions for businesses.',
        'A company specializing in renewable energy solutions and services.',
        'A manufacturer of devices for smart homes and automation.',
        'A provider of environmentally friendly packaging solutions for businesses.',
        'A healthcare services provider for patients and medical professionals.',
        'A supplier of automotive parts for retail and wholesale markets.',
        'A logistics company providing fast delivery solutions worldwide.',
        'A provider of smart grid technology for energy management.',
        'A company focused on clean energy and renewable resources.',
        'A network offering hospitality services to hotels and resorts.',
        'A tech lab focusing on the development of consumer electronics.',
        'A company manufacturing industrial machines for various sectors.',
        'A distributor of fresh and organic food products.',
        'A pharmaceutical company providing healthcare and wellness products.',
        'A transportation solutions provider for cargo and passengers.',
        'A digital warehousing solutions provider for e-commerce.',
        'A company specializing in green manufacturing processes.',
        'A luxury hotel group with properties in major cities worldwide.',
        'A provider of smart electronic devices for personal use.',
        'A company offering industrial power systems for factories and plants.'
    ],
    'business_parent_category': [
        'distribution_centre', 'distribution_centre', 'distribution_centre', 'distribution_centre',
        'distribution_centre', 'distribution_centre', 'distribution_centre', 'distribution_centre',
        'distribution_centre', 'distribution_centre', 'distribution_centre', 'distribution_centre',
        'distribution_centre', 'distribution_centre', 'distribution_centre', 'distribution_centre',
        'distribution_centre', 'distribution_centre', 'distribution_centre', 'distribution_centre'
    ]
}

df = pd.DataFrame(input_data)

# Classify businesses
classifier = BusinessClassifier()
final_df = classifier.classify_businesses(df)
final_df

Model not found locally in model/facebook/bart-large-mnli. Downloading and saving model...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/1.15k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Non-default generation parameters: {'forced_eos_token_id': 2}
100%|██████████| 20/20 [01:53<00:00,  5.67s/it]


Unnamed: 0,uuid,business_name,business_description,business_parent_category,business_subcategory
0,1,Tech Innovators Inc.,A company that develops cutting-edge software ...,distribution_centre,Technology and Electronics
1,2,Green Energy Solutions,A company specializing in renewable energy sol...,distribution_centre,Energy and Utilities
2,3,Smart Home Devices Co.,A manufacturer of devices for smart homes and ...,distribution_centre,Consumer Goods
3,4,Eco-Friendly Packaging,A provider of environmentally friendly packagi...,distribution_centre,Industrial and Manufacturing
4,5,HealthCare United,A healthcare services provider for patients an...,distribution_centre,Healthcare and Pharmaceuticals
5,6,Auto Parts Express,A supplier of automotive parts for retail and ...,distribution_centre,Automotive
6,7,Fast Logistics,A logistics company providing fast delivery so...,distribution_centre,Logistics and Transportation
7,8,Smart Grid Solutions,A provider of smart grid technology for energy...,distribution_centre,Energy and Utilities
8,9,Clean Energy Hub,A company focused on clean energy and renewabl...,distribution_centre,Energy and Utilities
9,10,Hospitality Network,A network offering hospitality services to hot...,distribution_centre,Hospitality and Services


In [None]:
import pandas as pd
import requests
import time

class BusinessDescriptionFetcher:
    def __init__(self, api_key):
        self.api_key = api_key
        self.base_url = "https://api.search.brave.com/"  # Update to the Brave API endpoint

    def fetch_description(self, business_name, country):
        query = f"{business_name} {country}"
        headers = {"Authorization": f"Bearer {self.api_key}"}
        params = {"q": query, "limit": 1}

        try:
            response = requests.get(self.base_url, headers=headers, params=params)
            response.raise_for_status()
            result = response.json()

            # Assuming API returns descriptions in a format we can extract
            description = result.get("items", [{}])[0].get("snippet", "")
            return description
        except requests.exceptions.RequestException as e:
            print(f"Error fetching description for {business_name}: {e}")
            return ""

    def fetch_descriptions(self, df):
        df["business_description"] = ""

        for i, row in df.iterrows():
            description = self.fetch_description(row["business_name"], row["country"])
            df.at[i, "business_description"] = description
            time.sleep(1)  # Rate-limiting; adjust as per API limits

        return df


In [None]:
data = {
    'uuid': [1, 2],
    'business_name': ['Tech Innovators Inc.', 'Green Energy Solutions'],
    'country': ['USA', 'Canada']
}
input_df = pd.DataFrame(data)
fetcher = BusinessDescriptionFetcher(api_key="YOUR_API_KEY")
output_df = fetcher.fetch_descriptions(input_df)
print(output_df)

In [None]:
import pandas as pd
from transformers import pipeline

class BusinessClassifier:
    def __init__(self):
        self.classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
        self.subcategories = {
            "distribution_centre": [
                "Consumer Goods",
                "Food and Beverage",
                "Healthcare and Pharmaceuticals",
                "Automotive",
                "Logistics and Transportation",
                "Technology and Electronics",
                "Industrial and Manufacturing",
                "Energy and Utilities",
                "Hospitality and Services"
            ]
            # Additional categories and subcategories can be added here
        }

    def classify_business(self, description, parent_category):
        labels = self.subcategories.get(parent_category, [])
        if not labels:
            return "Unknown"

        classification = self.classifier(description, labels)
        return classification["labels"][0] if classification["scores"][0] >= 0.5 else "Other"

    def classify_businesses(self, df):
        df["business_subcategory"] = ""

        for i, row in df.iterrows():
            parent_category = row["business_parent_category"]
            description = row["business_description"]
            subcategory = self.classify_business(description, parent_category)
            df.at[i, "business_subcategory"] = subcategory

        return df


In [None]:
input_data = {
    'uuid': [1, 2],
    'business_name': ['Tech Innovators Inc.', 'Green Energy Solutions'],
    'business_description': [
        'A company that develops cutting-edge software solutions for businesses.',
        'A company specializing in renewable energy solutions and services.'
    ],
    'business_parent_category': ['distribution_centre', 'distribution_centre']
}
input_df = pd.DataFrame(input_data)
classifier = BusinessClassifier()
output_df = classifier.classify_businesses(input_df)
print(output_df)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/1.15k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]



   uuid           business_name  \
0     1    Tech Innovators Inc.   
1     2  Green Energy Solutions   

                                business_description business_parent_category  \
0  A company that develops cutting-edge software ...      distribution_centre   
1  A company specializing in renewable energy sol...      distribution_centre   

  business_subcategory  
0                Other  
1                Other  
