# notebook for 0905 final dataset

In [1]:
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import pandas as pd
from tqdm import tqdm
from torch.utils.data import DataLoader, TensorDataset
tqdm.pandas()
import torch
from torch import nn
from torch.nn import Transformer
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
from sklearn.model_selection import train_test_split
import numpy as np

from sklearn.metrics import average_precision_score


In [2]:
import torch
print(torch.__version__)
print(torch.version.cuda)

# Check if CUDA (GPU support) is available
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using GPU:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("Using CPU")

# Now, you can use `device` to send your tensors to the GPU or CPU.


2.0.1
11.7
Using GPU: NVIDIA A100-SXM4-40GB


In [3]:
import torch

# Print available GPUs
for i in range(torch.cuda.device_count()):
    print(f"GPU {i}: {torch.cuda.get_device_name(i)}")


GPU 0: NVIDIA A100-SXM4-40GB
GPU 1: NVIDIA A100-SXM4-40GB


In [4]:
device

device(type='cuda')

In [5]:
import os
import psutil
def memory_usage():
    process = psutil.Process(os.getpid())
    total_memory = psutil.virtual_memory().total / (1024**3)  # total RAM in GB
    used_memory = total_memory - process.memory_info().rss / (1024**3)  # used RAM in GB
    available_memory = psutil.virtual_memory().available / (1024**3)  # available RAM in GB
    return f"Total Memory: {total_memory:.2f} GB | Used Memory: {used_memory:.2f} GB | Available Memory: {available_memory:.2f} GB"

print(memory_usage())

Total Memory: 1007.45 GB | Used Memory: 1007.07 GB | Available Memory: 988.72 GB


In [6]:

# Custom Dataset
class MyDataset(Dataset):
    def __init__(self, src_data, tgt_data):
        self.src_data = [torch.tensor(seq, dtype=torch.float32) for seq in src_data]
        self.tgt_data = [torch.tensor(tgt, dtype=torch.float32) for tgt in tgt_data]

    def __len__(self):
        return len(self.src_data)

    def __getitem__(self, idx):
        return self.src_data[idx], self.tgt_data[idx]

# Custom Collate function to handle variable sequence length
def collate_fn(batch):
    src, tgt = zip(*batch)
    src = pad_sequence(src, batch_first=True, padding_value=0)
    tgt = pad_sequence(tgt, batch_first=True, padding_value=-1)  # Special padding value for target data
    return src, tgt


def create_masks(src):
    src_key_padding_mask = (src == 0).all(axis=-1)
    
    # future_mask shape should be (seq_len, seq_len)
    future_mask = torch.triu(torch.ones((src.shape[1], src.shape[1])), diagonal=1).bool()

    return src_key_padding_mask, future_mask



In [7]:
# Transformer KG Model
# first KG_dims are word embedding vectors, will be put into a linear layer first to reduce the dim to KG_compress_dims
# the compressed word embedding will be concated with the other features , and go through transformers 
class TransformerKGModel(nn.Module):
    def __init__(self, KG_dims, KG_compress_dims, input_dim, embedding_size, num_heads, num_layers, dropout):
        super(TransformerKGModel, self).__init__()
        self.compress  = nn.Linear(KG_dims, KG_compress_dims)
        #self.concat_dims = KG_compress_dims
        self.embedding = nn.Linear(input_dim, embedding_size)
        self.transformer = Transformer(
            d_model=embedding_size,
            nhead=num_heads,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            dropout=dropout,
        )
        self.fc = nn.Linear(embedding_size, 1)

    def forward(self, src, src_key_padding_mask=None, future_mask=None):
        
        KG_part = src[:,:,:self.compress.in_features] # first KG_dims features are word embeddings
        
        
        other_parts = src[:,:, self.compress.in_features:] # [seq_len, batch_size, feature_size]
        # Compress KG_part  to KG_compress_dims
        compressed = self.compress(KG_part)  
        # Concatenate with second part
        src = torch.cat((compressed, other_parts), dim=-1)  
        # Reshape for transformer
   
        #src = src.unsqueeze(1)
        

        src = self.embedding(src)

        src = src.transpose(0, 1) 

        output = self.transformer.encoder(src, src_key_padding_mask=src_key_padding_mask, mask = future_mask)

        output = self.fc(output)

        output = torch.sigmoid(output)

        return output.transpose(0, 1).squeeze(-1)


In [8]:
import sklearn.metrics as skmetrics
def plot_roc(y_test, y_pred_prob, pos_label=True):
    """
    Function to plot the ROC curve and display the AUC of 
    this predictor.
    """
    fpr, tpr, _ = skmetrics.roc_curve(y_test, y_pred_prob, pos_label=pos_label)
    roc_auc = skmetrics.auc(fpr, tpr)
    plt.figure()
    lw = 2
    plt.plot(
        fpr[2],
        tpr[2],
        color="darkorange",
        lw=lw,
        label="ROC curve (area = %0.2f)" % roc_auc,
    )
    plt.plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--")
    plt.plot(fpr, tpr)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("Receiver operating characteristic example")
    plt.legend(loc="lower right")
    plt.show()
    

## data processing

In [34]:
# read study data
dataset = pd.read_csv("./data/new_data_sample_0905.csv", sep=',', index_col=None)


In [35]:
len(set(dataset.uid))

39931

In [36]:
dataset.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 36452215 entries, 0 to 36452214
Data columns (total 18 columns):
 #   Column                            Dtype  
---  ------                            -----  
 0   uid                               object 
 1   register_time                     object 
 2   wid                               object 
 3   spelling                          object 
 4   difficulty                        int64  
 5   review_time                       object 
 6   response                          int64  
 7   study_method                      int64  
 8   repeat_time_inday                 int64  
 9   real_interval_history_byday       object 
 10  last_real_interval_byday          int64  
 11  real_delta_interval_within_inday  int64  
 12  review_day_th                     int64  
 13  delta_t_between_inday             float64
 14  delta_t_within_inday              float64
 15  repeat_time_history               int64  
 16  delta_interval_within_inday       

In [37]:
 # del dataset['repeat_time_inday']information leakage 

In [38]:
dataset.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 36452215 entries, 0 to 36452214
Data columns (total 18 columns):
 #   Column                            Dtype  
---  ------                            -----  
 0   uid                               object 
 1   register_time                     object 
 2   wid                               object 
 3   spelling                          object 
 4   difficulty                        int64  
 5   review_time                       object 
 6   response                          int64  
 7   study_method                      int64  
 8   repeat_time_inday                 int64  
 9   real_interval_history_byday       object 
 10  last_real_interval_byday          int64  
 11  real_delta_interval_within_inday  int64  
 12  review_day_th                     int64  
 13  delta_t_between_inday             float64
 14  delta_t_within_inday              float64
 15  repeat_time_history               int64  
 16  delta_interval_within_inday       

### only use the words in the book list, there are a lot of phrases in the study logs , but they only take a very small proprotion in the dataset

In [39]:
len(set(dataset.wid))

24640

In [40]:
# Ensure the column is of string type
dataset['spelling'] = dataset['spelling'].astype(str)

# Replace NaNs with a placeholder string
dataset['spelling'] = dataset['spelling'].fillna('')

#sub_dataset = dataset[~dataset.spelling.str.contains(' ')]

