In [1]:
import sys
import os
from tqdm import tqdm
import random
import numpy as np
import math
# Add the parent directory to the sys.path
parent_dir = os.path.abspath(os.path.join(os.getcwd(), os.pardir))

sys.path.append(parent_dir)

#import simtrain
from simtrain.sim_models_new import User_simmulation_Model
from simtrain import SETTINGS_POLIMI as SETTINGS, process_dat

import torch
import torch.nn as nn

import ast

import paths
from os.path import join

In [2]:
num_items = 7
num_items_per_recom = 2
num_interaction_types = 2
recom_dim = 1
num_users = 11
min_inter = 2
max_inter = 4
state_size = SETTINGS.STATE_SIZE

In [3]:
from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):
    def __init__(self, data):
        """
        Args:
            data (list of dicts): Each dict contains 'timestamps', 'items', and 'labels' for a user.
        """
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        user_data = self.data[idx]
        timestamps = user_data['timestamps']
        items =  user_data['item_ids']
        labels = user_data['interaction_types']
        return timestamps, items, labels, user_data["user_means"], user_data["user_vars_log"], idx
    
    def Update_user_params(self, means_list, logvar_list, idx_list):
        #means_list.requires_grad = True
        #logvar_list.requires_grad = True
        
        #print(means_list)
        
        self.data[idx_list[0]]["user_means"] = means_list.tolist()
        self.data[idx_list[0]]["user_vars_log"] = logvar_list.tolist()
        #for means, logvar, idx in zip(means_list, logvar_list, idx_list):
        #    self.data[idx]["user_means"] = means
        #    self.data[idx]["user_vars_log"] = logvar


In [4]:
# generate random data
data = []
for user in range(num_users):
    num_interactions_now = random.randint(a=min_inter, b=max_inter)
    new = {
        'item_ids': torch.randint(low=1, high=num_items, size=(num_interactions_now, 
        num_items_per_recom, recom_dim)).to(torch.float32),
        'timestamps': torch.sort(torch.FloatTensor(num_interactions_now).uniform_(0, 1.))[0].to(torch.float32),
        'interaction_types': torch.randint(low=0, high=num_interaction_types-1, 
        size=(num_interactions_now, num_items_per_recom)),
        "user_means": torch.randn((state_size), requires_grad=True).to(torch.float32),
        "user_vars_log": torch.randn((state_size), requires_grad=True).to(torch.float32),
    }
    data.append(new)


# Create the dataset
dataset = CustomDataset(data)

# Example usage with DataLoader
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)# can only do batchsize 1

In [5]:
train_dat, stg = process_dat.load_dat(paths.cw_stages['output_new']['train'], new_data=True)

print(stg)

def convert_string_to_double_list(s):
    return ast.literal_eval(s)


# Apply the custom function
train_dat['item_ids'] = train_dat['item_ids'].apply(convert_string_to_double_list)
train_dat['user_means'] = train_dat['user_means'].apply(convert_string_to_double_list)
train_dat['user_vars_log'] = train_dat['user_vars_log'].apply(convert_string_to_double_list)
train_dat['timestamps'] = train_dat['timestamps'].apply(convert_string_to_double_list)
train_dat['interaction_types'] = train_dat['interaction_types'].apply(convert_string_to_double_list)

print("len: ", len(train_dat))
train_dat.head()

{'NI': 14, 'NU': 328, 'T': '[9.708333333333334, 12.4375, 12.604166666666666, 13.395833333333334, 15.458333333333334, 15.979166666666666, 17.333333333333332, 18.6875, 20.58333333333333, 21.354166666666668, 21.375, 21.39583333333333, 21.416666666666668, 21.4375, 21.45833333333333, 21.479166666666668, 22.52083333333333, 24.39583333333333, 24.416666666666668, 24.4375, 24.45833333333333, 24.479166666666668, 24.5, 24.52083333333333, 25.479166666666668, 25.5, 25.52083333333333, 25.541666666666668, 25.58333333333333, 25.64583333333333, 29.33333333333333, 30.14583333333333, 30.58333333333333, 30.604166666666668, 30.625, 30.979166666666668, 32.0, 34.354166666666664, 36.66666666666666, 36.6875, 37.5625, 37.583333333333336, 38.583333333333336, 43.10416666666666, 43.333333333333336]', 'NS': 100, 'INF_TIME': 1000}
len:  328


