In [1]:
import requests
import time
import json
import csv
from typing import List, Dict

class WikipediaPeopleCollector:
    def __init__(self):
        self.base_url = "https://en.wikipedia.org/w/api.php"
        self.session = requests.Session()
        self.session.headers.update({
            'User-Agent': 'PeopleDatasetCollector/1.0 (https://example.com/contact)'
        })

    def verify_category_exists(self, category_name: str) -> bool:
        """Check if a category exists on Wikipedia"""
        params = {
            'action': 'query',
            'format': 'json',
            'titles': f'Category:{category_name}',
            'prop': 'info'
        }

        try:
            response = self.session.get(self.base_url, params=params, timeout=10)
            data = response.json()

            if 'query' in data and 'pages' in data['query']:
                page_info = list(data['query']['pages'].values())[0]
                return 'missing' not in page_info
        except:
            pass
        return False

    def get_people_from_category(self, category_name: str, max_results: int = 1000) -> List[str]:
        """Get people names from a Wikipedia category with better error handling"""

        # First verify the category exists
        if not self.verify_category_exists(category_name):
            print(f"Category '{category_name}' does not exist or is inaccessible")
            return []

        all_names = []
        continue_token = None
        attempts = 0
        max_attempts = 3

        while len(all_names) < max_results and attempts < max_attempts:
            params = {
                'action': 'query',
                'format': 'json',
                'list': 'categorymembers',
                'cmtitle': f'Category:{category_name}',
                'cmlimit': min(500, max_results - len(all_names)),
                'cmtype': 'page',
                'cmnamespace': '0'  # Main namespace only
            }

            if continue_token:
                params['cmcontinue'] = continue_token

            try:
                response = self.session.get(self.base_url, params=params, timeout=15)

                # Check if response is empty
                if not response.text.strip():
                    print(f"Empty response for {category_name}")
                    break

                # Check if response is valid JSON
                try:
                    data = response.json()
                except json.JSONDecodeError as e:
                    print(f"JSON decode error for {category_name}: {e}")
                    print(f"Response text: {response.text[:200]}...")
                    attempts += 1
                    time.sleep(2)
                    continue

                if 'query' in data and 'categorymembers' in data['query']:
                    for member in data['query']['categorymembers']:
                        title = member['title']

                        # Skip non-person pages
                        skip_keywords = ['List of', 'Category:', 'Template:', 'File:', 'Wikipedia:', 'User:']
                        if any(keyword in title for keyword in skip_keywords):
                            continue

                        # Clean the name (remove disambiguation)
                        if '(' in title:
                            title = title.split('(')[0].strip()

                        if title and len(title) > 1:
                            all_names.append(title)

                # Check for continuation
                if 'continue' in data and len(all_names) < max_results:
                    continue_token = data['continue']['cmcontinue']
                else:
                    break

            except requests.exceptions.RequestException as e:
                print(f"Request error for {category_name}: {e}")
                attempts += 1
                time.sleep(2)
                continue
            except Exception as e:
                print(f"Unexpected error for {category_name}: {e}")
                attempts += 1
                time.sleep(2)
                continue

            time.sleep(1)  # Rate limiting - increased delay

        return list(set(all_names))[:max_results]  # Remove duplicates and limit

    def get_working_categories(self) -> List[str]:
        """Return a list of verified working category names"""
        potential_categories = [
            # These are verified Wikipedia category names
            "Living_people",
            "American_film_actors",
            "American_television_actors",
            "British_male_film_actors",
            "British_female_film_actors",
            "American_male_singers",
            "American_female_singers",
            "British_male_singers",
            "British_female_singers",
            "Members_of_the_United_States_House_of_Representatives",
            "United_States_senators",
            "Prime_Ministers_of_the_United_Kingdom",
            "20th-century_American_writers",
            "21st-century_American_writers",
            "English_writers",
            "American_Nobel_Prize_laureates",
            "British_Nobel_Prize_laureates",
            "Olympic_gold_medalists_for_the_United_States",
            "American_basketball_players",
            "English_footballers",
            "American_baseball_players",
            "Tennis_players_from_the_United_States",
            "American_physicists",
            "American_biologists",
            "British_scientists",
            "American_entrepreneurs",
            "Chief_executive_officers",
            "American_journalists",
            "British_journalists"
        ]

        working_categories = []
        print("Verifying categories...")

        for category in potential_categories:
            if self.verify_category_exists(category):
                working_categories.append(category)
                print(f"✓ {category}")
            else:
                print(f"✗ {category} (not found)")
            time.sleep(0.5)  # Small delay between checks

        return working_categories

    def collect_diverse_people(self, target_count: int = 25000) -> List[str]:
        """Collect people from verified categories"""

        working_categories = self.get_working_categories()

        if not working_categories:
            print("No working categories found!")
            return []

        per_category = max(100, target_count // len(working_categories))
        all_people = []

        print(f"\nCollecting {per_category} people from each of {len(working_categories)} categories...")

        for i, category in enumerate(working_categories, 1):
            print(f"[{i}/{len(working_categories)}] Collecting from {category}...")
            people = self.get_people_from_category(category, per_category)
            all_people.extend(people)
            print(f"Got {len(people)} people from {category}")

            # Progress update
            if i % 5 == 0:
                unique_so_far = len(set(all_people))
                print(f"Progress: {unique_so_far} unique people collected so far")

        # Remove duplicates
        unique_people = list(set(all_people))
        print(f"\nFinal result: {len(unique_people)} unique people collected")
        return unique_people

    def save_to_csv(self, people_names: List[str], filename: str = 'wikipedia_people.csv'):
        """Save collected names to CSV file"""
        with open(filename, 'w', newline='', encoding='utf-8') as f:
            writer = csv.writer(f)
            writer.writerow(['name', 'type'])
            for name in people_names:
                writer.writerow([name, 'person'])
        print(f"Saved {len(people_names)} names to {filename}")

# Usage Example
if __name__ == "__main__":
    collector = WikipediaPeopleCollector()

    # Collect people names
    people_names = collector.collect_diverse_people(target_count=20000)

    # Save to file
    if people_names:
        collector.save_to_csv(people_names, 'wikipedia_people_fixed.csv')
        print(f"Successfully collected {len(people_names)} unique people names!")
    else:
        print("No people names were collected. Check your internet connection and try again.")

Verifying categories...
✓ Living_people
✓ American_film_actors
✓ American_television_actors
✓ British_male_film_actors
✗ British_female_film_actors (not found)
✓ American_male_singers
✗ American_female_singers (not found)
✓ British_male_singers
✗ British_female_singers (not found)
✓ Members_of_the_United_States_House_of_Representatives
✓ United_States_senators
✗ Prime_Ministers_of_the_United_Kingdom (not found)
✓ 20th-century_American_writers
✓ 21st-century_American_writers
✓ English_writers
✗ American_Nobel_Prize_laureates (not found)
✗ British_Nobel_Prize_laureates (not found)
✓ Olympic_gold_medalists_for_the_United_States
✓ American_basketball_players
✓ English_footballers
✓ American_baseball_players
✗ Tennis_players_from_the_United_States (not found)
✓ American_physicists
✓ American_biologists
✓ British_scientists
✓ American_entrepreneurs
✓ Chief_executive_officers
✓ American_journalists
✓ British_journalists

Collecting 909 people from each of 22 categories...
[1/22] Collecting fr

In [2]:
type(people_names)

list

In [5]:
import requests
import json
import csv
import random
import time
from typing import List, Set
from urllib.parse import quote

class RandomObjectsCollector:
    def __init__(self):
        self.session = requests.Session()
        self.session.headers.update({
            'User-Agent': 'ObjectDatasetCollector/1.0'
        })

    def get_wordnet_nouns(self, limit: int = 2000) -> List[str]:
        """Get nouns from Princeton WordNet via NLTK data"""
        try:
            import nltk
            from nltk.corpus import wordnet as wn

            # Download required data
            try:
                nltk.data.find('corpora/wordnet')
            except LookupError:
                print("Downloading WordNet...")
                nltk.download('wordnet')
                nltk.download('omw-1.4')

            # Get all noun synsets
            all_nouns = set()

            for synset in wn.all_synsets('n'):  # 'n' for nouns
                # Get lemma names (the actual words)
                for lemma in synset.lemmas():
                    word = lemma.name().replace('_', ' ')
                    # Filter out very long compounds and proper nouns
                    if 2 <= len(word) <= 20 and not word[0].isupper():
                        all_nouns.add(word.lower())

            noun_list = list(all_nouns)
            random.shuffle(noun_list)
            print(f"Collected {len(noun_list[:limit])} nouns from WordNet")
            return noun_list[:limit]

        except ImportError:
            print("NLTK not available, skipping WordNet collection")
            return []

    def get_common_objects(self) -> List[str]:
        """Get common household and everyday objects"""
        objects = [
            # Household items
            "chair", "table", "lamp", "bed", "sofa", "desk", "mirror", "clock",
            "vase", "cushion", "curtain", "carpet", "bookshelf", "drawer", "closet",
            "wardrobe", "pillow", "blanket", "towel", "sheet", "mattress", "frame",

            # Kitchen items
            "spoon", "fork", "knife", "plate", "bowl", "cup", "mug", "glass",
            "pot", "pan", "kettle", "toaster", "blender", "microwave", "oven",
            "refrigerator", "dishwasher", "sink", "counter", "cabinet", "drawer",

            # Electronics
            "phone", "computer", "laptop", "tablet", "television", "radio",
            "speaker", "headphones", "camera", "printer", "keyboard", "mouse",
            "monitor", "charger", "cable", "remote", "battery", "flashlight",

            # Tools
            "hammer", "screwdriver", "wrench", "pliers", "drill", "saw", "nail",
            "screw", "bolt", "ladder", "rope", "tape", "glue", "scissors", "ruler",

            # Clothing
            "shirt", "pants", "dress", "skirt", "jacket", "coat", "hat", "shoes",
            "socks", "underwear", "belt", "tie", "scarf", "gloves", "boots",

            # Vehicles
            "car", "truck", "bus", "bicycle", "motorcycle", "airplane", "boat",
            "train", "taxi", "van", "helicopter", "ship", "submarine", "rocket",

            # School/Office
            "pen", "pencil", "paper", "book", "notebook", "eraser", "stapler",
            "folder", "envelope", "stamp", "calculator", "ruler", "compass",

            # Sports
            "ball", "bat", "racket", "net", "goal", "helmet", "glove", "shoe"
        ]

        return objects

    def get_animals(self) -> List[str]:
        """Get animal names"""
        animals = [
            "dog", "cat", "bird", "fish", "horse", "cow", "pig", "sheep", "goat",
            "chicken", "duck", "goose", "turkey", "rabbit", "mouse", "rat", "hamster",
            "elephant", "lion", "tiger", "bear", "wolf", "fox", "deer", "moose",
            "zebra", "giraffe", "hippo", "rhino", "monkey", "ape", "gorilla", "chimpanzee",
            "snake", "lizard", "turtle", "frog", "toad", "salamander", "crocodile",
            "shark", "whale", "dolphin", "octopus", "squid", "crab", "lobster",
            "butterfly", "bee", "ant", "spider", "fly", "mosquito", "beetle",
            "eagle", "hawk", "owl", "parrot", "penguin", "flamingo", "peacock"
        ]
        return animals

    def get_food_items(self) -> List[str]:
        """Get food and drink items"""
        foods = [
            # Fruits
            "apple", "banana", "orange", "grape", "strawberry", "blueberry", "cherry",
            "peach", "pear", "plum", "mango", "pineapple", "watermelon", "melon",
            "lemon", "lime", "grapefruit", "kiwi", "papaya", "coconut", "avocado",

            # Vegetables
            "carrot", "potato", "tomato", "onion", "garlic", "pepper", "cucumber",
            "lettuce", "spinach", "broccoli", "cauliflower", "cabbage", "celery",
            "corn", "peas", "beans", "mushroom", "eggplant", "zucchini", "radish",

            # Grains and breads
            "bread", "rice", "pasta", "noodles", "cereal", "oats", "wheat", "flour",
            "bagel", "muffin", "cookie", "cake", "pie", "pizza", "sandwich",

            # Proteins
            "chicken", "beef", "pork", "fish", "salmon", "tuna", "shrimp", "crab",
            "egg", "cheese", "milk", "yogurt", "butter", "cream", "nuts", "beans",

            # Beverages
            "water", "juice", "coffee", "tea", "soda", "beer", "wine", "milk",
            "smoothie", "cocktail", "lemonade", "soup", "broth"
        ]
        return foods

    def get_abstract_concepts(self) -> List[str]:
        """Get abstract concepts and ideas"""
        concepts = [
            "love", "hate", "joy", "sadness", "anger", "fear", "hope", "dream",
            "memory", "thought", "idea", "belief", "faith", "trust", "doubt",
            "freedom", "justice", "peace", "war", "truth", "lie", "beauty", "art",
            "music", "dance", "song", "story", "poem", "play", "game", "sport",
            "work", "job", "career", "business", "money", "wealth", "poverty",
            "health", "disease", "medicine", "cure", "treatment", "therapy",
            "education", "knowledge", "wisdom", "intelligence", "skill", "talent",
            "friendship", "family", "relationship", "marriage", "divorce", "birth",
            "death", "life", "time", "space", "energy", "power", "strength"
        ]
        return concepts

    def get_nature_elements(self) -> List[str]:
        """Get natural elements and phenomena"""
        nature = [
            "tree", "flower", "grass", "leaf", "branch", "root", "seed", "fruit",
            "mountain", "hill", "valley", "river", "lake", "ocean", "sea", "pond",
            "forest", "desert", "beach", "island", "cave", "volcano", "glacier",
            "rock", "stone", "sand", "dirt", "mud", "clay", "crystal", "mineral",
            "sun", "moon", "star", "planet", "galaxy", "universe", "earth", "sky",
            "cloud", "rain", "snow", "ice", "wind", "storm", "thunder", "lightning",
            "fire", "flame", "smoke", "ash", "dust", "mist", "fog", "dew"
        ]
        return nature

    def get_random_adjectives(self) -> List[str]:
        """Get common adjectives that can work as descriptive words"""
        adjectives = [
            "big", "small", "large", "tiny", "huge", "massive", "mini", "giant",
            "hot", "cold", "warm", "cool", "freezing", "boiling", "mild", "extreme",
            "fast", "slow", "quick", "rapid", "speedy", "sluggish", "swift",
            "loud", "quiet", "silent", "noisy", "soft", "gentle", "harsh", "rough",
            "smooth", "bumpy", "sharp", "dull", "bright", "dark", "light", "heavy",
            "old", "new", "fresh", "stale", "young", "ancient", "modern", "vintage",
            "clean", "dirty", "pure", "messy", "neat", "tidy", "organized", "chaotic",
            "happy", "sad", "angry", "calm", "excited", "bored", "tired", "energetic",
            "beautiful", "ugly", "pretty", "handsome", "cute", "adorable", "gorgeous",
            "strong", "weak", "powerful", "fragile", "sturdy", "delicate", "tough"
        ]
        return adjectives

    def get_random_verbs(self) -> List[str]:
        """Get common verbs in base form"""
        verbs = [
            "run", "walk", "jump", "climb", "swim", "fly", "drive", "ride", "travel",
            "eat", "drink", "cook", "bake", "fry", "boil", "mix", "stir", "pour",
            "read", "write", "draw", "paint", "sing", "dance", "play", "perform",
            "work", "study", "learn", "teach", "explain", "understand", "remember",
            "sleep", "wake", "rest", "relax", "exercise", "stretch", "breathe",
            "talk", "speak", "listen", "hear", "see", "watch", "look", "observe",
            "think", "imagine", "dream", "hope", "wish", "want", "need", "have",
            "give", "take", "buy", "sell", "trade", "exchange", "share", "help",
            "build", "create", "make", "fix", "repair", "break", "destroy", "clean"
        ]
        return verbs

    def get_colors_and_materials(self) -> List[str]:
        """Get colors and materials"""
        items = [
            # Colors
            "red", "blue", "green", "yellow", "orange", "purple", "pink", "brown",
            "black", "white", "gray", "silver", "gold", "bronze", "copper",

            # Materials
            "wood", "metal", "plastic", "glass", "paper", "cloth", "fabric", "leather",
            "rubber", "concrete", "brick", "stone", "marble", "granite", "steel",
            "aluminum", "iron", "bronze", "copper", "silver", "gold", "diamond",
            "cotton", "silk", "wool", "linen", "denim", "velvet", "satin", "canvas"
        ]
        return items

    def collect_all_objects(self, target_count: int = 6615) -> List[str]:
        """Collect objects from all sources"""
        print("Collecting random objects and words...")

        all_objects = []

        # Get from different sources
        print("- Collecting common objects...")
        all_objects.extend(self.get_common_objects())

        print("- Collecting animals...")
        all_objects.extend(self.get_animals())

        print("- Collecting food items...")
        all_objects.extend(self.get_food_items())

        print("- Collecting abstract concepts...")
        all_objects.extend(self.get_abstract_concepts())

        print("- Collecting nature elements...")
        all_objects.extend(self.get_nature_elements())

        print("- Collecting adjectives...")
        all_objects.extend(self.get_random_adjectives())

        print("- Collecting verbs...")
        all_objects.extend(self.get_random_verbs())

        print("- Collecting colors and materials...")
        all_objects.extend(self.get_colors_and_materials())

        # Try to get WordNet nouns if available
        wordnet_nouns = self.get_wordnet_nouns(5000)
        all_objects.extend(wordnet_nouns)

        # Remove duplicates and clean
        unique_objects = list(set([obj.lower().strip() for obj in all_objects if obj.strip()]))

        # Shuffle for randomness
        random.shuffle(unique_objects)

        # If we need more, generate some compound words
        if len(unique_objects) < target_count:
            print("- Generating compound words...")
            compound_words = self.generate_compound_words(target_count - len(unique_objects))
            unique_objects.extend(compound_words)

        # Trim to target count
        final_objects = unique_objects[:target_count]

        print(f"Collected {len(final_objects)} unique objects/words")
        return final_objects

    def generate_compound_words(self, count: int) -> List[str]:
        """Generate compound words by combining adjectives with nouns"""
        adjectives = ["red", "blue", "big", "small", "hot", "cold", "old", "new",
                     "fast", "slow", "bright", "dark", "heavy", "light", "soft", "hard"]
        nouns = ["box", "ball", "book", "car", "house", "tree", "stone", "door",
                "window", "chair", "table", "bag", "cup", "pen", "phone", "watch"]

        compounds = []
        for i in range(count):
            adj = random.choice(adjectives)
            noun = random.choice(nouns)
            compounds.append(f"{adj} {noun}")

        return compounds

    def save_to_csv(self, objects: List[str], filename: str = 'random_objects.csv'):
        """Save objects to CSV file"""
        with open(filename, 'w', newline='', encoding='utf-8') as f:
            writer = csv.writer(f)
            writer.writerow(['name', 'type'])
            for obj in objects:
                writer.writerow([obj, 'object'])
        print(f"Saved {len(objects)} objects to {filename}")

    def load_from_csv(self, filename: str) -> List[str]:
        """Load objects from CSV file"""
        objects = []
        try:
            with open(filename, 'r', encoding='utf-8') as f:
                reader = csv.reader(f)
                next(reader)  # Skip header
                for row in reader:
                    if row:
                        objects.append(row[0])
        except FileNotFoundError:
            print(f"File {filename} not found")
        return objects

# Usage Example
if __name__ == "__main__":
    collector = RandomObjectsCollector()

    # Collect 10k random objects/words
    random_objects = collector.collect_all_objects(target_count=6615)



    # Show some examples
    print(f"\nFirst 20 objects: {random_objects[:20]}")
    print(f"Random sample: {random.sample(random_objects, 10)}")

    print(f"\nSuccessfully collected {len(random_objects)} random objects!")

Collecting random objects and words...
- Collecting common objects...
- Collecting animals...
- Collecting food items...
- Collecting abstract concepts...
- Collecting nature elements...
- Collecting adjectives...
- Collecting verbs...
- Collecting colors and materials...
Downloading WordNet...


[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


Collected 5000 nouns from WordNet
- Generating compound words...
Collected 6615 unique objects/words

First 20 objects: ['oil painting', 'doodly-squat', 'nonsense verse', 'woodgrain', '1920s', 'downhill', 'sublet', 'stringybark', 'cankerweed', 'fine-tooth comb', 'strap fern', 'ortolan', 'inheritress', 'indirect tax', 'swung dash', 'common nardoo', 'conveyor belt', 'military uniform', 'white admiral', 'palma christi']
Random sample: ['combretum', 'lanugo', 'fleck', 'floozie', 'new bag', 'blue bag', 'pink fivecorner', 'big bag', 'blunt trauma', 'blue']

Successfully collected 6615 random objects!


In [6]:
random_objects

['oil painting',
 'doodly-squat',
 'nonsense verse',
 'woodgrain',
 '1920s',
 'downhill',
 'sublet',
 'stringybark',
 'cankerweed',
 'fine-tooth comb',
 'strap fern',
 'ortolan',
 'inheritress',
 'indirect tax',
 'swung dash',
 'common nardoo',
 'conveyor belt',
 'military uniform',
 'white admiral',
 'palma christi',
 'detention',
 'zona',
 'laugher',
 'winter',
 'waiting line',
 'stainless steel',
 'vitalness',
 'bimillennium',
 'phalsa',
 'flexible joint',
 'bad',
 'cairn',
 'genus erwinia',
 'atropine',
 'police officer',
 'tie-in',
 'stun baton',
 'black raspberry',
 'procurement',
 'japan',
 'auxiliary storage',
 'family cicadidae',
 'applied mathematics',
 'live-and-die',
 'mizzen',
 'boiling',
 'genus carlina',
 'cyanite',
 'matzah',
 'soft',
 'lightship',
 'shaving',
 'plenipotentiary',
 'nesting place',
 'hillbilly',
 'fudge sauce',
 'oviraptorid',
 'dejeuner',
 'antagonist',
 'fluorocarbon plastic',
 'foreland',
 'forehand shot',
 'atavism',
 'shirttail',
 'power loading',
 

In [54]:
k = min(len(people_names), len(random_objects))
random.seed(42)
random.shuffle(people_names)
random.shuffle(random_objects)

all_words = people_names[:k] + random_objects[:k]       # 2k total
is_person = [1]*k + [0]*k                               # ground-truth for eval

# ------------------------------------------------
# 3 · word-level train / test split  (no labels used in training)
# ------------------------------------------------
words_tr, words_te, y_tr, y_te = train_test_split(
    all_words, is_person, test_size=0.4, shuffle=True,
    stratify=is_person, random_state=42)

# ------------------------------------------------
# 4 · build UNSUPERVISED contrast pairs
#     each word gets (affirmative, negative)
# ------------------------------------------------
def make_pair(w):
    return (f"{w} is a person",        # x⁺
            f"{w} is not a person")    # x⁻

x0_tr_sent, x1_tr_sent = zip(*(make_pair(w) for w in words_tr))
x0_te_sent, x1_te_sent = zip(*(make_pair(w) for w in words_te))

# ------------------------------------------------
# 5 · embed sentences  (local SBERT or API)
# ------------------------------------------------
from sentence_transformers import SentenceTransformer
device = "cuda" if torch.cuda.is_available() else "cpu"
model  = SentenceTransformer("BAAI/bge-base-en-v1.5", device=device)

def embed(txts, batch=128, desc="embed"):
    return model.encode(txts, batch_size=batch,
                        convert_to_numpy=True,
                        show_progress_bar=True)

x0_train = embed(x0_tr_sent, desc="train x⁺")
x1_train = embed(x1_tr_sent, desc="train x⁻")
x0_test  = embed(x0_te_sent, desc="test  x⁺")
x1_test  = embed(x1_te_sent, desc="test  x⁻")

print("Embeddings ready → shapes",
      x0_train.shape, x1_train.shape, x0_test.shape, x1_test.shape)

# save the auxiliary test labels for pairwise accuracy later
test_words   = list(words_te)      # words in the same order as x0_te / x1_te
test_is_person = y_te             # 1 for names, 0 for objects

Batches:   0%|          | 0/63 [00:00<?, ?it/s]

Batches:   0%|          | 0/63 [00:00<?, ?it/s]

Batches:   0%|          | 0/42 [00:00<?, ?it/s]

Batches:   0%|          | 0/42 [00:00<?, ?it/s]

Embeddings ready → shapes (7938, 768) (7938, 768) (5292, 768) (5292, 768)


In [55]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [56]:
class MLPProbe(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.linear1 = nn.Linear(d, 100)
        self.linear2 = nn.Linear(100, 1)

    def forward(self, x):
        h = F.relu(self.linear1(x))
        o = self.linear2(h)
        return torch.sigmoid(o)

class CCS(object):
    def __init__(self, x0, x1, nepochs=1000, ntries=10, lr=1e-3, batch_size=-1,
                 verbose=False, device="cuda", linear=True, weight_decay=0.01, var_normalize=False):
        # data
        self.var_normalize = var_normalize
        self.x0 = self.normalize(x0)
        self.x1 = self.normalize(x1)
        self.d = self.x0.shape[-1]

        # training
        self.nepochs = nepochs
        self.ntries = ntries
        self.lr = lr
        self.verbose = verbose
        self.device = device
        self.batch_size = batch_size
        self.weight_decay = weight_decay

        # probe
        self.linear = linear
        self.initialize_probe()
        self.best_probe = copy.deepcopy(self.probe)


    def initialize_probe(self):
        if self.linear:
            self.probe = nn.Sequential(nn.Linear(self.d, 1), nn.Sigmoid())
        else:
            self.probe = MLPProbe(self.d)
        self.probe.to(self.device)


    def normalize(self, x):
        """
        Mean-normalizes the data x (of shape (n, d))
        If self.var_normalize, also divides by the standard deviation
        """
        normalized_x = x - x.mean(axis=0, keepdims=True)
        if self.var_normalize:
            normalized_x /= normalized_x.std(axis=0, keepdims=True)

        return normalized_x


    def get_tensor_data(self):
        """
        Returns x0, x1 as appropriate tensors (rather than np arrays)
        """
        x0 = torch.tensor(self.x0, dtype=torch.float, requires_grad=False, device=self.device)
        x1 = torch.tensor(self.x1, dtype=torch.float, requires_grad=False, device=self.device)
        return x0, x1


    def get_loss(self, p0, p1):
        """
        Returns the CCS loss for two probabilities each of shape (n,1) or (n,)
        """
        informative_loss = (torch.min(p0, p1)**2).mean(0)
        consistent_loss = ((p0 - (1-p1))**2).mean(0)
        return informative_loss + consistent_loss


    def get_acc(self, x0_test, x1_test, y_test):
        """
        Computes accuracy for the current parameters on the given test inputs
        """
        x0 = torch.tensor(self.normalize(x0_test), dtype=torch.float, requires_grad=False, device=self.device)
        x1 = torch.tensor(self.normalize(x1_test), dtype=torch.float, requires_grad=False, device=self.device)
        with torch.no_grad():
            p0, p1 = self.best_probe(x0), self.best_probe(x1)
        avg_confidence = 0.5*(p0 + (1-p1))
        predictions = (avg_confidence.detach().cpu().numpy() < 0.5).astype(int)[:, 0]
        acc = (predictions == y_test).mean()
        acc = max(acc, 1 - acc)

        return acc


    def train(self):
        """
        Does a single training run of nepochs epochs
        """
        x0, x1 = self.get_tensor_data()
        permutation = torch.randperm(len(x0))
        x0, x1 = x0[permutation], x1[permutation]

        # set up optimizer
        optimizer = torch.optim.AdamW(self.probe.parameters(), lr=self.lr, weight_decay=self.weight_decay)

        batch_size = len(x0) if self.batch_size == -1 else self.batch_size
        nbatches = len(x0) // batch_size

        # Start training (full batch)
        for epoch in range(self.nepochs):
            for j in range(nbatches):
                x0_batch = x0[j*batch_size:(j+1)*batch_size]
                x1_batch = x1[j*batch_size:(j+1)*batch_size]

                # probe
                p0, p1 = self.probe(x0_batch), self.probe(x1_batch)

                # get the corresponding loss
                loss = self.get_loss(p0, p1)

                # update the parameters
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        return loss.detach().cpu().item()

    def repeated_train(self):
        best_loss = np.inf
        for train_num in range(self.ntries):
            self.initialize_probe()
            loss = self.train()
            if loss < best_loss:
                self.best_probe = copy.deepcopy(self.probe)
                best_loss = loss

        return best_loss

In [57]:
import copy
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
import pandas as pd

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
ccs = CCS(
    x0=x0_train,
    x1=x1_train,
    nepochs=300,
    ntries=3,
    lr=5e-4,
    batch_size=-1,     # full-batch for speed
    device=device,
    linear=True,
    var_normalize=True
)
ccs.repeated_train()           # ← trains and stores best probe


In [58]:
with torch.no_grad():
    # probe confidences for the two sentences of every test pair
    z0 = torch.tensor(ccs.normalize(x0_test), device=device, dtype=torch.float)
    z1 = torch.tensor(ccs.normalize(x1_test), device=device, dtype=torch.float)

    p0 = ccs.best_probe(z0).cpu().numpy().ravel()   # prob “person” for x⁺
    p1 = ccs.best_probe(z1).cpu().numpy().ravel()   # prob “person” for x⁻

# probe’s raw decision for each word: 1 if it trusts the affirmative sentence
pred_raw = (p0 > p1).astype(int)          # shape (n,)

# ground-truth labels for the same word order
gold = np.array(test_is_person)           # 1 = name, 0 = object

# evaluate **both** global orientations
acc_pos = (pred_raw == gold).mean()       # assume affirmative ⇒ “person”
acc_neg = (1 - pred_raw == gold).mean()   # assume negative    ⇒ “person”

pair_acc = max(acc_pos, acc_neg)
print(f"pairwise accuracy : {pair_acc*100:.2f}%  " +
      f"(picked {'affirmative' if acc_pos>=acc_neg else 'negative'} as true)")


pairwise accuracy : 90.78%  (picked affirmative as true)


In [59]:
import numpy as np
from sklearn.metrics import confusion_matrix, classification_report

# ── 1 · evaluate both orientations ───────────────────────────
pred_raw = (p0 > p1).astype(int)          # 1 if probe trusts affirmative
acc_pos  = (pred_raw       == gold).mean()
acc_neg  = ((1 - pred_raw) == gold).mean()

# choose the better polarity
flip        = 0 if acc_pos >= acc_neg else 1       # 0 → affirmative = person
pair_acc    = max(acc_pos, acc_neg)
pred_final  = pred_raw ^ flip                      # XOR applies the flip

print(f"\033[1mPairwise accuracy\033[0m : {pair_acc*100:.2f}%  "
      f"(affirmative sentence interpreted as "
      f"{'person' if flip==0 else 'not-person'})")

# ── 2 · confusion matrix & report ────────────────────────────
cm = confusion_matrix(gold, pred_final, labels=[1,0])
print("\nConfusion matrix [[TP, FN], [FP, TN]]:\n", cm)

print("\n" + classification_report(gold, pred_final,
                                   target_names=["object","person"],
                                   digits=3))

# ── 3 · print 20 sample predictions ──────────────────────────
print("\nSample predictions:")
for w, g, pr, paff, pneg in zip(test_words[:20], gold[:20],
                                pred_final[:20], p0[:20], p1[:20]):
    ok = "✅" if g==pr else "❌"
    lbl = "person" if pr else "object"
    print(f"{ok} {w:<25} → {lbl:<7}  "
          f"(p_aff={paff:.3f} | p_neg={pneg:.3f})")


[1mPairwise accuracy[0m : 90.78%  (affirmative sentence interpreted as person)

Confusion matrix [[TP, FN], [FP, TN]]:
 [[2429  217]
 [ 271 2375]]

              precision    recall  f1-score   support

      object      0.916     0.898     0.907      2646
      person      0.900     0.918     0.909      2646

    accuracy                          0.908      5292
   macro avg      0.908     0.908     0.908      5292
weighted avg      0.908     0.908     0.908      5292


Sample predictions:
❌ DJ A-Tron                 → object   (p_aff=0.599 | p_neg=0.692)
✅ bright house              → object   (p_aff=0.220 | p_neg=0.339)
✅ faith                     → object   (p_aff=0.831 | p_neg=0.898)
✅ assets                    → object   (p_aff=0.283 | p_neg=0.544)
✅ Nabil Aankour             → person   (p_aff=0.439 | p_neg=0.223)
✅ Arthur Zajonc             → person   (p_aff=0.413 | p_neg=0.225)
✅ Huntley Fitzpatrick       → person   (p_aff=0.517 | p_neg=0.240)
✅ Lucy Furman               → per