In [28]:
import sys
import os
import json
import torch
import pickle
sys.path.append('/home/tom/uni_work/legaleval-subtask-a')

from collections import Counter
from models.CNN_BiLSTM import CNN_BiLSTM
from utils import sent2embeddings, label_encode
import torch.nn as nn

# Data

In [20]:
with open("../data/train.json") as json_file:
    data = json.load(json_file)

In [21]:
print(f"Number of documents : {len(data)}")

Number of documents : 247


#### Looping through each document

In [22]:
total_sentences = 0
group = []
for index,doc in enumerate(data):
    print(f"Document number : {index}")
    print(f"Number of sentences : {len(doc['annotations'][0]['result'])}")
    total_sentences += len(doc['annotations'][0]['result'])
    group.append(doc['meta']['group'])
    
print(f"Avg sentences = {total_sentences/len(data):.2f}")

Document number : 0
Number of sentences : 91
Document number : 1
Number of sentences : 72
Document number : 2
Number of sentences : 200
Document number : 3
Number of sentences : 119
Document number : 4
Number of sentences : 184
Document number : 5
Number of sentences : 211
Document number : 6
Number of sentences : 140
Document number : 7
Number of sentences : 87
Document number : 8
Number of sentences : 228
Document number : 9
Number of sentences : 99
Document number : 10
Number of sentences : 62
Document number : 11
Number of sentences : 213
Document number : 12
Number of sentences : 111
Document number : 13
Number of sentences : 199
Document number : 14
Number of sentences : 188
Document number : 15
Number of sentences : 271
Document number : 16
Number of sentences : 43
Document number : 17
Number of sentences : 82
Document number : 18
Number of sentences : 171
Document number : 19
Number of sentences : 149
Document number : 20
Number of sentences : 95
Document number : 21
Number of 

In [10]:
print(f"Number of law groups : \n{set(group)}")

Number of law groups : 
{'Criminal', 'Tax'}


#### Looping through each document + sentence

In [11]:
labels = []
for index,doc in enumerate(data):
    char_count = 0
    group = doc['meta']['group']
    print(f"Document number : {index:<6} Law Type : {group}")
    # print(f"Number of sentences : {len(doc['annotations'][0]['result'])}")
    if doc['annotations'][0]['result'] != []:
        for sentence_data in doc['annotations'][0]['result']:
            char_count += len(sentence_data['value']['text'])
            labels.append(sentence_data['value']['labels'][0])
        print(f"Avg number of chars : {char_count/len(doc['annotations'][0]['result']):.2f}")
    else:
        print(f"{'Document is empty':-^60}")

Document number : 0      Law Type : Criminal
Avg number of chars : 180.81
Document number : 1      Law Type : Tax
Avg number of chars : 153.39
Document number : 2      Law Type : Criminal
Avg number of chars : 193.93
Document number : 3      Law Type : Tax
Avg number of chars : 222.87
Document number : 4      Law Type : Tax
Avg number of chars : 183.04
Document number : 5      Law Type : Criminal
Avg number of chars : 143.89
Document number : 6      Law Type : Tax
Avg number of chars : 202.29
Document number : 7      Law Type : Criminal
Avg number of chars : 118.03
Document number : 8      Law Type : Criminal
Avg number of chars : 140.69
Document number : 9      Law Type : Criminal
Avg number of chars : 181.81
Document number : 10     Law Type : Criminal
Avg number of chars : 153.18
Document number : 11     Law Type : Tax
Avg number of chars : 177.47
Document number : 12     Law Type : Tax
Avg number of chars : 156.68
Document number : 13     Law Type : Criminal
Avg number of chars : 1

In [12]:
print(set(labels))

{'ANALYSIS', 'PRE_RELIED', 'RPC', 'RATIO', 'PRE_NOT_RELIED', 'NONE', 'FAC', 'ARG_RESPONDENT', 'RLC', 'PREAMBLE', 'ARG_PETITIONER', 'STA', 'ISSUE'}


In [13]:
Counter(labels)

Counter({'PREAMBLE': 4167,
         'NONE': 1423,
         'FAC': 5744,
         'ARG_RESPONDENT': 698,
         'RLC': 752,
         'ARG_PETITIONER': 1315,
         'ANALYSIS': 10695,
         'PRE_RELIED': 1431,
         'RATIO': 674,
         'RPC': 1081,
         'ISSUE': 367,
         'STA': 481,
         'PRE_NOT_RELIED': 158})

# Training through 1 document

In [23]:
label_encoder = label_encode(list(set(labels)))

In [24]:
model = CNN_BiLSTM()
model_optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
print(model)
print(model_optimizer)

CNN_BiLSTM(
  (word_conv): Conv2d(1, 1, kernel_size=(5, 1), stride=(1, 1))
  (word_max_pool): MaxPool2d(kernel_size=(2, 1), stride=(2, 1), padding=0, dilation=1, ceil_mode=False)
  (sent_conv): Conv2d(3, 1, kernel_size=(1, 1), stride=(1, 1))
  (bilstm): LSTM(768, 256, num_layers=2, batch_first=True, bidirectional=True)
  (fc1): Linear(in_features=512, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=13, bias=True)
  (relu): ReLU()
  (softmax): Softmax(dim=1)
)
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.005
    maximize: False
    weight_decay: 0
)


