## Predicting metadata using Scientific publications
### Use Case: NCBI Disease Corpus 
* The idea is given an abstract the method predicts disease names
* The targets are MeSH unique IDs of the disease names

#### When using Google Colab GPU

In [0]:
!pwd

/content


In [0]:
from google.colab import drive
drive.mount('/content/drive/')

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


In [0]:
import os
os.chdir("drive/My Drive/textcnn_updated/textcnn")

In [2]:
!ls

clean_data_abstract_ln_200.pkl	NCBItrainset.csv
clean_data.txt			network.py
components.py			predictions-Copy1.ipynb
dropout_layers.py		predictions.ipynb
first_review.png		__pycache__
glove.42B.300d.txt		random.csv
glove.42B.300d.txt.1		random.gsheet
glove.42B.300d.zip		results_cnn_disease_names.csv
glove.42B.300d.zip.1		results_cnn_superclass.csv
glove.42B.300d.zip.2		text.txt
linked_diseases_abstract.csv	Untitled.ipynb
NCBI_corpus			utils
NCBI_corpus.zip


In [167]:
import tensorflow as tf
tf.test.gpu_device_name()

'/device:GPU:0'

In [168]:
from tensorflow.python.client import device_lib
device_lib.list_local_devices()

[name: "/device:CPU:0"
 device_type: "CPU"
 memory_limit: 268435456
 locality {
 }
 incarnation: 13973072041570232803, name: "/device:XLA_CPU:0"
 device_type: "XLA_CPU"
 memory_limit: 17179869184
 locality {
 }
 incarnation: 61085472504264535
 physical_device_desc: "device: XLA_CPU device", name: "/device:XLA_GPU:0"
 device_type: "XLA_GPU"
 memory_limit: 17179869184
 locality {
 }
 incarnation: 4045528609568371374
 physical_device_desc: "device: XLA_GPU device", name: "/device:GPU:0"
 device_type: "GPU"
 memory_limit: 10791449396
 locality {
   bus_id: 1
   links {
   }
 }
 incarnation: 2439856266245656418
 physical_device_desc: "device: 0, name: Tesla K80, pci bus id: 0000:00:04.0, compute capability: 3.7"]

In [0]:
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
from keras.models import Sequential
from keras.layers import Dense, Flatten, LSTM, Conv1D, MaxPooling1D, Dropout, Activation,Input
from keras.layers.embeddings import Embedding

from sklearn.preprocessing import MultiLabelBinarizer
import csv,pandas,re
from keras.layers import Dense, LSTM, Dropout, Bidirectional, Embedding
from keras.layers import Embedding
import numpy as np
from keras.models import Model

import time,os
import components as com

# Torch imports
import torch
import torch.nn as nn
from torch import optim
import network as net
import traceback
from utils import tensor_utils as tu

from tqdm import tqdm

import pickle
import json
import requests
from pprint import pprint




device = torch.device("cuda")
cpu = torch.device("cpu")

In [0]:
debug = True
MAX_SEQUENCE_LENGTH = 200


class BadResponseException(Exception):
    pass


class Timer:
    """ Simple block which can be called as a context, to know the time of a block. """
    def __enter__(self):
        self.start = time.perf_counter()
        return self

    def __exit__(self, *args):
        self.end = time.perf_counter()
        self.interval = self.end - self.start

#### Preprocessing: cleaning string from characters and tokenization

In [0]:
#data helpers - copied from https://github.com/bhaveshoswal/CNN-text-classification-keras
def clean_str(string):
    """
    Tokenization/string cleaning for datasets.
    Original taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py
    """
    string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)
    string = re.sub(r"\'s", " \'s", string)
    string = re.sub(r"\'ve", " \'ve", string)
    string = re.sub(r"n\'t", " n\'t", string)
    string = re.sub(r"\'re", " \'re", string)
    string = re.sub(r"\'d", " \'d", string)
    string = re.sub(r"\'ll", " \'ll", string)
    string = re.sub(r",", " , ", string)
    string = re.sub(r"!", " ! ", string)
    string = re.sub(r"\(", " ", string)
    string = re.sub(r"\)", " ", string)
    string = re.sub(r"\?", " ", string)
    string = re.sub(r"\s{2,}", " ", string)
    return string.strip().lower()

#### SPARQL Query for getting the super classes of disease terms

In [None]:
def _get_class_for_doi_(doi:str) -> list:
    url = "http://id.nlm.nih.gov/mesh/sparql"
    query = """
                PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
                PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
                PREFIX xsd: <http://www.w3.org/2001/XMLSchema#>
                PREFIX owl: <http://www.w3.org/2002/07/owl#>
                PREFIX meshv: <http://id.nlm.nih.gov/mesh/vocab#>
                PREFIX mesh: <http://id.nlm.nih.gov/mesh/>
                PREFIX mesh2015: <http://id.nlm.nih.gov/mesh/2015/>
                PREFIX mesh2016: <http://id.nlm.nih.gov/mesh/2016/>
                PREFIX mesh2017: <http://id.nlm.nih.gov/mesh/2017/>
                PREFIX mesh2018: <http://id.nlm.nih.gov/mesh/2018/>
                PREFIX mesh2019: <http://id.nlm.nih.gov/mesh/2019/>
                SELECT DISTINCT ?p ?label ?uri 
                WHERE { mesh:%s meshv:broaderDescriptor* ?uri .
                    ?uri rdfs:label ?p.
                    FILTER NOT EXISTS{
                    ?uri meshv:broaderDescriptor ?x
                    }
                }
                """ % doi
    querystring = {"query":query, 'format':'json'}
    payload = ""
    headers = {
        'cache-control': "no-cache",
        'Postman-Token': "c9ace615-1b99-4caf-bcad-97ce1060973e"
        }

    response = requests.request("GET", url, data=payload, headers=headers, params=querystring)
    return [x['uri']['value'] for x in response.json()['results']['bindings']]


def get_classes_for_dois(dois:list) -> list:
    cls = []
    for doi in dois:
        print("doi: ", doi)
        cls += _get_class_for_doi_(doi)
       # print("super class": cls)
    return list(set(cls))

#### Getting the MeSH terms of predictions of test data (for checking the correctness of output)

