In [None]:

import torch
import numpy as np
import torch.nn as nn
from sklearn.utils import shuffle
from torch.autograd import Variable
from torch.nn import functional as F
from torchtext import data
from torchtext import datasets
import os
from torchtext.vocab import Vectors, GloVe
import time
import torch.optim as optim


In [None]:
class SelfAttention(nn.Module):
	def __init__(self, batch_size, output_size, hidden_size, vocab_size, embedding_length, weights):
		super(SelfAttention, self).__init__()

		self.batch_size = batch_size
		self.output_size = output_size
		self.hidden_size = hidden_size
		self.vocab_size = vocab_size
		self.embedding_length = embedding_length
		self.weights = weights

		self.word_embeddings = nn.Embedding(vocab_size, embedding_length)
		self.word_embeddings.weights = nn.Parameter(weights, requires_grad=False)
		self.dropout = 0.8
		self.bilstm = nn.LSTM(embedding_length, hidden_size, dropout=self.dropout, bidirectional=True)
		# We will use da = 350, r = 30 & penalization_coeff = 1 as per given in the self-attention original ICLR paper
		self.W_s1 = nn.Linear(2*hidden_size, 350)
		self.W_s2 = nn.Linear(350, 30)
		self.fc_layer = nn.Linear(30*2*hidden_size, 2000)
		self.label = nn.Linear(2000, 10)

	def attention_net(self, lstm_output):

		attn_weight_matrix = self.W_s2(torch.tanh(self.W_s1(lstm_output)))
		attn_weight_matrix = attn_weight_matrix.permute(0, 2, 1)
		attn_weight_matrix = F.softmax(attn_weight_matrix, dim=2)

		return attn_weight_matrix

	def forward(self, input_sentences, batch_size=None):

		input = self.word_embeddings(input_sentences)
		input = input.permute(1, 0, 2)
		if batch_size is None:
			h_0 = Variable(torch.zeros(2, self.batch_size, self.hidden_size))
			c_0 = Variable(torch.zeros(2, self.batch_size, self.hidden_size))
		else:
			h_0 = Variable(torch.zeros(2, batch_size, self.hidden_size))
			c_0 = Variable(torch.zeros(2, batch_size, self.hidden_size))

		output, (h_n, c_n) = self.bilstm(input, (h_0, c_0))
		output = output.permute(1, 0, 2)
		attn_weight_matrix = self.attention_net(output)
		hidden_matrix = torch.bmm(attn_weight_matrix, output)
		fc_out = self.fc_layer(hidden_matrix.view(-1, hidden_matrix.size()[1]*hidden_matrix.size()[2]))
		logits = self.label(fc_out)

		return logits

In [None]:


def load_dataset(data_path='/content/drive/MyDrive/SEBI /Adjudication Orders Annotations JSON/Model Data CSV/'):

    LABEL = data.LabelField(lower=True)
    TEXT =  data.Field(sequential=True,tokenize='spacy',lower=False, batch_first=True,fix_length=300)
    fields  = [(None,None),(None,None),('Label',LABEL),('text',TEXT)]
    
    train_ds,test_ds = data.TabularDataset.splits(
      path = data_path,
      train = 'model3_train_4.csv',
      test = 'test_4.csv',
      format = 'csv',
      fields = fields,
      skip_header = True)
    
    TEXT.build_vocab(train_ds, vectors=GloVe(name='6B', dim=300))
    LABEL.build_vocab(train_ds)

    word_embeddings = TEXT.vocab.vectors
    print ("Length of Text Vocabulary: " + str(len(TEXT.vocab)))
    print ("Vector size of Text Vocabulary: ", TEXT.vocab.vectors.size())
    print ("Label Length: " + str(len(LABEL.vocab)))


    train_data, valid_data = train_ds.split() # Further splitting of training_data to create new training_data & validation_data
    train_iter, valid_iter, test_iter = data.BucketIterator.splits((train_data, valid_data, test_ds), batch_size=32,  sort_key=lambda x: len(x.text), repeat=False, shuffle=True)

    vocab_size = len(TEXT.vocab)
    label_vocab = LABEL.vocab

    return TEXT, vocab_size, word_embeddings, train_iter, valid_iter, test_iter, label_vocab

In [None]:
TEXT, vocab_size, word_embeddings, train_iter, valid_iter, test_iter, label_vocab = load_dataset()

In [None]:
index_vocab = label_vocab.stoi
index_vocab

