In [2]:
import criteo_search2
from net import Net_embedding
import torch
from tqdm.auto import tqdm
from utils import AverageMeter

import numpy as np
import math

In [3]:
dataset = criteo_search2.CriteoSearchDataset("../data/" + 'Criteo_Search.txt')
seed = 2024

In [4]:
train_length = int(len(dataset) * 0.8)
test_length = len(dataset) - train_length
train_dataset, test_dataset = torch.utils.data.random_split(
            dataset, (train_length, test_length), generator=torch.Generator().manual_seed(seed))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('train_dataset length: ', len(train_dataset), 'test_dataset length: ', len(test_dataset), 'device: ', device)

train_dataset length:  1316781 test_dataset length:  329196 device:  cuda


In [5]:
def compute_optimal_interval2(interval_freq, node, epsilon, delta):
    
    # Step 3: RPWithPrior i.e. Algorithm 1 in paper
    k = len(interval_freq)
    fmax = 0 # max value of f
    for i in range(k):
        for j in range(i+1, k):
            h = interval_freq[i] * node[i+1] + \
                torch.sum(interval_freq[i+1:j] * (node[i+2:j+1] - node[i+1:j])) - \
                    interval_freq[j] * node[j]
            c1 = 2 * delta * interval_freq[i] - math.exp(-epsilon) *  h
            slope = math.exp(-epsilon) * (interval_freq[j] - interval_freq[i])
            
            d11 = slope * node[j] -c1
            d12 = slope * node[j+1] - c1
            
            c2 = 2 * delta * interval_freq[j] - math.exp(-epsilon) * h
            
            d21 = -slope * node[i] + c2
            d22 = -slope * node[i+1] + c2
            e1 = c1 / slope
            e2 = c2 / slope
            
            A1max = node[i]
            A2max = node[j]
            h1 = (h + interval_freq[j] * A2max - interval_freq[i] * A1max) / (
                2 * delta+math.exp(-epsilon)*(A2max - A1max) )
            if fmax < h1:
                fmax = h1
                A1 = A1max
                A2 = A2max
            
            # (n_i,n_{j+1})
            A2max = node[j+1]
            h1 = (h + interval_freq[j] * A2max - interval_freq[i] * A1max) / (
                2 * delta+math.exp(-epsilon)*(A2max - A1max) )
            if fmax < h1:
                fmax = h1
                A1 = A1max
                A2 = A2max
                
            # (n_{i+1},n_{j+1})
            A1max = node[i+1]
            h1 = (h + interval_freq[j] * A2max - interval_freq[i] * A1max) / (
                2 * delta+math.exp(-epsilon)*(A2max - A1max) )
            if fmax < h1:
                fmax = h1
                A1 = A1max
                A2 = A2max 
                
            # (n_{i+1},n_j)
            A2max = node[j]
            h1 = (h + interval_freq[j] * A2max - interval_freq[i] * A1max) / (
                2 * delta+math.exp(-epsilon)*(A2max - A1max) )
            if fmax < h1:
                fmax = h1
                A1 = A1max
                A2 = A2max
                
                
            if d21 * d22 < 0:
                # (e_2,n_j)
                A1max = e2
                A2max = node[j]
                
                h1 = (h + interval_freq[j] * A2max - interval_freq[i] * A1max) / (
                    2 * delta+math.exp(-epsilon)*(A2max - A1max) )
                if fmax < h1:
                    fmax = h1
                    A1 = A1max
                    A2 = A2max
                
                # (e_2,n_{j+1})
                A2max = node[j+1]
                
                h1 = (h + interval_freq[j] * A2max - interval_freq[i] * A1max) / (
                    2 * delta+math.exp(-epsilon)*(A2max - A1max) )
                if fmax < h1:
                    fmax = h1
                    A1 = A1max
                    A2 = A2max
            if d11 * d12 < 0:
                # (n_i,e_1)
                A1max = node[i]
                A2max = e1
                
                h1 = (h + interval_freq[j] * A2max - interval_freq[i] * A1max) / (
                    2 * delta+math.exp(-epsilon)*(A2max - A1max) )
                if fmax < h1:
                    fmax = h1
                    A1 = A1max
                    A2 = A2max
                
                # (n_{i+1}, e_1)   
                A1max = node[i+1]
                
                h1 = (h + interval_freq[j] * A2max - interval_freq[i] * A1max) / (
                    2 * delta+math.exp(-epsilon)*(A2max - A1max ) )
                if fmax < h1:
                    fmax = h1
                    A1 = A1max
                    A2 = A2max  
                    
            if  d11 * d12 < 0 and d21 * d22 < 0:  
                # (e_2,e_1)
                A1max = e2
                A2max = e1      
                h1 = (h + interval_freq[j] * A2max - interval_freq[i] * A1max) / (
                    2 * delta+math.exp(-epsilon)*(A2max - A1max ))
                if fmax < h1:
                    # fmax = h1
                    A1 = A1max
                    A2 = A2max
    return A1, A2