In [25]:
train_document = data[0]['annotations'][0]['result']
len(train_document)

91

In [33]:
def load_encoded(directory, filename):
    filepath = os.path.join(directory, filename)
    
    with open(filepath, 'rb') as f:
        target_encoded = pickle.load(f)
    
    return target_encoded

In [34]:
def save_encoded(target_encoded, directory, filename):
    if not os.path.exists(directory):
        os.makedirs(directory)
    
    filepath = os.path.join(directory, filename)
    
    with open(filepath, 'wb') as f:
        pickle.dump(target_encoded, f)

In [17]:
def train(model : object, model_optimizer : object,
          inp : torch.TensorType, target : torch.TensorType) -> float:
    """
    Train the decoder model for a single step using the given input and target sequences.

    Args:
        decoder (object): The decoder model to be trained.
        decoder_optimizer (object): The optimizer for updating the decoder's parameters.
        inp (torch.TensorType): The input sequence tensor.
        target (torch.TensorType): The target sequence tensor.

    Returns:
        float: The normalized loss for the current training step, averaged over the sequence length.
    """
    hidden, cell = model.init_hidden()
    model.zero_grad()
    loss = 0
    criterion = nn.CrossEntropyLoss()

    output, (hidden, cell) = model(inp, hidden, cell)
    print(output.size())
    print(target.size())
    loss += criterion(output, target)

    loss.backward()
    model_optimizer.step()

    return loss.item() 

In [35]:
max_sent_tok_len = 80
all_losses = []

for index, entry in enumerate(train_document):
    # if index > 10:
    #     break
    sent_triplet = []
    sent_tensor = torch.Tensor()
    if index == 0: # for start of document duplicate 1st sentence
        sent_triplet.append(train_document[index]['value']['text'])
        sent_triplet.append(train_document[index]['value']['text'])
        sent_triplet.append(train_document[index+1]['value']['text'])
    elif index == len(train_document) - 1: # for end of document duplicate last sentence
        sent_triplet.append(train_document[index-1]['value']['text'])
        sent_triplet.append(train_document[index]['value']['text'])
        sent_triplet.append(train_document[index]['value']['text'])
    else:
        sent_triplet.append(train_document[index-1]['value']['text'])
        sent_triplet.append(train_document[index]['value']['text'])
        sent_triplet.append(train_document[index+1]['value']['text'])
        
    for sent in sent_triplet:
        sent_emb = sent2embeddings(sent,MAX_LEN = max_sent_tok_len) 
        sent_tensor = torch.cat((sent_tensor,sent_emb),dim=0)
    
    save_encoded(sent_tensor, 'train_document', "embed_" + str(index))
    save_encoded(target_encoded, 'train_document', "target_" + str(index))
    print(f"input size : {sent_tensor.size()}")
    target_encoded = torch.from_numpy(label_encoder.transform(entry['value']['labels'])).float()
    print(f"target size : {target_encoded.size()}")
    # training tensor --> sent_tensor
    # training target --> target_encoded
    
    # loss = train(model,model_optimizer,sent_tensor,target_encoded)
    # print(loss)
    # all_losses.append(loss)

input size : torch.Size([3, 80, 768])
target size : torch.Size([1, 13])
input size : torch.Size([3, 80, 768])
target size : torch.Size([1, 13])
input size : torch.Size([3, 80, 768])
target size : torch.Size([1, 13])
input size : torch.Size([3, 80, 768])
target size : torch.Size([1, 13])
input size : torch.Size([3, 80, 768])
target size : torch.Size([1, 13])
input size : torch.Size([3, 80, 768])
target size : torch.Size([1, 13])
input size : torch.Size([3, 80, 768])
target size : torch.Size([1, 13])
input size : torch.Size([3, 80, 768])
target size : torch.Size([1, 13])
input size : torch.Size([3, 80, 768])
target size : torch.Size([1, 13])
input size : torch.Size([3, 80, 768])
target size : torch.Size([1, 13])
input size : torch.Size([3, 80, 768])
target size : torch.Size([1, 13])
input size : torch.Size([3, 80, 768])
target size : torch.Size([1, 13])
input size : torch.Size([3, 80, 768])
target size : torch.Size([1, 13])
input size : torch.Size([3, 80, 768])
target size : torch.Size([

In [None]:
for index in range(0, len(train_document)):
    # training tensor --> sent_tensor
    # training target --> target_encoded
    
    sent_tensor = load_encoded(sent_tensor, 'train_document', "embed_" + str(index))
    target_encoded = load_encoded(target_encoded, 'train_document', "target_" + str(index))
    loss = train(model,model_optimizer,sent_tensor,target_encoded)
    print(loss)
    all_losses.append(loss)