In [13]:
import sys
import os
from tqdm import tqdm
import random
# Add the parent directory to the sys.path
parent_dir = os.path.abspath(os.path.join(os.getcwd(), os.pardir))

sys.path.append(parent_dir)

# Now you can import file1
#import simtrain
from simtrain.sim_models_new import User_simmulation_Model
import torch
import torch.nn as nn


In [14]:
num_items = 7
num_items_per_recom = 2
num_interaction_types = 2
recom_dim = 1
num_users = 11
min_inter = 2
max_inter = 4
state_size = 2

In [15]:
from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):
    def __init__(self, data):
        """
        Args:
            data (list of dicts): Each dict contains 'timestamps', 'items', and 'labels' for a user.
        """
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        user_data = self.data[idx]
        timestamps = user_data['timestamps']
        items =  user_data['item_recom']
        labels = user_data['labels']
        return timestamps, items, labels, user_data["means"], user_data["log_var"], idx
    def Update_user_params(self, means_list, logvar_list, idx_list):
        means_list.requires_grad = True
        logvar_list.requires_grad = True
        
        for means, logvar, idx in zip(means_list, logvar_list, idx_list):
            self.data[idx]["means"] = means
            self.data[idx]["log_var"] = logvar

# Example data for multiple users

data = []
for user in range(num_users):
    num_interactions_now = random.randint(a=min_inter, b=max_inter)
    new = {
        'item_recom': torch.randint(low=1, high=num_items, size=(num_interactions_now, 
        num_items_per_recom, recom_dim)).to(torch.float32),
        'timestamps': torch.sort(torch.FloatTensor(num_interactions_now).uniform_(0, 1.))[0].to(torch.float32),
        'labels': torch.randint(low=0, high=num_interaction_types-1, 
        size=(num_interactions_now, num_items_per_recom)),
        "means": torch.randn((state_size), requires_grad=True).to(torch.float32),
        "log_var": torch.randn((state_size), requires_grad=True).to(torch.float32),
    }
    data.append(new)


# Create the dataset
dataset = CustomDataset(data)

# Example usage with DataLoader
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)# can only do batchsize 1

for batch in dataloader:
    timestamps, items, labels, means, var, idx = batch
    print('Timestamps:', timestamps, "\n dtype: ", timestamps.dtype)
    print('item_recom:', items, "\n dtype: ", items.dtype)
    print('Labels:', labels, "\n dtype: ", labels.dtype)
    print('means:', means, "\n dtype: ", means.dtype)
    print('log_var:', var, "\n dtype: ", var.dtype)


Timestamps: tensor([[0.0822, 0.4975, 0.5484]]) 
 dtype:  torch.float32
item_recom: tensor([[[[4.],
          [4.]],

         [[1.],
          [3.]],

         [[2.],
          [3.]]]]) 
 dtype:  torch.float32
Labels: tensor([[[0, 0],
         [0, 0],
         [0, 0]]]) 
 dtype:  torch.int64
means: tensor([[-0.3341, -0.9852]], grad_fn=<StackBackward0>) 
 dtype:  torch.float32
log_var: tensor([[0.9213, 0.8606]], grad_fn=<StackBackward0>) 
 dtype:  torch.float32
Timestamps: tensor([[0.0483, 0.8681]]) 
 dtype:  torch.float32
item_recom: tensor([[[[3.],
          [1.]],

         [[4.],
          [2.]]]]) 
 dtype:  torch.float32
Labels: tensor([[[0, 0],
         [0, 0]]]) 
 dtype:  torch.int64
means: tensor([[-0.4827, -0.7810]], grad_fn=<StackBackward0>) 
 dtype:  torch.float32
log_var: tensor([[-0.7085,  0.9509]], grad_fn=<StackBackward0>) 
 dtype:  torch.float32
Timestamps: tensor([[0.4861, 0.4927, 0.7249]]) 
 dtype:  torch.float32