In [0]:
def get_terms_of_classes(list_terms) -> list:
    url = "http://id.nlm.nih.gov/mesh/sparql"
    query = """
            PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
            PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
            PREFIX xsd: <http://www.w3.org/2001/XMLSchema#>
            PREFIX owl: <http://www.w3.org/2002/07/owl#>
            PREFIX meshv: <http://id.nlm.nih.gov/mesh/vocab#>
            PREFIX mesh: <http://id.nlm.nih.gov/mesh/>
            PREFIX mesh2015: <http://id.nlm.nih.gov/mesh/2015/>
            PREFIX mesh2016: <http://id.nlm.nih.gov/mesh/2016/>
            PREFIX mesh2017: <http://id.nlm.nih.gov/mesh/2017/>
            PREFIX mesh2018: <http://id.nlm.nih.gov/mesh/2018/>
            PREFIX mesh2019: <http://id.nlm.nih.gov/mesh/2019/>
            SELECT DISTINCT ?class
            WHERE { mesh:%s rdfs:label ?class . }
            """%list_terms
    querystring = {"query":query, 'format':'json'}
    payload = ""
    headers = {
        'cache-control': "no-cache",
        'Postman-Token': "c9ace615-1b99-4caf-bcad-97ce1060973e"
        }

    response = requests.request("GET", url, data=payload, headers=headers, params=querystring)
    return [x['class']['value'] for x in response.json()['results']['bindings']]


def get_list_for_terms(dois:list) -> list:
  cls = []
  for doi in dois:
    #doi.replace("http://id.nlm.nih.gov/mesh/", "")
    print("doi: ", doi)
    cls += get_terms_of_classes(doi.replace("http://id.nlm.nih.gov/mesh/", ""))
    print("class: ", cls)
  return list(set(cls))

#### Extracting abstracts and labels from the input file and cleaning data

In [0]:
df = pandas.read_csv('linked_diseases_abstract.csv')
data = [x for x in df.to_records(index=False)] 
def read_till_next(data,index):
    temp = []
    temp_loc = 0
    for d in data[index+1:]:
        if 'a' in d[1] or 't' in d[1]:
            return temp,temp_loc
        else:
            temp_loc = temp_loc + 1
            temp.append([d[2],d[3],d[4]])
#             print([d[2],d[3],d[4]])
#             break
    
    return temp,temp_loc

loc = 0
final_data = []
for index,d in enumerate(data):
    if '|t|' in d[1]:
        continue
    elif loc != 0:
        loc = loc - 1
        continue
    else:
        if '|a|' in d[1]:
            temp_holder = {
                'abstract' : d[1],
                'keys' : []
            }
            keys,loc = read_till_next(data,index)
            temp_holder['keys'] = keys
            final_data.append(temp_holder)
clean_data = []
for f in final_data:
    abstarct = clean_str(f['abstract'].split('|')[2])
    disease = []
    for k in f['keys']:
        disease.append(k[-1])
    clean_data.append((abstarct,list(set(disease))))

In [0]:
#Now chanign the hiearchy of the diseases dataset where we only consider Diseases ontology 
clean_data_upper_class = []
us = []
for c in tqdm(clean_data):
    try:
        labels = []
        for l in c[1]:
            if 'D' == l[0]:
                labels = labels + l.split('|')
        #change the class hiearchy 
        new_labels = list(set(get_classes_for_dois(labels)))
        clean_data_upper_class.append((c[0],new_labels))
    except:
        us.append(c)

clean_data = clean_data_upper_class

############################################## 

In [0]:
!pip install -U -q PyDrive

