In [1]:
import pandas as pd
from flair.embeddings import TransformerWordEmbeddings
from flair.models import SequenceTagger
from flair.trainers import ModelTrainer

from flair.data import Corpus
from flair.datasets import ColumnCorpus
from flair.data import Sentence
import numpy as np

from sklearn.model_selection import train_test_split

  from .autonotebook import tqdm as notebook_tqdm


# Load data

In [19]:
df = pd.read_csv("outputs/gold_aspects_data.csv")

In [20]:
df.head(30)

Unnamed: 0.1,Unnamed: 0,words,sentence,sent_id,labels_y
0,0,'','' CHAPTER III THE CUMÆAN SIBYL A part of the ...,0,O
1,0,CHAPTER,'' CHAPTER III THE CUMÆAN SIBYL A part of the ...,0,O
2,0,III,'' CHAPTER III THE CUMÆAN SIBYL A part of the ...,0,O
3,0,THE,'' CHAPTER III THE CUMÆAN SIBYL A part of the ...,0,O
4,0,CUMÆAN,'' CHAPTER III THE CUMÆAN SIBYL A part of the ...,0,B-aspect
5,0,SIBYL,'' CHAPTER III THE CUMÆAN SIBYL A part of the ...,0,B-aspect
6,0,A,'' CHAPTER III THE CUMÆAN SIBYL A part of the ...,0,O
7,0,part,'' CHAPTER III THE CUMÆAN SIBYL A part of the ...,0,O
8,0,of,'' CHAPTER III THE CUMÆAN SIBYL A part of the ...,0,O
9,0,the,'' CHAPTER III THE CUMÆAN SIBYL A part of the ...,0,B-aspect


In [8]:
#df.drop(columns=["Caption", "Filename"], inplace=True)

In [4]:
df.rename(columns={"Aspect": "text", "labels_y": "bio"}, inplace=True)

# Make 5 train-test split files

In [6]:
from sklearn.model_selection import KFold

In [7]:
kf = KFold(n_splits= 5, shuffle = False)

In [8]:
splits = []
for split in kf.split(df):
    splits.append(split[0])

In [67]:
def to_file(filename, sents_df):
    with open(f"/home/tess/experiments/train_test_splits/{filename}.txt", "w") as f:
        for el in sents_df: #for every sentence
            for text, label in zip(el[1]["text"], el[1]["bio"]):
                print("{}\t{}".format(text, label), file=f)
            print("\n", file = f)

In [68]:
for num, split in enumerate(kf.split(df)):
    num += 1
    
    train = df.iloc[split[0]]
    train_filename = "train_" + str(num) #make filename for train split

    train_sentences = train.groupby("Filename")
    to_file(train_filename, train_sentences)
    
    test = df.iloc[split[1]]
    test_filename = "test_" + str(num)
    
    test_sentences = test.groupby("Filename")
    to_file(test_filename, test_sentences)

# Train sequence classifier 
* BERT
* MACBERTH

In [2]:
data_folder = '/home/tess/experiments/train_test_splits/'

In [3]:
columns = {0: "text", 1: "bio" }

In [16]:
# init a corpus using column format, data folder and the names of the train, dev and test files
corpus: Corpus = ColumnCorpus(data_folder, columns,
                              train_file='train_5.txt',
                              test_file='test_5.txt')

2023-12-08 15:13:05,857 Reading data from /home/tess/experiments/train_test_splits
2023-12-08 15:13:05,859 Train: /home/tess/experiments/train_test_splits/train_5.txt
2023-12-08 15:13:05,860 Dev: None
2023-12-08 15:13:05,861 Test: /home/tess/experiments/train_test_splits/test_5.txt
2023-12-08 15:13:07,564 No dev split found. Using 0% (i.e. 174 samples) of the train split as dev data


In [17]:
len(corpus.train) #sentences

1565

In [18]:
# 2. what label do we want to predict?
label_type = 'bio'

# 3. make the label dictionary from the corpus
label_dict = corpus.make_label_dictionary(label_type=label_type, add_unk=False)
print(label_dict)

# 4. initialize fine-tuneable transformer embeddings WITH document context
embeddings = TransformerWordEmbeddings(model='emanjavacas/MacBERTh',
                                       layers="-1", #ONLY USE THE LAST LAYER (embeddings)
                                       subtoken_pooling="first",
                                       fine_tune=True, #adapt model to data
                                       use_context=True, #document context is considered during the embedding process (surrounding words, ...)
                                       )

