In [None]:
%%capture
!pip install openmined_psi

In [None]:
import syft as sy
duet = sy.join_duet(loopback=True)

In [None]:
import openmined_psi as psi

In [None]:
class PsiClientDuet:
    def __init__(self, duet, timeout_secs=-1):
        self.duet = duet
        
        # get the reveal intersection flag and create a client
        reveal_intersection_ptr = self.duet.store["reveal_intersection"]
        reveal_intersection = reveal_intersection_ptr.get(
            request_block=True,
            name="reveal_intersection",
            reason="Are we revealing or not?",
            timeout_secs=timeout_secs,
            delete_obj=True
        )
        self.reveal_intersection = reveal_intersection
        self.client = psi.client.CreateWithNewKey(reveal_intersection)
        
        # get the ServerSetup message
        setup_ptr = self.duet.store["setup"]
        self.setup = setup_ptr.get(
            request_block=True,
            name="setup",
            reason="To get the server setup",
            timeout_secs=timeout_secs,
            delete_obj=True
        )

    def intersect(self, client_items, timeout_secs=-1):
        # send the client request to the server
        self.duet.requests.add_handler(
            name="request",
            action="accept"
        )
        request = self.client.CreateRequest(client_items)
        request_ptr = request.tag("request").send(self.duet, pointable = True)
        
        # block until a response is received from the server
        while True:
            try:
                self.duet.store["response"]
            except:
                continue
            
            break
        
        # get the response from the server
        response_ptr = self.duet.store["response"]
        response = response_ptr.get(
            request_block=True,
            name="response",
            reason="To get the response",
            timeout_secs=timeout_secs,
            delete_obj=True
        )
        
        # calculate the intersection
        if self.reveal_intersection:
            return self.client.GetIntersection(self.setup, response)
        else:
            return self.client.GetIntersectionSize(self.setup, response)

In [None]:
client_items = ["Element " + str(i) for i in range(1000)]

In [None]:
client = PsiClientDuet(duet)
intersection = client.intersect(client_items)

In [None]:
if client.reveal_intersection:
    iset = set(intersection)
    for idx in range(len(client_items)):
        if idx % 2 == 0:
            assert idx in iset
        else:
            assert idx not in iset

In [None]:
if not client.reveal_intersection:
    assert intersection >= (len(client_items) / 2.0)
    assert intersection <= (1.1 * len(client_items) / 2.0)