[?25l[K    1% |▎                               | 10kB 22.7MB/s eta 0:00:01[K    2% |▋                               | 20kB 1.5MB/s eta 0:00:01[K    3% |█                               | 30kB 2.3MB/s eta 0:00:01[K    4% |█▎                              | 40kB 1.6MB/s eta 0:00:01[K    5% |█▋                              | 51kB 2.0MB/s eta 0:00:01[K    6% |██                              | 61kB 2.3MB/s eta 0:00:01[K    7% |██▎                             | 71kB 2.7MB/s eta 0:00:01[K    8% |██▋                             | 81kB 3.0MB/s eta 0:00:01[K    9% |███                             | 92kB 3.4MB/s eta 0:00:01[K    10% |███▎                            | 102kB 2.6MB/s eta 0:00:01[K    11% |███▋                            | 112kB 2.7MB/s eta 0:00:01[K    12% |████                            | 122kB 3.9MB/s eta 0:00:01[K    13% |████▎                           | 133kB 3.8MB/s eta 0:00:01[K    14% |████▋                           | 143kB 7.1MB/s eta 0:00:01[

In [0]:
with open(f'clean_data_abstract_ln_{MAX_SEQUENCE_LENGTH}.pkl', 'wb') as fp:
  pickle.dump(clean_data, fp)

In [0]:
with open ('clean_data_abstract_ln_200.pkl', 'rb') as fp:
  clean_data = pickle.load(fp)

#########################################################

#### Preprocessing data to a form usable for neural networks

In [0]:
#creating id map of class that is each label is given a unique id
all_class = list(set([g for f in clean_data for g in f[1]]))
all_class_to_id = {}
for index,value in enumerate(all_class):
    all_class_to_id[value] = index
    

texts = [f[0] for f in clean_data]

#Tokenize,idfy,pad
tokenizer = Tokenizer()
tokenizer.fit_on_texts(texts)
sequences = tokenizer.texts_to_sequences(texts)
word_index = tokenizer.word_index
data = pad_sequences(sequences, maxlen=MAX_SEQUENCE_LENGTH, padding='post')

Storing names of all the labels for superclasses that are extracted using SPARQL query

In [25]:
targets = get_list_for_terms(all_class)

doi:  http://id.nlm.nih.gov/mesh/D012140
class:  ['Respiratory Tract Diseases']
doi:  http://id.nlm.nih.gov/mesh/D055614
class:  ['Respiratory Tract Diseases', 'Genetic Phenomena']
doi:  http://id.nlm.nih.gov/mesh/D005441
class:  ['Respiratory Tract Diseases', 'Genetic Phenomena', 'Fluids and Secretions']
doi:  http://id.nlm.nih.gov/mesh/D017437
class:  ['Respiratory Tract Diseases', 'Genetic Phenomena', 'Fluids and Secretions', 'Skin and Connective Tissue Diseases']
doi:  http://id.nlm.nih.gov/mesh/D002318
class:  ['Respiratory Tract Diseases', 'Genetic Phenomena', 'Fluids and Secretions', 'Skin and Connective Tissue Diseases', 'Cardiovascular Diseases']
doi:  http://id.nlm.nih.gov/mesh/D004066
class:  ['Respiratory Tract Diseases', 'Genetic Phenomena', 'Fluids and Secretions', 'Skin and Connective Tissue Diseases', 'Cardiovascular Diseases', 'Digestive System Diseases']
doi:  http://id.nlm.nih.gov/mesh/D002468
class:  ['Respiratory Tract Diseases', 'Genetic Phenomena', 'Fluids and Se

In [27]:
targets

['Musculoskeletal Diseases',
 'Wounds and Injuries',
 'Hemic and Lymphatic Diseases',
 'Immune System Diseases',
 'Cell Physiological Phenomena',
 'Skin and Connective Tissue Diseases',
 'Physiological Phenomena',
 'Stomatognathic Diseases',
 'Behavioral Disciplines and Activities',
 'Eye Diseases',
 'Reproductive and Urinary Physiological Phenomena',
 'Musculoskeletal and Neural Physiological Phenomena',
 'Digestive System Diseases',
 'Parasitic Diseases',
 'Respiratory Tract Diseases',
 'Fluids and Secretions',
 'Nonsyndromic sensorineural hearing loss',
 'Diagnosis',
 'Genetic Phenomena',
 'Mental Disorders',
 'Nutritional and Metabolic Diseases',
 'Chemically-Induced Disorders',
 'Health Occupations',
 'Investigative Techniques',
 'Behavior and Behavior Mechanisms',
 'Cardiovascular Diseases',
 'Population Characteristics',
 'Endocrine System Diseases',
 'Cells',
 'Pathological Conditions, Signs and Symptoms',
 'Tissues',
 'Neoplasms',
 'Health Care Quality, Access, and Evaluation'

In [0]:
!pip install scikit-multilearn

Collecting scikit-multilearn
[?25l  Downloading https://files.pythonhosted.org/packages/bb/1f/e6ff649c72a1cdf2c7a1d31eb21705110ce1c5d3e7e26b2cc300e1637272/scikit_multilearn-0.2.0-py3-none-any.whl (89kB)
[K    100% |████████████████████████████████| 92kB 3.4MB/s 
[?25hInstalling collected packages: scikit-multilearn
Successfully installed scikit-multilearn-0.2.0


#### Constructing label matrix

In [18]:
#constructing the label matrix 
label_matrix = np.zeros((len(all_class),len(all_class)))
for index,value in enumerate(all_class):
    label_matrix[index][index] = 1
    
#Testing 
label_matrix[all_class_to_id[all_class[1]]]

array([0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [0]:
def create_multi_hot_label(label_matrix,labels):
    '''
    
        label_matrix --> matrix having one hot vectors
        labels --> [idfy]--> [2,5,6]
        return [0,0,1,0,0,1....]
    '''
    vector = np.zeros((label_matrix.shape[0]))
    for i in labels:
        vector = vector + label_matrix[all_class_to_id[all_class[i]]]
    
    return vector

In [0]:
label_vector = []
for f in clean_data:
    label_id = [all_class_to_id[l] for l in f[1]]
    multi_hot_vector = create_multi_hot_label(label_matrix,label_id)
    label_vector.append(multi_hot_vector)

#### Splitting data into taining and validation sets

In [0]:
X = data
Y = label_vector
split_index = int(len(X)*.80)
train_x,val_x = X[:split_index],X[split_index:]
train_y,val_y = Y[:split_index],Y[split_index:]

### Defining the NN Model

In [0]:
class BiLstmDot(net.Model):

    def __init__(self, _parameter_dict, _word_to_id, _device, _pointwise=False, _debug=False):

        self.debug = _debug
        self.parameter_dict = _parameter_dict
        self.device = _device
        self.pointwise = _pointwise
        self.word_to_id = _word_to_id
        
        self.hiddendim = self.parameter_dict['hidden_size'] * (2 * int(self.parameter_dict['bidirectional']))
        self.number_of_labels = self.parameter_dict['number_of_labels']

        if self.debug:
            print("Init Models")

        self.encoder = com.NotSuchABetterEncoder(
            number_of_layer=self.parameter_dict['number_of_layer'],
            bidirectional=self.parameter_dict['bidirectional'],
            embedding_dim=self.parameter_dict['embedding_dim'],
            max_length = self.parameter_dict['max_length'],
            hidden_dim=self.parameter_dict['hidden_size'],
            vocab_size=self.parameter_dict['vocab_size'],
            dropout=self.parameter_dict['dropout'],
            vectors=self.parameter_dict['vectors'],
            enable_layer_norm=False,
            mode = 'LSTM',
            debug = self.debug).to(self.device)
        
        
        self.dense = com.DenseClf(inputdim=self.hiddendim,        # *2 because we have two things concatinated here
                                  hiddendim=self.hiddendim/2,
                                  outputdim=self.number_of_labels).to(self.device)

    def train(self, data, optimizer, loss_fn, device):
    #
        if self.pointwise:
            return self._train_pointwise_(data, optimizer, loss_fn, device)
        else:
            return self._train_pairwise_(data, optimizer, loss_fn, device)

    def _train_pairwise_(self, data, optimizer, loss_fn, device):
        '''
            Given data, passes it through model, inited in constructor, returns loss and updates the weight
            :params data: {batch of question, pos paths, neg paths and dummy y labels}
            :params optimizer: torch.optim object
            :params loss fn: torch.nn loss object
            :params device: torch.device object

            returns loss
        '''

        # Unpacking the data and model from args
        sent_batch, y_label = data['sent_batch'], data['y_label']

        optimizer.zero_grad()

        # Encoding all the data
        hidden = self.encoder.init_hidden(sent_batch.shape[0],self.device)
        _, sent_batch_encoded, _, _ = self.encoder(tu.trim(sent_batch), hidden)

        # Calculating dot score
        pred = self.dense(sent_batch_encoded)


        '''
            If `y == 1` then it assumed the first input should be ranked higher
            (have a larger value) than the second input, and vice-versa for `y == -1`
        '''
        try:
#             print(torch.max(pred), torch.max(y_label))
            loss = loss_fn(pred, y_label.float())
        except RuntimeError:
            traceback.print_exc()
            print(torch.max(pred.to(cpu)), torch.max(y_label.to(cpu)))
#             print(pos_scores.shape, neg_scores.shape, y_label.shape,  ques_batch.shape, pos_batch.shape, neg_batch.shape)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.encoder.parameters(), .5)
        optimizer.step()
        return loss



    def predict(self, sent, device):
        """
            Same code works for both pairwise or pointwise
        """
        with torch.no_grad():

            self.encoder.eval()
            self.dense.eval()
            
            hidden = self.encoder.init_hidden(sent.shape[0], self.device)
            _, sent_encoded, _, _ = self.encoder(tu.trim(sent.long()), hidden)
            pred = self.dense(sent_encoded)


            self.encoder.train()
            self.dense.train()
            return pred

    def prepare_save(self):
        """

            This function is called when someone wants to save the underlying models.
            Returns a tuple of key:model pairs which is to be interpreted within save model.

        :return: [(key, model)]
        """
        return [('encoder', self.encoder), ('dense', self.dense)]

    def load_from(self, location):
        # Pull the data from disk
        if self.debug: print("loading Bilstmdot model from", location)
        self.encoder.load_state_dict(torch.load(location)['encoder'])
        if self.debug: print("model loaded with weights ,", self.get_parameter_sum())

#### Extracting the word embedding file and constructing the embedding matrix

In [None]:
#glove_file = "https://s3.eu-central-1.amazonaws.com/maastrichtuniversity-ids-open/metadata/glove.42B.300d.txt" #'glove.42B.300d.txt'

In [0]:
!wget 'https://s3.eu-central-1.amazonaws.com/maastrichtuniversity-ids-open/metadata/glove.42B.300d.txt'

--2019-02-15 11:12:29--  https://s3.eu-central-1.amazonaws.com/maastrichtuniversity-ids-open/metadata/glove.42B.300d.txt
Resolving s3.eu-central-1.amazonaws.com (s3.eu-central-1.amazonaws.com)... 52.219.72.20
Connecting to s3.eu-central-1.amazonaws.com (s3.eu-central-1.amazonaws.com)|52.219.72.20|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 5025028820 (4.7G) [text/plain]
Saving to: ‘glove.42B.300d.txt.2’


2019-02-15 11:21:31 (8.87 MB/s) - ‘glove.42B.300d.txt.2’ saved [5025028820/5025028820]



In [23]:
vecs = np.random.randn(len(word_index),300)
embeddings_index = {}
f = open('glove.42B.300d.txt', encoding = 'utf-8')
for line in f:
    values = line.split()
    word = values[0]
    coefs = np.asarray(values[1:], dtype='float32')
    embeddings_index[word] = coefs
f.close()

print('Found %s word vectors.' % len(embeddings_index))

Found 1917494 word vectors.


In [0]:
embedding_matrix = np.zeros((len(word_index) + 1, 300))
for word, i in word_index.items():
    embedding_vector = embeddings_index.get(word)
    if embedding_vector is not None:
        # words not found in embedding index will be all-zeros.
        embedding_matrix[i] = embedding_vector

In [0]:
vecs = embedding_matrix

#### Defining the parameter dictionary

In [0]:
parameter_dict = {
    'number_of_layer' : 1,
    'bidirectional' : True,
    'embedding_dim' : 300,
    'max_length' : MAX_SEQUENCE_LENGTH,
    'hidden_size':256,
    'vocab_size':len(word_index),
    'dropout':0.2,
    'vectors':vecs,
    'number_of_labels' : len(all_class)
}

#### Defining the loss and optimization function

In [0]:
loss_fn = nn.BCELoss()
modeler = BiLstmDot( _parameter_dict = parameter_dict,_word_to_id=None,
                                     _device=device,_pointwise=False, _debug=False)

optimizer = optim.Adam(list(filter(lambda p: p.requires_grad, modeler.encoder.parameters()))+
                       list(filter(lambda p: p.requires_grad, modeler.dense.parameters())))

Sampling the data

In [0]:
class SimplestSampler:
    """
        Given X and Y matrices (or lists of lists),
            it returns a batch worth of stuff upon __next__
    :return:
    """

    def __init__(self, data, bs: int = 64):

        try:
            assert len(data["x"]) == len(data["y"])
        except AssertionError:

            raise MismatchedDataError(f"Length of x is {len(data['x'])} while of y is {len(data['y'])}")

        self.x = data["x"]
        self.y = data["y"]
        self.n = len(self.x)
        self.bs = bs  # Batch Size

    def __len__(self):
        return self.n // self.bs - (1 if self.n % self.bs else 0)

    def __iter__(self):
        self.i, self.iter = 0, 0
        return self

    def __next__(self):
        if self.i + self.bs >= self.n:
            raise StopIteration

        _x, _y = self.x[self.i:self.i + self.bs], self.y[self.i:self.i + self.bs]
        self.i += self.bs

        return _x, _y

In [0]:
train_data = {
    'x' : train_x,
    'y': np.asarray(train_y)
}

valid_data = {
    'x' : val_x,
    'y' : np.asarray(val_y)
}

train_fact_iter = SimplestSampler(train_data)
valid_fact_iter = SimplestSampler(valid_data)

In [0]:
(train_x.sum(axis=0) == 0).any()


False

In [0]:
for dat in train_fact_iter:
    break
dat[0].shape, dat[1].shape

((64, 250), (64, 48))

In [0]:
_x = torch.tensor(dat[0], dtype=torch.long, device=device)
_y = torch.tensor(dat[1], dtype=torch.long, device=device)

data = {
    'sent_batch' : _x,
    'y_label' : _y
}

loss = modeler.train(data, optimizer, loss_fn, device)
op = modeler.predict(_x,device)

In [0]:
loss

tensor(0.6884, device='cuda:0', grad_fn=<BinaryCrossEntropyBackward>)

In [0]:
def eval(op,_y):
    return torch.mean(op.round()[_y==1])

### Training the network 

In [0]:
def simplest_loop(epochs: int,
                  train_data_iter,
                  valid_data_iter,
                  device: torch.device,
                  model,
                  optimizer,
                  loss_fn,
                  eval_fn= None) -> (list, list, list):
    """
        A fn which can be used to train a language model.
        The model doesn't need to be an nn.Module,
            but have an eval (optional), a train and a predict function.
        Data should be a dict like so:
            {"train":{"x":np.arr, "y":np.arr}, "val":{"x":np.arr, "y":np.arr} }
        Train_fn must return both loss and y_pred
        :param epochs: number of epochs to train for
        :param data: a dict having keys train_x, test_x, train_y, test_y
        :param device: torch device to create new tensor from data
        :param opt: optimizer
        :param loss_fn: loss function
        :param train_fn: function to call with x and y
        :param predict_fn: function to call with x (test)
        :param data_fn: a class to which we can pass X and Y, and get an iterator.
        :param eval_fn: (optional) function which when given pred and true, returns acc
        :return: traces
    """

    train_loss = []
    train_acc = []
    val_acc = []
    lrs = []

    # Epoch level
    for e in range(epochs):

        per_epoch_loss = []
        per_epoch_tr_acc = []

        # Train
        with Timer() as timer:

            # Make data

            for x, y in tqdm(train_data_iter):
                optimizer.zero_grad()

                _x = torch.tensor(x, dtype=torch.long, device=device)
                _y = torch.tensor(y, dtype=torch.long, device=device)
                
                data = {
                    'sent_batch' : _x,
                    'y_label' : _y
                }
                
                loss = model.train(data, optimizer, loss_fn, device)
                y_pred = model.predict(_x,device)

                per_epoch_tr_acc.append(eval_fn(y_pred,_y).item())
                per_epoch_loss.append(loss.item())


        # Val
        with torch.no_grad():

            per_epoch_vl_acc = []
            for x, y in tqdm(valid_data_iter):
                _x = torch.tensor(x, dtype=torch.long, device=device)
                _y = torch.tensor(y, dtype=torch.long, device=device)

                y_pred = model.predict(_x,device)

                per_epoch_vl_acc.append(eval_fn(y_pred, _y).item())

        # Bookkeep
        train_acc.append(np.mean(per_epoch_tr_acc))
        train_loss.append(np.mean(per_epoch_loss))
        val_acc.append(np.mean(per_epoch_vl_acc))

        print("Epoch: %(epo)03d | Loss: %(loss).5f | Tr_c: %(tracc)0.5f | Vl_c: %(vlacc)0.5f | Time: %(time).3f min"
              % {'epo': e,
                 'loss': float(np.mean(per_epoch_loss)),
                 'tracc': float(np.mean(per_epoch_tr_acc)),
                 'vlacc': float(np.mean(per_epoch_vl_acc)),
                 'time': timer.interval / 60.0})

    return train_acc, train_loss, val_acc

In [0]:
traces = simplest_loop(epochs= 75,
                  train_data_iter=train_fact_iter,
                  valid_data_iter=valid_fact_iter,
                  device=device,
                  model = modeler,
                  optimizer=optimizer,
                  loss_fn=loss_fn,
             eval_fn=eval)

9it [00:01,  6.06it/s]                       
2it [00:00, 24.96it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.18it/s]

Epoch: 000 | Loss: 0.46521 | Tr_c: 0.23453 | Vl_c: 0.29936 | Time: 0.027 min


9it [00:01,  6.33it/s]                       
2it [00:00, 24.85it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.36it/s]

Epoch: 001 | Loss: 0.23740 | Tr_c: 0.28793 | Vl_c: 0.29936 | Time: 0.024 min


9it [00:01,  6.38it/s]                       
2it [00:00, 25.37it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.35it/s]

Epoch: 002 | Loss: 0.21371 | Tr_c: 0.24214 | Vl_c: 0.19167 | Time: 0.024 min


9it [00:01,  6.33it/s]                       
2it [00:00, 24.62it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.40it/s]

Epoch: 003 | Loss: 0.20836 | Tr_c: 0.24772 | Vl_c: 0.29936 | Time: 0.024 min


9it [00:01,  6.38it/s]                       
2it [00:00, 24.80it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.33it/s]

Epoch: 004 | Loss: 0.20439 | Tr_c: 0.24286 | Vl_c: 0.19167 | Time: 0.024 min


9it [00:01,  6.39it/s]                       
2it [00:00, 24.89it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.40it/s]

Epoch: 005 | Loss: 0.20213 | Tr_c: 0.22665 | Vl_c: 0.29563 | Time: 0.024 min


9it [00:01,  6.33it/s]                       
2it [00:00, 24.64it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.31it/s]

Epoch: 006 | Loss: 0.19925 | Tr_c: 0.26617 | Vl_c: 0.25651 | Time: 0.024 min


9it [00:01,  6.34it/s]                       
2it [00:00, 24.51it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.25it/s]

Epoch: 007 | Loss: 0.19530 | Tr_c: 0.27163 | Vl_c: 0.28841 | Time: 0.024 min


9it [00:01,  6.36it/s]                       
2it [00:00, 25.02it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.33it/s]

Epoch: 008 | Loss: 0.19008 | Tr_c: 0.29814 | Vl_c: 0.29372 | Time: 0.024 min


9it [00:01,  6.37it/s]                       
2it [00:00, 24.28it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.40it/s]

Epoch: 009 | Loss: 0.18315 | Tr_c: 0.31449 | Vl_c: 0.28267 | Time: 0.024 min


9it [00:01,  6.36it/s]                       
2it [00:00, 24.89it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.29it/s]

Epoch: 010 | Loss: 0.17422 | Tr_c: 0.37094 | Vl_c: 0.25471 | Time: 0.024 min


9it [00:01,  6.46it/s]                       
2it [00:00, 25.04it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.37it/s]

Epoch: 011 | Loss: 0.16582 | Tr_c: 0.40900 | Vl_c: 0.31232 | Time: 0.023 min


9it [00:01,  6.38it/s]                       
2it [00:00, 24.66it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.44it/s]

Epoch: 012 | Loss: 0.15859 | Tr_c: 0.43851 | Vl_c: 0.41087 | Time: 0.024 min


9it [00:01,  6.43it/s]                       
2it [00:00, 25.06it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.47it/s]

Epoch: 013 | Loss: 0.14603 | Tr_c: 0.50000 | Vl_c: 0.44819 | Time: 0.024 min


9it [00:01,  6.42it/s]                       
2it [00:00, 25.05it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.34it/s]

Epoch: 014 | Loss: 0.13423 | Tr_c: 0.57598 | Vl_c: 0.49611 | Time: 0.023 min


9it [00:01,  6.29it/s]                       
2it [00:00, 24.54it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.40it/s]

Epoch: 015 | Loss: 0.12579 | Tr_c: 0.62727 | Vl_c: 0.56659 | Time: 0.024 min


9it [00:01,  6.39it/s]                       
2it [00:00, 24.45it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.33it/s]

Epoch: 016 | Loss: 0.11565 | Tr_c: 0.66540 | Vl_c: 0.59974 | Time: 0.024 min


9it [00:01,  6.33it/s]                       
2it [00:00, 24.72it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.32it/s]

Epoch: 017 | Loss: 0.10661 | Tr_c: 0.70523 | Vl_c: 0.58261 | Time: 0.024 min


9it [00:01,  6.39it/s]                       
2it [00:00, 25.09it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.36it/s]

Epoch: 018 | Loss: 0.10169 | Tr_c: 0.70800 | Vl_c: 0.56073 | Time: 0.024 min


9it [00:01,  6.34it/s]                       
2it [00:00, 24.53it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.34it/s]

Epoch: 019 | Loss: 0.09820 | Tr_c: 0.71242 | Vl_c: 0.60888 | Time: 0.024 min


9it [00:01,  6.38it/s]                       
2it [00:00, 24.56it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.35it/s]

Epoch: 020 | Loss: 0.09516 | Tr_c: 0.72112 | Vl_c: 0.55340 | Time: 0.024 min


9it [00:01,  6.32it/s]                       
2it [00:00, 23.80it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.36it/s]

Epoch: 021 | Loss: 0.09149 | Tr_c: 0.74786 | Vl_c: 0.49656 | Time: 0.024 min


9it [00:01,  6.33it/s]                       
2it [00:00, 24.80it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.24it/s]

Epoch: 022 | Loss: 0.09522 | Tr_c: 0.72986 | Vl_c: 0.52702 | Time: 0.024 min


9it [00:01,  6.33it/s]                       
2it [00:00, 25.05it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.39it/s]

Epoch: 023 | Loss: 0.08933 | Tr_c: 0.74668 | Vl_c: 0.52927 | Time: 0.024 min


9it [00:01,  6.34it/s]                       
2it [00:00, 25.09it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.34it/s]

Epoch: 024 | Loss: 0.08667 | Tr_c: 0.74550 | Vl_c: 0.63514 | Time: 0.024 min


9it [00:01,  6.32it/s]                       
2it [00:00, 24.57it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.35it/s]

Epoch: 025 | Loss: 0.08018 | Tr_c: 0.74824 | Vl_c: 0.60177 | Time: 0.024 min


9it [00:01,  6.39it/s]                       
2it [00:00, 25.34it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.38it/s]

Epoch: 026 | Loss: 0.07437 | Tr_c: 0.78614 | Vl_c: 0.59996 | Time: 0.024 min


9it [00:01,  6.38it/s]                       
2it [00:00, 24.32it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.36it/s]

Epoch: 027 | Loss: 0.06417 | Tr_c: 0.80562 | Vl_c: 0.63367 | Time: 0.024 min


9it [00:01,  6.37it/s]                       
2it [00:00, 24.96it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.23it/s]

Epoch: 028 | Loss: 0.05969 | Tr_c: 0.82438 | Vl_c: 0.63920 | Time: 0.024 min


9it [00:01,  6.35it/s]                       
2it [00:00, 24.97it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.41it/s]

Epoch: 029 | Loss: 0.05486 | Tr_c: 0.83778 | Vl_c: 0.61709 | Time: 0.024 min


9it [00:01,  6.39it/s]                       
2it [00:00, 24.55it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.26it/s]

Epoch: 030 | Loss: 0.05109 | Tr_c: 0.85886 | Vl_c: 0.63367 | Time: 0.024 min


9it [00:01,  6.34it/s]                       
2it [00:00, 24.92it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.35it/s]

Epoch: 031 | Loss: 0.04755 | Tr_c: 0.86949 | Vl_c: 0.64686 | Time: 0.024 min


9it [00:01,  6.38it/s]                       
2it [00:00, 25.06it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.35it/s]

Epoch: 032 | Loss: 0.04441 | Tr_c: 0.88022 | Vl_c: 0.66344 | Time: 0.024 min


9it [00:01,  6.40it/s]                       
2it [00:00, 24.77it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.26it/s]

Epoch: 033 | Loss: 0.04209 | Tr_c: 0.88912 | Vl_c: 0.68193 | Time: 0.024 min


9it [00:01,  6.35it/s]                       
2it [00:00, 24.54it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.30it/s]

Epoch: 034 | Loss: 0.04023 | Tr_c: 0.89448 | Vl_c: 0.67088 | Time: 0.024 min


9it [00:01,  6.35it/s]                       
2it [00:00, 24.69it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.29it/s]

Epoch: 035 | Loss: 0.03755 | Tr_c: 0.90722 | Vl_c: 0.63920 | Time: 0.024 min


9it [00:01,  6.27it/s]                       
2it [00:00, 24.93it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.27it/s]

Epoch: 036 | Loss: 0.03484 | Tr_c: 0.91841 | Vl_c: 0.62612 | Time: 0.024 min


9it [00:01,  6.36it/s]                       
2it [00:00, 24.42it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.37it/s]

Epoch: 037 | Loss: 0.03357 | Tr_c: 0.92218 | Vl_c: 0.62081 | Time: 0.024 min


9it [00:01,  6.36it/s]                       
2it [00:00, 24.94it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.32it/s]

Epoch: 038 | Loss: 0.03087 | Tr_c: 0.93503 | Vl_c: 0.65813 | Time: 0.024 min


9it [00:01,  6.33it/s]                       
2it [00:00, 24.75it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.33it/s]

Epoch: 039 | Loss: 0.02755 | Tr_c: 0.93592 | Vl_c: 0.68757 | Time: 0.024 min


9it [00:01,  6.34it/s]                       
2it [00:00, 25.30it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.31it/s]

Epoch: 040 | Loss: 0.02472 | Tr_c: 0.94458 | Vl_c: 0.69118 | Time: 0.024 min


9it [00:01,  6.39it/s]                       
2it [00:00, 24.38it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.35it/s]

Epoch: 041 | Loss: 0.02238 | Tr_c: 0.95418 | Vl_c: 0.65791 | Time: 0.024 min


9it [00:01,  6.31it/s]                       
2it [00:00, 24.57it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.33it/s]

Epoch: 042 | Loss: 0.02037 | Tr_c: 0.96373 | Vl_c: 0.67843 | Time: 0.024 min


9it [00:01,  6.32it/s]                       
2it [00:00, 24.36it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.33it/s]

Epoch: 043 | Loss: 0.01836 | Tr_c: 0.96433 | Vl_c: 0.70628 | Time: 0.024 min


9it [00:01,  6.33it/s]                       
2it [00:00, 24.90it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.27it/s]

Epoch: 044 | Loss: 0.01690 | Tr_c: 0.97079 | Vl_c: 0.68723 | Time: 0.024 min


9it [00:01,  6.35it/s]                       
2it [00:00, 24.53it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.36it/s]

Epoch: 045 | Loss: 0.01653 | Tr_c: 0.97511 | Vl_c: 0.67438 | Time: 0.024 min


9it [00:01,  6.33it/s]                       
2it [00:00, 24.97it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.36it/s]

Epoch: 046 | Loss: 0.01561 | Tr_c: 0.97468 | Vl_c: 0.66524 | Time: 0.024 min


9it [00:01,  6.37it/s]                       
2it [00:00, 25.46it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.22it/s]

Epoch: 047 | Loss: 0.01333 | Tr_c: 0.98387 | Vl_c: 0.64833 | Time: 0.024 min


9it [00:01,  6.37it/s]                       
2it [00:00, 24.76it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.40it/s]

Epoch: 048 | Loss: 0.01323 | Tr_c: 0.98303 | Vl_c: 0.67788 | Time: 0.024 min


9it [00:01,  6.36it/s]                       
2it [00:00, 24.81it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.33it/s]

Epoch: 049 | Loss: 0.01116 | Tr_c: 0.98694 | Vl_c: 0.71147 | Time: 0.024 min


9it [00:01,  6.32it/s]                       
2it [00:00, 24.69it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.31it/s]

Epoch: 050 | Loss: 0.00934 | Tr_c: 0.99216 | Vl_c: 0.65961 | Time: 0.024 min


9it [00:01,  6.29it/s]                       
2it [00:00, 25.01it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.29it/s]

Epoch: 051 | Loss: 0.00860 | Tr_c: 0.99367 | Vl_c: 0.68746 | Time: 0.024 min


9it [00:01,  6.35it/s]                       
2it [00:00, 24.60it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.30it/s]

Epoch: 052 | Loss: 0.00704 | Tr_c: 0.99581 | Vl_c: 0.68204 | Time: 0.024 min


9it [00:01,  6.38it/s]                       
2it [00:00, 25.18it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.40it/s]

Epoch: 053 | Loss: 0.00600 | Tr_c: 0.99754 | Vl_c: 0.66896 | Time: 0.024 min


9it [00:01,  6.36it/s]                       
2it [00:00, 24.62it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.32it/s]

Epoch: 054 | Loss: 0.00534 | Tr_c: 0.99834 | Vl_c: 0.68204 | Time: 0.024 min


9it [00:01,  6.35it/s]                       
2it [00:00, 24.79it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.33it/s]

Epoch: 055 | Loss: 0.00471 | Tr_c: 0.99951 | Vl_c: 0.68340 | Time: 0.024 min


9it [00:01,  6.33it/s]                       
2it [00:00, 24.95it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.34it/s]

Epoch: 056 | Loss: 0.00436 | Tr_c: 1.00000 | Vl_c: 0.69670 | Time: 0.024 min


9it [00:01,  6.33it/s]                       
2it [00:00, 24.72it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.43it/s]

Epoch: 057 | Loss: 0.00397 | Tr_c: 0.99959 | Vl_c: 0.67405 | Time: 0.024 min


9it [00:01,  6.33it/s]                       
2it [00:00, 24.57it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.31it/s]

Epoch: 058 | Loss: 0.00350 | Tr_c: 1.00000 | Vl_c: 0.67438 | Time: 0.024 min


9it [00:01,  6.35it/s]                       
2it [00:00, 24.13it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.39it/s]

Epoch: 059 | Loss: 0.00317 | Tr_c: 1.00000 | Vl_c: 0.68340 | Time: 0.024 min


9it [00:01,  6.38it/s]                       
2it [00:00, 24.89it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.36it/s]

Epoch: 060 | Loss: 0.00293 | Tr_c: 1.00000 | Vl_c: 0.67596 | Time: 0.024 min


9it [00:01,  6.35it/s]                       
2it [00:00, 24.87it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.36it/s]

Epoch: 061 | Loss: 0.00254 | Tr_c: 1.00000 | Vl_c: 0.68001 | Time: 0.024 min


9it [00:01,  6.37it/s]                       
2it [00:00, 24.76it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.23it/s]

Epoch: 062 | Loss: 0.00230 | Tr_c: 1.00000 | Vl_c: 0.67246 | Time: 0.024 min


9it [00:01,  6.35it/s]                       
2it [00:00, 24.77it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.32it/s]

Epoch: 063 | Loss: 0.00206 | Tr_c: 1.00000 | Vl_c: 0.67224 | Time: 0.024 min


9it [00:01,  6.36it/s]                       
2it [00:00, 25.06it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.31it/s]

Epoch: 064 | Loss: 0.00183 | Tr_c: 1.00000 | Vl_c: 0.67596 | Time: 0.024 min


9it [00:01,  6.31it/s]                       
2it [00:00, 24.66it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.30it/s]

Epoch: 065 | Loss: 0.00175 | Tr_c: 1.00000 | Vl_c: 0.68532 | Time: 0.024 min


9it [00:01,  6.31it/s]                       
2it [00:00, 24.95it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.37it/s]

Epoch: 066 | Loss: 0.00165 | Tr_c: 1.00000 | Vl_c: 0.67607 | Time: 0.024 min


9it [00:01,  6.36it/s]                       
2it [00:00, 24.24it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.34it/s]

Epoch: 067 | Loss: 0.00155 | Tr_c: 1.00000 | Vl_c: 0.67777 | Time: 0.024 min


9it [00:01,  6.39it/s]                       
2it [00:00, 25.42it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.31it/s]

Epoch: 068 | Loss: 0.00136 | Tr_c: 1.00000 | Vl_c: 0.69468 | Time: 0.024 min


9it [00:01,  6.34it/s]                       
2it [00:00, 24.26it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.33it/s]

Epoch: 069 | Loss: 0.00134 | Tr_c: 1.00000 | Vl_c: 0.68532 | Time: 0.024 min


9it [00:01,  6.29it/s]                       
2it [00:00, 24.26it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.27it/s]

Epoch: 070 | Loss: 0.00120 | Tr_c: 1.00000 | Vl_c: 0.67055 | Time: 0.024 min


9it [00:01,  6.30it/s]                       
2it [00:00, 24.81it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.31it/s]

Epoch: 071 | Loss: 0.00107 | Tr_c: 1.00000 | Vl_c: 0.69320 | Time: 0.024 min


9it [00:01,  6.28it/s]                       
2it [00:00, 25.04it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.37it/s]

Epoch: 072 | Loss: 0.00103 | Tr_c: 1.00000 | Vl_c: 0.68160 | Time: 0.024 min


9it [00:01,  6.32it/s]                       
2it [00:00, 24.77it/s]               
 12%|█▎        | 1/8 [00:00<00:01,  6.33it/s]

Epoch: 073 | Loss: 0.00094 | Tr_c: 1.00000 | Vl_c: 0.68362 | Time: 0.024 min


9it [00:01,  6.38it/s]                       
2it [00:00, 24.85it/s]               

Epoch: 074 | Loss: 0.00087 | Tr_c: 1.00000 | Vl_c: 0.69287 | Time: 0.024 min





#### Extracting the predictions of the model for validation set

In [0]:
def get_predictions_given_text(abstract):
    global modeler, all_class_to_id, word_index
    
    text_id = np.asarray([word_index.get(word, 0) for word in abstract.split()])
    _x = torch.tensor(text_id, dtype=torch.long, device=device).unsqueeze(0)
    y_pred = modeler.predict(_x, device)
    classes = [all_class[i] for i,x in enumerate(y_pred.round().int().squeeze()) if x.item() == 1]
    terms = get_list_for_terms(classes)
    return terms

#### Getting the abstracts of validation data

In [0]:
val_data =  tokenizer.sequences_to_texts(val_x)

In [0]:
y = (np.asarray(val_y))

In [0]:
all_class[0]

'http://id.nlm.nih.gov/mesh/D008919'

In [0]:
targets[0]

'Chemically-Induced Disorders'

In [0]:
y_actual[1][0][0]

'http://id.nlm.nih.gov/mesh/D009358'

In [0]:
y_actua = []
for result in val_y:
    idx = np.where(result==1)
    y_actua.append(np.take((all_class), idx))

In [0]:
y_actual = []
#a = np.array(a.tolist())
for nd in y_actua:
  y_actual.append(np.array(nd.tolist()))

In [0]:
labels = [get_list_for_terms(nd.tolist()) for nd in y_actual]

In [0]:
actual_labels = []
for i in y_actual:
  for j in i:
    #print(j[0])
    actual_labels.append(get_list_for_terms(j))
    

In [0]:
len(actual_labels)

155

In [0]:
predictions = [get_predictions_given_text(data_points) for index, data_points in enumerate(val_data)]

##### Saving the predictions in a file

In [0]:
string1 = val_data
string2 = predictions 
string3 = actual_labels

with open ("results_bilstm_superclass.csv","w") as f:
  writer = csv.writer(f)
  writer.writerows(zip(string1,string2,string3))
  
  
  
  

In [0]:
get_predictions_given_text(val_data[0]) 

['Neoplasms',
 'Immune System Diseases',
 'Cardiovascular Diseases',
 'Skin and Connective Tissue Diseases']

In [0]:
import ast



##### Checking the bias in predictions of superclasses

In [0]:
df = pandas.read_csv("results_bilstm_superclass.csv")
df.columns = ["abstract", "preds", "actual_labels"]

In [0]:
dfpred = df["preds"].tolist()
dfactual = df["actual_labels"].tolist()

In [0]:
 dumb = []
for i in y_actual[1]:
  dumb.append(get_list_for_terms(i))
  #for j in i:
   # for k in j:
      #dumb.append(get_list_for_terms(k))
    #  print(k)
print(dumb)
    #print(j[:5])

[['Congenital, Hereditary, and Neonatal Diseases and Abnormalities', 'Eye Diseases', 'Neoplasms']]


In [0]:
temp_list = []
for x in dfpred:
  x = ast.literal_eval(x)
  temp_list.append(x)

In [0]:
temp_list
#dfpred[0]#.strip('"')

In [160]:
preds1 = dict.fromkeys(all_class_dict, 0)
for x in range(len(temp_list)):
  for i in temp_list[x]:
    preds1[i]=preds1[i]+1
print(preds1)

{'Musculoskeletal Diseases': 21, 'Wounds and Injuries': 0, 'Hemic and Lymphatic Diseases': 9, 'Immune System Diseases': 13, 'Cell Physiological Phenomena': 0, 'Skin and Connective Tissue Diseases': 35, 'Physiological Phenomena': 2, 'Stomatognathic Diseases': 0, 'Behavioral Disciplines and Activities': 0, 'Eye Diseases': 11, 'Reproductive and Urinary Physiological Phenomena': 0, 'Musculoskeletal and Neural Physiological Phenomena': 0, 'Digestive System Diseases': 8, 'Parasitic Diseases': 0, 'Respiratory Tract Diseases': 0, 'Fluids and Secretions': 0, 'Nonsyndromic sensorineural hearing loss': 0, 'Diagnosis': 3, 'Genetic Phenomena': 1, 'Mental Disorders': 8, 'Nutritional and Metabolic Diseases': 44, 'Chemically-Induced Disorders': 0, 'Health Occupations': 1, 'Investigative Techniques': 1, 'Behavior and Behavior Mechanisms': 4, 'Cardiovascular Diseases': 10, 'Population Characteristics': 1, 'Endocrine System Diseases': 30, 'Cells': 0, 'Pathological Conditions, Signs and Symptoms': 49, 'Ti

In [0]:
all_class_dict = {k: v for v, k in enumerate(targets)}

In [0]:
temp_list1 = []
for x in dfactual:
  x = ast.literal_eval(x)
  temp_list1.append(x)

In [164]:
real = dict.fromkeys(all_class_dict, 0)
for x in range(len(temp_list1)):
  for i in temp_list1[x]:
    real[i]=real[i]+1
print(real)

{'Musculoskeletal Diseases': 22, 'Wounds and Injuries': 1, 'Hemic and Lymphatic Diseases': 12, 'Immune System Diseases': 21, 'Cell Physiological Phenomena': 3, 'Skin and Connective Tissue Diseases': 43, 'Physiological Phenomena': 6, 'Stomatognathic Diseases': 1, 'Behavioral Disciplines and Activities': 0, 'Eye Diseases': 17, 'Reproductive and Urinary Physiological Phenomena': 0, 'Musculoskeletal and Neural Physiological Phenomena': 0, 'Digestive System Diseases': 17, 'Parasitic Diseases': 0, 'Respiratory Tract Diseases': 1, 'Fluids and Secretions': 0, 'Nonsyndromic sensorineural hearing loss': 0, 'Diagnosis': 4, 'Genetic Phenomena': 4, 'Mental Disorders': 13, 'Nutritional and Metabolic Diseases': 48, 'Chemically-Induced Disorders': 1, 'Health Occupations': 4, 'Investigative Techniques': 4, 'Behavior and Behavior Mechanisms': 6, 'Cardiovascular Diseases': 22, 'Population Characteristics': 4, 'Endocrine System Diseases': 29, 'Cells': 0, 'Pathological Conditions, Signs and Symptoms': 43, 

In [0]:
from collections import Counter

real = Counter(real)
real.most_common()

In [0]:
preds1 = Counter(preds1)
preds1.most_common()