In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
## Imports

import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_sequence, pack_padded_sequence, pad_packed_sequence
import time
import json
import random
import numpy as np
from sklearn.metrics import precision_score, recall_score, f1_score, classification_report
import string
from collections import defaultdict
import os

SEED = 42

torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)


In [3]:
from model import LSTM_Emitter, CRF, Hier_LSTM_CRF_Classifier
from training import batchify, train_step, val_step, statistics, learn
from data import prepare_data_new

In [4]:
class Args:
    pretrained = True
    data_path = '../../data/bert_sentence_independent_embeddings' ## Input to the pre-trained embedding(should contain 4 sub-folders, IT test and train, CL test and train)
    save_path = './save/' ## path to save the model
    device = 'cpu' ## device to be used
    batch_size = 40 ## batch size
    print_every = 1 ## print loss after these many epochs
    lr = 0.01 ## learning rate
    reg = 0 ## weight decay for Adam Opt
    emb_dim = 768 ## the pre-trained embedding dimension of the sentences
    hidden_dim = 384
    epochs = 300 ## Something between 250-300

In [5]:
args = Args()
#np.warnings.filterwarnings('ignore', category=np.VisibleDeprecationWarning) 

## creating a directory to save models and other utility files
!mkdir './save/'

mkdir: ./save/: File exists


## Training Model on IT cases 

In [6]:
import pickle
from pathlib import Path

def load_documents(file_path):
        with open(file_path, "rb") as file:
            data_dict = pickle.load(file)
        return data_dict


## path for the training and testing files
main_input_path = '../../data/bert_sentence_independent_embeddings'

path_cl_train = Path(main_input_path,'CL_train.pkl')
path_cl_dev = Path(main_input_path,'CL_dev.pkl')
path_cl_test = Path(main_input_path,'CL_test.pkl')

path_it_train = Path(main_input_path,'IT_train.pkl')
path_it_dev = Path(main_input_path,'IT_dev.pkl')
path_it_test = Path(main_input_path,'IT_test.pkl')

train_data_dict = load_documents(path_cl_train)
dev_data_dict = load_documents(path_cl_train)
test_data_dict = load_documents(path_cl_train)

In [11]:
## Preparing data and Training Model for IT cases, similarly can be run for IT+CL and CL

print('\nPreparing data ...', end = ' ')

x_it_train, y_it_train, x_it_dev, y_it_dev, x_it_test, y_it_test, word2idx_it, tag2idx_it = prepare_data_new(train_data_dict, dev_data_dict, test_data_dict, args, args.emb_dim)

print('Done')


Preparing data ... Done


In [18]:
len(x_it_train)

40

In [12]:
# use loaded tag2idx
def load_config(path):
    # Load and parse the config file
    with open(path, 'r') as f:
        file = json.load(f)
    return file

#tag2idx_it = load_config('./save/tag2idx.json')

In [13]:
print('#Tags IT:', len(tag2idx_it))
print('Dump word2idx and tag2idx')
#with open(args.save_path + 'word2idx.json', 'w') as fp:
#    json.dump(word2idx_it, fp)
#with open(args.save_path + 'tag2idx.json', 'w') as fp:
#    json.dump(tag2idx_it, fp)

#Tags IT: 14
Dump word2idx and tag2idx


In [14]:
tag2idx_it

defaultdict(<function data.prepare_data_new.<locals>.<lambda>()>,
            {'<pad>': 0,
             '<start>': 1,
             '<end>': 2,
             'Fact': 3,
             'ArgumentPetitioner': 4,
             'ArgumentRespondent': 5,
             'RatioOfTheDecision': 6,
             'RulingByPresentCourt': 7,
             'RulingByLowerCourt': 8,
             'PrecedentNotReliedUpon': 9,
             'PrecedentReliedUpon': 10,
             'Statute': 11,
             'Issue': 12,
             'Dissent': 13})

In [15]:
print('\nInitializing model for IT ...', end = ' ')   
model = Hier_LSTM_CRF_Classifier(len(tag2idx_it), args.emb_dim, args.hidden_dim, tag2idx_it['<start>'], tag2idx_it['<end>'], tag2idx_it['<pad>'],args.device).to(args.device)
print('Done')


Initializing model for IT ... Done


In [16]:
args.epochs

300

In [17]:
print('\nEvaluating on test...')        
learn(model, x_it_train, y_it_train, x_it_dev, y_it_dev, x_it_test, y_it_test, tag2idx_it, args)


Evaluating on test...
  EPOCH     Tr_LOSS   Tr_F1    Val_LOSS  Val_F1
-----------------------------------------------------------
      1   26445.748   0.014   27805.688   0.039
      2   28224.742   0.063   24352.318   0.021
      3   24833.088   0.060   20475.387   0.040
      4   21019.412   0.069   20741.848   0.016
      5   21271.373   0.047   21272.195   0.040
      6   21770.471   0.074   20584.617   0.040
      7   21057.045   0.064   20119.805   0.040
      8   20476.148   0.067   19633.656   0.040
      9   20015.498   0.065   19159.668   0.040
     10   19536.176   0.061   19177.258   0.040
     11   19615.777   0.070   19255.842   0.022
     12   19663.312   0.066   18923.785   0.040
     13   19322.957   0.057   18394.773   0.040
     14   18741.980   0.057   18114.574   0.040
     15   18472.547   0.060   18113.594   0.040
     16   18433.490   0.056   18080.799   0.039
     17   18333.457   0.043   17839.133   0.040
     18   18130.461   0.052   17518.986   0.075
     

    169    7589.035   0.133    7153.213   0.069
    170    7259.565   0.087    6333.454   0.189
    171    6301.384   0.176    6874.673   0.128
    172    7187.184   0.100    7141.773   0.128
    173    7149.940   0.131    6770.953   0.131
    174    6957.550   0.143    7686.480   0.094
    175    7614.180   0.152    8039.624   0.091
    176    8024.770   0.119    6946.123   0.167
    177    7644.292   0.120    7458.808   0.133
    178    7110.850   0.155    6090.658   0.184
    179    6396.981   0.229    6523.886   0.169
    180    6691.164   0.163    6418.642   0.111
    181    6328.525   0.116    6052.525   0.081
    182    5906.412   0.096    5487.561   0.118
    183    5660.809   0.129    5925.110   0.061
    184    5848.146   0.114    5841.479   0.092
    185    5875.590   0.136    5515.171   0.150
    186    5857.179   0.153    5428.770   0.126
    187    6119.575   0.136    5718.940   0.184
    188    6010.889   0.167    5757.693   0.164
    189    6275.945   0.205    5612.311 

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


                        precision    recall  f1-score   support

                  Fact      0.571     0.577     0.574      2979
    ArgumentPetitioner      0.026     0.002     0.004       976
    ArgumentRespondent      0.107     0.221     0.144       996
    RatioOfTheDecision      0.386     0.521     0.443      2856
  RulingByPresentCourt      0.946     0.310     0.467       113
    RulingByLowerCourt      0.000     0.000     0.000       438
PrecedentNotReliedUpon      0.000     0.000     0.000       209
   PrecedentReliedUpon      0.315     0.321     0.318      1305
               Statute      0.939     0.463     0.620       367
                 Issue      0.000     0.000     0.000        17
               Dissent      0.000     0.000     0.000       284

              accuracy                          0.385     10540
             macro avg      0.299     0.220     0.234     10540
          weighted avg      0.360     0.385     0.362     10540



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
