In [None]:
import time
import torch
from www.utils import format_time
import numpy as np
from transformers import RobertaForMultipleChoice
import progressbar
from www.model.eval import evaluate_tiered
from sklearn.metrics import accuracy_score, f1_score



# Train a PyTorch model for one epoch
def ComputeLoss(out1,out2):
    loss_weights=[0.0, 0.4, 0.4, 0.2, 0.0]
    total_loss = 0.0
    total_loss += loss_weights[1] * out1['loss_preconditions'] / 20
    total_loss += loss_weights[2] * out1['loss_effects']  / 20
    total_loss += loss_weights[3] * out2['loss_conflicts']
    total_loss += loss_weights[4] * out2['loss_stories']

    return total_loss
def train_epoch(model,
                optimizer,
                train_dataloader,
                device,
                list_output=False,
                num_outputs=1,
                span_mode=False,
                seg_mode=False,
                classifier=None,
                multitask_idx=None):
    t0 = time.time()

    if not list_output:
        total_loss = 0
    else:
        total_loss = [0 for _ in range(num_outputs)]

    # Training mode
    model.train()

    if len(train_dataloader) * train_dataloader.batch_size >= 2500:
        progress_update = True
    else:
        progress_update = False

    for step, batch in enumerate(train_dataloader):
        # Progress update
        if progress_update and step % 50 == 0 and not step == 0:
            elapsed = format_time(time.time() - t0)
            print('\t(%s) Starting batch %s of %s.' %
                  (elapsed, str(step), str(len(train_dataloader))))

        input_ids = batch[0].to(device)
        input_mask = batch[1].to(device)
        labels = batch[2].to(device)

        # if input_ids.dim() > 2:
        #   input_ids = input_ids.view(input_ids.shape[0], -1)
        #   input_mask = input_mask.view(input_mask.shape[0], -1)

        # In some cases, we also include a span for each training sequence which the model uses to classify only certain parts of the input
        if span_mode:
            spans = batch[3].to(device)
        elif seg_mode:
            segment_ids = batch[3].to(device)
        else:
            spans = None

        # Forward pass
        model.zero_grad()
        if multitask_idx == None:
            if span_mode:
                out = model(input_ids,
                            token_type_ids=None,
                            attention_mask=input_mask,
                            labels=labels,
                            spans=spans)
            elif seg_mode:
                out = model(input_ids,
                            token_type_ids=segment_ids,
                            attention_mask=input_mask,
                            labels=labels)
            else:
                out = model(input_ids,
                            token_type_ids=None,
                            attention_mask=input_mask,
                            labels=labels)
        else:
            if span_mode:
                out = model(input_ids,
                            token_type_ids=None,
                            attention_mask=input_mask,
                            labels=labels,
                            spans=spans,
                            task_idx=multitask_idx)
            elif seg_mode:
                out = model(input_ids,
                            token_type_ids=segment_ids,
                            attention_mask=input_mask,
                            labels=labels,
                            task_idx=multitask_idx)
            else:
                out = model(input_ids,
                            token_type_ids=None,
                            attention_mask=input_mask,
                            labels=labels,
                            task_idx=multitask_idx)

        if classifier != None:
            sequence_output = out[0]
            logits = classifier(out)

            loss = None
            if labels is not None:
                if self.num_labels == 1:
                    #  We are doing regression
                    loss_fct = MSELoss()
                    loss = loss_fct(logits.view(-1), labels.view(-1))
                elif self.num_labels == 2:
                    loss_fct = CrossEntropyLoss()
                    loss = loss_fct(logits.view(-1, self.num_labels),
                                    labels.view(-1))

        else:
            loss = out[0]

        # Backward pass
        if not list_output:
            total_loss += loss.item()
            loss.backward()
        else:
            for o in range(num_outputs):
                total_loss[o] += loss[o].item()
                loss[o].backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                       1.0)  # Gradient clipping

        optimizer.step()

    if list_output:
        return list(np.array(total_loss) / len(train_dataloader)), model
    else:
        return total_loss / len(train_dataloader), model


