In [14]:
!pip3 install -r requirements.txt

Collecting scikit-learn (from -r requirements.txt (line 3))
  Downloading scikit_learn-1.4.1.post1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting transformers (from -r requirements.txt (line 4))
  Downloading transformers-4.37.2-py3-none-any.whl.metadata (129 kB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m129.4/129.4 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
Collecting scipy>=1.6.0 (from scikit-learn->-r requirements.txt (line 3))
  Downloading scipy-1.12.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (60 kB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.4/60.4 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting joblib>=1.2.0 (from scikit-learn->-r requirements.txt (line 3))
  Using cached joblib-1.3.2-py3-none-any.whl.metadata (5.4 kB)
Collecting threadpoolctl>=2.0.0 (from scikit-learn->-r requirements.txt (line 3))
  Downloading threadpoolct

In [17]:
import json
import datetime
import os
import pickle
import h3
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from transformers import BertTokenizer, BertForSequenceClassification, get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, matthews_corrcoef
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.optim import AdamW
from torch.nn.parallel import DistributedDataParallel


In [22]:

def truncation_rows(df, nb_rows):
    return df[:nb_rows]

def add_tokenization_column(df, config):
    """Add a column with the tokenization of the POLYLINE column
    /!\ in that case, for the json file, the trajectories are in the form given by the kaggle dataset /!\ 
    /!\ ie longitude, latitude instead of latitude, longitude.                                        /!\ 
    /!\ This is why we have to reverse the order of the coordinates for the tokenization              /!\ """
    
    df['Tokenization_2'] = df['POLYLINE'].apply(lambda x: [h3.geo_to_h3(x[i][1], x[i][0], config) for i in range(len(x))])
    return df

def extract_time_info(df):
    """Add columns with the day, hour and week of the year knowing the timestamp"""
    df['DATE'] = df['TIMESTAMP'].apply(lambda x: datetime.datetime.fromtimestamp(x).strftime('%Y-%m-%d %H:%M:%S'))
    df['DAY'] = df['DATE'].apply(lambda x: str(datetime.datetime.strptime(x.split(' ')[0],'%Y-%m-%d').isocalendar()[2]))
    df['HOUR'] = df['DATE'].apply(lambda x: x.split(' ')[1].split(':')[0])
    df['WEEK'] = df['DATE'].apply(lambda x: str(datetime.datetime.strptime(x.split(' ')[0],'%Y-%m-%d').isocalendar()[1]))
    return df

def formatting_to_str(df, column):
    """Transform the column to string type"""
    if isinstance(df[column][0], str):
        return df
    df[column] = df[column].astype(str)
    return df

def call_type_to_nb(df):
    """Transform the column CALL_TYPE to a number"""
    df['CALL_TYPE'] = df['CALL_TYPE'].apply(lambda x: 0 if x == 'A' else (1 if x == 'B' else 2))
    return df


def add_geo_and_context_tokens_tokenizer(tokenizer, data_format):
    # Add geo tokens to the tokenizer
    liste_token_geo = {token for sublist in data_format['Tokenization_2'] for token in sublist}
    nb_token_geo = len(liste_token_geo)
    tokenizer.add_tokens(list(liste_token_geo))  # Convert set to list and add to tokenizer

    # Add contextual info tokens to the tokenizer
    contextual_info_token = {str(data_format['CALL_TYPE'][i])
                             for i in range(len(data_format))}
    contextual_info_token.update(str(data_format['TAXI_ID'][i])
                                 for i in range(len(data_format)))
    contextual_info_token.update(data_format['DAY'][i]
                                 for i in range(len(data_format)))
    contextual_info_token.update(data_format['HOUR'][i]
                                 for i in range(len(data_format)))
    contextual_info_token.update(data_format['WEEK'][i]
                                 for i in range(len(data_format)))

    tokenizer.add_tokens(list(contextual_info_token))  # Convert set to list and add to tokenizer

    return tokenizer, nb_token_geo

def add_spaces_for_concat(data_format, column):
    """Add spaces before and after the values of the column""" 

    #We add space before and after the values of the column because we want to separate the tokens (words) with spaces like that : [CLS] 0 1 2 3 4 5 6 7 8 9 10 [SEP]
    data_format[column]=data_format[column].apply(lambda x: ' '+x)
    return data_format

def get_deb_traj(data_format, len_context_info):
    """the DEB_TRAJ column will be the tokenization column without the last token and the target token"""
    data_format['DEB_TRAJ']=data_format['Tokenization_2'].apply(lambda x: x[:-2])

    # we manage the length of the CONTEXT_INPUT column so that after the concatenation, it does not exceed 512 tokens
    # the -2 corresponds to the two special tokens [CLS] and [SEP]
    # for exemple here, if the trajectory input is too long, we keep the 512-6-2=504 last tokens
    data_format['DEB_TRAJ']=data_format['DEB_TRAJ'].apply(lambda x: x[-(512-len_context_info-2):] if len(x)>512-len_context_info-2 else x)    

    #then we keep the column in form of a string with spaces between the tokens (the space replaces the comma)
    data_format['DEB_TRAJ']=data_format['DEB_TRAJ'].apply(lambda x: ' '.join(x))

    return data_format



def get_deb_traj_and_target(data_format):
    """Get the DEB_TRAJ and TARGET columns well formatted but without the special tokens [CLS] and [SEP]"""

    #adding spaces for the concatenation after : we want to sperarate the tokens (words) with spaces
    data_format = add_spaces_for_concat(data_format, 'HOUR')
    data_format = add_spaces_for_concat(data_format, 'WEEK')
    data_format = add_spaces_for_concat(data_format, 'CALL_TYPE')
    data_format = add_spaces_for_concat(data_format, 'TAXI_ID')
    data_format = add_spaces_for_concat(data_format, 'DAY')

    #the column CONTEXT_INPUT will be the concatenation of the last token of the tokenization column + the day + the hour + the week + the call type + the taxi id
    data_format['CONTEXT_INPUT'] =data_format['Tokenization_2'].apply(lambda x: x[-1]) + data_format['DAY'] + data_format['HOUR'] + data_format['WEEK'] + data_format['CALL_TYPE'] + data_format['TAXI_ID']
    #we get the length of the containing information of the CONTEXT_INPUT column
    len_context_info = len(data_format['CONTEXT_INPUT'][0].split(' '))

    #we get the DEB_TRAJ column
    data_format=get_deb_traj(data_format, len_context_info)
    
    #we get the TARGET column
    data_format['TARGET']=data_format['Tokenization_2'].apply(lambda x: x[-2])

    return data_format


In [23]:


def formatting_to_train(data_format, tokenizer):
    """
    Format the data to train the model : 
    ------------------------------------

    1) format the input

        a) get the full_inputs
    - we concatenate the context input and the beginning of the trajectory which is the sequence we want to give to the model 
    - at the beginning, we add the CLS token and the end of the input the SEP token

        b) get the input_ids
    - we use the tokenizer to get the ids of the tokens that will be the input_ids thatthe model will take as input
    - we pad the input to the maximum length of 512

    2) and we create the attention masks

    - the attention mask is a list of 0 and 1, 0 for the padded tokens and 1 for the other tokens

    """
    
    #we remove the useless columns
    if 'Tokenization' in data_format.columns:
        data_format.drop(['Tokenization'],axis=1,inplace=True)
    if 'CALL_TYPE' in data_format.columns:
        data_format.drop(['CALL_TYPE'],axis=1,inplace=True)
    if 'TAXI_ID' in data_format.columns:
        data_format.drop(['TAXI_ID'],axis=1,inplace=True)
    if 'DAY' in data_format.columns:
        data_format.drop(['DAY'],axis=1,inplace=True)
    if 'HOUR' in data_format.columns:
        data_format.drop(['HOUR'],axis=1,inplace=True)
    if 'WEEK' in data_format.columns:
        data_format.drop(['WEEK'],axis=1,inplace=True)
    if 'Nb_points_token' in data_format.columns:
        data_format.drop(['Nb_points_token'],axis=1,inplace=True)


    #we get the columns CONTEXT_INPUT, DEB_TRAJ and TARGET
    c_inputs=data_format.CONTEXT_INPUT.values
    traj_inputs=data_format.DEB_TRAJ.values
    targets=data_format.TARGET.values

    print("concaténation des inputs, padding etc")

    #we create the input_ids, the attention_masks and the full_inputs
    input_ids = []
    full_inputs = []
    attention_masks = []
    for i in tqdm(range(len(c_inputs))):
        #no truncation is needed because we managed it before

        #we concatenate the context input and the trajectory input adding manually the CLS token and the SEP token
        full_input = '[CLS] ' + c_inputs[i] + ' ' + traj_inputs[i] + ' [SEP]'
        full_inputs.append(full_input)

        # we use the tokenizer to get the ids of the tokens that will be the input_ids that the model will take as input
        # the format of the input_ids would be : [101] + encoded_c_input + encoded_traj_input + [102]
        #the[101] token is the CLS token and the [102] token is the SEP token
        # TODO : test adding an additional SEP token between the context input and the trajectory input so that the format of the input_ids would be : [101] + encoded_c_input + [102] + encoded_traj_input + [102]
        encoded_full_input=tokenizer.encode(full_input, add_special_tokens=False)

        #we pad the input to the maximum length of 512
        encoded_full_input=encoded_full_input + [0]*(512-len(encoded_full_input))
        #we add the input_ids to the list
        input_ids.append(encoded_full_input)

        #we create the attention mask
        att_mask = [float(i>0) for i in encoded_full_input]
        #we add the attention mask to the list
        attention_masks.append(att_mask)

    return input_ids, attention_masks, targets, full_inputs



In [24]:

WORLD_S=2 #The world size is the number of processes we want to use
h3_config_size = 10
nb_rows = 60


In [25]:
#we load the data
with open('/home/daril_kw/data/02.06.23/train_clean.json', 'r') as openfile:
    json_loaded = json.load(openfile)

In [28]:
json_loaded


{'TRIP_ID': {'0': 1372636858620000589,
  '1': 1372637303620000596,
  '2': 1372636951620000320,
  '3': 1372636854620000520,
  '4': 1372637091620000337,
  '5': 1372636965620000231,
  '6': 1372637210620000456,
  '7': 1372637299620000011,
  '8': 1372637274620000403,
  '9': 1372637905620000320,
  '10': 1372636875620000233,
  '11': 1372637984620000520,
  '12': 1372637343620000571,
  '13': 1372638595620000233,
  '14': 1372638151620000231,
  '15': 1372637610620000497,
  '16': 1372638481620000403,
  '17': 1372639135620000570,
  '18': 1372637482620000005,
  '19': 1372639181620000089,
  '20': 1372638161620000423,
  '21': 1372637254620000657,
  '22': 1372638502620000320,
  '23': 1372639960620000309,
  '24': 1372637658620000596,
  '25': 1372639092620000233,
  '26': 1372639535620000161,
  '27': 1372640499620000596,
  '28': 1372639635620000178,
  '29': 1372640555620000235,
  '30': 1372639871620000653,
  '31': 1372639875620000009,
  '32': 1372637453620000648,
  '33': 1372640399620000320,
  '34': 13726

In [27]:
data_format = pd.DataFrame(data=json_loaded)

#we keep only nb_rows rows
data_format = truncation_rows(data_format, nb_rows)

In [29]:
data_format.info()

<class 'pandas.core.frame.DataFrame'>
Index: 60 entries, 0 to 60
Data columns (total 12 columns):
 #   Column           Non-Null Count  Dtype  
---  ------           --------------  -----  
 0   TRIP_ID          60 non-null     int64  
 1   CALL_TYPE        60 non-null     object 
 2   ORIGIN_CALL      9 non-null      float64
 3   ORIGIN_STAND     13 non-null     float64
 4   TAXI_ID          60 non-null     int64  
 5   TIMESTAMP        60 non-null     int64  
 6   DAY_TYPE         60 non-null     object 
 7   MISSING_DATA     60 non-null     bool   
 8   POLYLINE         60 non-null     object 
 9   Tokenization     60 non-null     object 
 10  Nb_points        60 non-null     int64  
 11  Nb_points_token  60 non-null     int64  
dtypes: bool(1), float64(2), int64(5), object(4)
memory usage: 5.7+ KB


In [30]:
data_format

Unnamed: 0,TRIP_ID,CALL_TYPE,ORIGIN_CALL,ORIGIN_STAND,TAXI_ID,TIMESTAMP,DAY_TYPE,MISSING_DATA,POLYLINE,Tokenization,Nb_points,Nb_points_token
0,1372636858620000589,C,,,20000589,1372636858,A,False,"[[-8.618643, 41.141412], [-8.618499, 41.141376...","[8a7b63adb347fff, 8a7b63adb35ffff, 8a7b63adbad...",23,23
1,1372637303620000596,B,,7.0,20000596,1372637303,A,False,"[[-8.639847, 41.159826], [-8.640351, 41.159871...","[8a7b63ad8047fff, 8a7b63ad8047fff, 8a7b63ad814...",19,19
2,1372636951620000320,C,,,20000320,1372636951,A,False,"[[-8.612964, 41.140359], [-8.613378, 41.14035]...","[8a7b63adb2cffff, 8a7b63adb2cffff, 8a7b63adb2e...",65,65
3,1372636854620000520,C,,,20000520,1372636854,A,False,"[[-8.574678, 41.151951], [-8.574705, 41.151942...","[8a7b63370727fff, 8a7b63370727fff, 8a7b6337072...",43,43
4,1372637091620000337,C,,,20000337,1372637091,A,False,"[[-8.645994, 41.18049], [-8.645949, 41.180517]...","[8a7b63adea37fff, 8a7b63adea37fff, 8a7b63adea0...",29,29
5,1372636965620000231,C,,,20000231,1372636965,A,False,"[[-8.615502, 41.140674], [-8.614854, 41.140926...","[8a7b63adb257fff, 8a7b63adb257fff, 8a7b63adb2c...",26,26
6,1372637210620000456,C,,,20000456,1372637210,A,False,"[[-8.57952, 41.145948], [-8.580942, 41.145039]...","[8a7b63370337fff, 8a7b63370a8ffff, 8a7b63370a1...",36,36
7,1372637299620000011,C,,,20000011,1372637299,A,False,"[[-8.617563, 41.146182], [-8.617527, 41.145849...","[8a7b63adb3a7fff, 8a7b63adb3affff, 8a7b63adb23...",34,34
8,1372637274620000403,C,,,20000403,1372637274,A,False,"[[-8.611794, 41.140557], [-8.611785, 41.140575...","[8a7b63375967fff, 8a7b63375967fff, 8a7b6337596...",38,38
9,1372637905620000320,C,,,20000320,1372637905,A,False,"[[-8.615907, 41.140557], [-8.614449, 41.141088...","[8a7b63adb257fff, 8a7b63adb2effff, 8a7b63adb2c...",19,19


In [31]:
#we count the number of rows for which teh column NB_POINTS is equal to 0 : there are 0 rows
#>>>print("nombre de lignes pour lesquelles le nombre de points est inférieur à 3 : ", len(data_format[data_format['Nb_points']<3]))
#   nombre de lignes pour lesquelles le nombre de points est inférieur à 3 :  0


#we add the tokenization column
data_format = add_tokenization_column(data_format, h3_config_size)

In [32]:
#we add the time info columns
data_format = extract_time_info(data_format)

In [33]:

#we remove the useless columns
data_format = data_format.drop(['MISSING_DATA','DATE','ORIGIN_CALL', 'DAY_TYPE', 'ORIGIN_CALL', 'ORIGIN_STAND', 'Nb_points', 'TIMESTAMP'], axis=1)

#we transform the columns TAXI_ID which was a number to string type
data_format = formatting_to_str(data_format, 'TAXI_ID')

#we transform the column CALL_TYPE to a number instead of a letter
data_format = call_type_to_nb(data_format)

#we transform the column CALL_TYPE which was a number to string type
data_format = formatting_to_str(data_format, 'CALL_TYPE')


  if isinstance(df[column][0], str):
  if isinstance(df[column][0], str):


In [35]:

#we remove the useless columns
data_format = data_format.drop(['MISSING_DATA','DATE','ORIGIN_CALL', 'DAY_TYPE', 'ORIGIN_CALL', 'ORIGIN_STAND', 'Nb_points', 'TIMESTAMP'], axis=1, errors='ignore') 
# to avoid the error while dropping, we can add the argument errors='ignore'

#we transform the columns TAXI_ID which was a number to string type
data_format = formatting_to_str(data_format, 'TAXI_ID')

#we transform the column CALL_TYPE to a number instead of a letter
data_format = call_type_to_nb(data_format)

#we transform the column CALL_TYPE which was a number to string type
data_format = formatting_to_str(data_format, 'CALL_TYPE')


  if isinstance(df[column][0], str):
  if isinstance(df[column][0], str):


In [36]:

#we get the tokenizer from the HuggingFace library, this one is the tokenizer of the model bert-base-cased but we could have taken the non trained tokenizer (TODO :test it with the non trained weights of the model as well)
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
#we add the geographical and contextual tokens to the tokenizer so that the vocabulary of the tokenizer is adapted to our data
tokenizer, nb_token_geo = add_geo_and_context_tokens_tokenizer(tokenizer, data_format)
#we get the number of labels which is the number of geographical tokens + 1 (the +1 is for the [SEP] token which is for the end of the sequence and the prediction)
nb_labels = nb_token_geo + 1

#we get the model from the HuggingFace library, this one is the model bert-base-cased but we could have taken the non trained model (if we want to train it from scratch)
model = BertForSequenceClassification.from_pretrained("bert-base-cased", num_labels=nb_labels)

tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]

  contextual_info_token = {str(data_format['CALL_TYPE'][i])
  contextual_info_token.update(str(data_format['TAXI_ID'][i])
  contextual_info_token.update(data_format['DAY'][i]
  contextual_info_token.update(data_format['HOUR'][i]
  contextual_info_token.update(data_format['WEEK'][i]
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [40]:
len(tokenizer)


29921

In [41]:
# the model embedding is get by the next line
model.get_input_embeddings() # the output is : Embedding(28996, 768) 28996 is the number of tokens in the vocabulary and 768 is the dimension of the embeddings

Embedding(28996, 768, padding_idx=0)

In [42]:

#we add the geographical and contextual tokens to the model so that the size of the model`s embedding is adapted to our data
model.resize_token_embeddings(len(tokenizer))

Embedding(29921, 768)

In [43]:


#save the model, the tokenizer and the data in different files
model.save_pretrained(f"/home/daril_kw/data/savings_for_60_rows/model_before_training_opti_full_for_para_60")
data_format.to_json(f"/home/daril_kw/data/savings_for_60_rows/data_with_time_info_ok_opti_full_for_para_60.json")
tokenizer.save_pretrained(f"/home/daril_kw/data/savings_for_60_rows/tokenizer_final_opti_full_for_para_60")


('/home/daril_kw/data/savings_for_60_rows/tokenizer_final_opti_full_for_para_60/tokenizer_config.json',
 '/home/daril_kw/data/savings_for_60_rows/tokenizer_final_opti_full_for_para_60/special_tokens_map.json',
 '/home/daril_kw/data/savings_for_60_rows/tokenizer_final_opti_full_for_para_60/vocab.txt',
 '/home/daril_kw/data/savings_for_60_rows/tokenizer_final_opti_full_for_para_60/added_tokens.json')

In [44]:

#we get the DEB_TRAJ and TARGET columns well formatted but without the special tokens [CLS] and [SEP]
#this is because we will add them later
data_format = get_deb_traj_and_target(data_format)

#we get the input_ids, the attention_masks, the targets and the full_inputs
input_ids, attention_masks, targets, full_inputs = formatting_to_train(data_format, tokenizer)


  len_context_info = len(data_format['CONTEXT_INPUT'][0].split(' '))


concaténation des inputs, padding etc


100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:00<00:00, 297.32it/s]


In [49]:
#size of input_ids
len(input_ids)
input_ids

[[101,
  29048,
  128,
  1406,
  1744,
  123,
  29899,
  29803,
  29803,
  29309,
  29877,
  29238,
  29864,
  29753,
  29007,
  29292,
  29871,
  29037,
  29122,
  29122,
  29122,
  29788,
  29788,
  29089,
  29859,
  29048,
  29048,
  29048,
  102,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,

In [116]:
attention_masks

[[1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  1.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,

In [None]:

#save the lists full_inputs, inputs_ids, attention_masks and the targets in different files
with open(f"/home/daril_kw/data/savings_for_60_rows/input_ids_full_opti_for_para_60.pkl", 'wb') as fp:
    pickle.dump(input_ids, fp)
with open(f"/home/daril_kw/data/savings_for_60_rows/attention_masks_full_opti_for_para_60.pkl", 'wb') as fp:
    pickle.dump(attention_masks, fp)
with open(f"/home/daril_kw/data/savings_for_60_rows/targets_full_opti_for_para_60.pkl", 'wb') as fp:
    pickle.dump(targets, fp)
with open(f"/home/daril_kw/data/savings_for_60_rows/full_inputs_full_opti_for_para_60.pkl", 'wb') as fp:
    pickle.dump(full_inputs, fp)



