In [1]:
from os import makedirs
from os.path import join
import logging
import numpy as np
import torch
import random
import sys

sys.path.append('../')
from torchsummary import summary
import torch as T
from torch import nn
from torch.utils.data import DataLoader
from transformers import DataCollatorForLanguageModeling, Trainer, TrainingArguments

from importlib import reload


In [2]:
from args import define_main_parser

#from dataset.prsa import PRSADataset
from data.card import TransactionDataset
from models.modules import TabFormerBertLM
from scripts.utils import random_split_dataset
from data.datacollator import TransDataCollatorForLanguageModeling
from models.lstm_classifier import LSTM
import models

In [3]:
from argparse import Namespace
config = vars(Namespace(cached=False, checkpoint=0, data_extension='', data_fname='card_transaction.v2', data_root='./data/credit_card/', data_type='card', do_eval=False, do_train=True, field_ce=True, field_hs=64, flatten=False, jid=1, lm_type='bert', log_dir='sam/logs', mlm=True, mlm_prob=0.15, nrows=None, num_train_epochs=3, output_dir='sam', save_steps=500, seed=9, skip_user=False, stride=5, user_ids=None, vocab_file='vocab.nb'))
config['data_root'] = "../dataset/credit_card/"
config['output_dir'] = "sample"
config['log_dir'] = "sample/logs"
makedirs(config['output_dir'], exist_ok=True)
makedirs(config['log_dir'], exist_ok=True)


In [4]:
seed = config['seed']
random.seed(seed)  # python
np.random.seed(seed)  # numpy
torch.manual_seed(seed)  # torch
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)  # torch.cuda

In [5]:
dataset = TransactionDataset(root=config['data_root'],
                            fname=config['data_fname'],
                            fextension="",
                            vocab_dir=config['output_dir'],
                            nrows=None,
                            user_ids=None,
                            seq_len=20,
                            mlm=config['mlm'],
                            cached=config['cached'],
                            stride=10,
                            flatten=config['flatten'],
                            return_labels=False,
                            skip_user=True)

100%|██████████| 7/7 [00:00<00:00, 354.04it/s]
  int)
100%|██████████| 1/1 [00:00<00:00,  1.45it/s]
100%|██████████| 1/1 [00:00<00:00,  8.88it/s]


In [6]:
vocab = dataset.vocab
custom_special_tokens = vocab.get_special_tokens()

totalN = len(dataset)
totalN = len(dataset)
trainN = int(0.6 * totalN)

valtestN = totalN - trainN
valN = int(valtestN * 0.5)
testN = valtestN - valN

In [7]:
lengths = [trainN, valN, testN]

In [8]:
print(f"# lengths: train [{trainN}]  valid [{valN}]  test [{testN}]")
print("# lengths: train [{:.2f}]  valid [{:.2f}]  test [{:.2f}]".format(trainN / totalN, valN / totalN,
                                                                               testN / totalN))

# lengths: train [599]  valid [200]  test [200]
# lengths: train [0.60]  valid [0.20]  test [0.20]


In [9]:
train_dataset, eval_dataset, test_dataset = random_split_dataset(dataset, lengths)

In [10]:
tab_net = TabFormerBertLM(custom_special_tokens,
                                  vocab=vocab,
                                  field_ce=config['field_ce'],
                                  flatten=config['flatten'],
                                  ncols=dataset.ncols,
                                  field_hidden_size=config['field_hs']
                                  )

In [11]:
collactor_cls = "TransDataCollatorForLanguageModeling"
data_collator = eval(collactor_cls)(
        tokenizer=tab_net.tokenizer, mlm=config['mlm'], mlm_probability=config['mlm_prob']
    )

In [12]:
model = tab_net.model
train_dataloader = DataLoader(
            train_dataset,
            batch_size=10,
            collate_fn=data_collator)
optim = torch.optim.AdamW(model.parameters(), lr=0.01)

In [13]:
for inps in train_dataloader:
    #print(inps.keys())
    #print(inps['input_ids'].shape)
    #print(inps['masked_lm_labels'].shape)
    model.train()
    outputs =model(**inps)
    #labels = inps.pop("labels")
    loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
    loss.backward()
    optim.step()
    break

In [14]:
model.save_pretrained('sample_model')

In [15]:
weight = torch.load('sample_model/pytorch_model.bin')
model.load_state_dict(weight)

<All keys matched successfully>

In [49]:
reload(models.lstm_classifier)

<module 'models.lstm_classifier' from '../models/lstm_classifier.py'>

In [66]:
class Classifier(nn.Module):
    def __init__(self, custom_special_tokens,
                 vocab,
                 field_ce,
                 flatten,
                 ncols,
                 field_hidden_size,
                 bert_feature_size
                 ):
        super(Classifier, self).__init__()
        bert_model = TabFormerBertLM(custom_special_tokens,
                                        vocab=vocab,
                                        field_ce=field_ce,
                                        flatten=flatten,
                                        ncols=ncols,
                                        field_hidden_size=field_hidden_size)

        bert_model = bert_model.model
        bert_model.load_state_dict(weight)
        self.field_transformer = bert_model.tab_embeddings
        self.bert = bert_model.tb_model


        #self.classifier = LSTM(emb_inp_size=bert_feature_size)
        self.classifier = models.lstm_classifier.ClinicalClassifier(batch_size=10, 
                                                                    hidden_dim=1062, 
                                                                    lstm_layers=3, 
                                                                    max_words=220)
    
    def forward(self, input_ids ,input_args):
        field_embeddings = self.field_transformer(input_ids)
        #print(f"Field Embedding shape - {field_embeddings.shape}")
        #input_args['input_ids'] = input_ids
        bert_features = self.bert(inputs_embeds=field_embeddings, **input_args)
        
        bert_features = bert_features[1]
        #print(f"Bert Features - {bert_features.shape}")
        #bert_features = bert_features.reshape((10, 20, 11, 1062))
        #print(f"Bert Features - {bert_features.shape}")
        #bert_features = bert_features.reshape((200, 11, 1062))
        #print(f"Bert Features - {bert_features.shape}")
        cls_out = self.classifier(bert_features)
        
        return cls_out
    

In [65]:
del classif
del Classifier

In [67]:
classif = Classifier(custom_special_tokens,
                 vocab=vocab,
                    field_ce=config['field_ce'],
                    flatten=config['flatten'],
                    ncols=dataset.ncols,
                    field_hidden_size=config['field_hs'],
                 bert_feature_size=1062)

In [68]:
#classif = ClassifierNew(batch_size=10, hidden_dim=1064, lstm_layers=3, max_words=220)

In [69]:
for inps in train_dataloader:
    input_ids = inps.pop('input_ids')
    #print(input_ids.shape)
    #print(inps['masked_lm_labels'].shape)
    #print(inps['masked_lm_labels'], )
    #print(inps.keys())
    
    class_out =classif(input_ids, inps)
    
    print(f"Class out - {class_out.shape}")
    #print(class_out)

Class out - torch.Size([10, 1])
Class out - torch.Size([10, 1])
Class out - torch.Size([10, 1])


KeyboardInterrupt: 

In [29]:
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence


In [24]:
aa = T.rand((10, 220, 1064))

In [26]:
aa = aa.float()

In [70]:
#aa

In [43]:
pad_r = pack_padded_sequence(aa, T.as_tensor([10]), batch_first=True, enforce_sorted=False)

In [44]:
pad_r.data.shape

torch.Size([10, 1064])

In [72]:
#pack_padded_sequence