In [6]:
def RPWithPrior3(train_loader, device, epsilon=0.1, delta=0.1):
    # Step 1: obain the prior from prev stage model
    for i, (x, z, y) in enumerate(train_loader):
        x, z, y = x.to(device), z.to(device), y.to(device)
        tartget = y

        if i == 0:
            x_sets = x
            z_sets = z
            y_sets = y
        else:
            x_sets = torch.cat((x_sets, x), 0)
            z_sets = torch.cat((z_sets, z), 0)
            y_sets = torch.cat((y_sets, y), 0)
            
        target_sets = y_sets

    # calculate the statistics of prior
    target_mean = target_sets.mean()
    target_std = target_sets.std() 
    
    # Step 2: calculate the histogram of prior 
    # calculate the value in each interval of the histogram
    k0 = ((torch.min(target_sets) - target_mean) / target_std ).floor().int().item()
    k1 = ((torch.max(target_sets) - target_mean) / target_std ).ceil().int().item()
    k = k1 - k0

    node = torch.zeros(k+1) # node in paper x_0...x_k
    interval_freq = torch.zeros(k) # value in each interval for histogram

    # calculate the relative frequency(probability) of each interval
    for i in range(k0, k1):
        if i == k0:
            node[i-k0] = torch.min(target_sets)
        else: 
            node[i-k0] = target_mean + i * target_std
        if i < k1 - 1:
            in_range = (target_sets - target_mean >= i * target_std) & \
                    (target_sets - target_mean < (i + 1) * target_std)
        else:
            in_range = (target_sets - target_mean >= i * target_std) & \
                    (target_sets - target_mean <= (i + 1) * target_std)
        interval_freq[i-k0] = in_range.sum().item()
    node[k] = torch.max(target_sets) 
    interval_freq = interval_freq / len(target_sets)
    
    # Step 3: RPWithPrior i.e. Algorithm 1 in this paper
    A1, A2 = compute_optimal_interval2(interval_freq, node, epsilon, delta)
    while (A2 - A1 < 2 * delta):
        print('test')
        delta = (A2 - A1) / 2
        A1, A2 = compute_optimal_interval2(interval_freq, node, epsilon, delta)
    print(torch.min(y_sets),interval_freq, A1, A2, torch.max(y_sets),(y_sets<A1).sum()/len(y_sets), (y_sets>A2).sum()/len(y_sets))
    
    # Step 4: add noise to target  ##### Algorithm 2 in this paper 
    # projection by Equation (3.6)  
    y_sets1 = y_sets.clone()   
    y_sets1[y_sets1 < A1] = A1 
    y_sets1[y_sets1 > A2] = A2

    rate = 1 / (math.exp(epsilon) *2 * delta + (A2 -A1))
    
    prob1 = (y_sets1 - A1) * rate 
    prob1[prob1 < 0] = 0
    prob2 = (A2 - y_sets1) * rate 
    prob2[prob2 < 0] = 0
    prob2 = 1- prob2
    
    new_label = 2 * torch.ones(len(y_sets1), dtype= int).to(device)
    random_tensor = torch.rand(len(y_sets1)).to(device)
    new_label[random_tensor - prob1 < 0] = 1
    new_label[random_tensor - prob2 > 0] = 3
    #############################
    y_tilde = y_sets1.clone()
    
    index = new_label == 1
    y_tilde[index] = A1 - delta + torch.rand(index.sum()).to(device) * torch.max(
        y_sets1[index] - A1,torch.zeros(index.sum()).to(device))
    index = new_label == 2
    y_tilde[index] = y_sets1[index] + delta * torch.empty_like(y_sets1[index]).uniform_(-1, 1).to(device)
    index = new_label == 3
    y_tilde[index] = A2 + delta - torch.rand(index.sum()).to(device) * torch.max(
        A2 - y_sets1[index],torch.zeros(index.sum()).to(device))
    
    return x_sets, z_sets, y_sets, y_tilde