In [41]:
word_net_diffculty_dataset = pd.read_csv('./data/tmp_word_net_diffculty_dataset.tsv',sep='\t')
#word_net_diffculty_dataset = word_net_diffculty_dataset.drop_duplicates()
word_net_diffculty_dataset.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 13428 entries, 0 to 13427
Data columns (total 6 columns):
 #   Column     Non-Null Count  Dtype  
---  ------     --------------  -----  
 0   wid        13428 non-null  object 
 1   spelling   13428 non-null  object 
 2   rsp1_rate  13428 non-null  float64
 3   rsp2_rate  13428 non-null  float64
 4   rsp3_rate  13428 non-null  float64
 5   rsp4_rate  13428 non-null  float64
dtypes: float64(4), object(2)
memory usage: 629.6+ KB


In [42]:
word_net_acknowledge_dataset = pd.read_csv('./data/tmp_word_net_acknowledge_dataset.tsv',sep='\t')
word_net_acknowledge_dataset = word_net_acknowledge_dataset.drop_duplicates()
word_net_acknowledge_dataset.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 6709 entries, 0 to 11647
Data columns (total 2 columns):
 #   Column            Non-Null Count  Dtype  
---  ------            --------------  -----  
 0   wid               6709 non-null   object 
 1   acknowledge_rate  6709 non-null   float64
dtypes: float64(1), object(1)
memory usage: 157.2+ KB


In [43]:
word_net_acknowledge_dataset.head()

Unnamed: 0,wid,acknowledge_rate
0,57067bb0a172044907c61dc0,0.318209
1,57067ba5a172044907c6195b,0.560077
2,57067ba7a172044907c61a8a,0.342722
3,57067b8da172044907c60feb,0.70794
4,57067ba7a172044907c61afa,0.285558


In [44]:
word_net_diffculty_ackknowledge_dataset = pd.merge(word_net_diffculty_dataset,word_net_acknowledge_dataset,on=['wid'],how='inner')
word_net_diffculty_ackknowledge_dataset = word_net_diffculty_ackknowledge_dataset.drop_duplicates()
word_net_diffculty_ackknowledge_dataset.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 6709 entries, 0 to 13427
Data columns (total 7 columns):
 #   Column            Non-Null Count  Dtype  
---  ------            --------------  -----  
 0   wid               6709 non-null   object 
 1   spelling          6709 non-null   object 
 2   rsp1_rate         6709 non-null   float64
 3   rsp2_rate         6709 non-null   float64
 4   rsp3_rate         6709 non-null   float64
 5   rsp4_rate         6709 non-null   float64
 6   acknowledge_rate  6709 non-null   float64
dtypes: float64(5), object(2)
memory usage: 419.3+ KB


In [45]:
word_net_diffculty_ackknowledge_dataset.head()

Unnamed: 0,wid,spelling,rsp1_rate,rsp2_rate,rsp3_rate,rsp4_rate,acknowledge_rate
0,57067bb0a172044907c61dc0,preach,0.3624,0.1708,0.4665,0.0003,0.318209
2,57067ba5a172044907c6195b,arise,0.5498,0.1521,0.2977,0.0004,0.560077
4,57067ba7a172044907c61a8a,ingredient,0.3437,0.173,0.4831,0.0003,0.342722
6,57067b8da172044907c60feb,application,0.684,0.1079,0.2077,0.0005,0.70794
8,57067ba7a172044907c61afa,ingenious,0.4271,0.1814,0.3912,0.0003,0.285558


In [46]:
dataset = pd.merge(dataset,word_net_diffculty_ackknowledge_dataset, on=['wid','spelling'],how='inner')
dataset.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 36301382 entries, 0 to 36301381
Data columns (total 23 columns):
 #   Column                            Dtype  
---  ------                            -----  
 0   uid                               object 
 1   register_time                     object 
 2   wid                               object 
 3   spelling                          object 
 4   difficulty                        int64  
 5   review_time                       object 
 6   response                          int64  
 7   study_method                      int64  
 8   repeat_time_inday                 int64  
 9   real_interval_history_byday       object 
 10  last_real_interval_byday          int64  
 11  real_delta_interval_within_inday  int64  
 12  review_day_th                     int64  
 13  delta_t_between_inday             float64
 14  delta_t_within_inday              float64
 15  repeat_time_history               int64  
 16  delta_interval_within_inday       

### add KG embedding and filter thw word sets again

In [47]:
# add KG embedding and filter thw word sets again
KGembedding = pd.read_csv('./data/KGembeddings.csv')
KGembedding.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 6685 entries, 0 to 6684
Data columns (total 65 columns):
 #   Column   Non-Null Count  Dtype  
---  ------   --------------  -----  
 0   w        6685 non-null   object 
 1   embed1   6685 non-null   float64
 2   embed2   6685 non-null   float64
 3   embed3   6685 non-null   float64
 4   embed4   6685 non-null   float64
 5   embed5   6685 non-null   float64
 6   embed6   6685 non-null   float64
 7   embed7   6685 non-null   float64
 8   embed8   6685 non-null   float64
 9   embed9   6685 non-null   float64
 10  embed10  6685 non-null   float64
 11  embed11  6685 non-null   float64
 12  embed12  6685 non-null   float64
 13  embed13  6685 non-null   float64
 14  embed14  6685 non-null   float64
 15  embed15  6685 non-null   float64
 16  embed16  6685 non-null   float64
 17  embed17  6685 non-null   float64
 18  embed18  6685 non-null   float64
 19  embed19  6685 non-null   float64
 20  embed20  6685 non-null   float64
 21  embed21  6685 

In [48]:
KGembedding.rename(columns={'w':'spelling'},inplace=True)
KGembedding.head()

Unnamed: 0,spelling,embed1,embed2,embed3,embed4,embed5,embed6,embed7,embed8,embed9,...,embed55,embed56,embed57,embed58,embed59,embed60,embed61,embed62,embed63,embed64
0,a,-0.112005,0.175023,0.081303,0.119429,0.234358,-0.285393,0.08896,-0.084222,-0.146858,...,0.121453,0.046255,0.027163,0.253145,0.250952,0.395841,0.263606,-0.024474,-0.125878,-0.034926
1,abandon,-0.400461,0.727612,0.237638,0.407411,1.023871,-1.093308,0.371002,-0.282474,-0.60378,...,0.57236,0.213719,0.139102,0.971377,0.966684,1.702201,1.043585,-0.034144,-0.49925,-0.070479
2,abdomen,-0.485593,1.103594,0.264719,0.511833,1.521983,-1.432766,0.503473,-0.349383,-0.807522,...,0.910728,0.306264,0.270944,1.298496,1.351456,2.541386,1.477359,0.069351,-0.689769,-0.026549
3,abide,-0.479631,1.098984,0.261993,0.506807,1.512208,-1.418678,0.499232,-0.345371,-0.799597,...,0.907186,0.303954,0.270819,1.286539,1.342442,2.527811,1.467622,0.073551,-0.683922,-0.024235
4,ability,-0.257335,0.465007,0.156881,0.264888,0.649716,-0.698882,0.233898,-0.185201,-0.38173,...,0.361201,0.133533,0.09038,0.622238,0.620555,1.080868,0.667867,-0.02428,-0.319865,-0.049535


In [49]:
dataset = pd.merge(KGembedding,dataset, on=['spelling'],how='inner')
dataset.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 34690422 entries, 0 to 34690421
Data columns (total 87 columns):
 #   Column                            Dtype  
