## Reservoir Sampling  
Feature: after n elements, the sample contains each element seen so far with probability $\frac{s}{n}$  
* Store all the first s elements in S
* Suppose we have seen $n-1^{th}$ elements, and now the $n^{th}$ element arrives(n > s)  
  - with probability s/n, we keep the $n^{th}$ element, else discard it
  - if we keep the $n^{th}$ element, then we need to delet one element in sample S. we pick that element uniformly at random

### proof:  
* we assume that the claim "after n elements, the sample contains each element seen so far with probability $\frac{s}{n}$ " is right  
* we need to prove after seeing the $n+1^{th}$ element and using the above sampling method, the element seen so far are sampled in S with 
  probability $\frac{s}{n+1}$  
* for the $n+1^{th}$ the requirement already meets. 
* for the elements already in S，the probability that it keeps in S after seeing $n+1^{th}$ element:  
    if the $n+1^{th}$ is discarded: probability = $1 - \frac{s}{n+1}$  
    if the $n+1^{th}$ is not discarded: probability = $\frac{s}{n+1}\frac{C_{S-1}^{1}}{C_{S}^{1}}$ =$\frac{s}{n+1}\frac{s-1}{s}$ 
    so the probability is:  
$$ (1 - \frac{s}{n+1}) + \frac{s}{n+1}\frac{s-1}{s} = \frac{n}{n+1} $$  
* so prob. s in S is:
$$\frac{s}{n}\frac{n}{n+1} = \frac{s}{n+1}$$

In [None]:
import random
import numpy as np
from collections import defaultdict
import time
from IPython.display import clear_output, display

## TRIEST  
use the Reservoir Sampling method to estimate triangles in graph  
t: total number of edges seen so far  
M: number of our samples   
The probability of edge being sampled: $\frac{M}{t}$  
So the probability that a triangle(a, b, c) of $G^{(t)}$ is in $G_S$ at time t is  
$$\pi_t = \frac{C_{t-3}^{M-3}}{C_{t}^{M}} = \frac{M(M-1)(M-2)}{t(t-1)(t-2)}$$

In [None]:
class TRIEST_BASE:
    def __init__(self, M):
        self.t = 0
        self.tau = 0
        self.local_tau = {}
        self.S = []
        self.neighbors = defaultdict(set)
        
        self.M = M
    
    def SampleEdge(self, edge, t):
        if self.t <= self.M:
            return True
        else:
            coin = np.random.binomial(1, self.M/t)
            if coin:
                rm_edge = random.choice(self.S)
                self.S.remove(rm_edge)
                self.UpdateCounters('-', rm_edge)
                return True
            else:
                return False
    
    def UpdateCounters(self, flag, edge):
        u, v = edge[0], edge[1]
        self.update_neighbor(flag, edge)
        N_u_v = self.neighbors[u].intersection(self.neighbors[v]) 
        
        if flag == '+':
            for neighbor in N_u_v:
                self.tau += 1
                self.local_tau[neighbor] = self.local_tau.get(neighbor, 0) + 1
                self.local_tau[u] = self.local_tau.get(u, 0) + 1
                self.local_tau[v] = self.local_tau.get(v, 0) + 1
            
        elif flag == '-':
            for neighbor in N_u_v:
                self.tau -= 1
                self.local_tau[neighbor] -= 1
                if self.local_tau[neighbor] == 0:
                    del self.local_tau[neighbor]
                    
                self.local_tau[u] -= 1
                if self.local_tau[u] == 0:
                    del self.local_tau[u]
                    
                self.local_tau[v] -= 1
                if self.local_tau[v] == 0:
                    del self.local_tau[v]
    
    def get_estimate(self):
        xee = max(1, (self.t * (self.t - 1) * (self.t - 2))/(self.M * (self.M - 1) * (self.M - 2)))
        
        estimate = int(xee * self.tau)
        
        self.local_estimate = {}
        
        for k in self.local_tau:
            self.local_estimate[k] = int(self.local_tau[k] * xee)
        
        return estimate, self.local_estimate
    
    def triest(self, edge):
        self.t += 1
        if self.SampleEdge(edge, self.t):
            self.S.append(edge)
            self.UpdateCounters('+', edge)
    
    def update_neighbor(self, flag, edge):
        u, v = edge[0], edge[1]
        if flag == '+':
            self.neighbors[u].add(v)
            self.neighbors[v].add(u)
        elif flag == '-':
            try:
                self.neighbors[u].remove(v)
            except:
                pass
            
            if not self.neighbors[u]:
                del self.neighbors[u]
        
            try:
                self.neighbors[v].remove(u)
            except:
                pass
            
            if not self.neighbors[v]:
                del self.neighbors[v]

In [None]:
class TRIEST_IMPR:
    def __init__(self, M):
        self.t = 0
        self.tau = 0
        self.local_tau = {}
        self.S = []
        self.neighbors = defaultdict(set)
        
        self.M = M
    
    def SampleEdge(self, edge, t):
        if self.t <= self.M:
            return True
        else:
            coin = np.random.binomial(1, self.M/t)
            if coin:
                rm_edge = random.choice(self.S)
                self.S.remove(rm_edge)
                self.update_neighbor('-', rm_edge)

                return True
            else:
                return False
    
    def UpdateCounters(self, edge):
        u, v = edge[0], edge[1]
        N_u_v = self.neighbors[u].intersection(self.neighbors[v]) 
        
        ita = max(1, (self.t - 1) * (self.t - 2)/(self.M * (self.M - 1)) )
        
        for neighbor in N_u_v:
            self.tau += ita
            self.local_tau[neighbor] = self.local_tau.get(neighbor, 0) + ita
            self.local_tau[u] = self.local_tau.get(u, 0) + ita
            self.local_tau[v] = self.local_tau.get(v, 0) + ita
    
    def get_estimate(self):
        for k in self.local_tau:
            self.local_tau[k] = int(self.local_tau[k])
        return int(self.tau), self.local_tau
    
    def triest(self, edge):
        self.t += 1
        self.UpdateCounters(edge)
        if self.SampleEdge(edge, self.t):
            self.S.append(edge)
            self.update_neighbor('+', edge)
            
    
    def update_neighbor(self, flag, edge):
        u, v = edge[0], edge[1]
        if flag == '+':
            self.neighbors[u].add(v)
            self.neighbors[v].add(u)
        elif flag == '-':
            try:
                self.neighbors[u].remove(v)
            except:
                pass
            
            if not self.neighbors[u]:
                del self.neighbors[u]
        
            try:
                self.neighbors[v].remove(u)
            except:
                pass
            
            if not self.neighbors[v]:
                del self.neighbors[v]