In [21]:
epsilon = 0.05
delta = 27

model = Net_embedding(vocab_size=dataset.get_vocab()).to(device)

optimizer = torch.optim.RMSprop(model.parameters(), lr=1e-3, weight_decay=1e-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50, eta_min=5e-6)

loss_func = torch.nn.MSELoss()
epoch = 50
batch_size = 8192
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, 
      shuffle=True, num_workers=8, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, 
      shuffle=True, num_workers=6, pin_memory=True)

# Train                                      
x_sets, z_sets, y_sets, y_tilde = RPWithPrior3(train_loader, device, epsilon= epsilon, delta=delta)

# Train the model with the Label-DP dataset
labeldp_dataset = torch.utils.data.TensorDataset(x_sets.detach().cpu(),
                    z_sets.detach().cpu(), y_tilde.detach().cpu())
labeldp_loader = torch.utils.data.DataLoader(labeldp_dataset,
            batch_size=batch_size, shuffle=True, num_workers=6, pin_memory=True)
for i in tqdm(range(epoch)):
    losses = AverageMeter()
    for j, (x, z, y) in enumerate(labeldp_loader):
        x, z, y = x.to(device), z.to(device), y.to(device)
        optimizer.zero_grad()
        output = model(x, z)
        loss = loss_func(output.view(-1), y)
        loss.backward()
        optimizer.step()
        losses.update(loss.item(), x.shape[0])
    lr_scheduler.step()

    if i >=40: 
        train_loss = AverageMeter()
        test_loss = AverageMeter()
        with torch.no_grad():
            for x, z, y in train_loader:
                x, z, y = x.to(device), z.to(device), y.to(device)
                output = model(x, z)
                loss = loss_func(output.view(-1), y)
                train_loss.update(loss.item(), x.shape[0])
            for x, z, y in test_loader:
                x, z, y = x.to(device), z.to(device), y.to(device)
                output = model(x, z)
                loss = loss_func(output.view(-1), y)
                test_loss.update(loss.item(), x.shape[0])
        print("Epoch: {:>2}| Train Loss: {:.2f}| Train Loss: {:.2f}| Test Loss: {:.2f} ".format(i, 
                    losses.avg, train_loss.avg, test_loss.avg))


tensor(0., device='cuda:0') tensor([0.6624, 0.1936, 0.0787, 0.0427, 0.0226]) tensor(0.) tensor(78.1321) tensor(400., device='cuda:0') tensor(0., device='cuda:0') tensor(0.3376, device='cuda:0')


  0%|          | 0/50 [00:00<?, ?it/s]

Epoch: 40| Train Loss: 1494.77| Train Loss: 8421.65| Test Loss: 8450.46 
Epoch: 41| Train Loss: 1472.66| Train Loss: 8239.61| Test Loss: 8270.34 
Epoch: 42| Train Loss: 1461.10| Train Loss: 7938.78| Test Loss: 7970.98 
Epoch: 43| Train Loss: 1452.30| Train Loss: 7968.97| Test Loss: 8001.23 
Epoch: 44| Train Loss: 1445.14| Train Loss: 8001.52| Test Loss: 8033.37 
Epoch: 45| Train Loss: 1440.34| Train Loss: 8439.04| Test Loss: 8469.53 
Epoch: 46| Train Loss: 1437.48| Train Loss: 8309.90| Test Loss: 8340.82 
Epoch: 47| Train Loss: 1433.56| Train Loss: 8181.70| Test Loss: 8211.74 
Epoch: 48| Train Loss: 1432.33| Train Loss: 8147.10| Test Loss: 8177.00 
Epoch: 49| Train Loss: 1431.48| Train Loss: 8144.66| Test Loss: 8174.73 