---  ------                            -----  
 0   spelling                          object 
 1   embed1                            float64
 2   embed2                            float64
 3   embed3                            float64
 4   embed4                            float64
 5   embed5                            float64
 6   embed6                            float64
 7   embed7                            float64
 8   embed8                            float64
 9   embed9                            float64
 10  embed10                           float64
 11  embed11                           float64
 12  embed12                           float64
 13  embed13                           float64
 14  embed14                           float64
 15  embed15                           float64
 16  embed16                           

### check the distribution of word review counts per user


In [50]:
data_gp = dataset.groupby(['uid']).count().sort_values(by='register_time')


percentiles_list = list(range(0, 101, 10)) + [1,5,95 ,99]  # includes 1st and 99th percentile
percentiles_list.sort()  # make sure the percentiles are in ascending order

percentiles = {i: data_gp['wid'].quantile(i/100) for i in percentiles_list}

# Now percentiles is a dictionary that contains the 0th, 1st, 10th, 20th, ..., 99th, 100th percentile of 'column_name'

for percentile in percentiles_list:
    print(f'{percentile} percentile: {percentiles[percentile]}')
    

0 percentile: 1.0
1 percentile: 3.0
5 percentile: 20.0
10 percentile: 30.0
20 percentile: 73.0
30 percentile: 148.0
40 percentile: 266.0
50 percentile: 437.0
60 percentile: 639.0
70 percentile: 921.0
80 percentile: 1364.0
90 percentile: 2265.0
95 percentile: 3266.8499999999985
99 percentile: 5761.549999999952
100 percentile: 27652.0


In [51]:
# only valid users : filtering out abnormal users
valid_uids = list(data_gp[(data_gp.register_time<=3000)&(data_gp.register_time>=30)].index)
len(valid_uids)

33700

In [52]:
#subselect the dataset which only has valid users
dataset= dataset[dataset.uid.isin(valid_uids)]
dataset["review_time"] = pd.to_datetime(dataset["review_time"] )
dataset["register_time"] = pd.to_datetime(dataset["register_time"] )
dataset = dataset.sort_values(by=['uid', 'review_time'])

In [53]:
# process Y
dataset['Y'] = dataset.response.progress_apply(lambda x : 1 if x ==1 else 0)

# lower memory
for column in dataset.select_dtypes(include=['int64']):
    dataset[column] = dataset[column].astype('int8')
for column in dataset.select_dtypes(include=['float64']):
    dataset[column] = dataset[column].astype('float32')

100%|██████████| 23656968/23656968 [00:12<00:00, 1924209.04it/s]


In [54]:
valid_word = list(set(dataset['spelling'].values))
len(valid_word)

6180

### create the last response(0/1) as an input feature for next time step. For the first-time review's last response , use the acknowledge rate for the word instead

In [55]:
def get_last_response(x):
    response_history = x[0]
    acknowledge_rate = x[1]
    try:
        last_response = int(response_history[-2])
        if last_response ==1 :
            None 
        else:
            last_response=0
    except : 
        last_response = int(acknowledge_rate>0.5)
    return last_response

In [56]:
dataset['last_response'] = dataset[['response_history','acknowledge_rate']].progress_apply(get_last_response, axis =1)

100%|██████████| 23656968/23656968 [01:28<00:00, 266848.96it/s]


In [57]:
dataset.head()

Unnamed: 0,spelling,embed1,embed2,embed3,embed4,embed5,embed6,embed7,embed8,embed9,...,repeat_time_history,delta_interval_within_inday,response_history,rsp1_rate,rsp2_rate,rsp3_rate,rsp4_rate,acknowledge_rate,Y,last_response
7021226,console,-0.474705,1.094874,0.258717,0.503144,1.507025,-1.408898,0.49611,-0.342881,-0.794755,...,1,0.0,1,0.4154,0.1731,0.4112,0.0003,0.303781,1,0
6836703,consist,-0.470568,0.962991,0.2606,0.480988,1.344495,-1.333904,0.460674,-0.332445,-0.745455,...,1,0.0,1,0.5524,0.1576,0.2895,0.0005,0.563981,1,1
6904211,consistent,-0.404658,0.898675,0.222523,0.423722,1.239092,-1.180395,0.411884,-0.290802,-0.661933,...,1,0.0,1,0.4417,0.1619,0.3961,0.0003,0.457183,1,0
7120550,constant,-0.43827,0.956392,0.241749,0.456376,1.321626,-1.270601,0.442882,-0.313167,-0.712319,...,1,0.0,1,0.447,0.1674,0.3853,0.0003,0.443431,1,0
7233272,constituent,-0.465048,1.06278,0.253318,0.490704,1.463766,-1.374023,0.482891,-0.335,-0.774168,...,1,0.0,1,0.3894,0.2085,0.4019,0.0003,0.367841,1,0


In [58]:
dataset.columns