# 5. initialize bare-bones sequence tagger (no CRF, no RNN, no reprojection)
# EMBEDDINGS ARE DIRECTLY USED without any linear projection
tagger = SequenceTagger(hidden_size=256,
                        embeddings=embeddings,
                        tag_dictionary=label_dict,
                        tag_type='bio',
                        use_crf=False,
                        use_rnn=False,
                        reproject_embeddings=False,
                        )

# 6. initialize trainer
trainer = ModelTrainer(tagger, corpus)

# 7. run fine-tuning
trainer.fine_tune('resources/taggers/sota-ner-flert',
                  learning_rate=5.0e-6,
                  mini_batch_size=4,
                  mini_batch_chunk_size=1,  # remove this parameter to speed up computation if you have a big GPU
                  )

2023-12-08 15:13:08,940 Computing label dictionary. Progress:


0it [00:00, ?it/s]
1565it [00:00, 22372.56it/s]

2023-12-08 15:13:09,019 Dictionary created for label 'bio' with 1 values: aspect (seen 5565 times)
Dictionary with 1 tags: aspect





2023-12-08 15:13:12,488 SequenceTagger predicts: Dictionary with 5 tags: O, S-aspect, B-aspect, E-aspect, I-aspect
2023-12-08 15:13:12,496 ----------------------------------------------------------------------------------------------------
2023-12-08 15:13:12,498 Model: "SequenceTagger(
  (embeddings): TransformerWordEmbeddings(
    (model): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30001, 768)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0-11): 12 x BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True

100%|██████████| 11/11 [00:01<00:00,  7.39it/s]

2023-12-08 15:14:25,318 DEV : loss 0.42506787180900574 - f1-score (micro avg)  0.4
2023-12-08 15:14:25,329 ----------------------------------------------------------------------------------------------------





2023-12-08 15:14:32,608 epoch 2 - iter 39/392 - loss 0.52966440 - time (sec): 7.28 - samples/sec: 783.44 - lr: 0.000005 - momentum: 0.000000
2023-12-08 15:14:39,620 epoch 2 - iter 78/392 - loss 0.50749951 - time (sec): 14.29 - samples/sec: 759.80 - lr: 0.000005 - momentum: 0.000000
2023-12-08 15:14:46,677 epoch 2 - iter 117/392 - loss 0.51507718 - time (sec): 21.34 - samples/sec: 762.53 - lr: 0.000005 - momentum: 0.000000
2023-12-08 15:14:53,865 epoch 2 - iter 156/392 - loss 0.52085682 - time (sec): 28.53 - samples/sec: 769.57 - lr: 0.000005 - momentum: 0.000000
2023-12-08 15:15:01,019 epoch 2 - iter 195/392 - loss 0.52403391 - time (sec): 35.69 - samples/sec: 762.77 - lr: 0.000005 - momentum: 0.000000
2023-12-08 15:15:08,121 epoch 2 - iter 234/392 - loss 0.52181147 - time (sec): 42.79 - samples/sec: 760.76 - lr: 0.000005 - momentum: 0.000000
2023-12-08 15:15:15,165 epoch 2 - iter 273/392 - loss 0.51955783 - time (sec): 49.83 - samples/sec: 770.03 - lr: 0.000005 - momentum: 0.000000
20

100%|██████████| 11/11 [00:01<00:00,  5.77it/s]

2023-12-08 15:15:38,970 DEV : loss 0.40169888734817505 - f1-score (micro avg)  0.5084
2023-12-08 15:15:38,980 ----------------------------------------------------------------------------------------------------





2023-12-08 15:15:46,028 epoch 3 - iter 39/392 - loss 0.44310533 - time (sec): 7.05 - samples/sec: 823.75 - lr: 0.000004 - momentum: 0.000000
2023-12-08 15:15:53,336 epoch 3 - iter 78/392 - loss 0.45103962 - time (sec): 14.35 - samples/sec: 767.99 - lr: 0.000004 - momentum: 0.000000
2023-12-08 15:16:00,614 epoch 3 - iter 117/392 - loss 0.46333794 - time (sec): 21.63 - samples/sec: 749.11 - lr: 0.000004 - momentum: 0.000000
2023-12-08 15:16:07,819 epoch 3 - iter 156/392 - loss 0.43512092 - time (sec): 28.84 - samples/sec: 753.27 - lr: 0.000004 - momentum: 0.000000
2023-12-08 15:16:15,073 epoch 3 - iter 195/392 - loss 0.43604739 - time (sec): 36.09 - samples/sec: 761.75 - lr: 0.000004 - momentum: 0.000000
2023-12-08 15:16:22,307 epoch 3 - iter 234/392 - loss 0.44692189 - time (sec): 43.32 - samples/sec: 773.28 - lr: 0.000004 - momentum: 0.000000
2023-12-08 15:16:29,167 epoch 3 - iter 273/392 - loss 0.43573674 - time (sec): 50.19 - samples/sec: 769.41 - lr: 0.000004 - momentum: 0.000000
20