# Train a state classification pipeline for one epoch
def train_epoch_tiered(model1,
                       model2,
                       optimizer,
                       train_dataloader,
                       device,
                       seg_mode=False,
                       return_losses=False,
                       build_learning_curves=False,
                       val_dataloader=None,
                       train_lc_data=None,
                       val_lc_data=None):
    t0 = time.time()

    total_loss = 0

    # Training mode
    model1.train()
    model2.train()
    for layer in model1.precondition_classifiers:
        layer.train()
    for layer in model1.effect_classifiers:
        layer.train()

    # if len(train_dataloader) * train_dataloader.batch_size >= 2500:
    #   progress_update = True
    # else:
    #   progress_update = False
    progress_update = False

    bar_size = len(train_dataloader)
    bar = progressbar.ProgressBar(max_value=bar_size,
                                  widgets=[
                                      progressbar.Bar('#', '[', ']'), ' ',
                                      progressbar.Percentage()
                                  ])
    bar_idx = 0
    bar.start()

    if train_lc_data is not None:
        train_lc_data.append([])
    if val_lc_data is not None:
        val_lc_data.append([])

    for step, batch in enumerate(train_dataloader):
        # Progress update
        if progress_update and step % 50 == 0 and not step == 0:
            elapsed = format_time(time.time() - t0)
            print('\t(%s) Starting batch %s of %s.' %
                  (elapsed, str(step), str(len(train_dataloader))))

        input_ids = batch[0].long().to(device)
        input_lengths = batch[1].to(device)  #.to(torch.int64).to('cpu')
        input_entities = batch[2].to(device)
        input_mask = batch[3].to(device)
        attributes = batch[4].long().to(device)
        preconditions = batch[5].long().to(device)
        effects = batch[6].long().to(device)
        conflicts = batch[7].long().to(device)
        labels = batch[8].long().to(device)

        if seg_mode:
            segment_ids = batch[8].to(device)
        else:
            segment_ids = None

        # Forward pass
        model.zero_grad()
        out_1 = classModel(input_ids, 
                    input_lengths,
                    input_entities,
                    attention_mask=input_mask,
                    token_type_ids=segment_ids,
                    attributes=attributes,
                    preconditions=preconditions,
                    effects=effects,
                    training=True)

        out_preconditions_softmax=out_1['out_preconditions_softmax']
        out_effects_softmax=out_1['out_effects_softmax']
        outcls=out_1['out']
        out_2 = conflictModel(input_ids, 
                    input_lengths,
                    input_entities,
                    out=outcls,
                    attention_mask=input_mask,
                    token_type_ids=segment_ids,
                    attributes=attributes,
                    out_preconditions_softmax=out_preconditions_softmax,
                    out_effects_softmax=out_effects_softmax,
                    conflicts=conflicts,
                    labels=labels,
                    training=True)

        out={}
        for k in out_1:
            out[k]=out_1[k]
        for k in out_2:
            out[k]=out_2[k]
        out['total_loss']=ComputeLoss(out_1,out_2)
        
        loss = out['total_loss']

        # Backward pass
        total_loss += loss.item()
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                       1.0)  # Gradient clipping

        optimizer.step()

        # Build learning curve data if needed
        if build_learning_curves:
            train_record = {
                'epoch':
                len(train_lc_data) - 1,
                'iteration':
                (len(train_lc_data) - 1) * len(train_dataloader) + step,
                'loss_preconditions':
                float(out['loss_preconditions'].detach().cpu().numpy()) /
                model.num_attributes,
                'loss_effects':
                float(out['loss_effects'].detach().cpu().numpy()) /
                model.num_attributes,
                'loss_conflicts':
                float(out['loss_conflicts'].detach().cpu().numpy()),
                'loss_stories':
                float(out['loss_stories'].detach().cpu().numpy()),
                'loss_total':
                float(out['total_loss'].detach().cpu().numpy())
            }
            train_lc_data[-1].append(train_record)

            # Add a validation record 5 times per epoch
            chunk_size = len(train_dataloader) // 5
            if (len(train_dataloader) - step - 1) % chunk_size == 0:
                validation_results = evaluate_tiered(
                    model,
                    val_dataloader,
                    device, [(accuracy_score, 'accuracy'), (f1_score, 'f1')],
                    seg_mode=False,
                    return_explanations=True,
                    return_losses=True,
                    verbose=False)
                out = validation_results[16]

                val_record = {
                    'epoch':
                    len(val_lc_data) - 1,
                    'iteration':
                    (len(val_lc_data) - 1) * len(train_dataloader) + step,
                    'loss_preconditions':
                    float(out['loss_preconditions'].detach().cpu().numpy()) /
                    model.num_attributes,
                    'loss_effects':
                    float(out['loss_effects'].detach().cpu().numpy()) /
                    model.num_attributes,
                    'loss_conflicts':
                    float(out['loss_conflicts'].detach().cpu().numpy()),
                    'loss_stories':
                    float(out['loss_stories'].detach().cpu().numpy()),
                    'loss_total':
                    float(out['total_loss'].detach().cpu().numpy())
                }
                val_lc_data[-1].append(val_record)

        bar_idx += 1
        bar.update(bar_idx)

    bar.finish()

    return total_loss / len(train_dataloader), model