Index(['spelling', 'embed1', 'embed2', 'embed3', 'embed4', 'embed5', 'embed6',
       'embed7', 'embed8', 'embed9', 'embed10', 'embed11', 'embed12',
       'embed13', 'embed14', 'embed15', 'embed16', 'embed17', 'embed18',
       'embed19', 'embed20', 'embed21', 'embed22', 'embed23', 'embed24',
       'embed25', 'embed26', 'embed27', 'embed28', 'embed29', 'embed30',
       'embed31', 'embed32', 'embed33', 'embed34', 'embed35', 'embed36',
       'embed37', 'embed38', 'embed39', 'embed40', 'embed41', 'embed42',
       'embed43', 'embed44', 'embed45', 'embed46', 'embed47', 'embed48',
       'embed49', 'embed50', 'embed51', 'embed52', 'embed53', 'embed54',
       'embed55', 'embed56', 'embed57', 'embed58', 'embed59', 'embed60',
       'embed61', 'embed62', 'embed63', 'embed64', 'uid', 'register_time',
       'wid', 'difficulty', 'review_time', 'response', 'study_method',
       'repeat_time_inday', 'real_interval_history_byday',
       'last_real_interval_byday', 'real_delta_interval_within

### create a new feature for estimate the user tenure

In [59]:
dataset['user_tenure'] = (dataset["review_time"] -dataset["register_time"]).dt.days

### create the dataframe , which is ready for be transformed to a sequential list

In [60]:
df_encoded = dataset[list(KGembedding.columns[1:]) +['Y','spelling','review_time','uid', 'user_tenure','rsp1_rate',
       'rsp2_rate', 'rsp3_rate', 'rsp4_rate', 'acknowledge_rate', 'difficulty',  'study_method', 'repeat_time_inday',
       'real_delta_interval_within_inday', 'review_day_th',
       'delta_t_between_inday', 'delta_t_within_inday', 'repeat_time_history',
       'delta_interval_within_inday']]
df_encoded.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 23656968 entries, 7021226 to 29463786
Data columns (total 83 columns):
 #   Column                            Dtype         
---  ------                            -----         
 0   embed1                            float32       
 1   embed2                            float32       
 2   embed3                            float32       
 3   embed4                            float32       
 4   embed5                            float32       
 5   embed6                            float32       
 6   embed7                            float32       
 7   embed8                            float32       
 8   embed9                            float32       
 9   embed10                           float32       
 10  embed11                           float32       
 11  embed12                           float32       
 12  embed13                           float32       
 13  embed14                           float32       
 14  embed15   

In [61]:
df_encoded = df_encoded.sort_values(by=['uid', 'review_time'])

In [62]:
#check if any null value in any row
rows_with_null = df_encoded[df_encoded.isnull().any(axis=1)]
rows_with_null

Unnamed: 0,embed1,embed2,embed3,embed4,embed5,embed6,embed7,embed8,embed9,embed10,...,acknowledge_rate,difficulty,study_method,repeat_time_inday,real_delta_interval_within_inday,review_day_th,delta_t_between_inday,delta_t_within_inday,repeat_time_history,delta_interval_within_inday


In [None]:
#df_encoded.to_csv("./data/df_encoded_KG_0905.csv",index=None)

In [64]:
del dataset

#### We have over 24 million study records of 3.3W unique users with over 6K words

### transform the dataframe to a sequncial list, each element is a list of a user's sequential study records, each record has X features
#### Consider the list will be very large, I have to optimize it using deque() and seperate chunks

## data read and transforming

In [9]:
df_encoded = pd.read_csv("./data/df_encoded_KG_0905.csv")

In [10]:
valid_uids= list(set(df_encoded.uid))

In [11]:

from collections import deque
chunks =200 # chunk size 
sub_uids=[valid_uids[i:i+chunks] for i in range (0, len(valid_uids),chunks)]
print(len(sub_uids))

169


In [12]:
import time
import itertools
import sys
import pickle
starttime=time.time()

src_data=[]
tgt_data=[]

for i,sub_uid in enumerate(sub_uids):
    print('loop',i)
    try:
        sub_df_encoded = df_encoded[df_encoded.uid.isin(sub_uid)]
        sub_src_data = [group.drop(columns=['uid',"Y",'spelling','review_time']).apply(list, axis=1).tolist() for _, group in tqdm(sub_df_encoded.groupby('uid'), desc='Processing data')]   
        src_data.append(sub_src_data)

        sub_tgt_data = [[elem[0] for elem in group.drop(columns=['uid','spelling','review_time']).apply(list, axis=1).tolist() ]for _, group in tqdm(sub_df_encoded[["uid",'spelling','review_time',"Y"]].groupby('uid'), desc='Processing data')]
        tgt_data.append(sub_tgt_data)


#         with open(f'./data/sub_src_data_list_{i}.pkl', 'wb') as file1:
#             pickle.dump(sub_src_data, file1)

#         with open(f'./data/sub_tgt_data_list_{i}.pkl', 'wb') as file2:
#             pickle.dump(sub_tgt_data, file2)
        #size_in_bytes = sys.getsizeof(src_data)
        #size_in_gigabytes = size_in_bytes / (1024 * 1024 * 1024)
        #print(f"The list takes up {size_in_gigabytes:.2f} GB.")
        print('total time',time.time() -starttime)
        #print(memory_usage())

    except: 
        print('error happens at loop',i )
        sub_df_encoded.to_csv(f"./data/sub_df_encoded_{i}.csv",index=None)
        
src_data =  list(itertools.chain(*src_data))
tgt_data =list(itertools.chain(*tgt_data))

print(memory_usage())

loop 0


Processing data: 100%|██████████| 200/200 [00:01<00:00, 139.02it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 406.84it/s]


total time 2.368722677230835
loop 1


Processing data: 100%|██████████| 200/200 [00:01<00:00, 162.71it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 443.16it/s]


total time 4.445590019226074
loop 2


Processing data: 100%|██████████| 200/200 [00:01<00:00, 141.13it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 402.05it/s]


total time 6.7803122997283936
loop 3


Processing data: 100%|██████████| 200/200 [00:01<00:00, 129.90it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 383.87it/s]


total time 9.250917434692383
loop 4


Processing data: 100%|██████████| 200/200 [00:01<00:00, 132.78it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 427.77it/s]


total time 11.633249759674072
loop 5


Processing data: 100%|██████████| 200/200 [00:01<00:00, 166.88it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 391.61it/s]


total time 13.752275228500366
loop 6


Processing data: 100%|██████████| 200/200 [00:01<00:00, 125.48it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 406.01it/s]


total time 16.246171236038208
loop 7


Processing data: 100%|██████████| 200/200 [00:01<00:00, 192.11it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 442.09it/s]


total time 18.161577463150024
loop 8


Processing data: 100%|██████████| 200/200 [00:01<00:00, 110.38it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 392.59it/s]


total time 20.9143385887146
loop 9


Processing data: 100%|██████████| 200/200 [00:01<00:00, 178.31it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 420.13it/s]


total time 22.922293424606323
loop 10


Processing data: 100%|██████████| 200/200 [00:01<00:00, 101.76it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 386.02it/s]


total time 25.807850122451782
loop 11


Processing data: 100%|██████████| 200/200 [00:01<00:00, 174.78it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 406.37it/s]


total time 27.86551260948181
loop 12


Processing data: 100%|██████████| 200/200 [00:01<00:00, 183.96it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 427.06it/s]


total time 29.83044695854187
loop 13


Processing data: 100%|██████████| 200/200 [00:02<00:00, 86.74it/s] 
Processing data: 100%|██████████| 200/200 [00:00<00:00, 352.08it/s]


total time 33.11113667488098
loop 14


Processing data: 100%|██████████| 200/200 [00:01<00:00, 198.48it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 451.07it/s]


total time 35.00505018234253
loop 15


Processing data: 100%|██████████| 200/200 [00:01<00:00, 193.78it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 445.87it/s]


total time 36.90705871582031
loop 16


Processing data: 100%|██████████| 200/200 [00:01<00:00, 181.56it/s]
Processing data: 100%|██████████| 200/200 [00:01<00:00, 122.76it/s]


total time 40.05214834213257
loop 17


Processing data: 100%|██████████| 200/200 [00:01<00:00, 172.97it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 400.83it/s]


total time 42.114402770996094
loop 18


Processing data: 100%|██████████| 200/200 [00:01<00:00, 179.96it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 430.33it/s]


total time 44.097957611083984
loop 19


Processing data: 100%|██████████| 200/200 [00:01<00:00, 190.19it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 434.54it/s]


total time 46.01983404159546
loop 20


Processing data: 100%|██████████| 200/200 [00:01<00:00, 184.12it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 423.49it/s]


total time 47.98504424095154
loop 21


Processing data: 100%|██████████| 200/200 [00:02<00:00, 77.57it/s] 
Processing data: 100%|██████████| 200/200 [00:00<00:00, 410.81it/s]


total time 51.45806694030762
loop 22


Processing data: 100%|██████████| 200/200 [00:01<00:00, 177.83it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 416.67it/s]


total time 53.4774169921875
loop 23


Processing data: 100%|██████████| 200/200 [00:01<00:00, 192.04it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 445.04it/s]


total time 55.38919973373413
loop 24


Processing data: 100%|██████████| 200/200 [00:01<00:00, 176.85it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 418.70it/s]


total time 57.40373158454895
loop 25


Processing data: 100%|██████████| 200/200 [00:01<00:00, 178.14it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 408.85it/s]


total time 59.41468954086304
loop 26


Processing data: 100%|██████████| 200/200 [00:01<00:00, 185.69it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 425.02it/s]


total time 61.36269927024841
loop 27


Processing data: 100%|██████████| 200/200 [00:02<00:00, 67.72it/s] 
Processing data: 100%|██████████| 200/200 [00:00<00:00, 397.68it/s]


total time 65.21607565879822
loop 28


Processing data: 100%|██████████| 200/200 [00:01<00:00, 191.01it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 440.86it/s]


total time 67.11084842681885
loop 29


Processing data: 100%|██████████| 200/200 [00:01<00:00, 191.40it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 437.58it/s]


total time 69.00145125389099
loop 30


Processing data: 100%|██████████| 200/200 [00:00<00:00, 205.13it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 462.84it/s]


total time 70.8049259185791
loop 31


Processing data: 100%|██████████| 200/200 [00:01<00:00, 188.52it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 431.99it/s]


total time 72.73319220542908
loop 32


Processing data: 100%|██████████| 200/200 [00:01<00:00, 184.43it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 422.66it/s]


total time 74.68562388420105
loop 33


Processing data: 100%|██████████| 200/200 [00:01<00:00, 191.30it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 444.33it/s]


total time 76.5759539604187
loop 34


Processing data: 100%|██████████| 200/200 [00:03<00:00, 60.75it/s] 
Processing data: 100%|██████████| 200/200 [00:00<00:00, 433.22it/s]


total time 80.75615906715393
loop 35


Processing data: 100%|██████████| 200/200 [00:01<00:00, 182.09it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 417.72it/s]


total time 82.74276900291443
loop 36


Processing data: 100%|██████████| 200/200 [00:01<00:00, 170.43it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 396.58it/s]


total time 84.81818008422852
loop 37


Processing data: 100%|██████████| 200/200 [00:01<00:00, 187.87it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 435.64it/s]


total time 86.75633263587952
loop 38


Processing data: 100%|██████████| 200/200 [00:01<00:00, 188.78it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 441.50it/s]


total time 88.69233870506287
loop 39


Processing data: 100%|██████████| 200/200 [00:01<00:00, 193.95it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 445.62it/s]


total time 90.58519268035889
loop 40


Processing data: 100%|██████████| 200/200 [00:01<00:00, 176.19it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 408.39it/s]


total time 92.61860013008118
loop 41


Processing data: 100%|██████████| 200/200 [00:01<00:00, 170.80it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 398.14it/s]


total time 94.69883751869202
loop 42


Processing data: 100%|██████████| 200/200 [00:03<00:00, 51.94it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 432.47it/s]


total time 99.41123414039612
loop 43


Processing data: 100%|██████████| 200/200 [00:01<00:00, 165.23it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 391.98it/s]


total time 101.56163001060486
loop 44


Processing data: 100%|██████████| 200/200 [00:01<00:00, 177.77it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 418.61it/s]


total time 103.58449125289917
loop 45


Processing data: 100%|██████████| 200/200 [00:01<00:00, 164.94it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 394.85it/s]


total time 105.72409844398499
loop 46


Processing data: 100%|██████████| 200/200 [00:01<00:00, 181.19it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 427.98it/s]


total time 107.70204401016235
loop 47


Processing data: 100%|██████████| 200/200 [00:01<00:00, 186.48it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 438.44it/s]


total time 109.6333544254303
loop 48


Processing data: 100%|██████████| 200/200 [00:01<00:00, 162.38it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 382.61it/s]


total time 111.80030751228333
loop 49


Processing data: 100%|██████████| 200/200 [00:01<00:00, 172.75it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 408.09it/s]


total time 113.85761904716492
loop 50


Processing data: 100%|██████████| 200/200 [00:01<00:00, 168.69it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 402.41it/s]


total time 115.95631003379822
loop 51


Processing data: 100%|██████████| 200/200 [00:01<00:00, 171.52it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 410.96it/s]


total time 118.03614330291748
loop 52


Processing data: 100%|██████████| 200/200 [00:04<00:00, 43.93it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 444.16it/s]


total time 123.45870637893677
loop 53


Processing data: 100%|██████████| 200/200 [00:01<00:00, 176.16it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 418.36it/s]


total time 125.48518705368042
loop 54


Processing data: 100%|██████████| 200/200 [00:01<00:00, 185.23it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 431.13it/s]


total time 127.44250512123108
loop 55


Processing data: 100%|██████████| 200/200 [00:01<00:00, 183.96it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 424.47it/s]


total time 129.400297164917
loop 56


Processing data: 100%|██████████| 200/200 [00:00<00:00, 200.04it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 459.82it/s]


total time 131.24204874038696
loop 57


Processing data: 100%|██████████| 200/200 [00:01<00:00, 173.17it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 400.85it/s]


total time 133.29587769508362
loop 58


Processing data: 100%|██████████| 200/200 [00:01<00:00, 192.04it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 439.77it/s]


total time 135.19502782821655
loop 59


Processing data: 100%|██████████| 200/200 [00:01<00:00, 176.83it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 414.08it/s]


total time 137.22990584373474
loop 60


Processing data: 100%|██████████| 200/200 [00:01<00:00, 183.23it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 420.17it/s]


total time 139.21366214752197
loop 61


Processing data: 100%|██████████| 200/200 [00:01<00:00, 196.45it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 454.38it/s]


total time 141.08641386032104
loop 62


Processing data: 100%|██████████| 200/200 [00:01<00:00, 182.58it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 424.72it/s]


total time 143.05185055732727
loop 63


Processing data: 100%|██████████| 200/200 [00:01<00:00, 172.76it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 400.36it/s]


total time 145.10553002357483
loop 64


Processing data: 100%|██████████| 200/200 [00:00<00:00, 213.31it/s]
Processing data: 100%|██████████| 200/200 [00:04<00:00, 43.19it/s] 


total time 151.0654158592224
loop 65


Processing data: 100%|██████████| 200/200 [00:01<00:00, 185.70it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 426.50it/s]


total time 153.01034879684448
loop 66


Processing data: 100%|██████████| 200/200 [00:01<00:00, 183.38it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 422.61it/s]


total time 154.97523999214172
loop 67


Processing data: 100%|██████████| 200/200 [00:01<00:00, 173.27it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 404.57it/s]


total time 157.02660989761353
loop 68


Processing data: 100%|██████████| 200/200 [00:01<00:00, 195.19it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 451.09it/s]


total time 158.8956172466278
loop 69


Processing data: 100%|██████████| 200/200 [00:01<00:00, 192.24it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 438.58it/s]


total time 160.7951214313507
loop 70


Processing data: 100%|██████████| 200/200 [00:00<00:00, 200.45it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 458.83it/s]


total time 162.6497642993927
loop 71


Processing data: 100%|██████████| 200/200 [00:01<00:00, 173.56it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 402.08it/s]


total time 164.70322847366333
loop 72


Processing data: 100%|██████████| 200/200 [00:01<00:00, 191.61it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 440.80it/s]


total time 166.60644602775574
loop 73


Processing data: 100%|██████████| 200/200 [00:01<00:00, 198.65it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 457.12it/s]


total time 168.46116256713867
loop 74


Processing data: 100%|██████████| 200/200 [00:01<00:00, 185.06it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 432.22it/s]


total time 170.41905522346497
loop 75


Processing data: 100%|██████████| 200/200 [00:01<00:00, 173.31it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 412.16it/s]


total time 172.47154569625854
loop 76


Processing data: 100%|██████████| 200/200 [00:01<00:00, 192.56it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 438.70it/s]


total time 174.37301683425903
loop 77


Processing data: 100%|██████████| 200/200 [00:01<00:00, 178.81it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 414.98it/s]


total time 176.4172706604004
loop 78


Processing data: 100%|██████████| 200/200 [00:00<00:00, 204.10it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 464.54it/s]


total time 178.22554302215576
loop 79


Processing data: 100%|██████████| 200/200 [00:01<00:00, 186.18it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 429.68it/s]


total time 180.16635298728943
loop 80


Processing data: 100%|██████████| 200/200 [00:06<00:00, 31.72it/s] 
Processing data: 100%|██████████| 200/200 [00:00<00:00, 456.29it/s]


total time 187.30933833122253
loop 81


Processing data: 100%|██████████| 200/200 [00:01<00:00, 164.24it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 383.88it/s]


total time 189.45545744895935
loop 82


Processing data: 100%|██████████| 200/200 [00:01<00:00, 177.52it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 413.23it/s]


total time 191.47838854789734
loop 83


Processing data: 100%|██████████| 200/200 [00:01<00:00, 192.86it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 446.77it/s]


total time 193.37535953521729
loop 84


Processing data: 100%|██████████| 200/200 [00:01<00:00, 181.22it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 411.45it/s]


total time 195.38119649887085
loop 85


Processing data: 100%|██████████| 200/200 [00:01<00:00, 190.52it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 440.42it/s]


total time 197.285569190979
loop 86


Processing data: 100%|██████████| 200/200 [00:01<00:00, 192.69it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 439.44it/s]


total time 199.1763870716095
loop 87


Processing data: 100%|██████████| 200/200 [00:01<00:00, 176.54it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 411.41it/s]


total time 201.18614768981934
loop 88


Processing data: 100%|██████████| 200/200 [00:01<00:00, 178.88it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 408.48it/s]


total time 203.19553542137146
loop 89


Processing data: 100%|██████████| 200/200 [00:01<00:00, 177.78it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 407.52it/s]


total time 205.20798516273499
loop 90


Processing data: 100%|██████████| 200/200 [00:00<00:00, 201.61it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 461.36it/s]


total time 207.0219123363495
loop 91


Processing data: 100%|██████████| 200/200 [00:01<00:00, 171.78it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 395.59it/s]


total time 209.09276461601257
loop 92


Processing data: 100%|██████████| 200/200 [00:01<00:00, 175.82it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 403.20it/s]


total time 211.11822772026062
loop 93


Processing data: 100%|██████████| 200/200 [00:01<00:00, 175.77it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 405.11it/s]


total time 213.14703965187073
loop 94


Processing data: 100%|██████████| 200/200 [00:00<00:00, 206.17it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 465.02it/s]


total time 214.93692255020142
loop 95


Processing data: 100%|██████████| 200/200 [00:01<00:00, 171.07it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 392.46it/s]


total time 217.0097062587738
loop 96


Processing data: 100%|██████████| 200/200 [00:01<00:00, 176.16it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 411.88it/s]


total time 219.03227758407593
loop 97


Processing data: 100%|██████████| 200/200 [00:01<00:00, 183.09it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 425.18it/s]


total time 220.9927270412445
loop 98


Processing data: 100%|██████████| 200/200 [00:01<00:00, 186.70it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 435.25it/s]


total time 222.9320616722107
loop 99


Processing data: 100%|██████████| 200/200 [00:08<00:00, 24.77it/s] 
Processing data: 100%|██████████| 200/200 [00:00<00:00, 417.34it/s]


total time 231.90669631958008
loop 100


Processing data: 100%|██████████| 200/200 [00:01<00:00, 171.09it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 407.94it/s]


total time 233.9818253517151
loop 101


Processing data: 100%|██████████| 200/200 [00:01<00:00, 193.38it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 457.98it/s]


total time 235.8613476753235
loop 102


Processing data: 100%|██████████| 200/200 [00:01<00:00, 190.61it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 441.79it/s]


total time 237.75726532936096
loop 103


Processing data: 100%|██████████| 200/200 [00:01<00:00, 164.79it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 402.56it/s]


total time 239.8999183177948
loop 104


Processing data: 100%|██████████| 200/200 [00:01<00:00, 180.74it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 424.62it/s]


total time 241.90177822113037
loop 105


Processing data: 100%|██████████| 200/200 [00:01<00:00, 191.27it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 451.49it/s]


total time 243.79837584495544
loop 106


Processing data: 100%|██████████| 200/200 [00:01<00:00, 176.07it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 423.55it/s]


total time 245.8205235004425
loop 107


Processing data: 100%|██████████| 200/200 [00:01<00:00, 159.80it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 384.23it/s]


total time 248.01560187339783
loop 108


Processing data: 100%|██████████| 200/200 [00:01<00:00, 165.71it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 395.53it/s]


total time 250.14988827705383
loop 109


Processing data: 100%|██████████| 200/200 [00:01<00:00, 186.44it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 439.56it/s]


total time 252.09237599372864
loop 110


Processing data: 100%|██████████| 200/200 [00:01<00:00, 194.05it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 451.52it/s]


total time 253.9798982143402
loop 111


Processing data: 100%|██████████| 200/200 [00:01<00:00, 187.21it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 442.97it/s]


total time 255.91320300102234
loop 112


Processing data: 100%|██████████| 200/200 [00:01<00:00, 188.34it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 439.20it/s]


total time 257.84719157218933
loop 113


Processing data: 100%|██████████| 200/200 [00:01<00:00, 186.79it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 437.59it/s]


total time 259.79329323768616
loop 114


Processing data: 100%|██████████| 200/200 [00:01<00:00, 181.45it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 426.75it/s]


total time 261.7778477668762
loop 115


Processing data: 100%|██████████| 200/200 [00:01<00:00, 187.51it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 439.34it/s]


total time 263.722158908844
loop 116


Processing data: 100%|██████████| 200/200 [00:00<00:00, 205.24it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 474.93it/s]


total time 265.53350019454956
loop 117


Processing data: 100%|██████████| 200/200 [00:01<00:00, 181.95it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 438.75it/s]


total time 267.50134110450745
loop 118


Processing data: 100%|██████████| 200/200 [00:01<00:00, 161.18it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 386.35it/s]


total time 269.68507623672485
loop 119


Processing data: 100%|██████████| 200/200 [00:01<00:00, 197.01it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 463.85it/s]


total time 271.56451416015625
loop 120


Processing data: 100%|██████████| 200/200 [00:01<00:00, 182.69it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 434.62it/s]


total time 273.53939986228943
loop 121


Processing data: 100%|██████████| 200/200 [00:01<00:00, 174.03it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 413.22it/s]


total time 275.59279799461365
loop 122


Processing data: 100%|██████████| 200/200 [00:01<00:00, 163.31it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 398.04it/s]


total time 277.75541377067566
loop 123


Processing data: 100%|██████████| 200/200 [00:11<00:00, 16.99it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 479.02it/s]


total time 290.3748927116394
loop 124


Processing data: 100%|██████████| 200/200 [00:01<00:00, 169.82it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 403.90it/s]


total time 292.4936294555664
loop 125


Processing data: 100%|██████████| 200/200 [00:01<00:00, 184.43it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 439.78it/s]


total time 294.4498414993286
loop 126


Processing data: 100%|██████████| 200/200 [00:01<00:00, 155.74it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 374.96it/s]


total time 296.7239260673523
loop 127


Processing data: 100%|██████████| 200/200 [00:01<00:00, 193.57it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 451.73it/s]


total time 298.60267329216003
loop 128


Processing data: 100%|██████████| 200/200 [00:01<00:00, 180.64it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 425.40it/s]


total time 300.61306285858154
loop 129


Processing data: 100%|██████████| 200/200 [00:01<00:00, 169.39it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 406.22it/s]


total time 302.69371485710144
loop 130


Processing data: 100%|██████████| 200/200 [00:01<00:00, 182.89it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 430.69it/s]


total time 304.66243743896484
loop 131


Processing data: 100%|██████████| 200/200 [00:01<00:00, 196.24it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 458.32it/s]


total time 306.5491065979004
loop 132


Processing data: 100%|██████████| 200/200 [00:01<00:00, 195.57it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 456.31it/s]


total time 308.4379873275757
loop 133


Processing data: 100%|██████████| 200/200 [00:01<00:00, 170.88it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 407.48it/s]


total time 310.52758717536926
loop 134


Processing data: 100%|██████████| 200/200 [00:01<00:00, 178.05it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 416.74it/s]


total time 312.5331025123596
loop 135


Processing data: 100%|██████████| 200/200 [00:01<00:00, 170.48it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 410.03it/s]


total time 314.64603567123413
loop 136


Processing data: 100%|██████████| 200/200 [00:01<00:00, 165.95it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 397.42it/s]


total time 316.7684030532837
loop 137


Processing data: 100%|██████████| 200/200 [00:01<00:00, 174.41it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 411.48it/s]


total time 318.82059717178345
loop 138


Processing data: 100%|██████████| 200/200 [00:01<00:00, 180.54it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 428.55it/s]


total time 320.82845759391785
loop 139


Processing data: 100%|██████████| 200/200 [00:01<00:00, 196.84it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 458.05it/s]


total time 322.7003664970398
loop 140


Processing data: 100%|██████████| 200/200 [00:01<00:00, 192.31it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 456.16it/s]


total time 324.6154406070709
loop 141


Processing data: 100%|██████████| 200/200 [00:01<00:00, 174.27it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 406.67it/s]


total time 326.6791024208069
loop 142


Processing data: 100%|██████████| 200/200 [00:01<00:00, 169.52it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 410.46it/s]


total time 328.79107093811035
loop 143


Processing data: 100%|██████████| 200/200 [00:01<00:00, 183.66it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 432.55it/s]


total time 330.7650496959686
loop 144


Processing data: 100%|██████████| 200/200 [00:00<00:00, 203.47it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 469.78it/s]


total time 332.59498858451843
loop 145


Processing data: 100%|██████████| 200/200 [00:01<00:00, 190.39it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 446.00it/s]


total time 334.5112552642822
loop 146


Processing data: 100%|██████████| 200/200 [00:01<00:00, 189.51it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 444.13it/s]


total time 336.4276297092438
loop 147


Processing data: 100%|██████████| 200/200 [00:01<00:00, 181.39it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 429.19it/s]


total time 338.4313585758209
loop 148


Processing data: 100%|██████████| 200/200 [00:01<00:00, 169.10it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 402.57it/s]


total time 340.53135895729065
loop 149


Processing data: 100%|██████████| 200/200 [00:01<00:00, 179.28it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 422.87it/s]


total time 342.5356502532959
loop 150


Processing data: 100%|██████████| 200/200 [00:01<00:00, 175.58it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 418.95it/s]


total time 344.57234168052673
loop 151


Processing data: 100%|██████████| 200/200 [00:01<00:00, 167.87it/s]
Processing data: 100%|██████████| 200/200 [00:12<00:00, 15.81it/s]


total time 358.83077478408813
loop 152


Processing data: 100%|██████████| 200/200 [00:01<00:00, 187.97it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 433.89it/s]


total time 360.7759885787964
loop 153


Processing data: 100%|██████████| 200/200 [00:01<00:00, 183.29it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 432.02it/s]


total time 362.7446713447571
loop 154


Processing data: 100%|██████████| 200/200 [00:01<00:00, 188.17it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 436.89it/s]


total time 364.68108773231506
loop 155


Processing data: 100%|██████████| 200/200 [00:01<00:00, 184.94it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 428.67it/s]


total time 366.6377160549164
loop 156


Processing data: 100%|██████████| 200/200 [00:01<00:00, 198.85it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 459.69it/s]


total time 368.4794600009918
loop 157


Processing data: 100%|██████████| 200/200 [00:01<00:00, 177.87it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 415.98it/s]


total time 370.49483013153076
loop 158


Processing data: 100%|██████████| 200/200 [00:01<00:00, 185.93it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 433.07it/s]


total time 372.44580698013306
loop 159


Processing data: 100%|██████████| 200/200 [00:01<00:00, 164.87it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 388.41it/s]


total time 374.60196232795715
loop 160


Processing data: 100%|██████████| 200/200 [00:01<00:00, 165.67it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 394.46it/s]


total time 376.73806858062744
loop 161


Processing data: 100%|██████████| 200/200 [00:01<00:00, 189.84it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 443.72it/s]


total time 378.6634941101074
loop 162


Processing data: 100%|██████████| 200/200 [00:00<00:00, 201.43it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 470.99it/s]


total time 380.4976050853729
loop 163


Processing data: 100%|██████████| 200/200 [00:01<00:00, 181.03it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 426.60it/s]


total time 382.50461411476135
loop 164


Processing data: 100%|██████████| 200/200 [00:01<00:00, 173.82it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 409.98it/s]


total time 384.56031250953674
loop 165


Processing data: 100%|██████████| 200/200 [00:01<00:00, 171.59it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 402.13it/s]


total time 386.63916301727295
loop 166


Processing data: 100%|██████████| 200/200 [00:01<00:00, 176.35it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 414.24it/s]


total time 388.6738955974579
loop 167


Processing data: 100%|██████████| 200/200 [00:01<00:00, 197.03it/s]
Processing data: 100%|██████████| 200/200 [00:00<00:00, 456.09it/s]


total time 390.55384278297424
loop 168


Processing data: 100%|██████████| 100/100 [00:00<00:00, 185.95it/s]
Processing data: 100%|██████████| 100/100 [00:00<00:00, 435.03it/s]

total time 391.7118365764618
Total Memory: 1007.45 GB | Used Memory: 917.95 GB | Available Memory: 899.42 GB





In [13]:
print(len(src_data), len(tgt_data))

33700 33700


In [14]:
# Split data into training and testing sets
src_train, src_test, tgt_train, tgt_test = train_test_split(src_data, tgt_data, test_size=0.2, random_state=0)

# Create datasets
train_dataset = MyDataset(src_train, tgt_train)
test_dataset = MyDataset(src_test, tgt_test)



In [15]:
# Create data loaders
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False, collate_fn=collate_fn)


### model with KG training 

In [16]:
# hyperparamter tuning for our model
# param_grid = {
#     'embedding_size': [32, 64, 128],
#     'num_heads': [2, 4, 8],
#     'num_layers': [ 2, 3, 4],
#     'dropout': [0.1, 0.2, 0.3],
#     'lr': [1e-3, 1e-4, 1e-5]
# }

# from sklearn.metrics import f1_score
# import itertools

# # Helper function to get all parameter combinations
# def get_param_combinations(param_grid):
#     keys = param_grid.keys()
#     values = (param_grid[key] for key in keys)
#     for instance in itertools.product(*values):
#         yield dict(zip(keys, instance))


In [17]:
import torch.nn as nn
# Transformer model
KG_dims =64
KG_compress_dims = 16
input_dim = len(src_train[0][0]) -KG_dims + KG_compress_dims
embedding_size = 16*2
num_heads =4
num_layers = 2
dropout = 0.1

model= TransformerKGModel(KG_dims, KG_compress_dims, input_dim, embedding_size, num_heads, num_layers, dropout)

# if torch.cuda.device_count() > 1:
#     print(f"Using {torch.cuda.device_count()} GPUs!")
#     model = nn.DataParallel(model)

model.to(device)


criterion = nn.BCELoss(reduction='none')
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

#device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
# model = model.to(device)
#next(model.parameters()).device


# Training Loop
num_epochs = 1000
patience = 30  # number of epochs to wait for improvement before stopping
min_delta = 0.001  # minimum improvement required to reset patience counter

#best_loss = None
#patience_counter = 0


best_f1 = None
patience_counter = 0
best_AUC=None

best_model_AUC = 0
best_model = None



In [18]:
# def check_memory(src):
#     bytes_per_element = src.element_size() # Returns the size of a single element in bytes.
#     num_elements = src.numel() # Returns the total number of elements in the tensor.
#     total_memory_bytes = bytes_per_element * num_elements
        
#     total_memory_kB = total_memory_bytes / 1024
#     total_memory_MB = total_memory_kB / 1024
#     total_memory_GB = total_memory_MB / 1024
#     print(total_memory_GB) 


In [19]:

start_time = time.time()
for epoch in range(num_epochs):
    #print("epoch:",epoch)
    epoch_loss = 0
    model.train()

    
    # # # # # # ## # #
    # traning 
    # # # # # # ## # #
    for i,(src, tgt) in enumerate(train_dataloader):
        #print("dataloader:" ,i )
        src = src.to(device)
        tgt = tgt.to(device)
    
        #print("src")
        #check_memory(src)
        # Create masks
        #print("Create masks:" ,i )
        src_key_padding_mask, future_mask = create_masks(src)
        
        #print("src_key_padding_mask")
        #check_memory(src_key_padding_mask)
        #print("future_mask")
       # check_memory(future_mask) 
        
        src_key_padding_mask = src_key_padding_mask.to(device)
        
        future_mask = future_mask.to(device)

        #print("Create masks done:" ,i )
        # Pass your sequences and mask through the model
        output = model(src, src_key_padding_mask, future_mask)
        # Calculate loss, perform backpropagation and update weights
        tgt_mask = (tgt != -1).to(device)
        loss = criterion(output[tgt_mask], tgt[tgt_mask])
        loss = loss.mean()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    
    epoch_loss /= len(train_dataloader)  # calculate mean epoch loss
    
    # Early stopping check
    #if best_loss is None or epoch_loss < best_loss - min_delta:
    #    best_loss = epoch_loss
    #    patience_counter = 0
    #else:
    #    patience_counter += 1
    #    if patience_counter >= patience:
    #        print(f'Early stopping on epoch {epoch}')
    #        break
    
    # # # # # # ## # #
    # Testing Loop
    # # # # # # ## # #
    model.eval()
    with torch.no_grad():
        all_probs, all_preds, all_targets = [], [],[]
        for src, tgt in test_dataloader:
            src = src.to(device)
            
            tgt = tgt.to(device)
            # Create masks
            src_key_padding_mask = (src == 0).all(axis=-1).to(device)

            # Pass your sequences and mask through the model
            output = model(src, src_key_padding_mask)

            # Save predictions and targets
            tgt_mask = (tgt != -1).to(device)
            probs = output[tgt_mask].cpu().numpy()
            preds = (output > 0.5).float()
            all_probs.extend(probs)
            all_preds.extend(preds[tgt_mask].cpu().numpy())
            all_targets.extend(tgt[tgt_mask].cpu().numpy())

        # Compute metrics
        accuracy = accuracy_score(all_targets, all_preds)
        f1 = f1_score(all_targets, all_preds)
        roc_auc = roc_auc_score(all_targets, all_probs) if len(np.unique(all_targets)) > 1 else np.nan
        avg_precision = average_precision_score(all_targets, all_probs)

    epoch_time= time.time() - start_time
    print(f"Epoch: {epoch}, timecost:{epoch_time:.2f} seconds, Loss: {epoch_loss}, Accuracy: {accuracy}, F1 Score: {f1}, ROC AUC: {roc_auc}, Average Precision Score: {avg_precision}")

    # # # # # # # ## # #
        # Early stopping check , find the best AUC 
     # # # # # # # # # # # # #
    
            
    if best_AUC is None or roc_auc > best_AUC + min_delta:
        best_AUC = roc_auc
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f'Early stopping on epoch {epoch}')
            break



    if best_AUC > best_model_AUC:
        best_model_AUC = best_AUC
        best_model = model
        torch.save(best_model.state_dict(), "./data/memory_transformer_KG_v2_0905.pth")
    print("best AUC:",best_model_AUC)

    
    