100%|██████████| 11/11 [00:02<00:00,  5.24it/s]

2023-12-08 15:16:52,652 DEV : loss 0.3584291338920593 - f1-score (micro avg)  0.5891
2023-12-08 15:16:52,666 ----------------------------------------------------------------------------------------------------





2023-12-08 15:16:59,990 epoch 4 - iter 39/392 - loss 0.35856894 - time (sec): 7.32 - samples/sec: 744.91 - lr: 0.000004 - momentum: 0.000000
2023-12-08 15:17:07,291 epoch 4 - iter 78/392 - loss 0.36588293 - time (sec): 14.62 - samples/sec: 744.90 - lr: 0.000004 - momentum: 0.000000
2023-12-08 15:17:14,525 epoch 4 - iter 117/392 - loss 0.36684325 - time (sec): 21.86 - samples/sec: 755.42 - lr: 0.000004 - momentum: 0.000000
2023-12-08 15:17:21,649 epoch 4 - iter 156/392 - loss 0.37224226 - time (sec): 28.98 - samples/sec: 755.17 - lr: 0.000004 - momentum: 0.000000
2023-12-08 15:17:28,712 epoch 4 - iter 195/392 - loss 0.39283967 - time (sec): 36.04 - samples/sec: 758.27 - lr: 0.000004 - momentum: 0.000000
2023-12-08 15:17:35,706 epoch 4 - iter 234/392 - loss 0.39920439 - time (sec): 43.04 - samples/sec: 767.27 - lr: 0.000004 - momentum: 0.000000
2023-12-08 15:17:42,584 epoch 4 - iter 273/392 - loss 0.40691627 - time (sec): 49.91 - samples/sec: 780.29 - lr: 0.000004 - momentum: 0.000000
20

100%|██████████| 11/11 [00:01<00:00,  5.79it/s]

2023-12-08 15:18:05,515 DEV : loss 0.3526793420314789 - f1-score (micro avg)  0.5663
2023-12-08 15:18:05,525 ----------------------------------------------------------------------------------------------------





2023-12-08 15:18:12,518 epoch 5 - iter 39/392 - loss 0.32597997 - time (sec): 6.99 - samples/sec: 750.55 - lr: 0.000003 - momentum: 0.000000
2023-12-08 15:18:19,511 epoch 5 - iter 78/392 - loss 0.34093989 - time (sec): 13.98 - samples/sec: 763.71 - lr: 0.000003 - momentum: 0.000000
2023-12-08 15:18:26,483 epoch 5 - iter 117/392 - loss 0.39150981 - time (sec): 20.96 - samples/sec: 773.93 - lr: 0.000003 - momentum: 0.000000
2023-12-08 15:18:33,437 epoch 5 - iter 156/392 - loss 0.37178075 - time (sec): 27.91 - samples/sec: 785.61 - lr: 0.000003 - momentum: 0.000000
2023-12-08 15:18:40,339 epoch 5 - iter 195/392 - loss 0.37053194 - time (sec): 34.81 - samples/sec: 812.83 - lr: 0.000003 - momentum: 0.000000
2023-12-08 15:18:47,227 epoch 5 - iter 234/392 - loss 0.36432367 - time (sec): 41.70 - samples/sec: 809.35 - lr: 0.000003 - momentum: 0.000000
2023-12-08 15:18:54,190 epoch 5 - iter 273/392 - loss 0.35983767 - time (sec): 48.66 - samples/sec: 798.80 - lr: 0.000003 - momentum: 0.000000
20

100%|██████████| 11/11 [00:02<00:00,  4.17it/s]

2023-12-08 15:19:18,409 DEV : loss 0.32897722721099854 - f1-score (micro avg)  0.6027
2023-12-08 15:19:18,419 ----------------------------------------------------------------------------------------------------





