In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch import nn, matmul
# from torch.nn.functional import softmaxm
from torchmetrics import Accuracy
import numpy as np
from sklearn.metrics import classification_report


import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

# Load Dataset

In [42]:
import pandas as pd
from pathlib import Path

usedcols = ['sentence', 'term1', 'term2']

df_caus = pd.read_csv(
    Path('..', 'data', 'crowd_truth_cause.csv'),
    sep=',', quotechar='"',
    skipinitialspace=True,
    encoding='utf-8',
    on_bad_lines='skip',
    usecols=usedcols
)
df_caus["is_cause"] = 1
df_caus["is_treat"] = 0
df_treat = pd.read_csv(
    Path('..', 'data', 'crowd_truth_treat.csv'),
    sep=',', quotechar='"',
    skipinitialspace=True,
    encoding='utf-8',
    on_bad_lines='skip',
    usecols=usedcols
)
df_treat["is_treat"] = 1
df_treat["is_cause"] = 0
df = df_caus.append(df_treat, ignore_index=True)
df


  df = df_caus.append(df_treat, ignore_index=True)


Unnamed: 0,term1,term2,sentence,is_cause,is_treat
0,AUTISM,TANTRUM,"The limited data suggest that, in children wit...",1,0
1,SLEEP PROBLEM,FAMILY STRESS,SLEEP PROBLEMs are associated with difficult b...,1,0
2,CEREBELLAR ATAXIA,DYSFUNCTION OF THE CEREBELLUM,The term CEREBELLAR ATAXIA is employed to indi...,1,0
3,CEREBELLAR DEGENERATION,CHRONIC ETHANOL ABUSE,Non hereditary causes of cerebellar degenerati...,1,0
4,HEART PROBLEM,ARTHRITIS,The disorder can present with a migratory ture...,1,0
...,...,...,...,...,...
7963,PARKINSON'S DISEASE,AMANTADINE,A 61 year old man with PARKINSON'S DISEASE (PD...,0,1
7964,DEPRESSION,IMIPRAMINE,With successful treatment of the patient's dep...,0,1
7965,ANGI,BEPRIDIL,Five of 15 patients receiving bepridil did not...,0,1
7966,HEMOPHILIA A,FACTOR VIII,The development of antibodies to factor VIII i...,0,1


In [162]:
# df.columns.values
# print(os.getcwd())
# file = "../data/crowd_truth_combined.csv"
# df.to_csv(file)


# Preprocessing

In [25]:
# Make case insensitive (no loss because emphasis on words does not play a role)
df['sentence'] = df['sentence'].map(lambda x: x.lower())

# Replace entities in sentence with placeholder tokens (may be useful for generalization when using n-grams)
df['sentence'] = df.apply(lambda x: x['sentence'].replace(x['term1'].lower(), 'TERMONE').replace('TERMONEs', 'TERMONE'), axis=1)
df['sentence'] = df.apply(lambda x: x['sentence'].replace(x['term2'].lower(), 'TERMTWO').replace('TERMTWOs', 'TERMTWO'), axis=1)

for i in range(5):
    print(df['sentence'][i])

df = df[df['sentence'].apply(lambda x: 'TERMONE' in x and 'TERMTWO' in x)]

print(f"Number of docs: {len(df)}")

the limited data suggest that, in children with mental retardation, TERMONE is associated with aggression, destruction of property, and TERMTWO.
TERMONE are associated with difficult behaviors and TERMTWO, and are often a focus of clinical attention over and above the primary asd diagnosis.
the term TERMONE is employed to indicate ataxia that is due to TERMTWO
non hereditary causes of TERMONE include TERMTWO, paraneoplastic TERMONE, high altitude cerebral oedema, coeliac disease, normal pressure hydrocephalus and cerebellitis.
the disorder can present with a migratory ture of TERMTWO with many other features like TERMONE, skin rash, gait abnormality and skin nodules.
Number of docs: 7821


In [26]:
# Convert labels to right dtype
label_cols = ['is_cause', 'is_treat']
df['is_cause'] = df['is_cause'].astype(float).astype(int)
df['is_treat'] = df['is_treat'].astype(float).astype(int)
df[label_cols].head()

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df['is_cause'] = df['is_cause'].astype(float).astype(int)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df['is_treat'] = df['is_treat'].astype(float).astype(int)


Unnamed: 0,is_cause,is_treat
0,1,0
1,1,0
2,1,0
3,1,0
4,1,0


In [27]:
import nltk 
from nltk import RegexpTokenizer
nltk.download('punkt') # for tokanization
nltk.download('stopwords') # for stopword removal

# Tokenize the sentences
tokenizer = RegexpTokenizer(r'\w+')
df['tokens'] = df['sentence'].apply(lambda x: tokenizer.tokenize(x))
# Remove stop words and tokens with length smaller than 2 (i.e. punctuations)
df['tokens'] = df['tokens'].apply(lambda x: [token for token in x if token not in nltk.corpus.stopwords.words('english') and len(token) > 1])
# Perform stemming
porter = nltk.PorterStemmer()
df['tokens_stem'] = df['tokens'].apply(lambda x: [porter.stem(token) for token in x])
for i in range(5):
    print(df['tokens_stem'][i])

[nltk_data] Downloading package punkt to /Users/holu/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to /Users/holu/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df['tokens'] = df['sentence'].apply(lambda x: tokenizer.tokenize(x))
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df['tokens'] = df['tokens'].apply(lambda x: [token for token in x if token not in nltk.corpus.stopwords.words('english') and len(token) > 1])


['limit', 'data', 'suggest', 'children', 'mental', 'retard', 'termon', 'associ', 'aggress', 'destruct', 'properti', 'termtwo']
['termon', 'associ', 'difficult', 'behavior', 'termtwo', 'often', 'focu', 'clinic', 'attent', 'primari', 'asd', 'diagnosi']
['term', 'termon', 'employ', 'indic', 'ataxia', 'due', 'termtwo']
['non', 'hereditari', 'caus', 'termon', 'includ', 'termtwo', 'paraneoplast', 'termon', 'high', 'altitud', 'cerebr', 'oedema', 'coeliac', 'diseas', 'normal', 'pressur', 'hydrocephalu', 'cerebel']
['disord', 'present', 'migratori', 'ture', 'termtwo', 'mani', 'featur', 'like', 'termon', 'skin', 'rash', 'gait', 'abnorm', 'skin', 'nodul']


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df['tokens_stem'] = df['tokens'].apply(lambda x: [porter.stem(token) for token in x])


In [28]:
# Dependencies for WorNetLemmatizer
nltk.download('wordnet')
nltk.download('omw-1.4')

# Perform lemmatization
lemmatizer = nltk.stem.WordNetLemmatizer()
df['tokens_lemma'] = df['tokens_stem'].apply(lambda x: [lemmatizer.lemmatize(token) for token in x])
for i in range(10):
    print(df['tokens_lemma'][i])

['limit', 'data', 'suggest', 'child', 'mental', 'retard', 'termon', 'associ', 'aggress', 'destruct', 'properti', 'termtwo']
['termon', 'associ', 'difficult', 'behavior', 'termtwo', 'often', 'focu', 'clinic', 'attent', 'primari', 'asd', 'diagnosi']
['term', 'termon', 'employ', 'indic', 'ataxia', 'due', 'termtwo']
['non', 'hereditari', 'caus', 'termon', 'includ', 'termtwo', 'paraneoplast', 'termon', 'high', 'altitud', 'cerebr', 'oedema', 'coeliac', 'diseas', 'normal', 'pressur', 'hydrocephalu', 'cerebel']
['disord', 'present', 'migratori', 'ture', 'termtwo', 'mani', 'featur', 'like', 'termon', 'skin', 'rash', 'gait', 'abnorm', 'skin', 'nodul']
['featur', 'termtwo', 'includ', 'skin', 'rash', 'extrem', 'photosensit', 'hair', 'loss', 'kidney', 'problem', 'emotiol', 'labil', 'lung', 'fibrosi', 'termon']
['mani', 'individu', 'termtwo', 'also', 'suffer', 'termon', 'high', 'cholesterol', 'heart', 'diseas']
['traditiolli', 'termon', 'suggest', 'total', 'impair', 'languag', 'abil', 'termtwo', 'de

[nltk_data] Downloading package wordnet to /Users/holu/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /Users/holu/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df['tokens_lemma'] = df['tokens_stem'].apply(lambda x: [lemmatizer.lemmatize(token) for token in x])


In [29]:
import spacy
import networkx as nx

nlp = spacy.load("en_core_web_sm")
    
doc = nlp(df['sentence'][0])

def shortest_dep_path(sentence):
    doc = nlp(sentence)
    edges = []
    for token in doc:
        for child in token.children:
            edges.append((
                '{0}'.format(token.lemma_),
                '{0}'.format(child.lemma_)))
    graph = nx.Graph(edges)
    entity1 = 'TERMONE'
    entity2 = 'TERMTWO'
    try:
        return nx.shortest_path(graph, source=entity1, target=entity2)
    except:
        return []

def remove_stop_words(tokens):
    return [x for x in tokens if x not in nltk.corpus.stopwords.words('english') and len(x) > 1]

df['sdp_tokens_lemma'] = df['sentence'].apply(lambda x: remove_stop_words(shortest_dep_path(x)))
# df['sdp_tokens_lemma'] = df['sdp_tokens_lemma'].map(lambda x: x.lower())
for i in range(10):
    print(f"Index: {i}")
    print(df['sentence'][i])
    print(df['tokens_lemma'][i])
    print(df['sdp_tokens_lemma'][i])

Index: 0
the limited data suggest that, in children with mental retardation, TERMONE is associated with aggression, destruction of property, and TERMTWO.
['limit', 'data', 'suggest', 'child', 'mental', 'retard', 'termon', 'associ', 'aggress', 'destruct', 'properti', 'termtwo']
['TERMONE', 'associate', 'destruction', 'TERMTWO']
Index: 1
TERMONE are associated with difficult behaviors and TERMTWO, and are often a focus of clinical attention over and above the primary asd diagnosis.
['termon', 'associ', 'difficult', 'behavior', 'termtwo', 'often', 'focu', 'clinic', 'attent', 'primari', 'asd', 'diagnosi']
['TERMONE', 'associate', 'behavior', 'TERMTWO']
Index: 2
the term TERMONE is employed to indicate ataxia that is due to TERMTWO
['term', 'termon', 'employ', 'indic', 'ataxia', 'due', 'termtwo']
['TERMONE', 'employ', 'indicate', 'TERMTWO']
Index: 3
non hereditary causes of TERMONE include TERMTWO, paraneoplastic TERMONE, high altitude cerebral oedema, coeliac disease, normal pressure hydro

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df['sdp_tokens_lemma'] = df['sentence'].apply(lambda x: remove_stop_words(shortest_dep_path(x)))


### Relations Tokens Length

In [170]:
from collections import Counter
# relation token list lengths
Counter(df['sdp_tokens_lemma'].map(lambda x: len(x)))

Counter({4: 1706,
         2: 1469,
         0: 1472,
         3: 1944,
         6: 295,
         5: 840,
         8: 24,
         7: 69,
         9: 2})

In [171]:
# treat
treat_df = df.loc[df['is_treat'] == 1]
Counter(treat_df['sdp_tokens_lemma'].map(lambda x: len(x)))

Counter({4: 852, 2: 730, 0: 727, 3: 980, 5: 423, 6: 152, 8: 12, 7: 33, 9: 1})

In [172]:
# cause
cause_df = df.loc[df['is_cause'] == 1]
Counter(cause_df['sdp_tokens_lemma'].map(lambda x: len(x)))

Counter({4: 854, 2: 739, 0: 745, 3: 964, 6: 143, 5: 417, 8: 12, 7: 36, 9: 1})

In [14]:
# X = df['tokens_lemma'].to_numpy()
y = df[['is_cause', 'is_treat']].to_numpy()


In [13]:
unique_words = set()
longest_sentence = 0
for sentence in df['tokens']:
    current_sentence = 0
    for word in sentence:
        current_sentence += 1
        if word not in unique_words:
            unique_words.add(word)
        if current_sentence > longest_sentence:
            longest_sentence = current_sentence
print(len(unique_words))
print(longest_sentence)

11212
110


In [15]:
from keras.preprocessing.text import one_hot
import copy
from keras_preprocessing.sequence import pad_sequences

X_tmp = []
for sentence in df['tokens_lemma']:
    sen_tmp = []
    for token in sentence:
        sen_tmp.append(one_hot(token, len(unique_words)))
    X_tmp.append(sen_tmp)

X_tmp = pad_sequences(X_tmp, longest_sentence, padding='post')

X = copy.deepcopy(X_tmp)

In [17]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y,  test_size=0.8, random_state=1)
X_test, X_val, y_test, y_val= train_test_split(X_test, y_test, test_size=0.5, random_state=1) 


In [18]:
class TorchDataset(Dataset):
    def __init__(self, x, y):
        super().__init__()
        self.x = x
        self.y = y

    def get_x(self):
        return self.x

    def get_y(self):
        return self.y
            
    def __getitem__(self, idx):
        x = self.x[idx]
        y = self.y[idx]
        return (torch.tensor(x).float(), torch.tensor(y).long())
    
    def __len__(self):
        return len(self.x)

    def get_dataloader(self, batch_size=128, num_workers=0, shuffle=False):
        return DataLoader(self, batch_size=batch_size, drop_last=True, pin_memory=True, num_workers=num_workers, shuffle=shuffle)

In [19]:
X_train = TorchDataset(X_train, y_train)
X_test = TorchDataset(X_test, y_test)
X_val = TorchDataset(X_val, y_val)

dataloaders = { 
    'train': X_train.get_dataloader(batch_size=256, shuffle=True), 
    'test': X_test.get_dataloader(batch_size=128, shuffle=False), 
    'val': X_val.get_dataloader(batch_size=128, shuffle=False)
}

# Multi Head Attention Model

Based on read papers we will try to implement a model that uses multi head attention and positional encoding.

In [20]:
device = torch.device("cpu")
from torch_position_embedding import PositionEmbedding

class TorchModel(LightningModule):
    def __init__(self, learning_rate=1e-2) -> None:
        super().__init__()
        self.save_hyperparameters('learning_rate')

        self.wordEmbeddings = nn.Embedding(11212,110)
        self.positionEmbeddings = nn.Embedding(110,40)
        # self.positionEmbeddings = PositionEmbedding(num_embeddings=11212, embedding_dim=110, mode=PositionEmbedding.MODE_ADD)
        self.transformerLayer = nn.TransformerEncoderLayer(150,15) #this transofrmer contains muti head attention
        self.linear1 = nn.Linear(150, 64)
        self.linear2 = nn.Linear(64, 1)
        self.linear3 = nn.Linear(110,  16)
        self.linear4 = nn.Linear(16, 2)
           
    def forward(self, x):
        positions = (torch.arange(0,110).reshape(1,110) + torch.zeros(x.shape[0],110)).to(device)
        sentence = torch.cat((self.wordEmbeddings(x.long()).squeeze(2),self.positionEmbeddings(positions.long())),axis=2)
        attended = self.transformerLayer(sentence)
        linear1 = F.relu(self.linear1(attended))
        linear2 = torch.sigmoid(self.linear2(linear1))
        linear2 = linear2.view(-1,110) # reshaping the layer as the transformer outputs a 2d tensor (or 3d considering the batch size)
        linear3 = F.relu(self.linear3(linear2))
        out = torch.sigmoid(self.linear4(linear3))
        return out
    
    def _loss_fn(self, out, y):
        loss = F.binary_cross_entropy(out, y) # Multiclass classification
        return loss
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        out = self(x)
        out = out.squeeze()
        loss = self._loss_fn(out, y.float())
        self.log('loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)      
        return loss

    def test_step(self, batch, batch_idx):
        print("TEST DATA")
        with torch.no_grad():
            x, y = batch
            out = self(x)
            out = out.squeeze()
            loss = self._loss_fn(out, y.float())
            report = classification_report(np.argmax(y, axis=1),np.argmax(out, axis=1),target_names=['is_cause', 'is_treat'])
            print(report)
            
    
    def validation_step(self, batch, batch_idx):
        with torch.no_grad():
            x, y = batch
            out = self(x)
            out = out.squeeze()
            loss = self._loss_fn(out, y.float())
            self.log('val_loss', loss, on_epoch=True, prog_bar=True, logger=True)
    
    def configure_optimizers(self):
        return torch.optim.Adagrad(
            self.parameters(), lr=self.hparams.learning_rate)
    

        

In [21]:
class TorchTrainer():
    def __init__(self, model, name, dirpath, dataloaders, max_epochs=50) -> None:
        self.model = model
        self.name = name
        self.dirpath = dirpath
        self.max_epochs = max_epochs
        self.dataloaders = dataloaders

    def run(self):
        logger = TensorBoardLogger(f"{self.dirpath}/tensorboard", name=self.name)
        callbacks = [
            ModelCheckpoint(dirpath=Path(self.dirpath, self.name), monitor="val_loss"),
            EarlyStopping(monitor='loss')
            ]
        trainer = Trainer(deterministic=True, logger=logger, callbacks=callbacks, max_epochs=self.max_epochs)
        trainer.fit(self.model, self.dataloaders['train'], self.dataloaders['val'])
        return trainer



In [22]:
model = TorchModel()

trainer = TorchTrainer(model, 'test', "../tuwnlpie/milestone2/lightning_logs/version_0/checkpoints/" , dataloaders, max_epochs=10)

  rank_zero_deprecation(


In [182]:
if not torch.backends.mps.is_available():
    print("MPS not available")
else:
    mps_device = torch.device("mps")
    model.to(mps_device)

In [23]:
the_trainer = trainer.run()

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

  | Name               | Type                    | Params
---------------------------------------------------------------
0 | wordEmbeddings     | Embedding               | 1.2 M 
1 | positionEmbeddings | Embedding               | 4.4 K 
2 | transformerLayer   | TransformerEncoderLayer | 707 K 
3 | linear1            | Linear                  | 9.7 K 
4 | linear2            | Linear                  | 65    
5 | linear3            | Linear                  | 1.8 K 
6 | linear4            | Linear                  | 34    
---------------------------------------------------------------
2.0 M     Trainable params
0         Non-trainable params
2.0 M     Total params
7.828     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


In [184]:
the_trainer.test(model, dataloaders['test'], ckpt_path="best")


Restoring states from the checkpoint path at /Users/holu/Documents/project-1div7/tuwnlpie/milestone2/lightning_logs/version_0/checkpoints/test/epoch=2-step=18-v2.ckpt
Loaded model weights from checkpoint at /Users/holu/Documents/project-1div7/tuwnlpie/milestone2/lightning_logs/version_0/checkpoints/test/epoch=2-step=18-v2.ckpt
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


TEST DATA
              precision    recall  f1-score   support

    is_cause       0.00      0.00      0.00        59
    is_treat       0.54      1.00      0.70        69

    accuracy                           0.54       128
   macro avg       0.27      0.50      0.35       128
weighted avg       0.29      0.54      0.38       128

TEST DATA


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

    is_cause       0.00      0.00      0.00        59
    is_treat       0.54      1.00      0.70        69

    accuracy                           0.54       128
   macro avg       0.27      0.50      0.35       128
weighted avg       0.29      0.54      0.38       128

TEST DATA
              precision    recall  f1-score   support

    is_cause       0.00      0.00      0.00        59
    is_treat       0.54      1.00      0.70        69

    accuracy                           0.54       128
   macro avg       0.27      0.50      0.35       128
weighted avg       0.29      0.54      0.38       128

TEST DATA


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

    is_cause       0.00      0.00      0.00        64
    is_treat       0.50      1.00      0.67        64

    accuracy                           0.50       128
   macro avg       0.25      0.50      0.33       128
weighted avg       0.25      0.50      0.33       128

TEST DATA
              precision    recall  f1-score   support

    is_cause       0.00      0.00      0.00        65
    is_treat       0.49      1.00      0.66        63

    accuracy                           0.49       128
   macro avg       0.25      0.50      0.33       128
weighted avg       0.24      0.49      0.32       128

TEST DATA


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

    is_cause       0.00      0.00      0.00        65
    is_treat       0.49      1.00      0.66        63

    accuracy                           0.49       128
   macro avg       0.25      0.50      0.33       128
weighted avg       0.24      0.49      0.32       128

TEST DATA
              precision    recall  f1-score   support

    is_cause       0.00      0.00      0.00        63
    is_treat       0.51      1.00      0.67        65

    accuracy                           0.51       128
   macro avg       0.25      0.50      0.34       128
weighted avg       0.26      0.51      0.34       128

TEST DATA


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

    is_cause       0.00      0.00      0.00        54
    is_treat       0.58      1.00      0.73        74

    accuracy                           0.58       128
   macro avg       0.29      0.50      0.37       128
weighted avg       0.33      0.58      0.42       128

TEST DATA
              precision    recall  f1-score   support

    is_cause       0.00      0.00      0.00        57
    is_treat       0.55      1.00      0.71        71

    accuracy                           0.55       128
   macro avg       0.28      0.50      0.36       128
weighted avg       0.31      0.55      0.40       128

TEST DATA


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

    is_cause       0.00      0.00      0.00        64
    is_treat       0.50      1.00      0.67        64

    accuracy                           0.50       128
   macro avg       0.25      0.50      0.33       128
weighted avg       0.25      0.50      0.33       128

TEST DATA
              precision    recall  f1-score   support

    is_cause       0.00      0.00      0.00        62
    is_treat       0.52      1.00      0.68        66

    accuracy                           0.52       128
   macro avg       0.26      0.50      0.34       128
weighted avg       0.27      0.52      0.35       128

TEST DATA


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

    is_cause       0.00      0.00      0.00        55
    is_treat       0.57      1.00      0.73        73

    accuracy                           0.57       128
   macro avg       0.29      0.50      0.36       128
weighted avg       0.33      0.57      0.41       128

TEST DATA
              precision    recall  f1-score   support

    is_cause       0.00      0.00      0.00        61
    is_treat       0.52      1.00      0.69        67

    accuracy                           0.52       128
   macro avg       0.26      0.50      0.34       128
weighted avg       0.27      0.52      0.36       128

TEST DATA


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

    is_cause       0.00      0.00      0.00        58
    is_treat       0.55      1.00      0.71        70

    accuracy                           0.55       128
   macro avg       0.27      0.50      0.35       128
weighted avg       0.30      0.55      0.39       128

TEST DATA
              precision    recall  f1-score   support

    is_cause       0.00      0.00      0.00        63
    is_treat       0.51      1.00      0.67        65

    accuracy                           0.51       128
   macro avg       0.25      0.50      0.34       128
weighted avg       0.26      0.51      0.34       128

TEST DATA


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

    is_cause       0.00      0.00      0.00        64
    is_treat       0.50      1.00      0.67        64

    accuracy                           0.50       128
   macro avg       0.25      0.50      0.33       128
weighted avg       0.25      0.50      0.33       128

TEST DATA
              precision    recall  f1-score   support

    is_cause       0.00      0.00      0.00        73
    is_treat       0.43      1.00      0.60        55

    accuracy                           0.43       128
   macro avg       0.21      0.50      0.30       128
weighted avg       0.18      0.43      0.26       128

TEST DATA


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

    is_cause       0.00      0.00      0.00        70
    is_treat       0.45      1.00      0.62        58

    accuracy                           0.45       128
   macro avg       0.23      0.50      0.31       128
weighted avg       0.21      0.45      0.28       128

TEST DATA
              precision    recall  f1-score   support

    is_cause       0.00      0.00      0.00        66
    is_treat       0.48      1.00      0.65        62

    accuracy                           0.48       128
   macro avg       0.24      0.50      0.33       128
weighted avg       0.23      0.48      0.32       128

TEST DATA


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

    is_cause       0.00      0.00      0.00        66
    is_treat       0.48      1.00      0.65        62

    accuracy                           0.48       128
   macro avg       0.24      0.50      0.33       128
weighted avg       0.23      0.48      0.32       128

TEST DATA
              precision    recall  f1-score   support

    is_cause       0.00      0.00      0.00        67
    is_treat       0.48      1.00      0.65        61

    accuracy                           0.48       128
   macro avg       0.24      0.50      0.32       128
weighted avg       0.23      0.48      0.31       128

TEST DATA


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

    is_cause       0.00      0.00      0.00        72
    is_treat       0.44      1.00      0.61        56

    accuracy                           0.44       128
   macro avg       0.22      0.50      0.30       128
weighted avg       0.19      0.44      0.27       128

TEST DATA
              precision    recall  f1-score   support

    is_cause       0.00      0.00      0.00        72
    is_treat       0.44      1.00      0.61        56

    accuracy                           0.44       128
   macro avg       0.22      0.50      0.30       128
weighted avg       0.19      0.44      0.27       128

TEST DATA
              precision    recall  f1-score   support

    is_cause       0.00      0.00      0.00        73
    is_treat       0.43      1.00      0.60        55

    accuracy                           0.43       128
   macro avg       0.21      0.50      0.30       128
weighted avg       0.18      0.43      0.26       128


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


[{}]

# Validation

In [185]:
the_trainer.validate(model=model, dataloaders=dataloaders['val'])

  rank_zero_warn(


Validation: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        val_loss            0.6935379505157471
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'val_loss': 0.6935379505157471}]

When doing this milestone we have tried different architectures and ideas based on other papers in nlp however they all ended up being quite insuffucient at predicting our labels. We propose looking at more complex models like BERT for further improvement.