item_recom: tensor([[[[1.],
          [4.]],

         [

In [16]:
def train_1_path(model, user_state, timestamps, items, labels, loss_func, num_classes, teacher_forcing=True):
    ''' expects batchsize of 1
    '''
    model.init_state(user_state)
    loss = 0.
    curr_time = 0.
    
    for interaction_id in range(len(timestamps[0])):
        
        h = timestamps[0][interaction_id] - curr_time
        curr_time = h
        model.evolve_state(h)
        # no intensity for now
        y_pred = model.view_recommendations(items[:,interaction_id])
        y_true = labels[:,interaction_id]
        y_true_onehot = nn.functional.one_hot(y_true, num_classes=num_classes).float()
        if teacher_forcing:
            model.jump(y_true_onehot)
        else:
            model.jump(y_pred)
        #loss += loss_func(y_true, y_pred)# for mse
        y_true = y_true.squeeze(0)
        y_pred = y_pred.squeeze(0)
        #print(y_true, y_pred)
        y_pred = nn.functional.log_softmax(y_pred,dim=-1)
        loss += loss_func(y_pred, y_true) # NLLL
    
    return loss


In [17]:
# parameter dicts
width= 10
user_state_dict = {"model_hyp": {"layer_width": [width, width, width]}}
intensity_state_dict = {"model_hyp": {"user_model_hyp": {"layer_width": [width, width, 3]},
                                          "global_model_hyp": {"layer_width": [width, 3]}}
                            }
interaction_state_dict = {"model_hyp": {"layer_width": [width, width ,width]}
                            }
jump_state_dict = {"model_hyp": {"layer_width": [width, width]}
                        }

In [18]:
import torch.optim as optim

hyperparameter_dict = {"state_size": state_size, "state_model": user_state_dict, "num_interaction_outcomes": num_interaction_types,
                           "intensity_model": intensity_state_dict, "num_recom" : num_items_per_recom,
                            "recom_dim":recom_dim, "interaction_model": interaction_state_dict,
                            "jump_model": jump_state_dict}
model = User_simmulation_Model(hyperparameter_dict)


In [19]:
#loss_func = nn.functional.mse_loss
loss_func = nn.NLLLoss()

def kl_divergence(mu1, sigma1, mu2, sigma2):
    """
    Compute the KL divergence between two normal distributions N(mu1, sigma1^2) and N(mu2, sigma2^2).

    Args:
        mu1 (Tensor): Mean of the first distribution.
        sigma1 (Tensor): Standard deviation of the first distribution.
        mu2 (Tensor): Mean of the second distribution.
        sigma2 (Tensor): Standard deviation of the second distribution.

    Returns:
        Tensor: KL divergence.
    """
    kl_div = torch.log(sigma2 / sigma1) + ((sigma1 ** 2 + (mu1 - mu2) ** 2) / (2 * sigma2 ** 2)) - 0.5
    return kl_div

def kl_divergence_to_standard_normal(mu, sigma):
    """
    Compute the KL divergence from a normal distribution N(mu, sigma^2) to the standard normal distribution N(0, 1).

    Args:
        mu (Tensor): Mean of the normal distribution.
        sigma (Tensor): Standard deviation of the normal distribution.

    Returns:
        Tensor: KL divergence.
    """
    sigma2 = sigma ** 2
    kl_div = 0.5 * (sigma2 + mu ** 2 - torch.log(sigma2) - 1)
    return kl_div

def kl_loss (mu, sigma):
    return kl_divergence(mu, sigma, 0, 0.1)

In [20]:
def print_user_params(datadataloader, print_var = False):
    for batch in dataloader:
        timestamps, item_recom, labels, means, logvar, idx = batch
        print("means: ", means)
        if (print_var):
            print("logvar: ", logvar)

In [21]:
def train(model, device, dataloader,num_epochs, state_size, loss_func, loss_func_kl, optimizer, num_classes, 
            logger, kl_weight = 1, user_lr = None, log_step_size = 1):
    model.to(device)
    for epoch in tqdm(range(num_epochs)):  # Example: Number of epochs
        loss_all, loss_base, loss_kl = 0, 0, 0
        #print_user_params(dataloader)# see if values change
        for batch in dataloader:
            # Zero the gradients
            optimizer.zero_grad()
            
            timestamps, item_recom, labels, means, logvar, idx = batch
            timestamps, item_recom, labels, means, logvar = timestamps.to(device), item_recom.to(device), labels.to(device), means.to(device), logvar.to(device)
            
            means.retain_grad()
            logvar.retain_grad()

            variances = torch.exp(logvar)
            user_state = means + variances*torch.randn((1, state_size))
            #delta_from_previous = torch.cat([torch.zeros((timestamps.size(0),1)), timestamps[:,1:] - timestamps[:,:-1]], dim=1)
            
            curr_loss_base = train_1_path(model=model, user_state=user_state, timestamps=timestamps, items=item_recom, labels=labels,  
                         loss_func=loss_func,num_classes= num_classes)
            curr_loss_kl = kl_weight * torch.sum(loss_func_kl(means, variances))#.view(1,-1)

            curr_loss_all = curr_loss_kl + curr_loss_base
            curr_loss_all.backward()
            optimizer.step()

            #logging
            loss_all += curr_loss_all.item()
            loss_base += curr_loss_base.item()
            loss_kl += curr_loss_kl.item()
            # maybe need to optim user_mean, user_var separate because of torch..
            if user_lr:
                with torch.no_grad():
                    means -= user_lr * means.grad
                    logvar -= user_lr * logvar.grad  
                means.grad.zero_()
                logvar.grad.zero_()
                
                dataloader.dataset.Update_user_params(means.detach(), logvar.detach(), idx)
        if epoch % log_step_size == 0:
            logger(loss_all, loss_base, loss_kl)
    logger(loss_all, loss_base, loss_kl)# log at the end


In [22]:
def logging_func(loss_all, loss_base, loss_kl):
    print("loss_all: ", loss_all, "\tloss_base: ",loss_base, "\loss_kl: ",loss_kl)

In [23]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

optimizer = optim.SGD(model.parameters(), lr=0.03)

Using device: cpu


In [24]:
print_user_params(dataloader, print_var = True)

means:  tensor([[-1.6682,  1.0347]], grad_fn=<StackBackward0>)
logvar:  tensor([[-0.5324,  0.0204]], grad_fn=<StackBackward0>)
means:  tensor([[0.9434, 0.3297]], grad_fn=<StackBackward0>)
logvar:  tensor([[0.2264, 0.6793]], grad_fn=<StackBackward0>)
means:  tensor([[-0.4827, -0.7810]], grad_fn=<StackBackward0>)
logvar:  tensor([[-0.7085,  0.9509]], grad_fn=<StackBackward0>)
means:  tensor([[ 0.1913, -0.9979]], grad_fn=<StackBackward0>)
logvar:  tensor([[-0.3750,  0.5932]], grad_fn=<StackBackward0>)
means:  tensor([[-0.3347,  0.5144]], grad_fn=<StackBackward0>)
logvar:  tensor([[ 1.3193, -0.9432]], grad_fn=<StackBackward0>)
means:  tensor([[-0.9807,  0.4182]], grad_fn=<StackBackward0>)
logvar:  tensor([[-1.2250, -0.9406]], grad_fn=<StackBackward0>)
means:  tensor([[-1.1417,  1.6132]], grad_fn=<StackBackward0>)
logvar:  tensor([[-0.3386,  0.7236]], grad_fn=<StackBackward0>)
means:  tensor([[-0.3777,  0.8752]], grad_fn=<StackBackward0>)
logvar:  tensor([[-0.5233, -1.1403]], grad_fn=<Stack

In [25]:
train(model, dataloader=dataloader, num_epochs =50, device = device, loss_func= loss_func, 
                loss_func_kl = kl_loss, kl_weight= 0.01, user_lr = 0.03,
                optimizer=optimizer, num_classes=num_interaction_types, logger = logging_func,
                state_size = state_size, log_step_size= 5)

  2%|▏         | 1/50 [00:00<00:28,  1.75it/s]

loss_all:  56.35666036605835 	loss_base:  20.88695991039276 \loss_kl:  35.46969974040985


 12%|█▏        | 6/50 [00:04<00:32,  1.35it/s]

loss_all:  36.44455337524414 	loss_base:  20.79441547393799 \loss_kl:  15.650138258934021


 22%|██▏       | 11/50 [00:18<01:32,  2.37s/it]

loss_all:  31.621041774749756 	loss_base:  20.79441547393799 \loss_kl:  10.82662644982338


 32%|███▏      | 16/50 [00:24<00:54,  1.59s/it]

loss_all:  28.93181085586548 	loss_base:  20.79441547393799 \loss_kl:  8.137395441532135


 42%|████▏     | 21/50 [00:32<00:47,  1.64s/it]

loss_all:  27.16498589515686 	loss_base:  20.79441547393799 \loss_kl:  6.370570197701454


 52%|█████▏    | 26/50 [00:42<00:44,  1.87s/it]

loss_all:  25.92131769657135 	loss_base:  20.79441547393799 \loss_kl:  5.126902222633362


 62%|██████▏   | 31/50 [00:51<00:31,  1.66s/it]

loss_all:  25.012293100357056 	loss_base:  20.79441547393799 \loss_kl:  4.217877812683582


 72%|███████▏  | 36/50 [00:58<00:21,  1.55s/it]

loss_all:  24.331183075904846 	loss_base:  20.79441547393799 \loss_kl:  3.5367674827575684


 82%|████████▏ | 41/50 [01:09<00:20,  2.24s/it]

loss_all:  23.81112277507782 	loss_base:  20.79441547393799 \loss_kl:  3.0167068913578987


 92%|█████████▏| 46/50 [01:16<00:06,  1.53s/it]

loss_all:  23.407631278038025 	loss_base:  20.79441547393799 \loss_kl:  2.613215833902359


100%|██████████| 50/50 [01:22<00:00,  1.64s/it]

loss_all:  23.147793531417847 	loss_base:  20.79441547393799 \loss_kl:  2.353378064930439





In [26]:
print_user_params(dataloader, print_var = True)

means:  tensor([[0.2057, 0.0719]], grad_fn=<StackBackward0>)
logvar:  tensor([[-0.6439, -0.5942]], grad_fn=<StackBackward0>)
means:  tensor([[ 0.0416, -0.2172]], grad_fn=<StackBackward0>)
logvar:  tensor([[-0.8083, -0.6009]], grad_fn=<StackBackward0>)
means:  tensor([[-0.0824,  0.1908]], grad_fn=<StackBackward0>)
logvar:  tensor([[-0.8737, -1.2611]], grad_fn=<StackBackward0>)
means:  tensor([[-0.2490,  0.3518]], grad_fn=<StackBackward0>)
logvar:  tensor([[-0.7939, -0.5914]], grad_fn=<StackBackward0>)
means:  tensor([[-0.2139,  0.0912]], grad_fn=<StackBackward0>)
logvar:  tensor([[-1.3269, -1.1168]], grad_fn=<StackBackward0>)
means:  tensor([[-0.0728, -0.2148]], grad_fn=<StackBackward0>)
logvar:  tensor([[-0.5818, -0.5843]], grad_fn=<StackBackward0>)
means:  tensor([[-0.0730,  0.1122]], grad_fn=<StackBackward0>)
logvar:  tensor([[-0.5734, -1.1186]], grad_fn=<StackBackward0>)
means:  tensor([[-0.1214,  0.1472]], grad_fn=<StackBackward0>)
logvar:  tensor([[-1.2145, -1.3186]], grad_fn=<Sta