2023-12-08 15:19:25,392 epoch 6 - iter 39/392 - loss 0.32966376 - time (sec): 6.97 - samples/sec: 832.64 - lr: 0.000003 - momentum: 0.000000
2023-12-08 15:19:32,547 epoch 6 - iter 78/392 - loss 0.36416378 - time (sec): 14.13 - samples/sec: 788.84 - lr: 0.000003 - momentum: 0.000000
2023-12-08 15:19:39,784 epoch 6 - iter 117/392 - loss 0.38000108 - time (sec): 21.36 - samples/sec: 785.24 - lr: 0.000003 - momentum: 0.000000
2023-12-08 15:19:47,347 epoch 6 - iter 156/392 - loss 0.36307655 - time (sec): 28.93 - samples/sec: 766.38 - lr: 0.000003 - momentum: 0.000000
2023-12-08 15:19:55,253 epoch 6 - iter 195/392 - loss 0.36309102 - time (sec): 36.83 - samples/sec: 747.08 - lr: 0.000003 - momentum: 0.000000
2023-12-08 15:20:02,839 epoch 6 - iter 234/392 - loss 0.36412692 - time (sec): 44.42 - samples/sec: 742.71 - lr: 0.000002 - momentum: 0.000000
2023-12-08 15:20:10,331 epoch 6 - iter 273/392 - loss 0.36196950 - time (sec): 51.91 - samples/sec: 742.86 - lr: 0.000002 - momentum: 0.000000
20

100%|██████████| 11/11 [00:01<00:00,  5.85it/s]

2023-12-08 15:20:34,727 DEV : loss 0.3244566321372986 - f1-score (micro avg)  0.6141
2023-12-08 15:20:34,738 ----------------------------------------------------------------------------------------------------





2023-12-08 15:20:42,268 epoch 7 - iter 39/392 - loss 0.35041957 - time (sec): 7.53 - samples/sec: 752.15 - lr: 0.000002 - momentum: 0.000000
2023-12-08 15:20:50,096 epoch 7 - iter 78/392 - loss 0.35293475 - time (sec): 15.36 - samples/sec: 724.38 - lr: 0.000002 - momentum: 0.000000
2023-12-08 15:20:57,265 epoch 7 - iter 117/392 - loss 0.33603914 - time (sec): 22.52 - samples/sec: 738.14 - lr: 0.000002 - momentum: 0.000000
2023-12-08 15:21:04,634 epoch 7 - iter 156/392 - loss 0.33678257 - time (sec): 29.89 - samples/sec: 750.63 - lr: 0.000002 - momentum: 0.000000
2023-12-08 15:21:12,107 epoch 7 - iter 195/392 - loss 0.33540040 - time (sec): 37.37 - samples/sec: 741.52 - lr: 0.000002 - momentum: 0.000000
2023-12-08 15:21:19,553 epoch 7 - iter 234/392 - loss 0.33001350 - time (sec): 44.81 - samples/sec: 737.16 - lr: 0.000002 - momentum: 0.000000
2023-12-08 15:21:26,904 epoch 7 - iter 273/392 - loss 0.33537618 - time (sec): 52.16 - samples/sec: 737.63 - lr: 0.000002 - momentum: 0.000000
20

100%|██████████| 11/11 [00:01<00:00,  5.60it/s]

2023-12-08 15:21:51,370 DEV : loss 0.32633787393569946 - f1-score (micro avg)  0.612
2023-12-08 15:21:51,381 ----------------------------------------------------------------------------------------------------





2023-12-08 15:21:58,933 epoch 8 - iter 39/392 - loss 0.36503467 - time (sec): 7.55 - samples/sec: 684.79 - lr: 0.000002 - momentum: 0.000000
2023-12-08 15:22:06,501 epoch 8 - iter 78/392 - loss 0.33253061 - time (sec): 15.12 - samples/sec: 720.91 - lr: 0.000002 - momentum: 0.000000
2023-12-08 15:22:14,077 epoch 8 - iter 117/392 - loss 0.31622426 - time (sec): 22.69 - samples/sec: 739.02 - lr: 0.000002 - momentum: 0.000000
2023-12-08 15:22:21,573 epoch 8 - iter 156/392 - loss 0.32177549 - time (sec): 30.19 - samples/sec: 728.09 - lr: 0.000001 - momentum: 0.000000
2023-12-08 15:22:29,116 epoch 8 - iter 195/392 - loss 0.32555810 - time (sec): 37.73 - samples/sec: 731.31 - lr: 0.000001 - momentum: 0.000000
2023-12-08 15:22:36,628 epoch 8 - iter 234/392 - loss 0.33366521 - time (sec): 45.24 - samples/sec: 728.18 - lr: 0.000001 - momentum: 0.000000
2023-12-08 15:22:43,765 epoch 8 - iter 273/392 - loss 0.33470553 - time (sec): 52.38 - samples/sec: 737.40 - lr: 0.000001 - momentum: 0.000000
20

100%|██████████| 11/11 [00:02<00:00,  5.37it/s]

2023-12-08 15:23:07,830 DEV : loss 0.3286391794681549 - f1-score (micro avg)  0.6183
2023-12-08 15:23:07,840 ----------------------------------------------------------------------------------------------------





