In [1]:
import torch
import torch.nn as nn
# from transformers.modeling_bert import BertPreTrainedModel, BertModel, BertConfig
from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertModel, BertConfig
# from torchcrf import CRF
from TorchCRF import CRF
from module import IntentClassifier, SlotClassifier, IntentTokenClassifier, MultiIntentClassifier, TagIntentClassifier
import logging


import argparse
import random
from datetime import datetime
import time
import argparse
from utils import init_logger, load_tokenizer, read_prediction_text, set_seed, MODEL_CLASSES, MODEL_PATH_MAP, get_intent_labels, get_slot_labels
from seqeval.metrics.sequence_labeling import get_entities

# logger = logging.getLogger()
logging.basicConfig(filename='log.log', filemode='w', format='%(name)s - %(levelname)s - %(message)s')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class JointBERTMultiIntent(BertPreTrainedModel):
    # multi_intent: 1,
    # intent_seq: 1,
    # tag_intent: 1,
    # bi_tag: 1,
    # cls_token_cat: 1,
    # intent_attn: 1,
    # num_mask: 4
    def __init__(self, config, args, intent_label_lst, slot_label_lst):
        super().__init__(config)
        self.args = args
        self.max_seq_len = args.max_seq_len
        self.num_intent_labels = len(intent_label_lst)
        self.num_slot_labels = len(slot_label_lst)
        # load pretrain bert
        self.bert = BertModel(config=config)
        
        self.slot_label_lst = slot_label_lst
        self.intent_label_lst = intent_label_lst
        
        self.slot_label_map = {i: label for i, label in enumerate(self.slot_label_lst)}
        
        # Use cross entropy ignore index as padding label id so that only real label ids contribute to the loss later
        self.pad_token_label_id = args.ignore_index

        # self.intent_classifier = IntentClassifier(config.hidden_size, self.num_intent_labels, args.dropout_rate)
        self.multi_intent_classifier = MultiIntentClassifier(config.hidden_size, self.num_intent_labels, args.dropout_rate)
        self.slot_classifier = SlotClassifier(config.hidden_size, self.num_slot_labels, args.dropout_rate)
        if args.intent_seq:
            self.intent_token_classifier = IntentTokenClassifier(config.hidden_size, self.num_intent_labels, args.dropout_rate)
        
        if args.tag_intent:
            if args.cls_token_cat:
                self.tag_intent_classifier = TagIntentClassifier(2 * config.hidden_size, self.num_intent_labels, args.dropout_rate)
            else:
                self.tag_intent_classifier = TagIntentClassifier(config.hidden_size, self.num_intent_labels, args.dropout_rate)
        
        if args.use_crf:
            self.crf = CRF(num_tags=self.num_slot_labels, batch_first=True)
            
    def slot_pred(self, slot_preds, slot_logits):
        if slot_preds is None:
            if self.args.use_crf:
                # decode() in `torchcrf` returns list with best index directly
                slot_preds = np.array(self.model.crf.decode(slot_logits))
            else:
                slot_preds = slot_logits.detach().cpu().numpy()

        else:
            if self.args.use_crf:
                slot_preds = np.append(slot_preds, np.array(self.model.crf.decode(slot_logits)), axis=0)
            else:
                slot_preds = np.append(slot_preds, slot_logits.detach().cpu().numpy(), axis=0)

        if not self.args.use_crf:
            slot_preds = np.argmax(slot_preds, axis=2)
            
        return slot_preds

    
    def intent_pred(self,intent_preds,intent_logits):
        if intent_preds is None:
            intent_preds = intent_logits.detach().cpu().numpy()
        else:
            intent_preds = np.append(intent_preds, intent_logits.detach().cpu().numpy(), axis=0)
            
        return  intent_preds
    
    def intent_token_pred(self,intent_token_preds, intent_token_logits):
        if self.args.intent_seq:
            if intent_token_preds is None:
                if self.args.use_crf:
                    intent_token_preds = np.array(self.model.crf.decode(intent_token_logits))
                else:
                    intent_token_preds = intent_token_logits.detach().cpu().numpy()

            else:
                if self.args.use_crf:
                    intent_token_preds = np.append(intent_token_preds, np.array(self.model.crf.decode(intent_token_logits)), axis=0)
                else:
                    intent_token_preds = np.append(intent_token_preds, intent_token_logits.detach().cpu().numpy(), axis=0)
                
        return intent_token_preds
        

    def forward(self,
                input_ids,
                attention_mask,
                token_type_ids,
                B_tag_mask = None,
                BI_tag_mask = None):#,
                # B_tag_mask,
                # BI_tag_mask,
                # tag_intent_label):
        """
            Args: B: batch_size; L: sequence length; I: the number of intents; M: number of mask; D: the output dim of Bert
            input_ids: B * L
            token_type_ids: B * L
            token_type_ids: B * L
            intent_label_ids: B * I
            slot_labels_ids: B * L
            intent_token_ids: B * L
            B_tag_mask: B * M * L
            BI_tag_mask: B * M * L
            tag_intent_label: B * M
        """
        # input_ids:  torch.Size([32, 50])
        # attention_mask:  torch.Size([32, 50])
        # token_type_ids:  torch.Size([32, 50])
        # intent_label_ids:  torch.Size([32, 10])
        # slot_labels_ids:  torch.Size([32, 50])
        # intent_token_ids:  torch.Size([32, 50])
        # B_tag_mask:  torch.Size([32, 4, 50])
        # BI_tag_mask:  torch.Size([32, 4, 50])
        # tag_intent_label:  torch.Size([32, 4])
        
        # (len_seq, batch_size, hidden_dim), (batch_size, hidden_dim)
        outputs = self.bert(input_ids, attention_mask=attention_mask,
                            token_type_ids=token_type_ids)  # sequence_output, pooled_output, (hidden_states), (attentions)
        # B * L * D
        sequence_output = outputs[0]
        # B * D
        pooled_output = outputs[1]  # [CLS]


        #====================================== get logits for 3 classifiers ======================================
        # (batch_size, num_intents)
        intent_logits = self.multi_intent_classifier(pooled_output)
        
        # (batch_size, seq_len, num_slots)
        slot_logits = self.slot_classifier(sequence_output)

        # (batch_size, seq_len, num_intents)
        intent_token_logits = self.intent_token_classifier(sequence_output)
        
        # ====================================== generate mask ======================================

        # 1.-------------------------------------- get preds from logits for mask ------------------------------------------
        intent_preds = None
        intent_preds = self.intent_pred(intent_preds,intent_logits)
        # intent_preds = torch.as_tensor(intent_preds > 0.5, dtype=torch.int32)
        
        slot_preds = None
        slot_preds= self.slot_pred(slot_preds, slot_logits) # 64x35

        intent_token_preds = None
        intent_token_preds = self.intent_token_pred(intent_token_preds, intent_token_logits)

        # 2.-------------------------------------- get masks from preds or true labels(training) ------------------------------------------
        # Slot result
        # (batch_size, seq_len)
        out_slot_label_list = [[] for _ in range(slot_preds.shape[0])]
        slot_preds_list = [[] for _ in range(slot_preds.shape[0])]
        
        slot_label_map = self.slot_label_map
                
        #mask
        B_tag_mask_pred = []
        BI_tag_mask_pred = []
        
        out_slot_labels_ids = slot_preds
        for i in range(out_slot_labels_ids.shape[0]): # for all samples
            # record the padding position
            pos_offset = [0 for _ in range(out_slot_labels_ids.shape[1])] # in shape of max_seq_len
            pos_cnt = 0
            padding_recording = [0 for _ in range(out_slot_labels_ids.shape[1])] # in shape of max_seq_len
            
            for j in range(out_slot_labels_ids.shape[1]): # for all token in seq
                if out_slot_labels_ids[i, j] != self.pad_token_label_id:
                    out_slot_label_list[i].append(slot_label_map[out_slot_labels_ids[i][j]]) # append word label if not pad
                    slot_preds_list[i].append(slot_label_map[slot_preds[i][j]]) # append word label if not pad
                    pos_offset[pos_cnt+1] = pos_offset[pos_cnt]
                    pos_cnt += 1
                else:
                    pos_offset[pos_cnt] = pos_offset[pos_cnt] + 1
                    padding_recording[j] = 1
                    
                # print('pos_offset: ',pos_offset)
                # print('pos_cnt; ',pos_cnt)
                # print('padding_recording: ',padding_recording)
                    

            entities = get_entities(slot_preds_list[i])
            entities = [tag for entity_idx, tag in enumerate(entities) if slot_preds_list[i][tag[1]].startswith('B')]
            #print(entities)
            
            if len(entities) > self.args.num_mask:
                entities = entities[:self.args.num_mask]
            
            entity_masks = []
            
            
            for entity_idx, entity in enumerate(entities):
                entity_mask = [0 for _ in range(out_slot_labels_ids.shape[1])]
                start_idx = entity[1] + pos_offset[entity[1]]
                end_idx = entity[2] + pos_offset[entity[2]] + 1
                if self.args.BI_tag:
                    entity_mask[start_idx:end_idx] = [1] * (end_idx - start_idx)
                    for padding_idx in range(start_idx, end_idx):
                        
                        #print('len(padding_recording): ',len(padding_recording))
                        #print('padding_idx: ',padding_idx)
                        
                        if padding_recording[padding_idx]:
                            entity_mask[padding_idx] = 0
                else:
                    entity_mask[start_idx] = 1
                    
                entity_masks.append(entity_mask)
            
            for extra_idx in range(self.args.num_mask - len(entity_masks)):
                entity_masks.append([
                    0 for _ in range(out_slot_labels_ids.shape[1])
                ])

            #print('entity mask: ', entity_masks)
            if self.args.BI_tag:
                BI_tag_mask_pred.append(entity_masks)
            else:
                B_tag_mask_pred.append(entity_masks)
                
        if self.args.BI_tag:
            BI_tag_mask_pred_tensor = torch.FloatTensor(BI_tag_mask_pred)
        else:
            B_tag_mask_pred_tensor = torch.FloatTensor(B_tag_mask_pred)
        
        BI_tag_mask_pred_input = None
        B_tag_mask_pred_input = None
        
        #============================================================================
        
        # after softmax
        tag_intent_logits = self.tag_intent_classifier(tag_intent_vec)
        
        
        total_loss = 0  
        
        # ==================================== 1. Intent Softmax ========================================
        # (batch_size, num_intents)
        intent_logits = self.multi_intent_classifier(pooled_output)
        intent_logits_cpu = intent_logits.data.cpu().numpy()
        
        if intent_label_ids is not None:
            if self.num_intent_labels == 1:
                intent_loss_fct = nn.MSELoss()
                intent_loss = intent_loss_fct(intent_logits.view(-1), intent_label_ids.view(-1, self.num_intent_labels))
            else:
                # intent_loss_fct = nn.CrossEntropyLoss()
                # default reduction is mean
                intent_loss_fct = nn.BCELoss()
                intent_loss = intent_loss_fct(intent_logits.view(-1, self.num_intent_labels) + 1e-10, intent_label_ids.view(-1, self.num_intent_labels))
            # Question: do we need to add weight here
            total_loss += intent_loss
            
        if intent_label_ids.type() != torch.cuda.FloatTensor:
            intent_label_ids = intent_label_ids.type(torch.cuda.FloatTensor)
            
        # ==================================== 2. Slot Softmax ========================================
        # (batch_size, seq_len, num_slots)
        print('len sequence_output: ',sequence_output.size)
        slot_logits = self.slot_classifier(sequence_output)
        print('len slot_logits: ', slot_logits.size)
        
        if slot_labels_ids is not None:
            if self.args.use_crf:
                slot_loss = self.crf(slot_logits, slot_labels_ids, mask=attention_mask.byte(), reduction='mean')
                slot_loss = -1 * slot_loss  # negative log-likelihood
            else:
                slot_loss_fct = nn.CrossEntropyLoss(ignore_index=self.args.ignore_index)
                # Only keep active parts of the loss
                if attention_mask is not None:
                    try:
                        active_loss = attention_mask.view(-1) == 1
                        attention_mask_cpu = attention_mask.data.cpu().numpy()
                        active_loss_cpu = active_loss.data.cpu().numpy()
                        active_logits = slot_logits.view(-1, self.num_slot_labels)[active_loss]
                        active_labels = slot_labels_ids.view(-1)[active_loss]
                        slot_loss = slot_loss_fct(active_logits, active_labels)
                    except:
                        print('intent_logits: ', intent_logits_cpu)
                        print('attention_mask: ', attention_mask_cpu)
                        print('active_loss: ', active_loss_cpu)
                        logger.info('intent_logits: ', intent_logits_cpu)
                        logger.info('attention_mask: ', attention_mask_cpu)
                        logger.info('active_loss: ', active_loss_cpu)
                else:
                    slot_loss = slot_loss_fct(slot_logits.view(-1, self.num_slot_labels), slot_labels_ids.view(-1))

            total_loss += self.args.slot_loss_coef * slot_loss
        
        
        # ==================================== 3. Intent Token Softmax ========================================
        intent_token_loss = 0.0
        if self.args.intent_seq:
            # (batch_size, seq_len, num_intents)
            intent_token_logits = self.intent_token_classifier(sequence_output)

            if intent_token_ids is not None:
                if self.args.use_crf:
                    intent_token_loss = self.crf(intent_token_logits, intent_token_ids, mask=attention_mask.byte, reduction='mean')
                    intent_token_loss = -1 * intent_token_loss
                else:
                    intent_token_loss_fct = nn.CrossEntropyLoss(ignore_index=self.args.ignore_index)
                    if attention_mask is not None:
                        active_intent_loss = attention_mask.view(-1) == 1
                        active_intent_logits = intent_token_logits.view(-1, self.num_intent_labels)[active_intent_loss]
                        active_intent_tokens = intent_token_ids.view(-1)[active_intent_loss]
                        intent_token_loss = intent_token_loss_fct(active_intent_logits, active_intent_tokens)
                    else:
                        intent_token_loss = intent_token_loss_fct(intent_token_logits.view(-1, self.num_intent_labels), intent_token_ids.view(-1))
                
                total_loss += self.args.slot_loss_coef * intent_token_loss
        
        # convert the sequence_out to long
        if BI_tag_mask != None and  BI_tag_mask.type() != torch.cuda.FloatTensor:
            BI_tag_mask = BI_tag_mask.type(torch.cuda.FloatTensor)

        if B_tag_mask != None and B_tag_mask.type() != torch.cuda.FloatTensor:
            B_tag_mask = B_tag_mask.type(torch.cuda.FloatTensor)
        
        tag_intent_loss = 0.0
        if self.args.tag_intent:
            # B * M * D
            if self.args.BI_tag:
                tag_intent_vec = torch.einsum('bml,bld->bmd', BI_tag_mask, sequence_output)
            else:
                tag_intent_vec = torch.einsum('bml,bld->bmd', B_tag_mask, sequence_output)
            
            if self.args.cls_token_cat:
                cls_token = pooled_output.unsqueeze(1)
                # B * M * D
                cls_token = cls_token.repeat(1, self.args.num_mask, 1)
                # B * M * 2D
                tag_intent_vec = torch.cat((cls_token, tag_intent_vec), dim=2)
            
            tag_intent_vec = tag_intent_vec.view(tag_intent_vec.size(0) * tag_intent_vec.size(1), -1)
            
            # after softmax
            tag_intent_logits = self.tag_intent_classifier(tag_intent_vec)

            if self.args.intent_attn:
                # (batch_size, num_intent) => (batch_size * num_mask, num_intent) sigmoid [0, 1]
                intent_probs = intent_logits.unsqueeze(1)
                intent_probs = intent_probs.repeat(1, self.args.num_mask, 1)
                intent_probs = intent_probs.view(intent_probs.size(0) * intent_probs.size(1), -1)
                # (batch_size * num_mask, num_intent)
                tag_intent_logits = tag_intent_logits * intent_probs
                tag_intent_logits = tag_intent_logits.div(tag_intent_logits.sum(dim=1, keepdim=True))
            
            # tag_intent_loss_fct = nn.CrossEntropyLoss(ignore_index=self.args.ignore_index)

            # tag_intent_loss = tag_intent_loss_fct(tag_intent_logits, tag_intent_label.view(-1))

            nll_fct = nn.NLLLoss(ignore_index=self.args.ignore_index)
            
            tag_intent_loss = nll_fct(torch.log(tag_intent_logits + 1e-10), tag_intent_label.view(-1))
            
            total_loss += self.args.tag_intent_coef * tag_intent_loss
            
        if self.args.intent_seq and self.args.tag_intent:
            outputs = ((intent_logits, slot_logits, intent_token_logits, tag_intent_logits),) + outputs[2:]  # add hidden states and attention if they are here
        elif self.args.intent_seq:
            outputs = ((intent_logits, slot_logits, intent_token_logits),) + outputs[2:]
        elif self.args.tag_intent:
            outputs = ((intent_logits, slot_logits, tag_intent_logits),) + outputs[2:]
        else:
            outputs = ((intent_logits, slot_logits),) + outputs[2:]
        
        outputs = ([total_loss, intent_loss, slot_loss, intent_token_loss, tag_intent_loss],) + outputs

        return outputs  # (loss), logits, (hidden_states), (attentions) # Logits is a tuple of intent and slot logits
    