In [None]:
def evaluate_tiered(MaxStoryLength,tslm_model,trip_model, eval_dataloader, device, metrics, seg_mode=False, return_softmax=False, return_explanations=False, return_losses=False, verbose=True):
  if verbose:
    print('\tBeginning evaluation...')

  t0 = time.time()

  tslm_model.zero_grad()
  trip_model.zero_grad()
  tslm_model.eval()
  trip_model.eval()
  for layer in trip_model.precondition_classifiers:
    layer.eval()
  for layer in trip_model.effect_classifiers:
    layer.eval()    

  all_pred_attributes = None
  all_attributes = None

  all_pred_prec = None
  all_prec = None

  all_pred_eff = None
  all_eff = None

  all_pred_conflicts = None
  all_conflicts = None

  all_pred_stories = None
  all_stories = None  
  if return_softmax:
    all_prob_stories = None
  
  if verbose:
    print('\t\tRunning prediction...')

  if verbose:
    bar_size = len(eval_dataloader)
    bar = progressbar.ProgressBar(max_value=bar_size, widgets=[progressbar.Bar('#', '[', ']'), ' ', progressbar.Percentage()])
    bar_idx = 0
    bar.start()

  # Aggregate losses
  agg_losses = {}

  # Get preds from model
  for batch in eval_dataloader:
    # Move to GPU
    batch = tuple(t.to(device) for t in batch)

    input_ids = batch[0].long().to(device)
    input_lengths = batch[1].to(device)
    input_entities = batch[2].to(device)
    input_mask = batch[3].to(device)
    attributes = batch[4].long().to(device)
    preconditions = batch[5].long().to(device)
    effects = batch[6].long().to(device)
    conflicts = batch[7].long().to(device)
    labels = batch[8].long().to(device)
    timestep_type_ids=batch[9].long().to(device)
    if seg_mode:
      segment_ids = batch[9].to(device)
    else:
      segment_ids = None

    batch_size, num_stories, num_entities, num_sents, seq_length = input_ids.shape

    with torch.no_grad():
            prec_result,effect_result,prec_pred,effect_pred,embedding_result,total_loss_pre,total_loss_effect=\
    tslm_entity_classifier(tslmclassifier,input_ids,input_mask,timestep_type_ids,preconditions,effects,att_to_num_classes,tslm_optimizer)
      # out = model(input_ids,

    out=trip_model(embedding_result,
                   input_ids.shape,
                   input_lengths,
                   input_entities,
                   out_preconditions=prec_pred,
                   out_preconditions_softmax=prec_result,
                   out_effects=effect_pred,
                   out_effects_softmax=effect_result,
                   attention_mask=input_mask,
                   token_type_ids=segment_ids,
                   attributes=attributes,
                   preconditions=preconditions,
                   effects=effects,
                   conflicts=conflicts,
                   labels=labels,
                   training=True)
    
        out['loss_preconditions']=total_loss_pre
        out['loss_effects']=total_loss_effect
        train_loss = out['loss_stories']+out['loss_conflicts']
        temp_total_loss=train_loss+total_loss_pre+total_loss_effect
        out['total_loss']=temp_total_loss
        
    if return_losses:
      for k in out:
        if 'loss' in k:
          if k not in agg_losses:
            agg_losses[k] = out[k]
          else:
            agg_losses[k] += out[k]

    # Get gt/predicted attributes
    if 'attributes' not in trip_model.ablation:
      label_ids = attributes.view(-1, attributes.shape[-1]).to('cpu').numpy()
      if all_attributes is None:
        all_attributes = label_ids
      else:
        all_attributes = np.concatenate((all_attributes, label_ids), axis=0)

      preds = out['out_attributes'].detach().cpu().numpy()
      preds[preds >= 0.5] = 1
      preds[preds < 0.5] = 0
      if all_pred_attributes is None:
        all_pred_attributes = preds
      else:
        all_pred_attributes = np.concatenate((all_pred_attributes, preds), axis=0)


    # Get gt/predicted preconditions
    label_ids = preconditions.view(-1, preconditions.shape[-1]).to('cpu').numpy()
    if all_prec is None:
      all_prec = label_ids
    else:
      all_prec = np.concatenate((all_prec, label_ids), axis=0)

    preds = out['out_preconditions'].detach().cpu().numpy()
    if all_pred_prec is None:
      all_pred_prec = preds
    else:
      all_pred_prec = np.concatenate((all_pred_prec, preds), axis=0)


    # Get gt/predicted preconditions
    label_ids = effects.view(-1, effects.shape[-1]).to('cpu').numpy()
    if all_eff is None:
      all_eff = label_ids
    else:
      all_eff = np.concatenate((all_eff, label_ids), axis=0)

    preds = out['out_effects'].detach().cpu().numpy()
    if all_pred_eff is None:
      all_pred_eff = preds
    else:
      all_pred_eff = np.concatenate((all_pred_eff, preds), axis=0)


    # Get gt/predicted conflict points
    label_ids = conflicts.to('cpu').numpy()
    if all_conflicts is None:
      all_conflicts = label_ids
    else:
      all_conflicts = np.concatenate((all_conflicts, label_ids), axis=0)

    # preds_start = torch.argmax(out['out_start'],dim=-1).detach().cpu().numpy()
    # preds_end = torch.argmax(out['out_end'],dim=-1).detach().cpu().numpy()
    # preds = np.stack((preds_start, preds_end), axis=1)

    preds = out['out_conflicts'].detach().cpu().numpy()
    preds[preds < 0.5] = 0.
    preds[preds >= 0.5] = 1.
    if all_pred_conflicts is None:
      all_pred_conflicts = preds
    else:
      all_pred_conflicts = np.concatenate((all_pred_conflicts, preds), axis=0)


    # Get gt/predicted story choices
    label_ids = labels.to('cpu').numpy()
    if all_stories is None:
      all_stories = label_ids
    else:
      all_stories = np.concatenate((all_stories, label_ids), axis=0)

    preds = torch.argmax(out['out_stories'], dim=-1).detach().cpu().numpy()
    if all_pred_stories is None:
      all_pred_stories = preds
    else:
      all_pred_stories = np.concatenate((all_pred_stories, preds), axis=0)
    if return_softmax:
      probs = torch.softmax(out['out_stories'], dim=-1).detach().cpu().numpy()
      if all_prob_stories is None:
        all_prob_stories = probs
      else:
        all_prob_stories = np.concatenate((all_prob_stories, probs), axis=0)

    if verbose:
      bar_idx += 1
      bar.update(bar_idx)
  if verbose:
    bar.finish()

  # Calculate metrics
  if verbose:
    print('\t\tComputing metrics...')

  # print(all_pred_attributes.shape)
  # print(all_attributes.shape)
  # print(all_pred_prec.shape)
  # print(all_prec.shape)
  # print(all_pred_eff.shape)
  # print(all_eff.shape)
  # print(all_pred_conflicts.shape)
  # print(all_conflicts.shape)
  # print(all_pred_stories.shape)
  # print(all_stories.shape)

  input_lengths = input_lengths.detach().cpu().numpy()

  # Overall metrics and per-category metrics for attributes, preconditions, and effects
  # NOTE: there are a lot of extra negative examples due to padding along sentene and entity dimenions. This can't affect F1, but will affect accuracy and make it disproportionately large.
  metr_attr = None
  if 'attributes' not in trip_model.ablation:
    metr_attr = compute_metrics(all_pred_attributes.flatten(), all_attributes.flatten(), metrics)
    for i in range(trip_model.num_attributes):
      metr_i = compute_metrics(all_pred_attributes[:, i], all_attributes[:, i], metrics)
      for k in metr_i:
        metr_attr['%s_%s' % (str(k), str(i))] = metr_i[k]

  metr_prec = compute_metrics(all_pred_prec.flatten(), all_prec.flatten(), metrics)
  for i in range(trip_model.num_attributes):
    metr_i = compute_metrics(all_pred_prec[:, i], all_prec[:, i], metrics)
    for k in metr_i:
      metr_prec['%s_%s' % (str(k), str(i))] = metr_i[k]

  metr_eff = compute_metrics(all_pred_eff.flatten(), all_eff.flatten(), metrics)
  for i in range(trip_model.num_attributes):
    metr_i = compute_metrics(all_pred_eff[:, i], all_eff[:, i], metrics)
    for k in metr_i:
      metr_eff['%s_%s' % (str(k), str(i))] = metr_i[k]

  # Conflict span metrics
  metr_conflicts = compute_metrics(all_pred_conflicts.flatten(), all_conflicts.flatten(), metrics)

  # metr_start = compute_metrics(all_pred_spans[:,0], all_spans[:,0], metrics)
  # for k in metr_start:
  #   metr[k + '_start'] = metr_start[k]

  # metr_end = compute_metrics(all_pred_spans[:,1], all_spans[:,1], metrics)
  # for k in metr_end:
  #   metr[k + '_end'] = metr_end[k]

  metr_stories = compute_metrics(all_pred_stories.flatten(), all_stories.flatten(), metrics)

  verifiability, explanations = verifiable_reasoning(all_stories, all_pred_stories, all_conflicts, all_pred_conflicts, all_prec, all_pred_prec, all_eff, all_pred_eff, return_explanations=True)
  metr_stories['verifiability'] = verifiability

  if verbose:
    print('\tFinished evaluation in %ss.' % str(format_time(time.time() - t0)))

  return_base = [metr_attr, all_pred_attributes, all_attributes, metr_prec, all_pred_prec, all_prec, metr_eff, all_pred_eff, all_eff, metr_conflicts, all_pred_conflicts, all_conflicts, metr_stories, all_pred_stories, all_stories]
  if return_softmax:
    return_base += [all_prob_stories]
  if return_explanations:
    return_base += [explanations]
  if return_losses:
    for k in agg_losses:
      if 'loss' in k:
        agg_losses[k] /= len(eval_dataloader)
    return_base += [agg_losses]
  
  return tuple(return_base)