2023-12-08 15:23:14,713 epoch 9 - iter 39/392 - loss 0.30917084 - time (sec): 6.87 - samples/sec: 766.99 - lr: 0.000001 - momentum: 0.000000
2023-12-08 15:23:22,161 epoch 9 - iter 78/392 - loss 0.28905659 - time (sec): 14.32 - samples/sec: 751.06 - lr: 0.000001 - momentum: 0.000000
2023-12-08 15:23:29,978 epoch 9 - iter 117/392 - loss 0.29830022 - time (sec): 22.13 - samples/sec: 747.67 - lr: 0.000001 - momentum: 0.000000
2023-12-08 15:23:37,417 epoch 9 - iter 156/392 - loss 0.30091799 - time (sec): 29.57 - samples/sec: 744.98 - lr: 0.000001 - momentum: 0.000000
2023-12-08 15:23:44,847 epoch 9 - iter 195/392 - loss 0.29783788 - time (sec): 37.00 - samples/sec: 737.14 - lr: 0.000001 - momentum: 0.000000
2023-12-08 15:23:52,156 epoch 9 - iter 234/392 - loss 0.31295937 - time (sec): 44.31 - samples/sec: 743.17 - lr: 0.000001 - momentum: 0.000000
2023-12-08 15:23:59,744 epoch 9 - iter 273/392 - loss 0.31267929 - time (sec): 51.90 - samples/sec: 735.91 - lr: 0.000001 - momentum: 0.000000
20

100%|██████████| 11/11 [00:02<00:00,  5.45it/s]

2023-12-08 15:24:23,730 DEV : loss 0.3304169476032257 - f1-score (micro avg)  0.6187
2023-12-08 15:24:23,742 ----------------------------------------------------------------------------------------------------





2023-12-08 15:24:31,458 epoch 10 - iter 39/392 - loss 0.30666172 - time (sec): 7.71 - samples/sec: 729.80 - lr: 0.000001 - momentum: 0.000000
2023-12-08 15:24:39,397 epoch 10 - iter 78/392 - loss 0.28872055 - time (sec): 15.65 - samples/sec: 705.83 - lr: 0.000000 - momentum: 0.000000
2023-12-08 15:24:46,577 epoch 10 - iter 117/392 - loss 0.30685347 - time (sec): 22.83 - samples/sec: 708.73 - lr: 0.000000 - momentum: 0.000000
2023-12-08 15:24:54,219 epoch 10 - iter 156/392 - loss 0.30065594 - time (sec): 30.47 - samples/sec: 715.53 - lr: 0.000000 - momentum: 0.000000
2023-12-08 15:25:01,630 epoch 10 - iter 195/392 - loss 0.31475251 - time (sec): 37.88 - samples/sec: 716.65 - lr: 0.000000 - momentum: 0.000000
2023-12-08 15:25:09,209 epoch 10 - iter 234/392 - loss 0.32372924 - time (sec): 45.46 - samples/sec: 718.81 - lr: 0.000000 - momentum: 0.000000
2023-12-08 15:25:16,962 epoch 10 - iter 273/392 - loss 0.31782745 - time (sec): 53.22 - samples/sec: 722.38 - lr: 0.000000 - momentum: 0.00

100%|██████████| 11/11 [00:01<00:00,  5.65it/s]


2023-12-08 15:25:42,231 DEV : loss 0.3353153169155121 - f1-score (micro avg)  0.6152
2023-12-08 15:25:43,204 ----------------------------------------------------------------------------------------------------
2023-12-08 15:25:43,208 Testing using last state of model ...


100%|██████████| 28/28 [00:04<00:00,  6.63it/s]

2023-12-08 15:25:47,480 
Results:
- F-score (micro) 0.6099
- F-score (macro) 0.6099
- Accuracy 0.4388

By class:
              precision    recall  f1-score   support

      aspect     0.6442    0.5791    0.6099      1466

   micro avg     0.6442    0.5791    0.6099      1466
   macro avg     0.6442    0.5791    0.6099      1466
weighted avg     0.6442    0.5791    0.6099      1466

2023-12-08 15:25:47,482 ----------------------------------------------------------------------------------------------------





{'test_score': 0.6099137931034482}

In [22]:
# make a sentence
sentence = Sentence('I saw a weeping willow while passing down the alleyway. The thorns of the roses reminded me of home, in Alabama.')

# predict aspect tags
tagger.predict(sentence)

In [23]:
print(sentence)

Sentence[24]: "I saw a weeping willow while passing down the alleyway. The thorns of the roses reminded me of home, in Alabama." → ["willow"/aspect, "roses"/aspect, "Alabama"/aspect]
