In [2]:
import numpy as np
import torch
from typing import List, Set, Tuple
import random
from functools import partial
import json
from datetime import datetime
import os
import matplotlib.pyplot as plt

def generate_random_msp(P: int) -> List[Set[int]]:
    sets = []
    num_sets = random.randint(1, P)
    
    size = random.randint(1, min(3, P))
    sets.append(set(random.sample(range(P), size)))
    
    for _ in range(num_sets - 1):
        prev_union = set().union(*sets)
        remaining = set(range(P)) - prev_union
        
        if remaining and random.random() < 0.7:
            new_elem = random.choice(list(remaining))
            base_elems = random.sample(list(prev_union), random.randint(0, len(prev_union)))
            new_set = set(base_elems + [new_elem])
        else:
            size = random.randint(1, len(prev_union))
            new_set = set(random.sample(list(prev_union), size))
        
        sets.append(new_set)
    
    return sets

class MSPFunction:
    def __init__(self, P: int, sets: List[Set[int]]):
        self.P = P
        self.sets = sets
            
        # Verify MSP property
        for i in range(1, len(sets)):
            prev_union = set().union(*sets[:i])
            diff = sets[i] - prev_union
            if len(diff) > 1:
                raise ValueError(f"Not an MSP: Set {sets[i]} adds {len(diff)} new elements: {diff}")
    
    def evaluate(self, z: torch.Tensor) -> torch.Tensor:
        batch_size = z.shape[0]
        result = torch.zeros(batch_size, dtype=torch.float64, device=z.device)
        
        for S in self.sets:
            term = torch.ones(batch_size, dtype=torch.float64, device=z.device)
            for idx in S:
                term = term * z[:, idx]
            result = result + term
            
        return result

In [16]:
sets = generate_random_msp(10)
print(sets)

[{2, 7}, {0, 2, 7}, {0, 7}, {1}, {0, 4}, {3, 7}, {0, 1, 2, 3, 4, 7}, {5, 7}]


In [17]:
msp = MSPFunction(10, sets)
print(msp)

<__main__.MSPFunction object at 0x7f5e907d6680>


In [22]:

X_test = 2 * torch.bernoulli(0.5 * torch.ones((10, 8), dtype=torch.float64)) - 1
y_test = msp.evaluate(X_test)
print(X_test)
print(y_test)

tensor([[ 1., -1.,  1.,  1., -1., -1.,  1., -1.],
        [-1., -1., -1., -1.,  1.,  1., -1.,  1.],
        [-1., -1., -1., -1.,  1., -1., -1., -1.],
        [-1.,  1.,  1., -1.,  1., -1., -1., -1.],
        [ 1.,  1.,  1.,  1., -1., -1., -1.,  1.],
        [ 1.,  1., -1., -1.,  1., -1.,  1., -1.],
        [ 1.,  1.,  1.,  1.,  1.,  1., -1.,  1.],
        [-1.,  1.,  1., -1.,  1.,  1.,  1., -1.],
        [ 1., -1.,  1., -1.,  1.,  1., -1., -1.],
        [ 1., -1.,  1., -1., -1., -1.,  1., -1.]], dtype=torch.float64)
tensor([-6., -2.,  0.,  2.,  2.,  4.,  8.,  0., -4., -2.], dtype=torch.float64)