In [3]:
if __name__ == '__main__':
    time_wait = random.uniform(0, 10)
    time.sleep(time_wait)
    parser = argparse.ArgumentParser()


    parser.add_argument("--task", default='gpsr', type=str, help="The name of the task to train")

#     parser.add_argument("--model_dir", default='./gpsr_model', required=True, type=str, help="Path to save, load model")
    parser.add_argument("--model_dir", default='./test_model', type=str, help="Path to save, load model")

    parser.add_argument("--data_dir", default="./data", type=str, help="The input data dir")
    parser.add_argument("--intent_label_file", default="intent_label.txt", type=str, help="Intent Label file")
    parser.add_argument("--slot_label_file", default="slot_label.txt", type=str, help="Slot Label file")
    parser.add_argument("--model_type", default="multibert", type=str, help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
#     parser.add_argument("--intent_seq", type=int, default=0, help="whether we use intent seq setting")
    parser.add_argument("--intent_seq", type=int, default=1, help="whether we use intent seq setting")

    parser.add_argument("--multi_intent", type=int, default=1, help="whether we use multi intent setting")
    parser.add_argument("--tag_intent", type=int, default=1, help="whether we can use tag to predict intent")

    parser.add_argument("--BI_tag", type=int, default=1, help='use BI sum or just B')
    parser.add_argument("--cls_token_cat", type=int, default=1, help='whether we cat the cls to the slot output of bert')
    parser.add_argument("--intent_attn", type=int, default=1, help='whether we use attention mechanism on the CLS intent output')
    parser.add_argument("--num_mask", type=int, default=7, help="assumptive number of slot in one sentence")
                                           #max slot num = 7


    parser.add_argument('--seed', type=int, default=25, help="random seed for initialization")
    parser.add_argument("--train_batch_size", default=64, type=int, help="Batch size for training.")
#     parser.add_argument("--train_batch_size", default=64, type=int, help="Batch size for training.")

    parser.add_argument("--eval_batch_size", default=128, type=int, help="Batch size for evaluation.")
    parser.add_argument("--max_seq_len", default=35, type=int, help="The maximum total input sequence length after tokenization.")
    parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
#     parser.add_argument("--num_train_epochs", default=10.0, type=float, help="Total number of training epochs to perform.")
    parser.add_argument("--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.")
                                            #####

    parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
    parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm", default=1, type=float, help="Max gradient norm.")
    parser.add_argument("--max_steps", default=-1, type=int, help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
    parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
    parser.add_argument("--dropout_rate", default=0.1, type=float, help="Dropout for fully-connected layers")
    parser.add_argument('--logging_steps', type=int, default=500, help="Log every X updates steps.")
    parser.add_argument('--save_steps', type=int, default=300, help="Save checkpoint every X updates steps.")
    parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
    parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the test set.")
    parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
    parser.add_argument("--ignore_index", default=0, type=int,
                        help='Specifies a target value that is ignored and does not contribute to the input gradient')
    parser.add_argument('--slot_loss_coef', type=float, default=2.0, help='Coefficient for the slot loss.')
    parser.add_argument('--tag_intent_coef', type=float, default=1.0, help='Coefficient for the tag intent loss')

    # CRF option
    parser.add_argument("--use_crf", action="store_true", help="Whether to use CRF")
    parser.add_argument("--slot_pad_label", default="PAD", type=str, help="Pad token for slot label pad (to be ignore when calculate loss)")
    parser.add_argument("--patience", default=0, type=int, help="The initial learning rate for Adam.")

    parser.add_argument('-f')#########################

    args = parser.parse_args()

    args.model_name_or_path = MODEL_PATH_MAP[args.model_type]
    args.model_name_or_path = MODEL_PATH_MAP[args.model_type]

    tokenizer = load_tokenizer(args)


In [4]:
from data_loader import load_and_cache_examples, processors

train_dataset = load_and_cache_examples(args, tokenizer, mode="train")
dev_dataset = load_and_cache_examples(args, tokenizer, mode="dev")
test_dataset = load_and_cache_examples(args, tokenizer, mode="test")

In [5]:
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from tqdm import tqdm, trange
import numpy as np

config_class, model_class, _ = MODEL_CLASSES[args.model_type]

config = config_class.from_pretrained(args.model_name_or_path, finetuning_task=args.task)

# model = model_class.from_pretrained(args.model_name_or_path,
#                                               config=config,
#                                               args=args,
#                                               intent_label_lst=get_intent_labels(args),
#                                               slot_label_lst=get_slot_labels(args))

model = JointBERTMultiIntent(config, args, get_intent_labels(args), get_slot_labels(args))
model.to('cuda')

JointBERTMultiIntent(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affi

In [6]:
train_sampler = RandomSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
train_iterator = trange(int(args.num_train_epochs), desc="Epoch")

for _ in train_iterator:
    #epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=False )
    for step, batch in enumerate(train_dataloader):
        print('train:  ')
        model.train()
        batch = tuple(t.to('cuda') for t in batch)  # GPU or CPU

        inputs = {'input_ids': batch[0],
                  'attention_mask': batch[1]}#,
                  # 'intent_label_ids': batch[3],
                  # 'slot_labels_ids': batch[4],
                  # 'intent_token_ids': batch[5],
                  # 'B_tag_mask': batch[6],
                  # 'BI_tag_mask': batch[7],
                  # 'tag_intent_label': batch[8]}
        if args.model_type != 'distilbert':
            inputs['token_type_ids'] = batch[2]
            outputs = model(**inputs)

            losses = outputs[0]
            loss = losses[0]
            intent_loss = losses[1]
            slot_loss = losses[2]
            intent_token_loss = losses[3]
            tag_intent_loss = losses[4]
            #print(tag_intent_loss)
            break


Epoch:   0%|          | 0/3 [00:00<?, ?it/s]

train:  


Epoch:   0%|          | 0/3 [00:00<?, ?it/s]

(64, 35)





AttributeError: 'NoneType' object has no attribute 'shape'