Unnamed: 0,user_id,user_means,user_vars_log,item_ids,timestamps,interaction_types
0,188,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[[119, 74, 263, 144, 261, 53, 217, 194, 178, 2...","[5.416666666666667, 12.25, 13.645833333333334,...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
1,491,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[[144, 279, 79, 84, 74, 247, 162, 165, 161, 13...","[54.35416666666666, 57.270833333333336, 57.354...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, ..."
2,561,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[[106, 125, 158, 27, 269, 264, 110, 50, 19, 16...","[20.58333333333333, 26.64583333333333, 28.5625...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
3,670,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[[165, 157, 187, 155, 95, 202, 99, 237, 288, 5...","[20.33333333333333, 26.39583333333333, 35.3125...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
4,749,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]","[[161, 199, 279, 12, 37, 84, 74, 132, 161, 284...","[5.083333333333333, 6.395833333333333, 11.8125...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."


In [6]:
list_of_dicts = train_dat.to_dict(orient='records')


In [7]:
checkpoint = torch.load(join(paths.dat, SETTINGS.filepaths_new['optimized_data']))
list_of_dicts = checkpoint['data']

In [8]:
list_of_dicts[0]

{'user_id': 188,
 'user_means': [0.02557029388844967,
  -0.12405752390623093,
  -0.01584812067449093,
  0.18346340954303741,
  -0.000515084364451468,
  0.001897393842227757,
  -0.05808410048484802,
  -0.19148965179920197],
 'user_vars_log': [-0.030332816764712334,
  0.22534839808940887,
  0.022230209782719612,
  0.1548035591840744,
  -0.0002639472368173301,
  0.0014818812487646937,
  0.04600786417722702,
  -0.10826224088668823],
 'item_ids': [[119,
   74,
   263,
   144,
   261,
   53,
   217,
   194,
   178,
   230,
   243,
   284,
   225,
   177,
   288,
   162,
   202,
   66,
   155,
   22,
   265,
   95,
   193,
   202,
   261,
   297,
   229,
   205,
   203,
   271,
   279,
   247,
   79,
   84,
   162,
   284,
   106,
   230,
   178,
   177,
   55,
   157,
   53,
   155,
   192,
   45,
   81,
   252,
   132,
   161,
   19,
   138,
   249,
   153,
   134,
   201,
   55,
   260,
   217,
   35,
   42,
   247,
   138,
   144,
   290,
   53,
   258,
   194,
   124,
   54,
   5,
   244

In [9]:

dataset = CustomDataset(list_of_dicts[:50])
# Example usage with DataLoader
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)# can only do batchsize 1

In [10]:
def test_timestamps(dataloader):
    smallest = float("inf")
    biggest = -1
    for batch in dataloader:
        timestamps, items, labels, means, var, idx = batch
        last = timestamps[0]
        smallest = min(smallest, last)
        biggest = max(biggest, timestamps[-1])
        for i in range(1,len(timestamps)):
            if timestamps[i] <= last:
                print("error, current: ", timestamps[i], "\tlast", last)
    print(smallest), print(biggest)
    return biggest

max_time = test_timestamps(dataloader)

tensor([0.0417], dtype=torch.float64)
tensor([69.3333], dtype=torch.float64)


In [11]:
for batch in dataloader:
    timestamps, items, labels, means, var, idx = batch
    print('Timestamps:', timestamps#, "\n dtype: ", timestamps.dtype
          )
    print('item_recom:', items#, "\n dtype: ", items.dtype
          )
    print('Labels:', labels#, "\n dtype: ", labels.dtype
          )
    print('means:', means#, "\n dtype: ", means.dtype
          )
    print('log_var:', var#, "\n dtype: ", var.dtype
          )
    break
    

Timestamps: [tensor([1.5000], dtype=torch.float64), tensor([1.6042], dtype=torch.float64), tensor([1.9792], dtype=torch.float64), tensor([51.2083], dtype=torch.float64), tensor([51.2500], dtype=torch.float64), tensor([51.3125], dtype=torch.float64), tensor([51.3333], dtype=torch.float64), tensor([52.4792], dtype=torch.float64), tensor([52.5417], dtype=torch.float64), tensor([52.6875], dtype=torch.float64), tensor([52.7292], dtype=torch.float64), tensor([56.3333], dtype=torch.float64), tensor([56.3958], dtype=torch.float64), tensor([56.4583], dtype=torch.float64), tensor([56.5000], dtype=torch.float64), tensor([56.5625], dtype=torch.float64), tensor([56.6042], dtype=torch.float64), tensor([56.6250], dtype=torch.float64), tensor([56.6875], dtype=torch.float64), tensor([58.1875], dtype=torch.float64), tensor([58.2500], dtype=torch.float64), tensor([58.2917], dtype=torch.float64), tensor([58.4167], dtype=torch.float64), tensor([66.2500], dtype=torch.float64), tensor([66.2917], dtype=torch.

In [12]:
def train_1_path(model, user_state, timestamps, items, labels, loss_func, num_classes, max_time, device, teacher_forcing=True):
    ''' expects batchsize of 1
    '''
    model.init_state(user_state)
    loss_base = 0.
    loss_intensity = 0.
    curr_time = 1e-10# because time 0 is used sadly
    N = len(timestamps)
    max_div_by_N = max_time/N
    for interaction_id in range(N):#[0]
        h = (timestamps[interaction_id] - curr_time).float()
        
        intensity = model.eval_intensity(h)#loss on that
        
        #try:
        model.evolve_state(h)
        #except Exception as e:
        #    print("delta: ",h , "\tnew: ", timestamps[interaction_id], "\told: ", curr_time)
        #    print(e)

        curr_time = h
        # no intensity for now
        y_pred = model.view_recommendations(items[interaction_id])# :,
        y_true = torch.as_tensor(labels[interaction_id])# :,
        y_true_onehot = nn.functional.one_hot(y_true, num_classes=num_classes).float()
        
        if teacher_forcing:
            model.jump(y_true_onehot)
        else:
            model.jump(y_pred)
        #loss += loss_func(y_true, y_pred)# for mse
        y_true = y_true.squeeze(0)
        y_pred = y_pred.squeeze(0)
        y_pred = nn.functional.log_softmax(y_pred)
        #print("true: ",y_true.shape, "\t predicted: ",y_pred.shape)
        #print(torch.unique(y_true))
        loss_base += loss_func(y_pred, y_true) # NLLL
        loss_intensity += -torch.log(intensity) + max_div_by_N*intensity
    
    return loss_base, loss_intensity


In [13]:
# parameter dicts
width= 10
user_state_dict = {"model_hyp": {"layer_width": [width, width, width]}}
intensity_state_dict = {"model_hyp": {"user_model_hyp": {"layer_width": [width, width, 3]},
                                          "global_model_hyp": {"layer_width": [width, 3]}}
                            }
interaction_state_dict = {"model_hyp": {"layer_width": [width, width ,width]}
                            }
jump_state_dict = {"model_hyp": {"layer_width": [width, width]}
                        }

In [14]:
import torch.optim as optim

hyperparameter_dict = {"state_size": state_size, "state_model": user_state_dict, "num_interaction_outcomes": num_interaction_types,
                           "intensity_model": intensity_state_dict, "num_recom" : num_items_per_recom,
                            "recom_dim":recom_dim, "interaction_model": interaction_state_dict,
                            "jump_model": jump_state_dict}
model = User_simmulation_Model(hyperparameter_dict)

In [15]:
# load model
model.load_state_dict(torch.load(join(paths.dat, SETTINGS.filepaths_new['user_model'])))

<All keys matched successfully>

In [16]:
#loss_func = nn.functional.mse_loss
loss_func = nn.NLLLoss()

def kl_divergence(mu1, sigma1, mu2, sigma2):
    """
    Compute the KL divergence between two normal distributions N(mu1, sigma1^2) and N(mu2, sigma2^2).

    Args:
        mu1 (Tensor): Mean of the first distribution.
        sigma1 (Tensor): Standard deviation of the first distribution.
        mu2 (Tensor): Mean of the second distribution.
        sigma2 (Tensor): Standard deviation of the second distribution.

    Returns:
        Tensor: KL divergence.
    """
    kl_div = torch.log(sigma2 / sigma1) + ((sigma1 ** 2 + (mu1 - mu2) ** 2) / (2 * sigma2 ** 2)) - 0.5
    return kl_div

def kl_divergence_to_standard_normal(mu, sigma):
    """
    Compute the KL divergence from a normal distribution N(mu, sigma^2) to the standard normal distribution N(0, 1).

    Args:
        mu (Tensor): Mean of the normal distribution.
        sigma (Tensor): Standard deviation of the normal distribution.

    Returns:
        Tensor: KL divergence.
    """
    sigma2 = sigma ** 2
    kl_div = 0.5 * (sigma2 + mu ** 2 - torch.log(sigma2) - 1)
    return kl_div

def kl_loss(mu, sigma):
    return kl_divergence(mu, sigma, 0, 0.1)

In [17]:
def print_user_params(datadataloader, print_var = False, num_examples=10):
    i = 0
    for batch in dataloader:
        timestamps, item_recom, labels, means, logvar, idx = batch
        print("means: ", means)
        if (print_var):
            print("logvar: ", logvar)
        i+=1
        if i >= num_examples:
            return

In [18]:
def train(model, device, dataloader,num_epochs, state_size, loss_func, loss_func_kl, optimizer, num_classes, 
            logger, max_time, kl_weight = 1, user_lr = None, log_step_size = 1):
    model.to(device)
    for epoch in tqdm(range(num_epochs)):  # Example: Number of epochs
        loss_all, loss_base, loss_kl, loss_intensity = 0, 0, 0, 0
        #print_user_params(dataloader)# see if values change
        for batch in dataloader:
            # Zero the gradients
            optimizer.zero_grad()
            
            timestamps, item_recom, labels, means, logvar, idx = batch
            timestamps, means, logvar = torch.as_tensor(timestamps).to(device).float(), \
                 torch.as_tensor(means).to(device).float(), torch.as_tensor(logvar).to(device).float()  #  item_recom, labels, = item_recom.to(device), labels.to(device),
            
            #timestamps, item_recom, labels, means, logvar = torch.as_tensor(timestamps).to(device), \
            #    torch.as_tensor(item_recom).to(device), torch.as_tensor(labels).to(device), \
            #    torch.as_tensor(means).to(device), torch.as_tensor(logvar).to(device)
            means.requires_grad = True
            logvar.requires_grad = True
            #means.retain_grad()
            #logvar.retain_grad()

            variances = torch.exp(logvar)
            user_state = means + variances*torch.randn((1, state_size))
            #delta_from_previous = torch.cat([torch.zeros((timestamps.size(0),1)), timestamps[:,1:] - timestamps[:,:-1]], dim=1)
            
            curr_loss_base, curr_loss_intensity = train_1_path(model=model, user_state=user_state, timestamps=timestamps, items=item_recom, labels=labels,  
                         loss_func=loss_func, max_time=max_time, num_classes=num_classes, device=device)
            curr_loss_kl = kl_weight * torch.sum(loss_func_kl(means, variances))#.view(1,-1)
            
            curr_loss_all = curr_loss_kl + curr_loss_base+curr_loss_intensity
            curr_loss_all.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            #logging
            loss_all += curr_loss_all.item()
            loss_base += curr_loss_base.item()
            loss_kl += curr_loss_kl.item()
            loss_intensity += curr_loss_intensity.item()
            # maybe need to optim user_mean, user_var separate because of torch..
            if user_lr:
                torch.nn.utils.clip_grad_norm_(means, max_norm=1.0)
                torch.nn.utils.clip_grad_norm_(logvar, max_norm=1.0)
                with torch.no_grad():
                    means -= user_lr * means.grad
                    logvar -= user_lr * logvar.grad  
                #print(means)
                means.grad.zero_()
                logvar.grad.zero_()
                dataloader.dataset.Update_user_params(means.detach(), logvar.detach(), idx)
        if epoch % log_step_size == 0:
            logger(loss_all, loss_base, loss_kl, loss_intensity)
    logger(loss_all, loss_base, loss_kl, loss_intensity)# log at the end


In [19]:
def logging_func(loss_all, loss_base, loss_kl, loss_intensity):
    print("loss_all: ", loss_all, "\tloss_base: ",loss_base, "\tloss_kl: ",loss_kl, "\tloss_intensity: ",loss_intensity, "\tlog of the loss: ", math.log(loss_all))

In [20]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

optimizer = optim.SGD(model.parameters(), lr=0.003)

Using device: cpu


In [21]:
print_user_params(dataloader, print_var = True)

means:  [tensor([0.], dtype=torch.float64), tensor([0.], dtype=torch.float64), tensor([0.], dtype=torch.float64), tensor([0.], dtype=torch.float64), tensor([0.], dtype=torch.float64), tensor([0.], dtype=torch.float64), tensor([0.], dtype=torch.float64), tensor([0.], dtype=torch.float64)]
logvar:  [tensor([0.], dtype=torch.float64), tensor([0.], dtype=torch.float64), tensor([0.], dtype=torch.float64), tensor([0.], dtype=torch.float64), tensor([0.], dtype=torch.float64), tensor([0.], dtype=torch.float64), tensor([0.], dtype=torch.float64), tensor([0.], dtype=torch.float64)]
means:  [tensor([0.1249], dtype=torch.float64), tensor([-0.1292], dtype=torch.float64), tensor([-0.0440], dtype=torch.float64), tensor([0.1998], dtype=torch.float64), tensor([0.0587], dtype=torch.float64), tensor([0.0197], dtype=torch.float64), tensor([-0.0262], dtype=torch.float64), tensor([-0.1064], dtype=torch.float64)]
logvar:  [tensor([-0.1422], dtype=torch.float64), tensor([-0.1078], dtype=torch.float64), tensor

In [22]:
train(model, dataloader=dataloader, num_epochs=2, device=device, loss_func=loss_func, 
                loss_func_kl=kl_divergence_to_standard_normal, kl_weight=1., user_lr=0.3,
                optimizer=optimizer, num_classes=num_interaction_types, logger=logging_func,
                state_size=state_size,max_time=max_time, log_step_size=1)

  0%|          | 0/2 [00:00<?, ?it/s]

42
1
42




1
61
1
65
3
66
20
66
20
66
27
66
27
88
63
102
63
121
65
121
69
122
38
4
44
4
97
8
127
18
1
18
1
19
121
59
1
65
25
71
25
71
25
107
19
46
22
47
28
48
28
68
28
70
28
87
14
2
14
2
15
2
15
2
15
4
15
5
15
5
28
5
28
8
28
9
34
9
38
9
38
9
74
9
74
31
74
32
75
32
75
32
81
32
85
82
85
82
93
82
93
82
98
82
98
82
121
82
121
82
121
83
122
108
12
26
53
26
145
26
183
27
2
9
2
9
2
9
2
9
56
9
60
34
60
34
61
34
61
35
61
35
61
35
68
35
68
35
68
35
68
35
68
36
68
36
72
36
72
36
72
36
72
36
73
36
73
40
73
40
76
40
83
44
83
44
83
44
83
45
83
45
84
45
84
45
84
45
92
45
92
46
92
46
94
46
94
46
95
46
95
46
95
47
96
47
96
47
96
48
96
48
96
48
96
48
96
48
97
49
97
49
97
49
97
49
97
49
100
49
100
49
102
50
102
50
102
50
102
50
103
55
103
55
103
55
103
55
103
56
103
56
103
65
104
65
104
65
104
65
105
65
105
66
106
66
106
66
106
66
109
66
109
66
109
67
109
67
110
78
110
82
110
82
110
82
110
84
110
84
110
84
111
84
111
84
111
85
111
85
111
85
113
85
114
85
114
85
114
85
114
106
114
106
114
2
1
7
2
8
9
8
16
11
16
37
1

 50%|█████     | 1/2 [09:21<09:21, 561.01s/it]

loss_all:  1.6469632208724544e+34 	loss_base:  2.0660024836963218e+31 	loss_kl:  5.702362060546875 	loss_intensity:  1.644897218388758e+34 	log of the loss:  78.7868262817616
158
1
158
15
158
15
2
26
5
27
9
5
1
7
165
7
165
7
169
7
169
7
181
8
182
8
182
8
182
8
187
8
187
9
213
9
19
3
19
87
19
38
4
44
4
97
8
127
2
9
2
9
2
9
2
9
56
9
60
34
60
34
61
34
61
35
61
35
61
35
68
35
68
35
68
35
68
35
68
36
68
36
72
36
72
36
72
36
72
36
73
36
73
40
73
40
76
40
83
44
83
44
83
44
83
45
83
45
84
45
84
45
84
45
92
45
92
46
92
46
94
46
94
46
95
46
95
46
95
47
96
47
96
47
96
48
96
48
96
48
96
48
96
48
97
49
97
49
97
49
97
49
97
49
100
49
100
49
102
50
102
50
102
50
102
50
103
55
103
55
103
55
103
55
103
56
103
56
103
65
104
65
104
65
104
65
105
65
105
66
106
66
106
66
106
66
109
66
109
66
109
67
109
67
110
78
110
82
110
82
110
82
110
84
110
84
110
84
111
84
111
84
111
85
111
85
111
85
113
85
114
85
114
85
114
85
114
106
114
106
114
4
1
5
11
51
15
51
21
58
21
142
25
162
31
162
47
175
174
1
215
1
224
1
23

100%|██████████| 2/2 [-2:19:11<00:00, -0.00it/s]

loss_all:  7.960807488839941e+32 	loss_base:  9.693664392413413e+29 	loss_kl:  14.185860753059387 	loss_intensity:  7.951113824447528e+32 	log of the loss:  75.75725341384234





TypeError: logging_func() missing 1 required positional argument: 'loss_intensity'

In [None]:
print_user_params(dataloader, print_var = True)

means:  [tensor([0.0004], dtype=torch.float64), tensor([-0.1009], dtype=torch.float64), tensor([0.0252], dtype=torch.float64), tensor([-0.2571], dtype=torch.float64), tensor([-0.0517], dtype=torch.float64), tensor([-0.0831], dtype=torch.float64), tensor([0.0315], dtype=torch.float64), tensor([0.0501], dtype=torch.float64)]
logvar:  [tensor([-0.0007], dtype=torch.float64), tensor([0.1677], dtype=torch.float64), tensor([0.0446], dtype=torch.float64), tensor([-0.1314], dtype=torch.float64), tensor([0.0181], dtype=torch.float64), tensor([0.1845], dtype=torch.float64), tensor([-0.0450], dtype=torch.float64), tensor([0.0787], dtype=torch.float64)]
means:  [tensor([0.0751], dtype=torch.float64), tensor([-0.1105], dtype=torch.float64), tensor([-0.0324], dtype=torch.float64), tensor([0.1352], dtype=torch.float64), tensor([0.0611], dtype=torch.float64), tensor([-0.0762], dtype=torch.float64), tensor([0.0141], dtype=torch.float64), tensor([-0.2076], dtype=torch.float64)]
logvar:  [tensor([-0.1365

In [None]:
SETTINGS.filepaths['user_model']

'saved_models_polimi/accordion/user_model.h5'

In [None]:
#save model
torch.save(model.state_dict(), join(paths.dat, SETTINGS.filepaths_new['user_model']))

In [None]:
# save data(changes during training)
torch.save({
    'data': dataloader.dataset.data,
}, join(paths.dat, SETTINGS.filepaths_new['optimized_data'])# 'saved_models_polimi/data.h5'
)