#     if best_f1 is None or f1 > best_f1 + min_delta:
#         best_f1 = f1
#         patience_counter = 0
#     else:
#         patience_counter += 1
#         if patience_counter >= patience:
#             print(f'Early stopping on epoch {epoch}')
#             break

            



Epoch: 0, timecost:201.07 seconds, Loss: 0.5571009640347002, Accuracy: 0.7339584279275215, F1 Score: 0.8293229022363893, ROC AUC: 0.7206060565306357, Average Precision Score: 0.8560682660220055
best AUC: 0.7206060565306357
Epoch: 1, timecost:396.31 seconds, Loss: 0.5374710150040929, Accuracy: 0.7329229186866422, F1 Score: 0.8203643335820736, ROC AUC: 0.7341519258328222, Average Precision Score: 0.8628554223176953
best AUC: 0.7341519258328222
Epoch: 2, timecost:588.64 seconds, Loss: 0.5317547624295119, Accuracy: 0.7390309653392809, F1 Score: 0.8232074265047301, ROC AUC: 0.7480338738271014, Average Precision Score: 0.8715178712349129
best AUC: 0.7480338738271014
Epoch: 3, timecost:780.47 seconds, Loss: 0.5252390283682226, Accuracy: 0.7393770506558612, F1 Score: 0.8247357848393027, ROC AUC: 0.743585941421181, Average Precision Score: 0.8680181645391933
best AUC: 0.7480338738271014
Epoch: 4, timecost:971.82 seconds, Loss: 0.5236478974981902, Accuracy: 0.7362893273246612, F1 Score: 0.819375

### Save model and data

In [20]:
import torch

# Assume you have a model instance called 'model'
torch.save(best_model.state_dict(), "./data/memory_transformer_KG_v2_0905.pth")


In [47]:
# with open(f'./data/src_data_KG.pkl', 'wb') as file1:
#     pickle.dump(src_data, file1)
    
# with open(f'./data/tgt_data_KG.pkl', 'wb') as file1:
#     pickle.dump(tgt_data, file1)

## release memory in GPU

In [20]:
import torch
import gc

# Function to check if an object is a tensor on the GPU
def is_tensor_on_gpu(obj):
    return torch.is_tensor(obj) and obj.is_cuda

# Inspect all objects and delete tensors on GPU
for obj in gc.get_objects():
    try:
        if is_tensor_on_gpu(obj):
            del obj
    except Exception as e:
        # Handle any exception that occurs during deletion
        print(f"Error while deleting object: {e}")

# Clear up all unused memory
torch.cuda.empty_cache()

# Optional: run garbage collector
gc.collect()




1401

In [48]:
# del src_data
# del tgt_data