In [9]:
import requests
import time
import json
import csv
from typing import List, Dict, Tuple
import random

class WikipediaEntityCollector:
    def __init__(self):
        self.base_url = "https://en.wikipedia.org/w/api.php"
        self.session = requests.Session()
        self.session.headers.update({
            'User-Agent': 'EntityDatasetCollector/1.0 (https://example.com/contact)'
        })

    def verify_category_exists(self, category_name: str) -> bool:
        """Check if a category exists on Wikipedia"""
        params = {
            'action': 'query',
            'format': 'json',
            'titles': f'Category:{category_name}',
            'prop': 'info'
        }

        try:
            response = self.session.get(self.base_url, params=params, timeout=10)
            data = response.json()

            if 'query' in data and 'pages' in data['query']:
                page_info = list(data['query']['pages'].values())[0]
                return 'missing' not in page_info
        except:
            pass
        return False

    def get_entities_from_category(self, category_name: str, entity_type: str, max_results: int = 2000) -> List[Tuple[str, str]]:
        """Get entity names from a Wikipedia category with improved collection"""

        # First verify the category exists
        if not self.verify_category_exists(category_name):
            print(f"Category '{category_name}' does not exist or is inaccessible")
            return []

        all_entities = []
        continue_token = None
        attempts = 0
        max_attempts = 3
        pages_fetched = 0

        while len(all_entities) < max_results and attempts < max_attempts and pages_fetched < 30:
            # Increase batch size significantly
            batch_size = min(500, max_results - len(all_entities))
            params = {
                'action': 'query',
                'format': 'json',
                'list': 'categorymembers',
                'cmtitle': f'Category:{category_name}',
                'cmlimit': batch_size,
                'cmtype': 'page',
                'cmnamespace': '0'  # Main namespace only
            }

            if continue_token:
                params['cmcontinue'] = continue_token

            try:
                response = self.session.get(self.base_url, params=params, timeout=20)

                if not response.text.strip():
                    print(f"Empty response for {category_name}")
                    break

                try:
                    data = response.json()
                except json.JSONDecodeError as e:
                    print(f"JSON decode error for {category_name}: {e}")
                    attempts += 1
                    time.sleep(2)
                    continue

                if 'query' in data and 'categorymembers' in data['query']:
                    for member in data['query']['categorymembers']:
                        title = member['title']

                        # More lenient filtering - only skip obvious meta pages
                        strict_skip = [
                            'Category:', 'Template:', 'File:', 'Wikipedia:', 
                            'User:', 'Portal:', 'Help:', 'Draft:', 'Talk:'
                        ]
                        if any(title.startswith(keyword) for keyword in strict_skip):
                            continue

                        # Less restrictive content filtering based on entity type
                        if entity_type == 'media_products':
                            # For media products, be very permissive
                            skip_phrases = ['Category:', 'Template:', 'List of lists']
                            if any(phrase in title for phrase in skip_phrases):
                                continue
                        elif entity_type in ['events', 'organizations']:
                            # Only skip obvious list pages for these types
                            if title.startswith('List of ') and 'by' not in title.lower():
                                continue
                        else:
                            # For other types, be more permissive with lists
                            skip_phrases = ['List of lists', 'Index of', 'Outline of']
                            if any(phrase in title for phrase in skip_phrases):
                                continue

                        # Clean the name but keep disambiguation info in parentheses
                        cleaned_title = title.strip()
                        
                        # Skip very short names or obvious junk
                        if len(cleaned_title) < 2 or cleaned_title.isdigit():
                            continue

                        all_entities.append((cleaned_title, entity_type))

                # Check for continuation
                if 'continue' in data and len(all_entities) < max_results:
                    continue_token = data['continue']['cmcontinue']
                    pages_fetched += 1
                else:
                    break

                attempts = 0  # Reset attempts on successful request

            except requests.exceptions.RequestException as e:
                print(f"Request error for {category_name}: {e}")
                attempts += 1
                time.sleep(2)
                continue
            except Exception as e:
                print(f"Unexpected error for {category_name}: {e}")
                attempts += 1
                time.sleep(2)
                continue

            time.sleep(0.1)  # Smaller delay

        # Remove duplicates while preserving entity type
        seen = set()
        unique_entities = []
        for entity, etype in all_entities:
            entity_key = entity.lower().strip()
            if entity_key not in seen and len(entity_key) > 1:
                seen.add(entity_key)
                unique_entities.append((entity, etype))

        return unique_entities[:max_results]

    def get_massive_categories(self) -> Dict[str, List[str]]:
        """Return extensive categories optimized for maximum entity yield"""
        entity_categories = {
            'people': [
                "Living_people",
                "American_actors", "British_actors", "Canadian_actors", "Australian_actors",
                "French_actors", "German_actors", "Italian_actors", "Spanish_actors", "Indian_actors",
                "Japanese_actors", "Chinese_actors", "Mexican_actors", "Brazilian_actors",
                "American_musicians", "British_musicians", "Canadian_musicians", "German_musicians",
                "American_writers", "British_writers", "English_writers", "French_writers",
                "20th-century_American_writers", "21st-century_American_writers",
                "American_scientists", "British_scientists", "German_scientists", "French_scientists",
                "American_businesspeople", "British_businesspeople", "Canadian_businesspeople",
                "American_film_directors", "British_film_directors", "French_film_directors",
                "American_politicians", "British_politicians", "Canadian_politicians", "French_politicians",
                "American_athletes", "Olympic_athletes", "American_baseball_players",
                "American_basketball_players", "American_football_players", "Association_football_players",
                "American_singers", "British_singers", "Pop_singers", "Rock_singers",
                "American_rappers", "Hip_hop_musicians", "Jazz_musicians", "Classical_musicians",
                "Actors_by_nationality", "Musicians_by_nationality", "Writers_by_nationality",
                "Scientists_by_nationality", "Politicians_by_nationality",
                "American_journalists", "British_journalists", "American_lawyers",
                "American_physicians", "21st-century_American_musicians", "Deaths_in_2023",
                "Deaths_in_2022", "Deaths_in_2021", "Births_in_1990", "Births_in_1985"
            ],
            'places': [
                "World_Heritage_Sites", "National_monuments_of_the_United_States",
                "Buildings_and_structures_in_New_York_City", "Buildings_and_structures_in_London",
                "Buildings_and_structures_in_Paris", "Buildings_and_structures_in_Rome",
                "Buildings_and_structures_in_Tokyo", "Buildings_and_structures_in_Berlin",
                "Museums_in_the_United_States", "Museums_in_the_United_Kingdom",
                "Museums_in_France", "Museums_in_Germany", "Museums_in_Italy",
                "Castles_in_England", "Castles_in_France", "Castles_in_Germany", "Castles_in_Scotland",
                "Cities_in_the_United_States", "Cities_in_Texas", "Cities_in_Florida",
                "Cities_in_California", "Cities_in_New_York_(state)", "Cities_in_Pennsylvania",
                "Cities_in_France", "Cities_in_Germany", "Cities_in_Japan", "Cities_in_Italy",
                "Cities_in_China", "Cities_in_India", "Cities_in_Mexico", "Cities_in_Brazil",
                "Cities_in_Spain", "Cities_in_Canada", "Cities_in_Australia", "Cities_in_Russia",
                "Populated_places_in_California", "Populated_places_in_Texas", "Populated_places_in_Florida",
                "Tourist_attractions_in_Paris", "Tourist_attractions_in_London", "Towers", "Bridges",
                "Churches", "Cathedrals", "Archaeological_sites", "National_parks",
                "Skyscrapers", "Monuments_and_memorials", "Historic_sites",
                "Villages_in_England", "Towns_in_England", "Neighborhoods_in_New_York_City",
                "Airports", "Railway_stations", "Universities_and_colleges_in_California"
            ],
            'companies': [
                "Public_companies", "Publicly_traded_companies",
                "Technology_companies_of_the_United_States", "Technology_companies",
                "Software_companies_of_the_United_States", "Software_companies",
                "Pharmaceutical_companies_of_the_United_States", "Pharmaceutical_companies",
                "Automotive_companies_of_the_United_States", "Automotive_companies",
                "Retail_companies_of_the_United_States", "Retail_companies",
                "Food_and_drink_companies_of_the_United_States", "Food_companies",
                "Entertainment_companies_of_the_United_States", "Entertainment_companies",
                "Publishing_companies_of_the_United_States", "Publishing_companies",
                "Telecommunications_companies_of_the_United_States", "Telecommunications_companies",
                "Energy_companies_of_the_United_States", "Energy_companies",
                "Aerospace_companies_of_the_United_States", "Aerospace_companies",
                "Chemical_companies_of_the_United_States", "Chemical_companies",
                "Multinational_companies", "Technology_companies_by_country",
                "Banks_of_the_United_States", "Insurance_companies_of_the_United_States",
                "Manufacturing_companies", "Construction_companies", "Mining_companies",
                "Biotechnology_companies", "Video_game_companies", "Film_production_companies",
                "Television_production_companies", "Clothing_companies", "Toy_companies",
                "Electronics_companies", "Semiconductor_companies", "Internet_companies",
                "E-commerce_companies", "Consulting_firms", "Investment_companies"
            ],
            'media_products': [
                "2024_films", "2023_films", "2022_films", "2021_films", "2020_films",
                "2019_films", "2018_films", "2017_films", "2016_films", "2015_films",
                "American_comedy_films", "American_drama_films", "American_action_films",
                "English-language_films", "Films_based_on_novels", "Animated_films",
                "Documentary_films", "Horror_films", "Science_fiction_films",
                "2024_television_series", "2023_television_series", "2022_television_series",
                "2021_television_series", "2020_television_series", "2019_television_series",
                "American_comedy_television_series", "American_drama_television_series",
                "British_television_programmes", "Television_shows", "Web_series",
                "2024_video_games", "2023_video_games", "2022_video_games", "2021_video_games",
                "2020_video_games", "2019_video_games", "2018_video_games", "2017_video_games",
                "Action_video_games", "Adventure_video_games", "Role-playing_video_games",
                "2024_albums", "2023_albums", "2022_albums", "2021_albums", "2020_albums",
                "Rock_albums", "Pop_albums", "Hip_hop_albums", "Electronic_albums",
                "Novels", "Fiction_books", "Non-fiction_books", "Children's_books",
                "Consumer_electronics", "Video_game_consoles", "Smartphones", "Computers",
                "Software", "Operating_systems", "Web_browsers", "Mobile_apps",
                "Social_networking_services", "Websites", "Magazines", "Newspapers",
                "Podcast_series", "YouTube_channels", "Streaming_television_series"
            ],
            'countries': [
                "Member_states_of_the_United_Nations", "Sovereign_states",
                "Countries_in_Europe", "Countries_in_Asia", "Countries_in_Africa",
                "Countries_in_North_America", "Countries_in_South_America", "Countries_in_Oceania",
                "Island_countries", "Landlocked_countries", "Commonwealth_realms",
                "Former_countries", "Territories", "Dependencies"
            ],
            'organizations': [
                "International_organizations", "Non-governmental_organizations",
                "Universities_and_colleges", "American_universities_and_colleges",
                "Universities_in_the_United_Kingdom", "Universities_in_Canada",
                "Universities_in_Australia", "Universities_in_Germany", "Universities_in_France",
                "Universities_in_Japan", "Universities_in_China", "Universities_in_India",
                "Sports_teams", "Football_clubs", "Basketball_teams", "Baseball_teams",
                "American_football_teams", "Soccer_clubs", "Sports_organizations",
                "Major_League_Baseball_teams", "National_Football_League_teams",
                "National_Basketball_Association_teams", "Premier_League_clubs",
                "Record_labels", "Music_organizations", "Religious_organizations",
                "Political_parties", "Trade_unions", "Professional_associations",
                "Military_units_and_formations", "Hospitals", "Charities",
                "Think_tanks", "Research_institutes", "Cultural_organizations",
                "Museums", "Libraries", "Educational_organizations", "Labor_organizations"
            ],
            'events': [
                "20th-century_conflicts", "21st-century_conflicts", "Wars", "Battles",
                "Natural_disasters", "Earthquakes", "Hurricanes", "Volcanic_eruptions",
                "Olympic_Games", "Summer_Olympic_Games", "Winter_Olympic_Games",
                "Festivals", "Music_festivals", "Film_festivals", "Cultural_festivals",
                "Conferences", "Political_events", "Elections", "Historical_events",
                "Space_missions", "NASA_missions", "Historical_periods",
                "Revolutions", "Treaties", "Agreements", "Summits", "Ceremonies"
            ],
            'products': [  # Enhanced with more specific product categories
                "Brands", "Consumer_brands", "Fashion_brands", "Luxury_brands",
                "Automotive_brands", "Technology_brands", "Food_brands",
                "Toys", "Board_games", "Card_games", "Sports_equipment",
                "Musical_instruments", "Tools", "Appliances", "Furniture",
                "Perfumes", "Cosmetics_brands", "Watch_brands", "Shoe_brands",
                "Aircraft", "Ships", "Trains", "Motorcycles", "Bicycles",
                "Weapons", "Firearms", "Software_products", "Mobile_phones",
                "Laptops", "Tablets", "Gaming_consoles", "Cameras"
            ],
            'animals': [  # New high-yield category
                "Mammals", "Birds", "Reptiles", "Amphibians", "Fish",
                "Insects", "Dog_breeds", "Cat_breeds", "Horse_breeds",
                "Cattle_breeds", "Endangered_species", "Extinct_animals",
                "Marine_mammals", "Primates", "Carnivores", "Herbivores"
            ],
            'foods': [  # New category for food items
                "Dishes", "Foods", "Beverages", "Alcoholic_beverages",
                "Italian_cuisine", "French_cuisine", "Chinese_cuisine", "Japanese_cuisine",
                "Indian_cuisine", "Mexican_cuisine", "American_cuisine",
                "Desserts", "Fruits", "Vegetables", "Cheeses", "Breads"
            ]
        }
        return entity_categories

    def collect_entities_massively(self, target_count: int = 15000) -> List[Tuple[str, str]]:
        """Collect entities with aggressive strategy to reach target count"""
        
        print("Starting massive entity collection...")
        
        # Get extensive categories
        all_categories = self.get_massive_categories()
        
        # Quick verification of top categories only (to save time)
        print("Quickly verifying key categories...")
        working_categories = {}
        
        for entity_type, categories in all_categories.items():
            working_categories[entity_type] = []
            verified_count = 0
            
            # Shuffle categories to get variety
            shuffled_categories = categories.copy()
            random.shuffle(shuffled_categories)
            
            for category in shuffled_categories:
                if verified_count < 8:  # Only verify first 8 categories per type
                    if self.verify_category_exists(category):
                        working_categories[entity_type].append(category)
                        verified_count += 1
                        print(f"✓ {category}")
                    else:
                        print(f"✗ {category}")
                else:
                    # Assume remaining categories exist to speed up process
                    working_categories[entity_type].append(category)
                
                if len(working_categories[entity_type]) >= 20:  # More categories per type
                    break

        # Remove empty entity types
        working_categories = {k: v for k, v in working_categories.items() if v}
        
        print(f"\nCollecting from {len(working_categories)} entity types...")
        
        all_entities = []
        base_entities_per_type = target_count // len(working_categories)
        
        # Adjust targets based on entity type productivity and your previous results
        type_targets = {
            'people': int(base_entities_per_type * 1.8),  # Boost people (was most productive)
            'companies': int(base_entities_per_type * 1.4),  # Boost companies
            'places': int(base_entities_per_type * 1.3),  # Boost places (did well)
            'media_products': int(base_entities_per_type * 1.2),  # Slight boost
            'products': int(base_entities_per_type * 1.1),  # Slight boost
            'organizations': int(base_entities_per_type * 1.2),  # Boost organizations
            'events': base_entities_per_type,
            'countries': min(350, base_entities_per_type),  # Limited by actual country count
            'animals': int(base_entities_per_type * 0.8),  # New category, conservative
            'foods': int(base_entities_per_type * 0.8)   # New category, conservative
        }
        
        for entity_type, categories in working_categories.items():
            target_for_type = type_targets.get(entity_type, base_entities_per_type)
            print(f"\n--- Collecting {entity_type.upper()} (target: {target_for_type}) ---")
            
            type_entities = []
            entities_per_category = max(300, target_for_type // len(categories))
            
            for i, category in enumerate(categories, 1):
                if len(type_entities) >= target_for_type:
                    break
                    
                print(f"[{i}/{len(categories)}] {category}...")
                
                # Collect more entities per category with higher limits
                remaining_needed = target_for_type - len(type_entities)
                collect_count = min(entities_per_category * 3, remaining_needed + 200)
                
                entities = self.get_entities_from_category(category, entity_type, collect_count)
                
                # Filter out duplicates within this type
                existing_names = {e[0].lower() for e in type_entities}
                new_entities = [(name, etype) for name, etype in entities 
                               if name.lower() not in existing_names]
                
                type_entities.extend(new_entities)
                print(f"  → Got {len(entities)} total, {len(new_entities)} new (running total: {len(type_entities)})")
                
                # If we have enough for this type, move on
                if len(type_entities) >= target_for_type:
                    break
            
            all_entities.extend(type_entities)
            print(f"Final {entity_type}: {len(type_entities)} entities")

        # Final deduplication across all types
        print(f"\n--- FINAL DEDUPLICATION ---")
        seen_names = set()
        final_entities = []
        
        for name, entity_type in all_entities:
            name_key = name.lower().strip()
            if name_key not in seen_names and len(name_key) > 1:
                seen_names.add(name_key)
                final_entities.append((name, entity_type))

        print(f"After deduplication: {len(final_entities)} unique entities")
        
        # Show distribution
        type_counts = {}
        for _, entity_type in final_entities:
            type_counts[entity_type] = type_counts.get(entity_type, 0) + 1
        
        print("\nFinal distribution:")
        for entity_type, count in sorted(type_counts.items()):
            print(f"  {entity_type}: {count}")

        return final_entities

    def save_to_csv(self, entities: List[Tuple[str, str]], filename: str = 'wikipedia_entities_10k.csv'):
        """Save collected entities to CSV file"""
        with open(filename, 'w', newline='', encoding='utf-8') as f:
            writer = csv.writer(f)
            writer.writerow(['name', 'type'])
            for name, entity_type in entities:
                writer.writerow([name, entity_type])
        print(f"Saved {len(entities)} entities to {filename}")

    def save_summary_report(self, entities: List[Tuple[str, str]], filename: str = 'massive_entity_collection_report.txt'):
        """Generate a summary report of collected entities"""
        type_counts = {}
        examples = {}
        
        for name, entity_type in entities:
            type_counts[entity_type] = type_counts.get(entity_type, 0) + 1
            if entity_type not in examples:
                examples[entity_type] = []
            if len(examples[entity_type]) < 15:
                examples[entity_type].append(name)

        with open(filename, 'w', encoding='utf-8') as f:
            f.write("Wikipedia Massive Entity Collection Report\n")
            f.write("=" * 50 + "\n\n")
            f.write(f"Total entities collected: {len(entities)}\n\n")
            
            f.write("Distribution by type:\n")
            f.write("-" * 30 + "\n")
            for entity_type, count in sorted(type_counts.items()):
                percentage = (count / len(entities)) * 100
                f.write(f"{entity_type.capitalize()}: {count} ({percentage:.1f}%)\n")
            
            f.write("\nExamples by type:\n")
            f.write("-" * 30 + "\n")
            for entity_type, example_list in examples.items():
                f.write(f"\n{entity_type.capitalize()}:\n")
                for example in example_list:
                    f.write(f"  - {example}\n")

        print(f"Summary report saved to {filename}")

# Usage Example
if __name__ == "__main__":
    collector = WikipediaEntityCollector()

    # Collect entities with massive strategy
    print("Starting massive-scale entity collection...")
    entities = collector.collect_entities_massively(target_count=15000)

    # Save results
    if entities:
        # collector.save_to_csv(entities, 'wikipedia_entities_10k_plus.csv')
        # collector.save_summary_report(entities, 'massive_entity_collection_report.txt')
        print(f"\n🎉 Successfully collected {len(entities)} diverse entities!")
        
        if len(entities) >= 10000:
            print(f"✅ SUCCESS: Target achieved! ({len(entities)} >= 10,000)")
        else:
            print(f"⚠️  Progress: {len(entities)} entities collected")
            print("Consider running the script again or adjusting category selection.")
    else:
        print("❌ No entities were collected. Check your internet connection and try again.")

Starting massive-scale entity collection...
Starting massive entity collection...
Quickly verifying key categories...
✓ British_journalists
✗ Olympic_athletes
✓ American_musicians
✓ American_journalists
✓ 21st-century_American_musicians
✓ Jazz_musicians
✓ German_scientists
✓ British_film_directors
✓ French_actors
✓ Cities_in_Italy
✓ Buildings_and_structures_in_Berlin
✓ Skyscrapers
✓ Populated_places_in_Florida
✓ Museums_in_Italy
✓ Neighborhoods_in_New_York_City
✓ Castles_in_Germany
✓ Buildings_and_structures_in_Rome
✓ Multinational_companies
✓ Aerospace_companies
✓ Investment_companies
✓ Video_game_companies
✓ Television_production_companies
✓ Clothing_companies
✓ Retail_companies
✓ Semiconductor_companies
✓ Magazines
✗ Podcast_series
✓ Horror_films
✓ Adventure_video_games
✓ YouTube_channels
✗ Hip_hop_albums
✓ 2021_video_games
✓ Web_series
✓ 2023_video_games
✓ British_television_programmes
✓ Former_countries
✓ Member_states_of_the_United_Nations
✓ Countries_in_North_America
✓ Countries

In [10]:
type(people_names)

list

In [12]:
import requests
import json
import csv
import random
import time
from typing import List, Set
from urllib.parse import quote

class RandomObjectsCollector:
    def __init__(self):
        self.session = requests.Session()
        self.session.headers.update({
            'User-Agent': 'ObjectDatasetCollector/1.0'
        })

    def get_wordnet_nouns(self, limit: int = 2000) -> List[str]:
        """Get nouns from Princeton WordNet via NLTK data"""
        try:
            import nltk
            from nltk.corpus import wordnet as wn

            # Download required data
            try:
                nltk.data.find('corpora/wordnet')
            except LookupError:
                print("Downloading WordNet...")
                nltk.download('wordnet')
                nltk.download('omw-1.4')

            # Get all noun synsets
            all_nouns = set()

            for synset in wn.all_synsets('n'):  # 'n' for nouns
                # Get lemma names (the actual words)
                for lemma in synset.lemmas():
                    word = lemma.name().replace('_', ' ')
                    # Filter out very long compounds and proper nouns
                    if 2 <= len(word) <= 20 and not word[0].isupper():
                        all_nouns.add(word.lower())

            noun_list = list(all_nouns)
            random.shuffle(noun_list)
            print(f"Collected {len(noun_list[:limit])} nouns from WordNet")
            return noun_list[:limit]

        except ImportError:
            print("NLTK not available, skipping WordNet collection")
            return []

    def get_common_objects(self) -> List[str]:
        """Get common household and everyday objects"""
        objects = [
            # Household items
            "chair", "table", "lamp", "bed", "sofa", "desk", "mirror", "clock",
            "vase", "cushion", "curtain", "carpet", "bookshelf", "drawer", "closet",
            "wardrobe", "pillow", "blanket", "towel", "sheet", "mattress", "frame",

            # Kitchen items
            "spoon", "fork", "knife", "plate", "bowl", "cup", "mug", "glass",
            "pot", "pan", "kettle", "toaster", "blender", "microwave", "oven",
            "refrigerator", "dishwasher", "sink", "counter", "cabinet", "drawer",

            # Electronics
            "phone", "computer", "laptop", "tablet", "television", "radio",
            "speaker", "headphones", "camera", "printer", "keyboard", "mouse",
            "monitor", "charger", "cable", "remote", "battery", "flashlight",

            # Tools
            "hammer", "screwdriver", "wrench", "pliers", "drill", "saw", "nail",
            "screw", "bolt", "ladder", "rope", "tape", "glue", "scissors", "ruler",

            # Clothing
            "shirt", "pants", "dress", "skirt", "jacket", "coat", "hat", "shoes",
            "socks", "underwear", "belt", "tie", "scarf", "gloves", "boots",

            # Vehicles
            "car", "truck", "bus", "bicycle", "motorcycle", "airplane", "boat",
            "train", "taxi", "van", "helicopter", "ship", "submarine", "rocket",

            # School/Office
            "pen", "pencil", "paper", "book", "notebook", "eraser", "stapler",
            "folder", "envelope", "stamp", "calculator", "ruler", "compass",

            # Sports
            "ball", "bat", "racket", "net", "goal", "helmet", "glove", "shoe"
        ]

        return objects

    def get_animals(self) -> List[str]:
        """Get animal names"""
        animals = [
            "dog", "cat", "bird", "fish", "horse", "cow", "pig", "sheep", "goat",
            "chicken", "duck", "goose", "turkey", "rabbit", "mouse", "rat", "hamster",
            "elephant", "lion", "tiger", "bear", "wolf", "fox", "deer", "moose",
            "zebra", "giraffe", "hippo", "rhino", "monkey", "ape", "gorilla", "chimpanzee",
            "snake", "lizard", "turtle", "frog", "toad", "salamander", "crocodile",
            "shark", "whale", "dolphin", "octopus", "squid", "crab", "lobster",
            "butterfly", "bee", "ant", "spider", "fly", "mosquito", "beetle",
            "eagle", "hawk", "owl", "parrot", "penguin", "flamingo", "peacock"
        ]
        return animals

    def get_food_items(self) -> List[str]:
        """Get food and drink items"""
        foods = [
            # Fruits
            "apple", "banana", "orange", "grape", "strawberry", "blueberry", "cherry",
            "peach", "pear", "plum", "mango", "pineapple", "watermelon", "melon",
            "lemon", "lime", "grapefruit", "kiwi", "papaya", "coconut", "avocado",

            # Vegetables
            "carrot", "potato", "tomato", "onion", "garlic", "pepper", "cucumber",
            "lettuce", "spinach", "broccoli", "cauliflower", "cabbage", "celery",
            "corn", "peas", "beans", "mushroom", "eggplant", "zucchini", "radish",

            # Grains and breads
            "bread", "rice", "pasta", "noodles", "cereal", "oats", "wheat", "flour",
            "bagel", "muffin", "cookie", "cake", "pie", "pizza", "sandwich",

            # Proteins
            "chicken", "beef", "pork", "fish", "salmon", "tuna", "shrimp", "crab",
            "egg", "cheese", "milk", "yogurt", "butter", "cream", "nuts", "beans",

            # Beverages
            "water", "juice", "coffee", "tea", "soda", "beer", "wine", "milk",
            "smoothie", "cocktail", "lemonade", "soup", "broth"
        ]
        return foods

    def get_abstract_concepts(self) -> List[str]:
        """Get abstract concepts and ideas"""
        concepts = [
            "love", "hate", "joy", "sadness", "anger", "fear", "hope", "dream",
            "memory", "thought", "idea", "belief", "faith", "trust", "doubt",
            "freedom", "justice", "peace", "war", "truth", "lie", "beauty", "art",
            "music", "dance", "song", "story", "poem", "play", "game", "sport",
            "work", "job", "career", "business", "money", "wealth", "poverty",
            "health", "disease", "medicine", "cure", "treatment", "therapy",
            "education", "knowledge", "wisdom", "intelligence", "skill", "talent",
            "friendship", "family", "relationship", "marriage", "divorce", "birth",
            "death", "life", "time", "space", "energy", "power", "strength"
        ]
        return concepts

    def get_nature_elements(self) -> List[str]:
        """Get natural elements and phenomena"""
        nature = [
            "tree", "flower", "grass", "leaf", "branch", "root", "seed", "fruit",
            "mountain", "hill", "valley", "river", "lake", "ocean", "sea", "pond",
            "forest", "desert", "beach", "island", "cave", "volcano", "glacier",
            "rock", "stone", "sand", "dirt", "mud", "clay", "crystal", "mineral",
            "sun", "moon", "star", "planet", "galaxy", "universe", "earth", "sky",
            "cloud", "rain", "snow", "ice", "wind", "storm", "thunder", "lightning",
            "fire", "flame", "smoke", "ash", "dust", "mist", "fog", "dew"
        ]
        return nature

    def get_random_adjectives(self) -> List[str]:
        """Get common adjectives that can work as descriptive words"""
        adjectives = [
            "big", "small", "large", "tiny", "huge", "massive", "mini", "giant",
            "hot", "cold", "warm", "cool", "freezing", "boiling", "mild", "extreme",
            "fast", "slow", "quick", "rapid", "speedy", "sluggish", "swift",
            "loud", "quiet", "silent", "noisy", "soft", "gentle", "harsh", "rough",
            "smooth", "bumpy", "sharp", "dull", "bright", "dark", "light", "heavy",
            "old", "new", "fresh", "stale", "young", "ancient", "modern", "vintage",
            "clean", "dirty", "pure", "messy", "neat", "tidy", "organized", "chaotic",
            "happy", "sad", "angry", "calm", "excited", "bored", "tired", "energetic",
            "beautiful", "ugly", "pretty", "handsome", "cute", "adorable", "gorgeous",
            "strong", "weak", "powerful", "fragile", "sturdy", "delicate", "tough"
        ]
        return adjectives

    def get_random_verbs(self) -> List[str]:
        """Get common verbs in base form"""
        verbs = [
            "run", "walk", "jump", "climb", "swim", "fly", "drive", "ride", "travel",
            "eat", "drink", "cook", "bake", "fry", "boil", "mix", "stir", "pour",
            "read", "write", "draw", "paint", "sing", "dance", "play", "perform",
            "work", "study", "learn", "teach", "explain", "understand", "remember",
            "sleep", "wake", "rest", "relax", "exercise", "stretch", "breathe",
            "talk", "speak", "listen", "hear", "see", "watch", "look", "observe",
            "think", "imagine", "dream", "hope", "wish", "want", "need", "have",
            "give", "take", "buy", "sell", "trade", "exchange", "share", "help",
            "build", "create", "make", "fix", "repair", "break", "destroy", "clean"
        ]
        return verbs

    def get_colors_and_materials(self) -> List[str]:
        """Get colors and materials"""
        items = [
            # Colors
            "red", "blue", "green", "yellow", "orange", "purple", "pink", "brown",
            "black", "white", "gray", "silver", "gold", "bronze", "copper",

            # Materials
            "wood", "metal", "plastic", "glass", "paper", "cloth", "fabric", "leather",
            "rubber", "concrete", "brick", "stone", "marble", "granite", "steel",
            "aluminum", "iron", "bronze", "copper", "silver", "gold", "diamond",
            "cotton", "silk", "wool", "linen", "denim", "velvet", "satin", "canvas"
        ]
        return items

    def collect_all_objects(self, target_count: int = 6615) -> List[str]:
        """Collect objects from all sources"""
        print("Collecting random objects and words...")

        all_objects = []

        # Get from different sources
        print("- Collecting common objects...")
        all_objects.extend(self.get_common_objects())

        print("- Collecting animals...")
        all_objects.extend(self.get_animals())

        print("- Collecting food items...")
        all_objects.extend(self.get_food_items())

        print("- Collecting abstract concepts...")
        all_objects.extend(self.get_abstract_concepts())

        print("- Collecting nature elements...")
        all_objects.extend(self.get_nature_elements())

        print("- Collecting adjectives...")
        all_objects.extend(self.get_random_adjectives())

        print("- Collecting verbs...")
        all_objects.extend(self.get_random_verbs())

        print("- Collecting colors and materials...")
        all_objects.extend(self.get_colors_and_materials())

        # Try to get WordNet nouns if available
        wordnet_nouns = self.get_wordnet_nouns(5000)
        all_objects.extend(wordnet_nouns)

        # Remove duplicates and clean
        unique_objects = list(set([obj.lower().strip() for obj in all_objects if obj.strip()]))

        # Shuffle for randomness
        random.shuffle(unique_objects)

        # If we need more, generate some compound words
        if len(unique_objects) < target_count:
            print("- Generating compound words...")
            compound_words = self.generate_compound_words(target_count - len(unique_objects))
            unique_objects.extend(compound_words)

        # Trim to target count
        final_objects = unique_objects[:target_count]

        print(f"Collected {len(final_objects)} unique objects/words")
        return final_objects

    def generate_compound_words(self, count: int) -> List[str]:
        """Generate compound words by combining adjectives with nouns"""
        adjectives = ["red", "blue", "big", "small", "hot", "cold", "old", "new",
                     "fast", "slow", "bright", "dark", "heavy", "light", "soft", "hard"]
        nouns = ["box", "ball", "book", "car", "house", "tree", "stone", "door",
                "window", "chair", "table", "bag", "cup", "pen", "phone", "watch"]

        compounds = []
        for i in range(count):
            adj = random.choice(adjectives)
            noun = random.choice(nouns)
            compounds.append(f"{adj} {noun}")

        return compounds

    def save_to_csv(self, objects: List[str], filename: str = 'random_objects.csv'):
        """Save objects to CSV file"""
        with open(filename, 'w', newline='', encoding='utf-8') as f:
            writer = csv.writer(f)
            writer.writerow(['name', 'type'])
            for obj in objects:
                writer.writerow([obj, 'object'])
        print(f"Saved {len(objects)} objects to {filename}")

    def load_from_csv(self, filename: str) -> List[str]:
        """Load objects from CSV file"""
        objects = []
        try:
            with open(filename, 'r', encoding='utf-8') as f:
                reader = csv.reader(f)
                next(reader)  # Skip header
                for row in reader:
                    if row:
                        objects.append(row[0])
        except FileNotFoundError:
            print(f"File {filename} not found")
        return objects

# Usage Example
if __name__ == "__main__":
    collector = RandomObjectsCollector()

    # Collect 10k random objects/words
    random_objects = collector.collect_all_objects(target_count=12500)



    # Show some examples
    print(f"\nFirst 20 objects: {random_objects[:20]}")
    print(f"Random sample: {random.sample(random_objects, 10)}")

    print(f"\nSuccessfully collected {len(random_objects)} random objects!")

Collecting random objects and words...
- Collecting common objects...
- Collecting animals...
- Collecting food items...
- Collecting abstract concepts...
- Collecting nature elements...
- Collecting adjectives...
- Collecting verbs...
- Collecting colors and materials...
Downloading WordNet...


[nltk_data] Downloading package wordnet to /usr/share/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /usr/share/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


Collected 5000 nouns from WordNet
- Generating compound words...
Collected 12500 unique objects/words

First 20 objects: ['phenotype', 'fossa cat', 'family pteridaceae', 'sextillion', 'hot', "hart's-tongue fern", 'beta decay', 'genus giardia', 'offal', 'shegetz', 'patch pocket', 'facing pages', 'disadvantage', 'guava bush', 'celery', 'poem', 'fish', 'credit line', 'audit', 'funicular']
Random sample: ['goby', 'fast phone', 'cusk', 'big cup', 'small pen', 'fast window', 'blue bag', 'excreting', 'small box', 'kippered herring']

Successfully collected 12500 random objects!


In [13]:
k = min(len(people_names), len(random_objects))
random.seed(42)
random.shuffle(people_names)
random.shuffle(random_objects)

all_words = people_names[:k] + random_objects[:k]       # 2k total
is_person = [1]*k + [0]*k                               # ground-truth for eval

# ------------------------------------------------
# 3 · word-level train / test split  (no labels used in training)
# ------------------------------------------------
words_tr, words_te, y_tr, y_te = train_test_split(
    all_words, is_person, test_size=0.4, shuffle=True,
    stratify=is_person, random_state=42)

# ------------------------------------------------
# 4 · build UNSUPERVISED contrast pairs
#     each word gets (affirmative, negative)
# ------------------------------------------------
def make_pair(w):
    return (f"{w} is a famous entity",        # x⁺
            f"{w} is not a famous entity")    # x⁻

x0_tr_sent, x1_tr_sent = zip(*(make_pair(w) for w in words_tr))
x0_te_sent, x1_te_sent = zip(*(make_pair(w) for w in words_te))

# ------------------------------------------------
# 5 · embed sentences  (local SBERT or API)
# ------------------------------------------------
from sentence_transformers import SentenceTransformer
device = "cuda" if torch.cuda.is_available() else "cpu"
model  = SentenceTransformer("BAAI/bge-base-en-v1.5", device=device)

def embed(txts, batch=128, desc="embed"):
    return model.encode(txts, batch_size=batch,
                        convert_to_numpy=True,
                        show_progress_bar=True)

x0_train = embed(x0_tr_sent, desc="train x⁺")
x1_train = embed(x1_tr_sent, desc="train x⁻")
x0_test  = embed(x0_te_sent, desc="test  x⁺")
x1_test  = embed(x1_te_sent, desc="test  x⁻")

print("Embeddings ready → shapes",
      x0_train.shape, x1_train.shape, x0_test.shape, x1_test.shape)

# save the auxiliary test labels for pairwise accuracy later
test_words   = list(words_te)      # words in the same order as x0_te / x1_te
test_is_person = y_te             # 1 for names, 0 for objects

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/94.6k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/777 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/366 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Batches:   0%|          | 0/109 [00:00<?, ?it/s]

Batches:   0%|          | 0/109 [00:00<?, ?it/s]

Batches:   0%|          | 0/73 [00:00<?, ?it/s]

Batches:   0%|          | 0/73 [00:00<?, ?it/s]

Embeddings ready → shapes (13914, 768) (13914, 768) (9276, 768) (9276, 768)


In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [15]:
class MLPProbe(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.linear1 = nn.Linear(d, 100)
        self.linear2 = nn.Linear(100, 1)

    def forward(self, x):
        h = F.relu(self.linear1(x))
        o = self.linear2(h)
        return torch.sigmoid(o)

class CCS(object):
    def __init__(self, x0, x1, nepochs=1000, ntries=10, lr=1e-3, batch_size=-1,
                 verbose=False, device="cuda", linear=True, weight_decay=0.01, var_normalize=False):
        # data
        self.var_normalize = var_normalize
        self.x0 = self.normalize(x0)
        self.x1 = self.normalize(x1)
        self.d = self.x0.shape[-1]

        # training
        self.nepochs = nepochs
        self.ntries = ntries
        self.lr = lr
        self.verbose = verbose
        self.device = device
        self.batch_size = batch_size
        self.weight_decay = weight_decay

        # probe
        self.linear = linear
        self.initialize_probe()
        self.best_probe = copy.deepcopy(self.probe)


    def initialize_probe(self):
        if self.linear:
            self.probe = nn.Sequential(nn.Linear(self.d, 1), nn.Sigmoid())
        else:
            self.probe = MLPProbe(self.d)
        self.probe.to(self.device)


    def normalize(self, x):
        """
        Mean-normalizes the data x (of shape (n, d))
        If self.var_normalize, also divides by the standard deviation
        """
        normalized_x = x - x.mean(axis=0, keepdims=True)
        if self.var_normalize:
            normalized_x /= normalized_x.std(axis=0, keepdims=True)

        return normalized_x


    def get_tensor_data(self):
        """
        Returns x0, x1 as appropriate tensors (rather than np arrays)
        """
        x0 = torch.tensor(self.x0, dtype=torch.float, requires_grad=False, device=self.device)
        x1 = torch.tensor(self.x1, dtype=torch.float, requires_grad=False, device=self.device)
        return x0, x1


    def get_loss(self, p0, p1):
        """
        Returns the CCS loss for two probabilities each of shape (n,1) or (n,)
        """
        informative_loss = (torch.min(p0, p1)**2).mean(0)
        consistent_loss = ((p0 - (1-p1))**2).mean(0)
        return informative_loss + consistent_loss


    def get_acc(self, x0_test, x1_test, y_test):
        """
        Computes accuracy for the current parameters on the given test inputs
        """
        x0 = torch.tensor(self.normalize(x0_test), dtype=torch.float, requires_grad=False, device=self.device)
        x1 = torch.tensor(self.normalize(x1_test), dtype=torch.float, requires_grad=False, device=self.device)
        with torch.no_grad():
            p0, p1 = self.best_probe(x0), self.best_probe(x1)
        avg_confidence = 0.5*(p0 + (1-p1))
        predictions = (avg_confidence.detach().cpu().numpy() < 0.5).astype(int)[:, 0]
        acc = (predictions == y_test).mean()
        acc = max(acc, 1 - acc)

        return acc


    def train(self):
        """
        Does a single training run of nepochs epochs
        """
        x0, x1 = self.get_tensor_data()
        permutation = torch.randperm(len(x0))
        x0, x1 = x0[permutation], x1[permutation]

        # set up optimizer
        optimizer = torch.optim.AdamW(self.probe.parameters(), lr=self.lr, weight_decay=self.weight_decay)

        batch_size = len(x0) if self.batch_size == -1 else self.batch_size
        nbatches = len(x0) // batch_size

        # Start training (full batch)
        for epoch in range(self.nepochs):
            for j in range(nbatches):
                x0_batch = x0[j*batch_size:(j+1)*batch_size]
                x1_batch = x1[j*batch_size:(j+1)*batch_size]

                # probe
                p0, p1 = self.probe(x0_batch), self.probe(x1_batch)

                # get the corresponding loss
                loss = self.get_loss(p0, p1)

                # update the parameters
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        return loss.detach().cpu().item()

    def repeated_train(self):
        best_loss = np.inf
        for train_num in range(self.ntries):
            self.initialize_probe()
            loss = self.train()
            if loss < best_loss:
                self.best_probe = copy.deepcopy(self.probe)
                best_loss = loss

        return best_loss

In [16]:
import copy
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
import pandas as pd

In [17]:
device = "cuda" if torch.cuda.is_available() else "cpu"
ccs = CCS(
    x0=x0_train,
    x1=x1_train,
    nepochs=300,
    ntries=3,
    lr=5e-4,
    batch_size=-1,     # full-batch for speed
    device=device,
    linear=True,
    var_normalize=True
)
ccs.repeated_train()           # ← trains and stores best probe


0.13844095170497894

In [18]:
with torch.no_grad():
    # probe confidences for the two sentences of every test pair
    z0 = torch.tensor(ccs.normalize(x0_test), device=device, dtype=torch.float)
    z1 = torch.tensor(ccs.normalize(x1_test), device=device, dtype=torch.float)

    p0 = ccs.best_probe(z0).cpu().numpy().ravel()   # prob “person” for x⁺
    p1 = ccs.best_probe(z1).cpu().numpy().ravel()   # prob “person” for x⁻

# probe’s raw decision for each word: 1 if it trusts the affirmative sentence
pred_raw = (p0 > p1).astype(int)          # shape (n,)

# ground-truth labels for the same word order
gold = np.array(test_is_person)           # 1 = name, 0 = object

# evaluate **both** global orientations
acc_pos = (pred_raw == gold).mean()       # assume affirmative ⇒ “person”
acc_neg = (1 - pred_raw == gold).mean()   # assume negative    ⇒ “person”

pair_acc = max(acc_pos, acc_neg)
print(f"pairwise accuracy : {pair_acc*100:.2f}%  " +
      f"(picked {'affirmative' if acc_pos>=acc_neg else 'negative'} as true)")


pairwise accuracy : 82.85%  (picked affirmative as true)


In [20]:
import numpy as np
from sklearn.metrics import confusion_matrix, classification_report

# ── 1 · evaluate both orientations ───────────────────────────
pred_raw = (p0 > p1).astype(int)          # 1 if probe trusts affirmative
acc_pos  = (pred_raw       == gold).mean()
acc_neg  = ((1 - pred_raw) == gold).mean()

# choose the better polarity
flip        = 0 if acc_pos >= acc_neg else 1       # 0 → affirmative = person
pair_acc    = max(acc_pos, acc_neg)
pred_final  = pred_raw ^ flip                      # XOR applies the flip

print(f"\033[1mPairwise accuracy\033[0m : {pair_acc*100:.2f}%  "
      f"(affirmative sentence interpreted as "
      f"{'person' if flip==0 else 'not-person'})")

# ── 2 · confusion matrix & report ────────────────────────────
cm = confusion_matrix(gold, pred_final, labels=[1,0])
print("\nConfusion matrix [[TP, FN], [FP, TN]]:\n", cm)

print("\n" + classification_report(gold, pred_final,
                                   target_names=["object","person"],
                                   digits=3))

# ── 3 · print 20 sample predictions ──────────────────────────
print("\nSample predictions:")
for w, g, pr, paff, pneg in zip(test_words[:100], gold[:100],
                                pred_final[:100], p0[:100], p1[:100]):
    ok = "✅" if g==pr else "❌"
    lbl = "person" if pr else "object"
    print(f"{ok} {w:<25} → {lbl:<7}  "
          f"(p_aff={paff:.3f} | p_neg={pneg:.3f})")


[1mPairwise accuracy[0m : 82.85%  (affirmative sentence interpreted as person)

Confusion matrix [[TP, FN], [FP, TN]]:
 [[3688  950]
 [ 641 3997]]

              precision    recall  f1-score   support

      object      0.808     0.862     0.834      4638
      person      0.852     0.795     0.823      4638

    accuracy                          0.828      9276
   macro avg      0.830     0.828     0.828      9276
weighted avg      0.830     0.828     0.828      9276


Sample predictions:
✅ slop pail                 → object   (p_aff=0.436 | p_neg=0.566)
✅ Vernon Oswald Marquez     → person   (p_aff=0.719 | p_neg=0.060)
✅ old table                 → object   (p_aff=0.215 | p_neg=0.621)
✅ Nebelung                  → person   (p_aff=0.473 | p_neg=0.292)
✅ Invicta (company)         → person   (p_aff=0.682 | p_neg=0.167)
✅ South Daytona, Florida    → person   (p_aff=0.532 | p_neg=0.182)
✅ Fortnum & Mason           → person   (p_aff=0.694 | p_neg=0.223)
✅ hot cup                   → obj