index_list = label_vocab.itos
index_list 

In [None]:

def clip_gradient(model, clip_value):
    params = list(filter(lambda p: p.grad is not None, model.parameters()))
    for p in params:
        p.grad.data.clamp_(-clip_value, clip_value)
    
def train_model(model, train_iter, epoch):
    total_epoch_loss = 0
    total_epoch_acc = 0
    optim = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()))
    steps = 0
    model.train()
    for idx, batch in enumerate(train_iter):
        text = batch.text
        target = batch.Label
        target = torch.autograd.Variable(target).long()
        if (text.size()[0] is not 32):# One of the batch returned by BucketIterator has length different than 32.
            continue
        optim.zero_grad()
        prediction = model(text)
        loss = loss_fn(prediction, target)
        num_corrects = (torch.max(prediction, 1)[1].view(target.size()).data == target.data).float().sum()
        acc = 100.0 * num_corrects/len(batch)
        loss.backward()
        clip_gradient(model, 1e-1)
        optim.step()
        steps += 1
        
        if steps % 100 == 0:
            print (f'Epoch: {epoch+1}, Idx: {idx+1}, Training Loss: {loss.item():.4f}, Training Accuracy: {acc.item(): .2f}%')
        
        total_epoch_loss += loss.item()
        total_epoch_acc += acc.item()
        
    return total_epoch_loss/len(train_iter), total_epoch_acc/len(train_iter)

def eval_model(model, val_iter):
    total_epoch_loss = 0
    total_epoch_acc = 0
    model.eval()
    with torch.no_grad():
        for idx, batch in enumerate(val_iter):
            text = batch.text
            if (text.size()[0] is not 32):
                continue
            target = batch.Label
            target = torch.autograd.Variable(target).long()
            prediction = model(text)
            #print(target,torch.max(prediction, 1)[1].view(target.size()).data,sep='\n\n')
            loss = loss_fn(prediction, target)
            num_corrects = (torch.max(prediction, 1)[1].view(target.size()).data == target.data).sum()
            acc = 100.0 * num_corrects/len(batch)
            total_epoch_loss += loss.item()
            total_epoch_acc += acc.item()

    return total_epoch_loss/len(val_iter), total_epoch_acc/len(val_iter)
	

learning_rate = 0.001
batch_size = 32
output_size = 10
hidden_size = 256
embedding_length = 300

model = SelfAttention(batch_size, output_size, hidden_size, vocab_size, embedding_length, word_embeddings)
loss_fn = F.cross_entropy

for epoch in range(11):
    train_loss, train_acc = train_model(model, train_iter, epoch)
    val_loss, val_acc = eval_model(model, valid_iter)
    
    print(f'Epoch: {epoch+1:02}, Train Loss: {train_loss:.3f}, Train Acc: {train_acc:.2f}%, Val. Loss: {val_loss:3f}, Val. Acc: {val_acc:.2f}%')
    
test_loss, test_acc = eval_model(model, test_iter)
print(f'Test Loss: {test_loss:.3f}, Test Acc: {test_acc:.2f}%')


In [None]:
#predicitions for test data 
from sklearn.metrics import confusion_matrix, classification_report

model.eval()
all_targets = []
all_predicted = []
with torch.no_grad():
    for idx, batch in enumerate(test_iter):
        text = batch.text
        if (text.size()[0] is not 32):
            continue
        target = batch.Label
        #print(f"taregt label {target.numpy()}")
        all_targets += batch.Label
        prediction = model(text)
        all_predicted += torch.max(prediction, 1)[1].view(target.size()).data
        #print(f"predicted label {torch.max(prediction, 1)[1].view(target.size()).data.numpy()}",end='\n\n')



        

In [None]:
print(classification_report(all_targets, all_predicted, target_names=index_list))

                        precision    recall  f1-score   support

         material fact       0.39      0.57      0.47        87
       defendant claim       0.66      0.36      0.46        76
       procedural fact       0.40      0.68      0.50        47
subjective observation       0.65      0.37      0.47        41
        statutory fact       0.40      0.43      0.41        14
         issues framed       1.00      0.56      0.72        16
          related fact       0.75      0.46      0.57        13
            allegation       1.00      0.23      0.38        13
               penalty       1.00      0.43      0.60         7
             violation       0.10      0.17      0.12         6

              accuracy                           0.48       320
             macro avg       0.64      0.43      0.47       320
          weighted avg       0.57      0.47      0.48       320

