In [1]:
!python -m pip install transformers torchmetrics

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.25.1-py3-none-any.whl (5.8 MB)
[K     |████████████████████████████████| 5.8 MB 25.0 MB/s 
[?25hCollecting torchmetrics
  Downloading torchmetrics-0.11.0-py3-none-any.whl (512 kB)
[K     |████████████████████████████████| 512 kB 99.4 MB/s 
[?25hCollecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[K     |████████████████████████████████| 7.6 MB 29.6 MB/s 
[?25hCollecting huggingface-hub<1.0,>=0.10.0
  Downloading huggingface_hub-0.11.1-py3-none-any.whl (182 kB)
[K     |████████████████████████████████| 182 kB 23.9 MB/s 
Installing collected packages: tokenizers, huggingface-hub, transformers, torchmetrics
Successfully installed huggingface-hub-0.11.1 tokenizers-0.13.2 torchmetrics-0.11.0 transformers-4.25.1


In [None]:
# !python -m pip install torchmetrics

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [1]:
import torchmetrics

In [2]:
import transformers

In [3]:
import traceback
import csv

import pandas as pd


def write_tsv_dataframe(filepath, dataframe):
    """
        Stores `DataFrame` as tsv file

        Parameters
        ----------
        filepath : str
            Path to tsv file
        dataframe : pd.DataFrame
            DataFrame to store

        Raises
        ------
        IOError
            if the file can't be opened
    """
    try:
        dataframe.to_csv(filepath, encoding='utf-8', sep='\t', index=False, header=True, quoting=csv.QUOTE_NONE)
    except IOError:
        traceback.print_exc()


In [4]:
def combine_columns(df_arguments, df_labels):
    """Combines the two `DataFrames` on column `Argument ID`"""
    return pd.merge(df_arguments, df_labels, on='Argument ID')


In [5]:
def split_arguments(df_arguments):
    """Splits `DataFrame` by column `Usage` into `train`-, `validation`-, and `test`-arguments"""
    train_arguments = df_arguments.loc[df_arguments['Usage'] == 'train'].drop(['Usage'], axis=1).reset_index(drop=True)
    valid_arguments = df_arguments.loc[df_arguments['Usage'] == 'validation'].drop(['Usage'], axis=1).reset_index(drop=True)
    test_arguments = df_arguments.loc[df_arguments['Usage'] == 'test'].drop(['Usage'], axis=1).reset_index(drop=True)
    
    return train_arguments, valid_arguments, test_arguments


In [6]:
def create_dataframe_head(argument_ids, model_name):
    """
        Creates `DataFrame` usable to append predictions to it

        Parameters
        ----------
        argument_ids : list[str]
            First column of the resulting DataFrame
        model_name : str
            Second column of DataFrame will contain the given model name

        Returns
        -------
        pd.DataFrame
            prepared DataFrame
    """
    df_model_head = pd.DataFrame(argument_ids, columns=['Argument ID'])
    df_model_head['Method'] = [model_name] * len(argument_ids)

    return df_model_head


In [7]:
import json
class MissingColumnError(AttributeError):
    """Error indicating that an imported DataFrame lacks necessary columns"""
    pass


In [8]:
def load_json_file(filepath):
    """Load content of json-file from `filepath`"""
    with open(filepath, 'r') as  json_file:
        return json.load(json_file)


In [9]:
def load_values_from_json(filepath):
    """Load values per level from json-file from `filepath`"""
    json_values = load_json_file(filepath)
    values = { "1":set(), "2":set(), "3":set(), "4a":set(), "4b":set() }
    for value in json_values["values"]:
        values["1"].add(value["name"])
        values["2"].add(value["level2"])
        for valueLevel3 in value["level3"]:
            values["3"].add(valueLevel3)
        for valueLevel4a in value["level4a"]:
            values["4a"].add(valueLevel4a)
        for valueLevel4b in value["level4b"]:
            values["4b"].add(valueLevel4b)
    values["1"] = sorted(values["1"])
    values["2"] = sorted(values["2"])
    values["3"] = sorted(values["3"])
    values["4a"] = sorted(values["4a"])
    values["4b"] = sorted(values["4b"])
    return values


In [10]:
def load_arguments_from_tsv(filepath, default_usage='test'):
    """
        Reads arguments from tsv file

        Parameters
        ----------
        filepath : str
            The path to the tsv file
        default_usage : str, optional
            The default value if the column "Usage" is missing

        Returns
        -------
        pd.DataFrame
            the DataFrame with all arguments

        Raises
        ------
        MissingColumnError
            if the required columns "Argument ID" or "Premise" are missing in the read data
        IOError
            if the file can't be read
        """
    try:
        dataframe = pd.read_csv(filepath, encoding='utf-8', sep='\t', header=0)
        if not {'Argument ID', 'Premise'}.issubset(set(dataframe.columns.values)):
            raise MissingColumnError('The argument "%s" file does not contain the minimum required columns [Argument ID, Premise].' % filepath)
        if 'Usage' not in dataframe.columns.values:
            dataframe['Usage'] = [default_usage] * len(dataframe)
        return dataframe
    except IOError:
        traceback.print_exc()
        raise


In [11]:
def load_labels_from_tsv(filepath, label_order):
    """
        Reads label annotations from tsv file

        Parameters
        ----------
        filepath : str
            The path to the tsv file
        label_order : list[str]
            The listing and order of the labels to use from the read data

        Returns
        -------
        pd.DataFrame
            the DataFrame with the annotations

        Raises
        ------
        MissingColumnError
            if the required columns "Argument ID" or names from `label_order` are missing in the read data
        IOError
            if the file can't be read
        """
    try:
        dataframe = pd.read_csv(filepath, encoding='utf-8', sep='\t', header=0)
        dataframe = dataframe[['Argument ID'] + label_order]
        return dataframe
    except IOError:
        traceback.print_exc()
        raise
    except KeyError:
        raise MissingColumnError('The file "%s" does not contain the required columns for its level.' % filepath)


In [12]:
import sys
import getopt
import os

In [13]:
model_dir = 'models'
data_dir = 'data'

In [14]:
if not os.path.exists(model_dir):
    os.makedirs(model_dir)


In [15]:
argument_filepath = os.path.join(data_dir, 'arguments.tsv')
value_json_filepath = os.path.join(data_dir, 'values.json')


In [16]:
df_arguments = load_arguments_from_tsv(argument_filepath, default_usage='train')

In [17]:
values = load_values_from_json(value_json_filepath)
num_labels_Lv2 = len(values['2'])


In [18]:
df_arguments.keys()

Index(['Argument ID', 'Part', 'Usage', 'Conclusion', 'Stance', 'Premise'], dtype='object')

In [19]:
for ip in df_arguments['Argument ID']:
  #print(df_arguments['Stance'][ip])
  
  print(ip)


A01001
A01002
A01003
A01004
A01005
A01006
A01007
A01008
A01009
A01010
A01011
A01012
A01013
A01014
A01015
A01016
A01017
A01018
A01019
A01020
A02001
A02002
A02003
A02004
A02005
A02006
A02007
A02008
A02009
A02010
A02011
A02012
A02013
A02014
A02015
A02016
A02017
A02018
A02019
A02020
A03001
A03002
A03003
A03004
A03005
A03006
A03007
A03008
A03009
A03010
A03011
A03012
A03013
A03014
A03015
A03016
A03017
A03018
A03019
A03020
A04001
A04002
A04003
A04004
A04005
A04006
A04007
A04008
A04009
A04010
A04011
A04012
A04013
A04014
A04015
A04016
A04017
A04018
A04019
A04020
A05001
A05002
A05003
A05004
A05005
A05006
A05007
A05008
A05009
A05010
A05011
A05012
A05013
A05014
A05015
A05016
A05017
A05018
A05019
A05020
A05021
A05022
A05023
A05024
A05025
A05026
A05027
A05028
A05029
A05030
A05031
A05032
A05033
A05034
A05035
A05036
A05037
A05038
A05043
A05045
A05048
A05049
A05050
A05052
A05053
A05054
A05055
A05056
A05057
A05058
A05059
A05062
A05063
A05064
A05065
A05066
A05067
A05069
A05071
A05072
A05073
A05074
A05075

A19296
A19297
A19298
A19299
A19301
A19302
A19303
A19304
A19305
A19306
A19307
A19308
A19309
A19310
A19311
A19312
A19313
A19314
A19315
A19316
A19317
A19318
A19319
A19320
A19321
A19322
A19323
A19324
A19325
A19326
A19328
A19331
A19332
A19333
A19334
A19335
A19336
A19338
A19339
A19342
A19343
A19344
A19345
A19346
A19347
A19348
A19349
A19350
A19351
A19352
A19353
A19355
A19356
A19358
A19359
A19362
A19363
A19365
A19366
A19367
A19369
A19370
A19371
A19372
A19373
A19375
A19376
A19378
A19379
A19383
A19384
A19385
A19386
A19387
A19388
A19389
A19390
A19391
A19392
A19393
A19394
A19395
A19396
A19397
A19398
A19399
A19400
A19401
A19403
A19405
A19406
A19408
A19409
A19410
A19411
A19413
A19414
A19415
A19416
A19417
A19418
A19419
A19421
A19422
A19423
A19424
A19425
A19426
A19427
A19428
A19429
A19430
A19431
A19432
A19433
A19434
A19435
A19436
A19437
A19438
A19439
A19440
A19441
A19442
A19443
A19444
A19445
A19446
A19447
A19448
A19449
A19451
A19452
A19453
A19454
A19455
A19456
A19457
A19458
A19460
A19462
A19463
A19464

A23170
A23172
A23173
A23175
A23177
A23178
A23180
A23181
A23182
A23183
A23184
A23185
A23186
A23187
A23188
A23190
A23192
A23193
A23194
A23195
A23196
A23197
A23198
A23199
A23200
A23202
A23203
A23204
A23205
A23206
A23208
A23209
A23212
A23213
A23214
A23215
A23216
A23217
A23218
A23220
A23221
A23223
A23224
A23226
A23227
A23228
A23229
A23230
A23231
A23232
A23233
A23234
A23235
A23236
A23237
A23238
A23239
A23242
A23243
A23246
A23247
A23248
A23250
A23252
A23254
A23255
A23256
A23257
A23258
A23259
A23260
A23261
A23262
A23263
A23264
A23265
A23266
A23268
A23270
A23271
A23272
A23273
A23274
A23275
A23276
A23277
A23278
A23279
A23280
A23281
A23282
A23283
A23284
A23285
A23286
A23287
A23288
A23289
A23290
A23291
A23292
A23293
A23294
A23295
A23296
A23297
A23298
A23300
A23301
A23303
A23305
A23306
A23307
A23308
A23309
A23310
A23311
A23312
A23313
A23314
A23315
A23317
A23318
A23319
A23320
A23322
A23323
A23324
A23326
A23327
A23328
A23329
A23330
A23332
A23333
A23334
A23335
A23336
A23337
A23338
A23339
A23340
A23342

A23395
A23420
A23424
A23428
A23434
A23457
A23464
A23480
A23487
A24010
A24012
A24028
A24060
A24061
A24072
A24076
A24090
A24107
A24111
A24116
A24125
A24137
A24141
A24144
A24152
A24155
A24159
A24166
A24181
A24212
A24239
A24246
A24256
A24275
A24276
A24281
A24282
A24296
A24299
A24310
A24317
A24319
A24325
A24342
A24347
A24348
A24351
A24360
A24361
A24376
A24377
A24388
A24389
A24417
A24422
A24428
A24433
A24476
A24484
A24500
A25019
A25024
A25037
A25049
A25053
A25061
A25068
A25071
A25088
A25089
A25091
A25103
A25104
A25119
A25138
A25146
A25157
A25178
A25187
A25190
A25192
A25194
A25197
A25201
A25211
A25214
A25226
A25239
A25250
A25291
A25297
A25305
A25309
A25323
A25355
A25356
A25362
A25376
A25390
A25398
A25430
A25432
A25436
A25437
A25465
A25466
A25469
A25476
A25480
A25486
A25493
A25494
B28001
B28002
B28003
B28004
B28005
B28006
B28007
B28008
B28009
B28010
B28011
B28012
B28013
B28014
B28015
B28016
B28017
B28018
B28019
B28020
B28021
B28022
B28023
B28024
B28025
B28026
B28027
B28028
B28029
B28030
B28031

In [20]:
level =2
label_filepath = os.path.join(data_dir, 'labels-level{}.tsv'.format(str(level)))
df_labels = load_labels_from_tsv(label_filepath, values[str(level)])

In [21]:
a = df_labels.keys()
for key in df_labels.keys():
  print(len(df_labels[key]),key)

5270 Argument ID
5270 Achievement
5270 Benevolence: caring
5270 Benevolence: dependability
5270 Conformity: interpersonal
5270 Conformity: rules
5270 Face
5270 Hedonism
5270 Humility
5270 Power: dominance
5270 Power: resources
5270 Security: personal
5270 Security: societal
5270 Self-direction: action
5270 Self-direction: thought
5270 Stimulation
5270 Tradition
5270 Universalism: concern
5270 Universalism: nature
5270 Universalism: objectivity
5270 Universalism: tolerance


In [22]:

df_labels['Achievement'][0]

0

In [23]:
from typing import Dict, List

In [24]:
def generate_pairwise_input(dataset, labels, datatype):
    """
    TODO: group all premises and corresponding hypotheses and labels of the datapoints
    a datapoint as seen earlier is a dict of premis, hypothesis and label
    """
    #raise NotImplementedError
    premise=[]
    conclusion=[]
    stance=[]
    n_labels =labels.keys()
    n_labels = n_labels[1:]
    print(n_labels)
    label=[]
    
    n = len(dataset['Argument ID'])
    m = len(labels['Argument ID'])
    arguments = []
    print(n,m)
    for i in range(n):
        if dataset['Usage'][i]==datatype:
          premise.append(dataset['Premise'][i])
          conclusion.append(dataset['Conclusion'][i])
          stance.append(dataset['Stance'][i])
          arguments.append(dataset['Argument ID'][i])
    for i in range(m):
        if (labels['Argument ID'][i] in arguments):
          sent_label = []
          #print(i)
          for l in range(len(n_labels)):
              #print(n_labels[l])
              sent_label.append(int(labels[n_labels[l]][i]))
          label.append(sent_label)

    return premise, conclusion, stance, label

In [55]:
train_premises, train_conclusion, train_stance, train_labels = generate_pairwise_input(df_arguments, df_labels, 'train')
val_premises, val_conclusion, val_stance, val_labels = generate_pairwise_input(df_arguments, df_labels, 'validation')
test_premises, test_conclusion, test_stance, test_labels = generate_pairwise_input(df_arguments, df_labels, 'test')

Index(['Achievement', 'Benevolence: caring', 'Benevolence: dependability',
       'Conformity: interpersonal', 'Conformity: rules', 'Face', 'Hedonism',
       'Humility', 'Power: dominance', 'Power: resources',
       'Security: personal', 'Security: societal', 'Self-direction: action',
       'Self-direction: thought', 'Stimulation', 'Tradition',
       'Universalism: concern', 'Universalism: nature',
       'Universalism: objectivity', 'Universalism: tolerance'],
      dtype='object')
5270 5270
Index(['Achievement', 'Benevolence: caring', 'Benevolence: dependability',
       'Conformity: interpersonal', 'Conformity: rules', 'Face', 'Hedonism',
       'Humility', 'Power: dominance', 'Power: resources',
       'Security: personal', 'Security: societal', 'Self-direction: action',
       'Self-direction: thought', 'Stimulation', 'Tradition',
       'Universalism: concern', 'Universalism: nature',
       'Universalism: objectivity', 'Universalism: tolerance'],
      dtype='object')
5270 5

In [26]:
# Nothing to do for this class!
import torch
from transformers import BertModel
from transformers import AutoTokenizer
from typing import Dict, List

class BatchTokenizer:
    """Tokenizes and pads a batch of input sentences."""

    def __init__(self):
        """Initializes the tokenizer

        Args:
            pad_symbol (Optional[str], optional): The symbol for a pad. Defaults to "<P>".
        """
        self.hf_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
    
    def get_sep_token(self,):
        return self.hf_tokenizer.sep_token
    
    def __call__(self, prem_batch: List[str], hyp_batch: List[str], stance_batch: List[str]) -> List[List[str]]:
        """Uses the huggingface tokenizer to tokenize and pad a batch.

        We return a dictionary of tensors per the huggingface model specification.

        Args:
            batch (List[str]): A List of sentence strings

        Returns:
            Dict: The dictionary of token specifications provided by HuggingFace
        """
        # The HF tokenizer will PAD for us, and additionally combine 
        # The two sentences deimited by the [SEP] token.
        batch_len = len(prem_batch)
        #spaces = [" "]*batch_len
        conc_batch = [stance_batch[i]+" "+hyp_batch[i] for i in range(batch_len)]
        enc = self.hf_tokenizer(
            prem_batch,
            conc_batch,
            padding=True,
            return_token_type_ids=False,
            return_tensors='pt'
        )

        return enc
    

# HERE IS AN EXAMPLE OF HOW TO USE THE BATCH TOKENIZER
tokenizer = BatchTokenizer()
a = [["this is the premise.", "This is also a premise"], ["this is the hypothesis", "This is a second hypothesis"],["in favour of", "against"]]
x = tokenizer(*a)
print(x)
tokenizer.hf_tokenizer.batch_decode(x["input_ids"])



{'input_ids': tensor([[  101,  2023,  2003,  1996, 18458,  1012,   102,  1999,  7927,  1997,
          2023,  2003,  1996, 10744,   102],
        [  101,  2023,  2003,  2036,  1037, 18458,   102,  2114,  2023,  2003,
          1037,  2117, 10744,   102,     0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]])}


['[CLS] this is the premise. [SEP] in favour of this is the hypothesis [SEP]',
 '[CLS] this is also a premise [SEP] against this is a second hypothesis [SEP] [PAD]']

In [27]:
def chunk(lst, n):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[:][i:i + n]

def chunk_multi(lst1, lst2, lst3, n):
    for i in range(0, len(lst1), n):
        yield lst1[i: i + n], lst2[i: i + n], lst3[i: i + n]
        


In [28]:
sum=0
import numpy as np
# for i in range(5270):
#   sum += np.sum(np.array(train_labels[:][i]))
print(np.sum(np.array(train_labels)))
#print(sum)

14675


In [56]:
# Notice that since we use huggingface, we tokenize and
# encode in all at once!
batch_size=64
tokenizer = BatchTokenizer()
train_input_batches = [b for b in chunk_multi(train_premises, train_conclusion, train_stance, batch_size)]
# Tokenize + encode
train_input_batches = [tokenizer(*batch) for batch in train_input_batches]

In [57]:
val_input_batches = [b for b in chunk_multi(val_premises, val_conclusion, val_stance, batch_size)]
# Tokenize + encode
val_input_batches = [tokenizer(*batch) for batch in val_input_batches]


In [31]:
len(val_labels)

277

In [32]:
label_ids = ['Achievement', 'Benevolence: caring', 'Benevolence: dependability',
       'Conformity: interpersonal', 'Conformity: rules', 'Face', 'Hedonism',
       'Humility', 'Power: dominance', 'Power: resources',
       'Security: personal', 'Security: societal', 'Self-direction: action',
       'Self-direction: thought', 'Stimulation', 'Tradition',
       'Universalism: concern', 'Universalism: nature',
       'Universalism: objectivity', 'Universalism: tolerance']
l_ids = []
for l in label_ids:
    a =[l]
    l_ids.append(a)


In [33]:
from transformers import BertTokenizer, BertModel
import torch
import numpy as np
embeddings = []
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
for i in range(20):
    input_ids = tokenizer(l_ids[i], return_tensors="pt")
    output = model(**input_ids)
    final_layer = output.last_hidden_state[:,0,:]
    #print(final_layer.shape)
    #hidden_shape = final_layer.shape
    #embed = torch.reshape(final_layer,(1,hidden_shape[1]) )
    embed = final_layer[0]
    embeddings.append(embed.detach().numpy())
embeddings = np.array(embeddings)
embeddings.shape

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


(20, 768)

In [34]:
from nltk.cluster import KMeansClusterer
import nltk

def clustering_question(data,NUM_CLUSTERS = 5):
    kclusterer = KMeansClusterer(
        NUM_CLUSTERS, distance=nltk.cluster.util.cosine_distance,
        repeats=1000,avoid_empty_clusters=True)

    assigned_clusters = kclusterer.cluster(data, assign_clusters=True)

    return assigned_clusters

In [35]:
clusters = clustering_question(embeddings)
clusters

[1, 3, 0, 0, 3, 1, 2, 1, 4, 4, 2, 2, 3, 3, 1, 1, 2, 0, 0, 0]

In [36]:
cluster_labels =[[],[],[],[],[]]
cluster_ids = [[],[],[],[],[]]
for i in range(20):
    cluster_labels[clusters[i]].append(label_ids[i])
    cluster_ids[clusters[i]].append(i)
cluster_ids, cluster_labels

([[2, 3, 17, 18, 19],
  [0, 5, 7, 14, 15],
  [6, 10, 11, 16],
  [1, 4, 12, 13],
  [8, 9]],
 [['Benevolence: dependability',
   'Conformity: interpersonal',
   'Universalism: nature',
   'Universalism: objectivity',
   'Universalism: tolerance'],
  ['Achievement', 'Face', 'Humility', 'Stimulation', 'Tradition'],
  ['Hedonism',
   'Security: personal',
   'Security: societal',
   'Universalism: concern'],
  ['Benevolence: caring',
   'Conformity: rules',
   'Self-direction: action',
   'Self-direction: thought'],
  ['Power: dominance', 'Power: resources']])

In [37]:
for i in range(20):
  print(i, label_ids[i])
print(np.nonzero(train_labels[0]), train_premises[0], train_labels[0])

0 Achievement
1 Benevolence: caring
2 Benevolence: dependability
3 Conformity: interpersonal
4 Conformity: rules
5 Face
6 Hedonism
7 Humility
8 Power: dominance
9 Power: resources
10 Security: personal
11 Security: societal
12 Self-direction: action
13 Self-direction: thought
14 Stimulation
15 Tradition
16 Universalism: concern
17 Universalism: nature
18 Universalism: objectivity
19 Universalism: tolerance
(array([11], dtype=int64),) if entrapment can serve to more easily capture wanted criminals, then why shouldn't it be legal? [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]


In [38]:
# 0 Self-direction: thought
# 1 Self-direction: action	
# 2 Stimulation	
# 3 Hedonism	
# 4 Achievement	
# 5 Power: dominance	
# 6 Power: resources	
# 7 Face	
# 8 Security: personal	
# 9 Security: societal	
# 10 Tradition	
# 11 Conformity: rules	
# 12 Conformity: interpersonal	
# 13 Humility	
# 14 Benevolence: caring	
# 15 Benevolence: dependability	
# 16 Universalism: concern	
# 17 Universalism: nature	
# 18 Universalism: tolerance	
# 19 Universalism: objectivity

In [39]:
# # Achievement, Face, Power: dominance, Power: resources [4, 7, 5, 6]

# #Benevolence: caring, Benevolence: dependability, Humility, Universalism: concern [14, 15, 13, 16]

# # Stimulation, Tradition, Self-direction: action, Self-direction: thought [2, 10, 1, 0]

# # Conformity: interpersonal, Conformity: rules, Security: personal, Security: societal [12, 11, 8, 9]

# # Hedonism, Universalism: nature, Universalism: objectivity, Universalism: tolerance [3, 17,18, 19]
# # Note: This is just one possible way to group these elements. There may be other valid ways to do so.


# clusters = [
#     [4, 7, 5, 6],
#     [14, 15, 13, 16],
#     [2, 10, 1, 0],
#     [12, 11, 8, 9],
#     [3, 17,18, 19]
# ]


In [40]:
# Achievement, Face, Power: dominance, Power: resources [0, 5, 8, 9]


#Benevolence: caring, Benevolence: dependability, Humility, Universalism: concern [1,2,7, 16]


# Stimulation, Tradition, Self-direction: action, Self-direction: thought [14, 15, 12, 13]


# Conformity: interpersonal, Conformity: rules, Security: personal, Security: societal [3, 4, 10, 11]


# Hedonism, Universalism: nature, Universalism: objectivity, Universalism: tolerance [6, 17,18, 19]
# Note: This is just one possible way to group these elements. There may be other valid ways to do so.


# clusters = [
#     [0, 5, 8, 9],
#     [1,2,7, 16],
#     [14, 15, 12, 13],
#     [3, 4, 10, 11],
#     [6, 17,18, 19]
# ]


In [41]:
def encode_labels(labels: List[List[int]]) -> torch.FloatTensor:
    """Turns the batch of labels into a tensor

    Args:
        labels (List[List[int]]): List of all labels in the batch

    Returns:
        torch.FloatTensor: Tensor of all labels in the batch
    """
    
    return torch.LongTensor(labels)


In [42]:
check_labels = encode_labels(train_labels[0:10])
np.nonzero(train_labels[9]), np.nonzero(check_labels[9])

((array([12, 13, 15, 16], dtype=int64),),
 tensor([[12],
         [13],
         [15],
         [16]]))

In [43]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [44]:
train_label_batches = [b for b in chunk(train_labels, batch_size)]
train_label_batches = [encode_labels(batch) for batch in train_label_batches]

In [45]:
val_label_batches = [b for b in chunk(val_labels, batch_size)]
val_label_batches = [encode_labels(batch) for batch in val_label_batches]

In [46]:
val_label_batches[0][0]

tensor([1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1])

In [47]:
class GroupedClassifier(torch.nn.Module):
    def __init__(self, output_size: int, hidden_size: int):
        super().__init__()
        self.output_size = output_size
        self.hidden_size = hidden_size
        # Initialize BERT, which we use instead of a single embedding layer.
        self.bert = BertModel.from_pretrained("bert-base-uncased")
        # TODO [OPTIONAL]: Updating all BERT parameters can be slow and memory intensive. 
        # Freeze them if training is too slow. Notice that the learning
        # rate should probably be smaller in this case.
        # Uncommenting out the below 2 lines means only our classification layer will be updated.
        for param in self.bert.parameters():
            param.requires_grad = False
        self.bert_hidden_dimension = self.bert.config.hidden_size
        print(self.bert_hidden_dimension)
        # TODO: Add an extra hidden layer in the classifier, projecting
        #      from the BERT hidden dimension to hidden size.
        # TODO: Add a relu nonlinearity to be used in the forward method
        #      https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html
        #self.middle_layer1 = torch.nn.Linear(self.bert_hidden_dimension, 64)
        #self.middle_layer2 = torch.nn.Linear(self.bert_hidden_dimension, 64)
        #self.middle_layer3 = torch.nn.Linear(self.bert_hidden_dimension, 64)
        #self.middle_layer4 = torch.nn.Linear(self.bert_hidden_dimension, 64)
        #self.middle_layer5 = torch.nn.Linear(self.bert_hidden_dimension, 64)
        #self.middle_layer6 = torch.nn.Linear(self.bert_hidden_dimension, 64)
        
        self.hidden_layer = torch.nn.Linear(self.bert_hidden_dimension, 512)
        self.hidden_layer1 = torch.nn.Linear(512, 32)
        self.hidden_layer2 = torch.nn.Linear(512, 32)
        self.hidden_layer3 = torch.nn.Linear(512, 32)
        self.hidden_layer4 = torch.nn.Linear(512, 32)
        self.hidden_layer5 = torch.nn.Linear(512, 32)
        self.hidden_layer6 = torch.nn.Linear(512, 32)
        self.hidden_layer7 = torch.nn.Linear(512, 32)
        self.hidden_layer8 = torch.nn.Linear(512, 32)
        self.hidden_layer9 = torch.nn.Linear(512, 32)
        self.hidden_layer10 = torch.nn.Linear(512, 32)
        self.hidden_layer11 = torch.nn.Linear(512, 32)
        self.hidden_layer12 = torch.nn.Linear(512, 32)
        self.hidden_layer13 = torch.nn.Linear(512, 32)
        self.hidden_layer14 = torch.nn.Linear(512, 32)
        self.hidden_layer15 = torch.nn.Linear(512, 32)
        self.hidden_layer16 = torch.nn.Linear(512, 32)
        self.hidden_layer17 = torch.nn.Linear(512, 32)
        self.hidden_layer18 = torch.nn.Linear(512, 32)
        self.hidden_layer19 = torch.nn.Linear(512, 32)
        self.hidden_layer20 = torch.nn.Linear(512, 32)
        #self.hidden_layer2 = torch.nn.Linear(self.hidden_size, 32)
        #self.hidden_layer3 = torch.nn.Linear(128, 32)
        #self.hidden_layer4 = torch.nn.Linear(32, 8)
        self.relu = torch.nn.ReLU()
        self.classifier1 = torch.nn.Linear(32, 1)
        self.classifier2 = torch.nn.Linear(32, 1)
        self.classifier3 = torch.nn.Linear(32, 1)
        self.classifier4 = torch.nn.Linear(32, 1)
        self.classifier5 = torch.nn.Linear(32, 1)
        self.classifier6 = torch.nn.Linear(32, 1)
        self.classifier7 = torch.nn.Linear(32, 1)
        self.classifier8 = torch.nn.Linear(32, 1)
        self.classifier9 = torch.nn.Linear(32, 1)
        self.classifier10 = torch.nn.Linear(32, 1)
        self.classifier11 = torch.nn.Linear(32, 1)
        self.classifier12 = torch.nn.Linear(32, 1)
        self.classifier13 = torch.nn.Linear(32, 1)
        self.classifier14 = torch.nn.Linear(32, 1)
        self.classifier15 = torch.nn.Linear(32, 1)
        self.classifier16 = torch.nn.Linear(32, 1)
        self.classifier17 = torch.nn.Linear(32, 1)
        self.classifier18 = torch.nn.Linear(32, 1)
        self.classifier19 = torch.nn.Linear(32, 1)
        self.classifier20 = torch.nn.Linear(32, 1)
        
        #self.classifier_middle1 = torch.nn.Linear(64, 1)
        #self.classifier_middle2 = torch.nn.Linear(64, 1)
        #self.classifier_middle3 = torch.nn.Linear(64, 1)
        #self.classifier_middle4 = torch.nn.Linear(64, 1)
        #self.classifier_middle5 = torch.nn.Linear(64, 1)
        #self.classifier = torch.nn.Linear(self.hidden_size, self.output_size)
        self.log_softmax = torch.nn.LogSoftmax(dim=2)

    def encode_text(
        self,
        symbols: Dict
    ) -> torch.Tensor:
        """Encode the (batch of) sequence(s) of token symbols with an LSTM.
            Then, get the last (non-padded) hidden state for each symbol and return that.

        Args:
            symbols (Dict): The Dict of token specifications provided by the HuggingFace tokenizer

        Returns:
            torch.Tensor: The final hiddens tate of the LSTM, which represents an encoding of
                the entire sentence
        """
        # First we get the contextualized embedding for each input symbol
        # We no longer need an LSTM, since BERT encodes context and 
        # gives us a single vector describing the sequence in the form of the [CLS] token.
        embedded = self.bert(**symbols)
        #print(embedded)
        #print("Embedded", embedded.pooler_output.shape, embedded.last_hidden_state.shape)
        # TODO: Get the [CLS] token using the `pooler_output` from 
        #      The BertModel output. See here: https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel
        #      and check the returns for the forward method.
        # We want to return a tensor of the form batch_size x 1 x bert_hidden_dimension
        #raise NotImplementedError
        
        #pool_output_shape = embedded.pooler_output.shape
        #return torch.reshape(embedded.pooler_output,(pool_output_shape[0],1,pool_output_shape[1]) )
        last_hidden_state = embedded.last_hidden_state[:,0,:]
        hidden_shape = last_hidden_state.shape
        return torch.reshape(last_hidden_state,(hidden_shape[0],1,hidden_shape[1]) )

    def forward(
        self,
        symbols: Dict,
    ) -> torch.Tensor:
        """_summary_

        Args:
            symbols (Dict): The Dict of token specifications provided by the HuggingFace tokenizer

        Returns:
            torch.Tensor: _description_
        """
        encoded_sents = self.encode_text(symbols)
        #output = self.hidden_layer1(encoded_sents)
        #output = self.relu(output)
        #outputs = [self.hidden_layers[i](encoded_sents) for i in range(self.output_size)]
        #outputs = [self.relu(outputs[i].to(device)) for i in range(self.output_size)]
        #outputs = [self.classifiers[i](outputs[i].to(device)) for i in range(self.output_size)]
        # outputs = []
        # for i in range(self.output_size):
        #     output = self.hidden_layers[i](encoded_sents)
        #     output = self.relu(output)
        #     output = self.classifiers[i](output)
        #     output = torch.nn.Sigmoid()(output)
        #     outputs.append(output)
        
        #middle1 = self.middle_layer1(encoded_sents)
        #middle1 = self.relu(middle1)
        #middle_output1 = self.classifier_middle1(middle1)
        #middle_output1 = torch.nn.Sigmoid()(middle_output1)
        
        #middle2 = self.middle_layer2(encoded_sents)
        #middle2 = self.relu(middle2)
        #middle_output2 = self.classifier_middle2(middle2)
        #middle_output2 = torch.nn.Sigmoid()(middle_output2)
        
        #middle3 = self.middle_layer3(encoded_sents)
        #middle3 = self.relu(middle3)
        #middle_output3 = self.classifier_middle3(middle3)
        #middle_output3 = torch.nn.Sigmoid()(middle_output3)
        
        #middle4 = self.middle_layer4(encoded_sents)
        #middle4 = self.relu(middle4)
        #middle_output4 = self.classifier_middle4(middle4)
        #middle_output4 = torch.nn.Sigmoid()(middle_output4)
        
        #middle5 = self.middle_layer5(encoded_sents)
        #middle5 = self.relu(middle5)
        #middle_output5 = self.classifier_middle5(middle5)
        #middle_output5 = torch.nn.Sigmoid()(middle_output5)
        enc = self.hidden_layer(encoded_sents)
        output1 = self.hidden_layer1(enc)
        output1 = self.relu(output1)
        output1 = self.classifier1(output1)
        output1 = torch.nn.Sigmoid()(output1)
        
        output2 = self.hidden_layer2(enc)
        output2 = self.relu(output2)
        output2 = self.classifier2(output2)
        output2 = torch.nn.Sigmoid()(output2)
        
        output3 = self.hidden_layer3(enc)
        output3 = self.relu(output3)
        output3 = self.classifier3(output3)
        output3 = torch.nn.Sigmoid()(output3)
        
        output4 = self.hidden_layer4(enc)
        output4 = self.relu(output4)
        output4 = self.classifier4(output4)
        output4 = torch.nn.Sigmoid()(output4)
        
        output5 = self.hidden_layer5(enc)
        output5 = self.relu(output5)
        output5 = self.classifier5(output5)
        output5 = torch.nn.Sigmoid()(output5)
        
        
        output6 = self.hidden_layer6(enc)
        output6 = self.relu(output6)
        output6 = self.classifier6(output6)
        output6 = torch.nn.Sigmoid()(output6)
        
        
        output7 = self.hidden_layer7(enc)
        output7 = self.relu(output7)
        output7 = self.classifier7(output7)
        output7 = torch.nn.Sigmoid()(output7)
        
        
        output8 = self.hidden_layer8(enc)
        output8 = self.relu(output8)
        output8 = self.classifier8(output8)
        output8 = torch.nn.Sigmoid()(output8)
        
        output9 = self.hidden_layer9(enc)
        output9 = self.relu(output9)
        output9 = self.classifier9(output9)
        output9 = torch.nn.Sigmoid()(output9)
        
        output10 = self.hidden_layer10(enc)
        output10 = self.relu(output10)
        output10 = self.classifier10(output10)
        output10 = torch.nn.Sigmoid()(output10)
        
        output11 = self.hidden_layer11(enc)
        output11 = self.relu(output11)
        output11 = self.classifier11(output11)
        output11 = torch.nn.Sigmoid()(output11)
        
        output12 = self.hidden_layer12(enc)
        output12 = self.relu(output12)
        output12 = self.classifier12(output12)
        output12 = torch.nn.Sigmoid()(output12)
        
        output13 = self.hidden_layer13(enc)
        output13 = self.relu(output13)
        output13 = self.classifier13(output13)
        output13 = torch.nn.Sigmoid()(output13)
        
        output14 = self.hidden_layer14(enc)
        output14 = self.relu(output14)
        output14 = self.classifier14(output14)
        output14 = torch.nn.Sigmoid()(output14)
        
        output15 = self.hidden_layer15(enc)
        output15 = self.relu(output15)
        output15 = self.classifier15(output15)
        output15 = torch.nn.Sigmoid()(output15)
        
        output16 = self.hidden_layer16(enc)
        output16 = self.relu(output16)
        output16 = self.classifier16(output16)
        output16 = torch.nn.Sigmoid()(output16)
        
        output17 = self.hidden_layer17(enc)
        output17 = self.relu(output17)
        output17 = self.classifier17(output17)
        output17 = torch.nn.Sigmoid()(output17)
        
        output18 = self.hidden_layer18(enc)
        output18 = self.relu(output18)
        output18 = self.classifier18(output18)
        output18 = torch.nn.Sigmoid()(output18)
        
        output19 = self.hidden_layer19(enc)
        output19 = self.relu(output19)
        output19 = self.classifier19(output19)
        output19 = torch.nn.Sigmoid()(output19)
        
        output20 = self.hidden_layer20(enc)
        output20 = self.relu(output20)
        output20 = self.classifier20(output20)
        output20 = torch.nn.Sigmoid()(output20)
        outputs = torch.cat((output1, output2, output3, output4, output5, output6, output7, output8, output9, output10, output11, output12, output13, output14, output15, output16, output17, output18, output19, output20),2)
        #group_outputs = torch.cat((middle_output1, middle_output2, middle_output3, middle_output4, middle_output5),2)
        return outputs

In [48]:
def predict(model: torch.nn.Module, sents: torch.Tensor) -> List:
    sents = sents.to(device)
    logits = model(sents)
    res = []
    logitslen = len(logits)
    #print(logits[0].shape)
    for i in range(logitslen):
        datares = []
        for j in range(20):
            datares.append(logits[i][0][j] > 0.5)
        res.append(datares)
    return res


In [49]:
import numpy as np

from numpy import logical_and, sum as t_sum
def precision(predicted_labels, true_labels, which_label=1):
    """
    Precision is True Positives / All Positives Predictions
    """
    pred_which = np.array([pred == which_label for pred in predicted_labels])
    true_which = np.array([lab == which_label for lab in true_labels])
    denominator = t_sum(pred_which)
    if denominator:
        return t_sum(logical_and(pred_which, true_which))/denominator
    else:
        return 0.


def recall(predicted_labels, true_labels, which_label=1):
    """
    Recall is True Positives / All Positive Labels
    """
    pred_which = np.array([pred == which_label for pred in predicted_labels])
    true_which = np.array([lab == which_label for lab in true_labels])
    denominator = t_sum(true_which)
    if denominator:
        return t_sum(logical_and(pred_which, true_which))/denominator
    else:
        return 0.


def f1_score(
    predicted_labels: List[int],
    true_labels: List[int],
    which_label: int
):
    """
    F1 score is the harmonic mean of precision and recall
    """
    P = precision(predicted_labels, true_labels, which_label=which_label)
    R = recall(predicted_labels, true_labels, which_label=which_label)
    if P and R:
        return 2*P*R/(P+R)
    else:
        return 0.


def macro_f1(
    predicted_labels: List[int],
    true_labels: List[int],
    possible_labels: List[int]
):
    scores = [f1_score(predicted_labels, true_labels, l) for l in possible_labels]
    # Macro, so we take the uniform avg.
    print(scores)
    return sum(scores) / len(scores)

In [50]:
def f1Score_multiLabel(preds, labels):
    nLabels = 20
    relevants = [0]*20
    positives = [0]*20
    truePositives = [0]*20
    correct = [0]*20
    for i in range(len(preds)):
        for j in range(nLabels):
            if(preds[i][j]==1):
                positives[j] += 1
                if(labels[i][j]==1):
                    truePositives[j] += 1
    for i in range(len(preds)):
        for j in range(nLabels):
            if(preds[i][j]==labels[i][j]):
                correct[j] += 1
    
    for i in range(len(labels)):
        for j in range(nLabels):
            if(labels[i][j]==1):
                relevants[j] += 1
    
    precisions = []*nLabels
    recalls = []*nLabels
    f1Scores = []*nLabels
    accuracies = []*nLabels
    
    precision =0
    recall = 0
    f1 = 0
    #print(truePositives, positives, relevants)
    for i in range(nLabels):
        if(positives[i]>0):
            precision = truePositives[i]/positives[i]
        precisions.append(precision)
        if(relevants[i]>0):
            recall = truePositives[i]/relevants[i]
        recalls.append(recall)
        #print(precision,recall,i)
        if(precision>0 and recall>0):
            f1 = 2 * precision * recall / (precision + recall)
        f1Scores.append(f1)
        accuracies.append(correct[i]/len(preds))
    precision_mean = np.mean(precisions)
    recall_mean = np.mean(recalls)
    f1_mean = np.mean(f1Scores)
    accuracy = np.mean(accuracies)
    return f1_mean, precision_mean, recall_mean, accuracy, f1Scores, precisions, recalls, accuracies
    


In [51]:
import random
from tqdm import tqdm_notebook as tqdm
from torchmetrics.classification import MultilabelHammingDistance
def training_loop(
    num_epochs,
    train_features,
    train_labels,
    dev_sents,
    dev_labels,
    optimizer,
    scheduler,
    model,
):
    print("Training...")
    all_f1 = []
    all_P = []
    all_R = []
    all_L = []
    all_CELoss = []
    all_HMLoss = []
    all_acc = []
    loss_func = torch.nn.CrossEntropyLoss()
    hammingLoss = MultilabelHammingDistance(num_labels=20)
    batches = list(zip(train_features, train_labels))
    random.shuffle(batches)
    for i in range(num_epochs):
        losses = []
        #CELosses = []
        #HMLosses = []
        for features, labels in tqdm(batches):
            # Empty the dynamic computation graph
            features = features.to(device)
            labels = labels.float()
            labels = labels.to(device)
            optimizer.zero_grad()
            preds = model(features)
            loss = loss_func(preds.squeeze(1),labels[:,:20])
            #group_loss = loss_func(preds[1].squeeze(1),labels[:,20:25])
            #loss = local_loss + 4*group_loss
            """
            loss0 = loss_func(preds[0].squeeze(1).squeeze(1), labels[:,0])
            #loss = loss_func(preds0.squeeze(1), labels)
            loss1 = loss_func(preds[1].squeeze(1).squeeze(1), labels[:,1]) 
            loss2 = loss_func(preds[2].squeeze(1).squeeze(1), labels[:,2]) 
            loss3 = loss_func(preds[3].squeeze(1).squeeze(1), labels[:,3]) 
            loss4 = loss_func(preds[4].squeeze(1).squeeze(1), labels[:,4]) 
            loss5 = loss_func(preds[5].squeeze(1).squeeze(1), labels[:,5]) 
            loss6 = loss_func(preds[6].squeeze(1).squeeze(1), labels[:,6]) 
            loss7 = loss_func(preds[7].squeeze(1).squeeze(1), labels[:,7]) 
            loss8 = loss_func(preds[8].squeeze(1).squeeze(1), labels[:,8]) 
            loss9 = loss_func(preds[9].squeeze(1).squeeze(1), labels[:,9]) 
            loss10 = loss_func(preds[10].squeeze(1).squeeze(1), labels[:,10])
            loss11 = loss_func(preds[11].squeeze(1).squeeze(1), labels[:,11]) 
            loss12 = loss_func(preds[12].squeeze(1).squeeze(1), labels[:,12]) 
            loss13 = loss_func(preds[13].squeeze(1).squeeze(1), labels[:,13]) 
            loss14 = loss_func(preds[14].squeeze(1).squeeze(1), labels[:,14]) 
            loss15 = loss_func(preds[15].squeeze(1).squeeze(1), labels[:,15]) 
            loss16 = loss_func(preds[16].squeeze(1).squeeze(1), labels[:,16]) 
            loss17 = loss_func(preds[17].squeeze(1).squeeze(1), labels[:,17]) 
            loss18 = loss_func(preds[18].squeeze(1).squeeze(1), labels[:,18]) 
            loss19 = loss_func(preds[19].squeeze(1).squeeze(1), labels[:,19])
            loss_group1 = loss_func(preds[20].squeeze(1).squeeze(1), labels[:,20])
            loss_group2 = loss_func(preds[21].squeeze(1).squeeze(1), labels[:,21])
            loss_group3 = loss_func(preds[22].squeeze(1).squeeze(1), labels[:,22])
            loss_group4 = loss_func(preds[23].squeeze(1).squeeze(1), labels[:,23])
            loss_group5 = loss_func(preds[24].squeeze(1).squeeze(1), labels[:,24])
            
            local_loss = local_loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6 + loss7 + loss8 + loss9 + loss10 + loss11 + loss12 + loss13 + loss14 + loss15 + loss16 + loss17 + loss18 + loss19
            group_loss = loss_group1 + loss_group2 + loss_group3 + loss_group4 + loss_group5
            loss = local_loss + 4*group_loss
            """
            #print("Preds ",preds.shape)
            #print("Labels ", labels.shape)
            #loss = loss_func(preds, labels)
            #loss2 = 20*hammingLoss(preds, labels)
            
            #loss = loss1 + loss2
            # Backpropogate the loss through our model
            loss.backward()
            optimizer.step()
            #CELosses.append(loss1.item())
            #HMLosses.append(loss2.item())
            losses.append(loss.item())
        
        #print(f"epoch {i}, loss: {np.sum(losses)/len(losses)}, HM loss: {np.sum(HMLosses)/len(HMLosses)}, CE loss: {np.sum(CELosses)/len(CELosses)}")
        print(f"epoch {i}, loss: {np.sum(losses)/len(losses)}")
        # Estimate the f1 score for the development set
        print("Evaluating dev...")
        all_preds = []
        all_labels = []
        for sents, labels in tqdm(zip(dev_sents, dev_labels), total=len(dev_sents)):
            sents = sents.to(device)
            pred = predict(model, sents)
            all_preds.extend(pred)
            all_labels.extend(list(labels))
        # #print(range(len(set(train_labels))))

        dev_f1, dev_P, dev_R, dev_acc, dev_all_f1, dev_all_P, dev_all_R, dev_all_acc = f1Score_multiLabel(all_preds, all_labels)
        print(f"Dev F1 {dev_f1},  Dev Precision {dev_P}, Dev Recall {dev_R}, Dev Accuracy {dev_acc}")
        all_f1.append(dev_all_f1)
        all_P.append(dev_all_P)
        all_R.append(dev_all_R)
        all_L.append(losses)
        #all_CELoss.append(CELosses)
        #all_HMLoss.append(HMLosses)
        all_acc.append(dev_all_acc)
        scheduler.step()
        #print(optimizer)
    # Return the trained model
    with open("all_f1_base.csv", 'ab') as abc:
        np.savetxt(abc, 
               all_f1,
               delimiter =", ", 
               fmt ='%s')
    with open("all_P_base.csv", 'ab') as abc:
        np.savetxt(abc, 
               all_P,
               delimiter =", ", 
               fmt ='%s')
    with open("all_R_base.csv", 'ab') as abc:
        np.savetxt(abc, 
               all_R,
               delimiter =", ", 
               fmt ='%s')
    with open("all_acc_base.csv", 'ab') as abc:
        np.savetxt(abc, 
               all_acc,
               delimiter =", ", 
               fmt ='%s')
    with open("all_L_base.csv", 'ab') as abc:
        np.savetxt(abc, 
               all_L,
               delimiter =", ", 
               fmt ='%s')
    """
    with open("all_CELoss_base.csv", 'ab') as abc:
        np.savetxt(abc, 
               all_CELoss,
               delimiter =", ", 
               fmt ='%s')
    with open("all_HMLoss_base.csv", 'ab') as abc:
        np.savetxt(abc, 
               all_CELoss,
               delimiter =", ", 
               fmt ='%s')
    """
    return model

In [52]:
from transformers.optimization import get_linear_schedule_with_warmup
epochs = 200
epoch_warmup = 40
# TODO: Find a good learning rate
LR = 1e-4

possible_labels = 20
model = GroupedClassifier(output_size=possible_labels, hidden_size=512)
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), LR)
scheduler = get_linear_schedule_with_warmup(optimizer, epoch_warmup,epochs)


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


768


In [58]:
model =training_loop(
    epochs,
    train_input_batches,
    train_label_batches,
    val_input_batches,
    val_label_batches,
    optimizer,
    scheduler,
    model,
)

Training...


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for features, labels in tqdm(batches):


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 0, loss: 10.371476358442164
Evaluating dev...


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for sents, labels in tqdm(zip(dev_sents, dev_labels), total=len(dev_sents)):


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.2185782514080099,  Dev Precision 0.141015165603893, Dev Recall 0.5004992934580523, Dev Accuracy 0.48646209386281586


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 1, loss: 10.361774942768154
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.22695520340838599,  Dev Precision 0.13672827590325184, Dev Recall 0.4436054663650809, Dev Accuracy 0.5602888086642598


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 2, loss: 10.331963482187755
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.25292042015578514,  Dev Precision 0.2056866600763539, Dev Recall 0.36647875911913236, Dev Accuracy 0.6317689530685919


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 3, loss: 10.279778010809599
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.38842145470909306,  Dev Precision 0.21856700042737978, Dev Recall 0.4438655462184874, Dev Accuracy 0.6315884476534295


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 4, loss: 10.207405296724234
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.389618236559178,  Dev Precision 0.24705893894208134, Dev Recall 0.45, Dev Accuracy 0.6308664259927798


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 5, loss: 10.1279672793488
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.39491662621309176,  Dev Precision 0.25354135903456704, Dev Recall 0.4381707317073171, Dev Accuracy 0.6438628158844766


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 6, loss: 10.055112995318513
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4093915941614819,  Dev Precision 0.27468775207944884, Dev Recall 0.41951219512195126, Dev Accuracy 0.6653429602888087


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 7, loss: 9.992742958353526
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.41846730239286245,  Dev Precision 0.29569591058760736, Dev Recall 0.40963414634146345, Dev Accuracy 0.6749097472924188


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 8, loss: 9.942564437638469
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4089566603574917,  Dev Precision 0.2991018070167596, Dev Recall 0.3989837398373984, Dev Accuracy 0.6759927797833936


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 9, loss: 9.902786005788775
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4086414438158058,  Dev Precision 0.2980860454215625, Dev Recall 0.398760162601626, Dev Accuracy 0.6763537906137185


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 10, loss: 9.869995416100345
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4107722291751849,  Dev Precision 0.2937402370299921, Dev Recall 0.4024186991869919, Dev Accuracy 0.6756317689530686


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 11, loss: 9.843163269669262
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4104392551758892,  Dev Precision 0.28832480946650624, Dev Recall 0.40546747967479674, Dev Accuracy 0.6747292418772564


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 12, loss: 9.820578475496662
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.41669567029108484,  Dev Precision 0.29225304386528855, Dev Recall 0.4095731707317073, Dev Accuracy 0.6758122743682311


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 13, loss: 9.80042799195247
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4189285251555187,  Dev Precision 0.2781066756755378, Dev Recall 0.4118495934959349, Dev Accuracy 0.6754512635379062


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 14, loss: 9.78169417025438
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.41989933161831594,  Dev Precision 0.2781017859622327, Dev Recall 0.4126829268292683, Dev Accuracy 0.6754512635379062


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 15, loss: 9.763164925931106
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.40682208885743487,  Dev Precision 0.30166137714346913, Dev Recall 0.41637340301974446, Dev Accuracy 0.6754512635379062


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 16, loss: 9.744373577744213
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.3811009189282686,  Dev Precision 0.33251301505177067, Dev Recall 0.42119165332579966, Dev Accuracy 0.6750902527075813


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 17, loss: 9.725065387896638
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.3765470383639821,  Dev Precision 0.37060269963956566, Dev Recall 0.42754300467715095, Dev Accuracy 0.6754512635379062


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 18, loss: 9.705146910539314
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.38619248841973186,  Dev Precision 0.3912014330602218, Dev Recall 0.4546004742346206, Dev Accuracy 0.6779783393501805


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 19, loss: 9.684993622907951
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4016872234817904,  Dev Precision 0.3741338275381694, Dev Recall 0.46851971473389886, Dev Accuracy 0.6796028880866426


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 20, loss: 9.664809895985162
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.41364470178522394,  Dev Precision 0.3537370491915802, Dev Recall 0.4824454747518919, Dev Accuracy 0.682129963898917


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 21, loss: 9.644743136505582
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4207382392373182,  Dev Precision 0.3424176974604523, Dev Recall 0.4881276666825668, Dev Accuracy 0.6844765342960288


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 22, loss: 9.625201196812872
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4193497529901573,  Dev Precision 0.32363089439613635, Dev Recall 0.4916218790815122, Dev Accuracy 0.6851985559566787


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 23, loss: 9.60628134457033
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4223981928999331,  Dev Precision 0.32046565060007454, Dev Recall 0.4942562910894074, Dev Accuracy 0.6889891696750902


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 24, loss: 9.587926722284573
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4250573586617811,  Dev Precision 0.30033997684135705, Dev Recall 0.4954965226391049, Dev Accuracy 0.6929602888086643


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 25, loss: 9.570021252133953
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.43056550082428435,  Dev Precision 0.29970381634672977, Dev Recall 0.5009480134015787, Dev Accuracy 0.6927797833935019


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 26, loss: 9.552797118229652
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4326000217342429,  Dev Precision 0.2988592825500446, Dev Recall 0.5033038290386443, Dev Accuracy 0.6945848375451265


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 27, loss: 9.536403051063196
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.43403124785225244,  Dev Precision 0.2985404144499497, Dev Recall 0.5086483842254326, Dev Accuracy 0.6967509025270757


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 28, loss: 9.520756600508049
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.43749404723127167,  Dev Precision 0.300883157315023, Dev Recall 0.5131698154640967, Dev Accuracy 0.7


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 29, loss: 9.505860890915145
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4388940536592757,  Dev Precision 0.3015942467347854, Dev Recall 0.5117840809518452, Dev Accuracy 0.7025270758122744


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 30, loss: 9.49172870436711
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.442014083824595,  Dev Precision 0.30394551995345154, Dev Recall 0.5146760943110917, Dev Accuracy 0.7055956678700361


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 31, loss: 9.478194820347117
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.44652371773864924,  Dev Precision 0.3074720995651265, Dev Recall 0.5157051694026668, Dev Accuracy 0.7086642599277979


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 32, loss: 9.464832405545819
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4484002422194088,  Dev Precision 0.30852824984580346, Dev Recall 0.5165292563189867, Dev Accuracy 0.7110108303249097


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 33, loss: 9.450958963650375
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.43991963910420545,  Dev Precision 0.325040156816272, Dev Recall 0.5385735813305448, Dev Accuracy 0.7135379061371842


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 34, loss: 9.436884872948946
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.43877881901461013,  Dev Precision 0.32326209716097604, Dev Recall 0.5375319146638781, Dev Accuracy 0.7173285198555958


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 35, loss: 9.423486396447936
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4392689077957712,  Dev Precision 0.3236286639614492, Dev Recall 0.5364674888935699, Dev Accuracy 0.7196750902527076


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 36, loss: 9.41000985387546
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.44482397334335644,  Dev Precision 0.33043149701056773, Dev Recall 0.5346738964565951, Dev Accuracy 0.7236462093862815


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 37, loss: 9.395932944852914
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.44643834198865606,  Dev Precision 0.33264490833102633, Dev Recall 0.5338808681986426, Dev Accuracy 0.7287003610108304


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 38, loss: 9.382844953394647
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4477341204139183,  Dev Precision 0.3339038457737743, Dev Recall 0.5338808681986426, Dev Accuracy 0.7317689530685921


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 39, loss: 9.37084742446444
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.44610852640256937,  Dev Precision 0.3324748748123492, Dev Recall 0.5297988659901234, Dev Accuracy 0.731768953068592


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 40, loss: 9.35967530065508
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.44836599087738405,  Dev Precision 0.3230210910142931, Dev Recall 0.5309657407492312, Dev Accuracy 0.7355595667870036


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 41, loss: 9.349341463686814
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45382817335750414,  Dev Precision 0.3275270714838525, Dev Recall 0.5322795267695721, Dev Accuracy 0.740072202166065


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 42, loss: 9.33988922033737
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4540038578423536,  Dev Precision 0.32784034780013493, Dev Recall 0.5297200029600482, Dev Accuracy 0.7415162454873646


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 43, loss: 9.331185369349237
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.454693823210863,  Dev Precision 0.32872452267116115, Dev Recall 0.528401232733511, Dev Accuracy 0.7435018050541516


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 44, loss: 9.323102502680536
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4556530413745462,  Dev Precision 0.3298202027956952, Dev Recall 0.5270962012869702, Dev Accuracy 0.7447653429602888


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 45, loss: 9.315568988002948
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45623172279909047,  Dev Precision 0.33096171184574363, Dev Recall 0.5287796753332492, Dev Accuracy 0.7471119133574008


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 46, loss: 9.308487287208216
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45801382939149915,  Dev Precision 0.332815521278956, Dev Recall 0.5285555856973947, Dev Accuracy 0.7492779783393503


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 47, loss: 9.301789881578133
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45945854693987176,  Dev Precision 0.3340555527563332, Dev Recall 0.5299069370487461, Dev Accuracy 0.7505415162454875


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 48, loss: 9.295431037447345
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45632400945078694,  Dev Precision 0.3323297227697397, Dev Recall 0.5278053369747543, Dev Accuracy 0.7509025270758124


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 49, loss: 9.289395019189636
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.44565512010392966,  Dev Precision 0.34575545069960356, Dev Recall 0.5308913286529789, Dev Accuracy 0.7519855595667871


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 50, loss: 9.28365245505945
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.44808571731210023,  Dev Precision 0.34770533602513676, Dev Recall 0.5309758603132342, Dev Accuracy 0.7530685920577619


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 51, loss: 9.278169446916722
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.44886779589443,  Dev Precision 0.3488736536017655, Dev Recall 0.5304996698370437, Dev Accuracy 0.754332129963899


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 52, loss: 9.272928116926506
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4495102284071527,  Dev Precision 0.3495291443888519, Dev Recall 0.5304996698370437, Dev Accuracy 0.7559566787003611


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 53, loss: 9.267902851104736
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4523458052818389,  Dev Precision 0.3520749560795733, Dev Recall 0.5317817211190949, Dev Accuracy 0.7583032490974729


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 54, loss: 9.263009804398266
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45391997358579295,  Dev Precision 0.35384377182846954, Dev Recall 0.5317817211190949, Dev Accuracy 0.759927797833935


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 55, loss: 9.25814792291442
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4550043696640594,  Dev Precision 0.3549033028645994, Dev Recall 0.5317817211190949, Dev Accuracy 0.7611913357400721


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 56, loss: 9.253283365448908
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4555609777910079,  Dev Precision 0.35519060454865026, Dev Recall 0.5326150544524283, Dev Accuracy 0.7620938628158844


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 57, loss: 9.248673275335511
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45520488147858085,  Dev Precision 0.35272837878412405, Dev Recall 0.5315733877857616, Dev Accuracy 0.7624548736462093


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 58, loss: 9.244245564759668
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4555803926522298,  Dev Precision 0.3534101108022397, Dev Recall 0.5310971973095711, Dev Accuracy 0.7629963898916967


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 59, loss: 9.239882924663487
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4555843478358078,  Dev Precision 0.35186752274828736, Dev Recall 0.5315733877857616, Dev Accuracy 0.7629963898916967


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 60, loss: 9.235438019482057
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.44352744705105784,  Dev Precision 0.3878666870981427, Dev Recall 0.5351844988968727, Dev Accuracy 0.7635379061371841


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 61, loss: 9.230983783949666
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4440978442397284,  Dev Precision 0.3885126100664712, Dev Recall 0.5351844988968727, Dev Accuracy 0.7642599277978339


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 62, loss: 9.226528011151213
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.44460101815181946,  Dev Precision 0.3890473030886555, Dev Recall 0.5351844988968727, Dev Accuracy 0.7651624548736462


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 63, loss: 9.22205145679303
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4440696425707237,  Dev Precision 0.36322419420886753, Dev Recall 0.5351844988968727, Dev Accuracy 0.7651624548736462


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 64, loss: 9.217598032595506
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.44305466260592236,  Dev Precision 0.3498654374286493, Dev Recall 0.5341428322302061, Dev Accuracy 0.7644404332129964


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 65, loss: 9.21322520099469
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4434943834634685,  Dev Precision 0.34792458726851827, Dev Recall 0.5341428322302061, Dev Accuracy 0.7653429602888087


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 66, loss: 9.209063643839821
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45158108040675815,  Dev Precision 0.35874768755882314, Dev Recall 0.5400066016862086, Dev Accuracy 0.766245487364621


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 67, loss: 9.205140612018642
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4544031393512067,  Dev Precision 0.35780051592609924, Dev Recall 0.5431063321444296, Dev Accuracy 0.766245487364621


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 68, loss: 9.201447814258177
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4548079610054736,  Dev Precision 0.35798668381184007, Dev Recall 0.5431063321444296, Dev Accuracy 0.7666064981949459


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 69, loss: 9.197937815936644
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45432005768730993,  Dev Precision 0.35757713641310307, Dev Recall 0.542634634031222, Dev Accuracy 0.7666064981949459


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 70, loss: 9.194587657700724
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45444061261130714,  Dev Precision 0.357735648848255, Dev Recall 0.5420691578407458, Dev Accuracy 0.7669675090252708


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 71, loss: 9.191385560960912
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.450592734879382,  Dev Precision 0.3525444291851085, Dev Recall 0.5383168684634565, Dev Accuracy 0.765884476534296


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 72, loss: 9.188289443058753
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4504503539375649,  Dev Precision 0.35215597149707045, Dev Recall 0.5383168684634565, Dev Accuracy 0.7660649819494585


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 73, loss: 9.185325174189325
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4504489433590685,  Dev Precision 0.35245932016142034, Dev Recall 0.5377286331693388, Dev Accuracy 0.7664259927797834


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 74, loss: 9.182441433863854
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45209731031513034,  Dev Precision 0.3538173341983867, Dev Recall 0.5376411991572725, Dev Accuracy 0.7671480144404332


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 75, loss: 9.179660533791157
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45385114932450865,  Dev Precision 0.3547588607129535, Dev Recall 0.5386828658239391, Dev Accuracy 0.7680505415162455


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 76, loss: 9.17695520884955
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45336610094364904,  Dev Precision 0.35345925308820547, Dev Recall 0.5386828658239391, Dev Accuracy 0.7680505415162455


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 77, loss: 9.174340767646903
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45333946017816046,  Dev Precision 0.35378703110852716, Dev Recall 0.5371650086810821, Dev Accuracy 0.7684115523465704


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 78, loss: 9.171791311520249
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45323481738593585,  Dev Precision 0.35394803108501455, Dev Recall 0.5362078736746706, Dev Accuracy 0.7684115523465704


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 79, loss: 9.169325010100408
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4533174989631566,  Dev Precision 0.35423224526004105, Dev Recall 0.535736175561463, Dev Accuracy 0.7684115523465704


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 80, loss: 9.166914128545505
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4528696136359245,  Dev Precision 0.35319004309982643, Dev Recall 0.535736175561463, Dev Accuracy 0.768231046931408


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 81, loss: 9.164562360564275
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4530015008317655,  Dev Precision 0.3535739113820752, Dev Recall 0.5349549255614631, Dev Accuracy 0.7685920577617329


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 82, loss: 9.162264339959444
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4556258740766726,  Dev Precision 0.35580142048272584, Dev Recall 0.5377451041328916, Dev Accuracy 0.7691335740072203


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 83, loss: 9.160036257843473
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4552191511574553,  Dev Precision 0.3554734428561766, Dev Recall 0.536703437466225, Dev Accuracy 0.7689530685920578


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 84, loss: 9.157850172982288
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45539561527550204,  Dev Precision 0.3557067694958803, Dev Recall 0.536703437466225, Dev Accuracy 0.7689530685920578


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 85, loss: 9.155713650717663
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.455128624630259,  Dev Precision 0.3549610633297168, Dev Recall 0.536703437466225, Dev Accuracy 0.7689530685920578


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 86, loss: 9.153632213820272
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4553219432803798,  Dev Precision 0.3551396822558205, Dev Recall 0.536703437466225, Dev Accuracy 0.7693140794223827


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 87, loss: 9.151591222677657
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4548977752269021,  Dev Precision 0.3544633281530704, Dev Recall 0.5362317393530174, Dev Accuracy 0.7689530685920578


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 88, loss: 9.149601274461888
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4545559962617868,  Dev Precision 0.35374473757897407, Dev Recall 0.5362317393530174, Dev Accuracy 0.7689530685920578


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 89, loss: 9.147652953418334
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4543269296147995,  Dev Precision 0.3533965531547027, Dev Recall 0.5357600412398098, Dev Accuracy 0.7687725631768954


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 90, loss: 9.145748145544706
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45480017113915927,  Dev Precision 0.3537430669348784, Dev Recall 0.5357600412398098, Dev Accuracy 0.7689530685920578


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 91, loss: 9.143876154031327
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4545302656060959,  Dev Precision 0.35340973360154504, Dev Recall 0.5357600412398098, Dev Accuracy 0.7687725631768954


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 92, loss: 9.142050885442478
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45434564192957216,  Dev Precision 0.353166952584978, Dev Recall 0.5357600412398098, Dev Accuracy 0.7687725631768954


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 93, loss: 9.140265094700144
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45418068779866727,  Dev Precision 0.35314039643054096, Dev Recall 0.5352746043466059, Dev Accuracy 0.7687725631768954


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 94, loss: 9.13850402120334
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4541062221105971,  Dev Precision 0.35325928817848185, Dev Recall 0.5344933543466059, Dev Accuracy 0.7687725631768954


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 95, loss: 9.136792282559979
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45401452635008954,  Dev Precision 0.3529862047277904, Dev Recall 0.5344933543466059, Dev Accuracy 0.7687725631768954


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 96, loss: 9.13510297661397
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45394539441739923,  Dev Precision 0.3529341343996851, Dev Recall 0.5344933543466059, Dev Accuracy 0.7691335740072203


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 97, loss: 9.133451262516761
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45330881242672183,  Dev Precision 0.35254543744303557, Dev Recall 0.5331238690524882, Dev Accuracy 0.7687725631768954


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 98, loss: 9.13183109084172
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4518056326588912,  Dev Precision 0.3515844101280592, Dev Recall 0.5318418177704369, Dev Accuracy 0.7689530685920578


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 99, loss: 9.130241522148474
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4519141555126967,  Dev Precision 0.35168631601474465, Dev Recall 0.5318418177704369, Dev Accuracy 0.7691335740072203


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 100, loss: 9.128677467801678
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45226937354307395,  Dev Precision 0.3520558369117885, Dev Recall 0.5318418177704369, Dev Accuracy 0.7693140794223827


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 101, loss: 9.127148500129358
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45226937354307395,  Dev Precision 0.3520558369117885, Dev Recall 0.5318418177704369, Dev Accuracy 0.7693140794223827


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 102, loss: 9.125645139324131
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4518386383361984,  Dev Precision 0.351893662070303, Dev Recall 0.5310084844371036, Dev Accuracy 0.7691335740072203


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 103, loss: 9.124169862092431
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4526906999466136,  Dev Precision 0.35294643674256154, Dev Recall 0.5310084844371036, Dev Accuracy 0.7700361010830326


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 104, loss: 9.12272140161315
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4527332038280085,  Dev Precision 0.3529354297235419, Dev Recall 0.5310084844371036, Dev Accuracy 0.770216606498195


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 105, loss: 9.121303266553737
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4527398000297855,  Dev Precision 0.35280250225817567, Dev Recall 0.5310084844371036, Dev Accuracy 0.7700361010830326


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 106, loss: 9.119900810184763
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4526297235001976,  Dev Precision 0.353042378659763, Dev Recall 0.5295799130085321, Dev Accuracy 0.770216606498195


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 107, loss: 9.118524985526925
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4526297235001976,  Dev Precision 0.353042378659763, Dev Recall 0.5295799130085321, Dev Accuracy 0.770216606498195


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 108, loss: 9.117175116467832
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45274954547897195,  Dev Precision 0.3531779802376723, Dev Recall 0.5295799130085321, Dev Accuracy 0.7703971119133575


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 109, loss: 9.115848833055638
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4530570002552218,  Dev Precision 0.35334173751291864, Dev Recall 0.5296014338119756, Dev Accuracy 0.7703971119133575


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 110, loss: 9.1145488397399
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45344736035896294,  Dev Precision 0.35386993616611734, Dev Recall 0.5296014338119756, Dev Accuracy 0.7707581227436824


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 111, loss: 9.113271008676557
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4536793560090923,  Dev Precision 0.35414350809226336, Dev Recall 0.5296014338119756, Dev Accuracy 0.7711191335740073


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 112, loss: 9.112015062303685
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45346704123384984,  Dev Precision 0.3538235320157083, Dev Recall 0.5296014338119756, Dev Accuracy 0.7707581227436824


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 113, loss: 9.110777385199247
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45179777965643264,  Dev Precision 0.35265255497794346, Dev Recall 0.5283193825299243, Dev Accuracy 0.7705776173285199


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 114, loss: 9.109567756083473
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4522748586607448,  Dev Precision 0.35317404175610323, Dev Recall 0.5283193825299243, Dev Accuracy 0.7709386281588448


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 115, loss: 9.10837649587375
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45237590046055826,  Dev Precision 0.3534499634227451, Dev Recall 0.5278476844167168, Dev Accuracy 0.7712996389891698


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 116, loss: 9.107209575710012
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45249605476397947,  Dev Precision 0.35358703866984126, Dev Recall 0.5278476844167168, Dev Accuracy 0.7714801444043322


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 117, loss: 9.106057465966067
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45249605476397947,  Dev Precision 0.35358703866984126, Dev Recall 0.5278476844167168, Dev Accuracy 0.7714801444043322


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 118, loss: 9.104930379497471
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4526805108193838,  Dev Precision 0.3537855205100021, Dev Recall 0.5278476844167168, Dev Accuracy 0.7718411552346571


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 119, loss: 9.103823362891355
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4526805108193838,  Dev Precision 0.3537855205100021, Dev Recall 0.5278476844167168, Dev Accuracy 0.7718411552346571


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 120, loss: 9.102735270315142
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4520009107756495,  Dev Precision 0.35336898125344507, Dev Recall 0.5262331010833835, Dev Accuracy 0.7714801444043322


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 121, loss: 9.10166288489726
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4519290504221254,  Dev Precision 0.35350194892373055, Dev Recall 0.5254518510833834, Dev Accuracy 0.7714801444043322


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 122, loss: 9.10061034871571
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4519290504221254,  Dev Precision 0.35350194892373055, Dev Recall 0.5254518510833834, Dev Accuracy 0.7714801444043322


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 123, loss: 9.099579547768208
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4519290504221254,  Dev Precision 0.35350194892373055, Dev Recall 0.5254518510833834, Dev Accuracy 0.7714801444043322


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 124, loss: 9.098566802579965
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4521453258712131,  Dev Precision 0.3537695921940912, Dev Recall 0.5254518510833834, Dev Accuracy 0.7714801444043322


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 125, loss: 9.097569266361976
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4521453258712131,  Dev Precision 0.3537695921940912, Dev Recall 0.5254518510833834, Dev Accuracy 0.7714801444043322


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 126, loss: 9.09659086768307
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4526383924283529,  Dev Precision 0.3544933911476349, Dev Recall 0.5254518510833834, Dev Accuracy 0.7716606498194947


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 127, loss: 9.095629969639564
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4526383924283529,  Dev Precision 0.3544933911476349, Dev Recall 0.5254518510833834, Dev Accuracy 0.7716606498194947


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 128, loss: 9.09468735509844
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45303432499216384,  Dev Precision 0.3550613307412519, Dev Recall 0.5254518510833834, Dev Accuracy 0.7720216606498196


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 129, loss: 9.093756077894524
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.454030677825122,  Dev Precision 0.3561449344636752, Dev Recall 0.5254518510833834, Dev Accuracy 0.7727436823104693


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 130, loss: 9.09284442218382
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.454030677825122,  Dev Precision 0.3561449344636752, Dev Recall 0.5254518510833834, Dev Accuracy 0.7727436823104693


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 131, loss: 9.091953185067249
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.454030677825122,  Dev Precision 0.3561449344636752, Dev Recall 0.5254518510833834, Dev Accuracy 0.7727436823104693


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 132, loss: 9.091077761863595
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.454030677825122,  Dev Precision 0.3561449344636752, Dev Recall 0.5254518510833834, Dev Accuracy 0.7727436823104693


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 133, loss: 9.090217924829739
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4543037701945999,  Dev Precision 0.3563010242968687, Dev Recall 0.5259372879765873, Dev Accuracy 0.7729241877256318


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 134, loss: 9.089371674096407
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45442684313150883,  Dev Precision 0.3564431760251344, Dev Recall 0.5259372879765873, Dev Accuracy 0.7731046931407942


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 135, loss: 9.088542240769115
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45402075403046693,  Dev Precision 0.3560322754556014, Dev Recall 0.5259372879765873, Dev Accuracy 0.7727436823104693


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 136, loss: 9.0877318168754
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45402075403046693,  Dev Precision 0.3560322754556014, Dev Recall 0.5259372879765873, Dev Accuracy 0.7727436823104693


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 137, loss: 9.08693263068128
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4540927408671596,  Dev Precision 0.35609108712383636, Dev Recall 0.5259372879765873, Dev Accuracy 0.7729241877256318


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 138, loss: 9.086148418597322
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45402075403046693,  Dev Precision 0.3560322754556014, Dev Recall 0.5259372879765873, Dev Accuracy 0.7727436823104693


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 139, loss: 9.085379201974442
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45403628969380944,  Dev Precision 0.3560791570427898, Dev Recall 0.5259372879765873, Dev Accuracy 0.7727436823104693


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 140, loss: 9.0846278916544
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.453943441365667,  Dev Precision 0.355951633307298, Dev Recall 0.5259372879765873, Dev Accuracy 0.7727436823104693


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 141, loss: 9.083889057387166
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.453943441365667,  Dev Precision 0.355951633307298, Dev Recall 0.5259372879765873, Dev Accuracy 0.7727436823104693


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 142, loss: 9.083163296998437
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45381523623746184,  Dev Precision 0.3557351830908478, Dev Recall 0.5259372879765873, Dev Accuracy 0.7725631768953068


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 143, loss: 9.082452055233627
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45381523623746184,  Dev Precision 0.3557351830908478, Dev Recall 0.5259372879765873, Dev Accuracy 0.7725631768953068


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 144, loss: 9.081754470939067
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45219934219772673,  Dev Precision 0.3546786439703452, Dev Recall 0.5248956213099207, Dev Accuracy 0.7723826714801444


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 145, loss: 9.081073639997795
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45219934219772673,  Dev Precision 0.3546786439703452, Dev Recall 0.5248956213099207, Dev Accuracy 0.7723826714801444


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 146, loss: 9.080404302967128
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45205685276952307,  Dev Precision 0.35447618837995865, Dev Recall 0.5248956213099207, Dev Accuracy 0.7722021660649819


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 147, loss: 9.079753256555813
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4517831774229755,  Dev Precision 0.3543203074163308, Dev Recall 0.5244101844167168, Dev Accuracy 0.7720216606498195


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 148, loss: 9.079112771731705
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4517831774229755,  Dev Precision 0.3543203074163308, Dev Recall 0.5244101844167168, Dev Accuracy 0.7720216606498195


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 149, loss: 9.078484983586554
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4531680569538997,  Dev Precision 0.3552668374995839, Dev Recall 0.5266720891786215, Dev Accuracy 0.7725631768953067


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 150, loss: 9.077872247838263
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45342434844687485,  Dev Precision 0.3554033028239687, Dev Recall 0.5271482796548119, Dev Accuracy 0.7727436823104692


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 151, loss: 9.077269660892771
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45372612150536373,  Dev Precision 0.35587365908979945, Dev Recall 0.5271482796548119, Dev Accuracy 0.7729241877256318


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 152, loss: 9.076679443245503
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45367482004521903,  Dev Precision 0.3558382466736231, Dev Recall 0.5270317424739017, Dev Accuracy 0.7729241877256318


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 153, loss: 9.076105003926292
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45379825245363203,  Dev Precision 0.3559820239913135, Dev Recall 0.5270317424739017, Dev Accuracy 0.7731046931407942


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 154, loss: 9.07554292678833
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45379825245363203,  Dev Precision 0.3559820239913135, Dev Recall 0.5270317424739017, Dev Accuracy 0.7731046931407942


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 155, loss: 9.074994044517403
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45379825245363203,  Dev Precision 0.3559820239913135, Dev Recall 0.5270317424739017, Dev Accuracy 0.7731046931407942


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 156, loss: 9.074457666767177
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4541011039777225,  Dev Precision 0.3563559556152452, Dev Recall 0.5270317424739017, Dev Accuracy 0.7732851985559567


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 157, loss: 9.073933494624807
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4541011039777225,  Dev Precision 0.3563559556152452, Dev Recall 0.5270317424739017, Dev Accuracy 0.7732851985559567


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 158, loss: 9.073422040512312
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4541011039777225,  Dev Precision 0.3563559556152452, Dev Recall 0.5270317424739017, Dev Accuracy 0.7732851985559567


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 159, loss: 9.072920991413628
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4541011039777225,  Dev Precision 0.3563559556152452, Dev Recall 0.5270317424739017, Dev Accuracy 0.7732851985559567


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 160, loss: 9.072432710163628
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.454091538019309,  Dev Precision 0.35652972531419647, Dev Recall 0.5259900758072351, Dev Accuracy 0.7732851985559567


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 161, loss: 9.07195846358342
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4543562178287622,  Dev Precision 0.35674735034390204, Dev Recall 0.5263472186643781, Dev Accuracy 0.7732851985559567


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 162, loss: 9.071495589925282
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4544280903185098,  Dev Precision 0.35680652194153517, Dev Recall 0.5263472186643781, Dev Accuracy 0.773465703971119


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 163, loss: 9.07104500727867
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4527896125172413,  Dev Precision 0.35580782459025423, Dev Recall 0.5250651673823268, Dev Accuracy 0.773465703971119


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 164, loss: 9.070606402496793
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45311792989070226,  Dev Precision 0.3563628662183764, Dev Recall 0.5250651673823268, Dev Accuracy 0.7736462093862815


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 165, loss: 9.070179163520017
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45311792989070226,  Dev Precision 0.3563628662183764, Dev Recall 0.5250651673823268, Dev Accuracy 0.7736462093862815


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 166, loss: 9.069765724352937
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.453504772997285,  Dev Precision 0.3565978122678021, Dev Recall 0.5261068340489934, Dev Accuracy 0.7738267148014439


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 167, loss: 9.069361466080395
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.453504772997285,  Dev Precision 0.3565978122678021, Dev Recall 0.5261068340489934, Dev Accuracy 0.7738267148014439


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 168, loss: 9.068970089528099
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.453504772997285,  Dev Precision 0.3565978122678021, Dev Recall 0.5261068340489934, Dev Accuracy 0.7738267148014439


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 169, loss: 9.068588641152454
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.453504772997285,  Dev Precision 0.3565978122678021, Dev Recall 0.5261068340489934, Dev Accuracy 0.7738267148014439


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 170, loss: 9.068221775453482
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.453504772997285,  Dev Precision 0.3565978122678021, Dev Recall 0.5261068340489934, Dev Accuracy 0.7738267148014439


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 171, loss: 9.0678642329885
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.453504772997285,  Dev Precision 0.3565978122678021, Dev Recall 0.5261068340489934, Dev Accuracy 0.7738267148014439


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 172, loss: 9.067519842688718
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.453504772997285,  Dev Precision 0.3565978122678021, Dev Recall 0.5261068340489934, Dev Accuracy 0.7738267148014439


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 173, loss: 9.067186561983023
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.453504772997285,  Dev Precision 0.3565978122678021, Dev Recall 0.5261068340489934, Dev Accuracy 0.7738267148014439


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 174, loss: 9.066864689784264
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.453504772997285,  Dev Precision 0.3565978122678021, Dev Recall 0.5261068340489934, Dev Accuracy 0.7738267148014439


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 175, loss: 9.066553827541977
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45336099467731283,  Dev Precision 0.3564007267446296, Dev Recall 0.5261068340489934, Dev Accuracy 0.7736462093862816


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 176, loss: 9.066254544613967
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45335603521681733,  Dev Precision 0.35640740800725274, Dev Recall 0.5261068340489934, Dev Accuracy 0.7736462093862816


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 177, loss: 9.06596516851169
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45335603521681733,  Dev Precision 0.35640740800725274, Dev Recall 0.5261068340489934, Dev Accuracy 0.7736462093862816


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 178, loss: 9.065688339631949
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45335603521681733,  Dev Precision 0.35640740800725274, Dev Recall 0.5261068340489934, Dev Accuracy 0.7736462093862816


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 179, loss: 9.065421794777485
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45335603521681733,  Dev Precision 0.35640740800725274, Dev Recall 0.5261068340489934, Dev Accuracy 0.7736462093862816


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 180, loss: 9.065166181592799
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45335603521681733,  Dev Precision 0.35640740800725274, Dev Recall 0.5261068340489934, Dev Accuracy 0.7736462093862816


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 181, loss: 9.064922830951748
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45335603521681733,  Dev Precision 0.35640740800725274, Dev Recall 0.5261068340489934, Dev Accuracy 0.7736462093862816


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 182, loss: 9.064690205588269
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45335603521681733,  Dev Precision 0.35640740800725274, Dev Recall 0.5261068340489934, Dev Accuracy 0.7736462093862816


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 183, loss: 9.064468967380808
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45335603521681733,  Dev Precision 0.35640740800725274, Dev Recall 0.5261068340489934, Dev Accuracy 0.7736462093862816


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 184, loss: 9.064258319228443
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45335603521681733,  Dev Precision 0.35640740800725274, Dev Recall 0.5261068340489934, Dev Accuracy 0.7736462093862816


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 185, loss: 9.064059328677049
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45335603521681733,  Dev Precision 0.35640740800725274, Dev Recall 0.5261068340489934, Dev Accuracy 0.7736462093862816


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 186, loss: 9.06387072890552
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45335603521681733,  Dev Precision 0.35640740800725274, Dev Recall 0.5261068340489934, Dev Accuracy 0.7736462093862816


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 187, loss: 9.063693345482669
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45335603521681733,  Dev Precision 0.35640740800725274, Dev Recall 0.5261068340489934, Dev Accuracy 0.7736462093862816


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 188, loss: 9.063527071653906
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45335603521681733,  Dev Precision 0.35640740800725274, Dev Recall 0.5261068340489934, Dev Accuracy 0.7736462093862816


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 189, loss: 9.063372170747217
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45335603521681733,  Dev Precision 0.35640740800725274, Dev Recall 0.5261068340489934, Dev Accuracy 0.7736462093862816


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 190, loss: 9.063227681971307
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45335603521681733,  Dev Precision 0.35640740800725274, Dev Recall 0.5261068340489934, Dev Accuracy 0.7736462093862816


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 191, loss: 9.063094808094537
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45335603521681733,  Dev Precision 0.35640740800725274, Dev Recall 0.5261068340489934, Dev Accuracy 0.7736462093862816


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 192, loss: 9.062972510038916
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45335603521681733,  Dev Precision 0.35640740800725274, Dev Recall 0.5261068340489934, Dev Accuracy 0.7736462093862816


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 193, loss: 9.062861606256286
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45369216967059883,  Dev Precision 0.35698711815218026, Dev Recall 0.5261068340489934, Dev Accuracy 0.773826714801444


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 194, loss: 9.062761249826915
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45369216967059883,  Dev Precision 0.35698711815218026, Dev Recall 0.5261068340489934, Dev Accuracy 0.773826714801444


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 195, loss: 9.062672002991633
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45369216967059883,  Dev Precision 0.35698711815218026, Dev Recall 0.5261068340489934, Dev Accuracy 0.773826714801444


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 196, loss: 9.062593218105942
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45369216967059883,  Dev Precision 0.35698711815218026, Dev Recall 0.5261068340489934, Dev Accuracy 0.773826714801444


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 197, loss: 9.062524653192776
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.45369216967059883,  Dev Precision 0.35698711815218026, Dev Recall 0.5261068340489934, Dev Accuracy 0.773826714801444


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 198, loss: 9.062467133820947
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4538301958955816,  Dev Precision 0.3571893612115662, Dev Recall 0.5261068340489934, Dev Accuracy 0.7740072202166065


  0%|          | 0/67 [00:00<?, ?it/s]

epoch 199, loss: 9.062420567469811
Evaluating dev...


  0%|          | 0/5 [00:00<?, ?it/s]

Dev F1 0.4538301958955816,  Dev Precision 0.3571893612115662, Dev Recall 0.5261068340489934, Dev Accuracy 0.7740072202166065


In [59]:
torch.save(model, 'LClassifiers_v3.pt')

In [53]:
model = torch.load('LClassifiers_v3.pt')

In [58]:
print("Evaluating test...")
all_preds = []
all_labels = []

test_input_batches = [b for b in chunk_multi(test_premises, test_conclusion, test_stance, batch_size)]
# Tokenize + encode
test_input_batches = [tokenizer(*batch) for batch in test_input_batches]


test_label_batches = [b for b in chunk(test_labels, batch_size)]
test_label_batches = [encode_labels(batch) for batch in test_label_batches]
for sents, labels in tqdm(zip(test_input_batches, test_label_batches), total=len(test_input_batches)):
    pred = predict(model, sents)
    all_preds.extend(pred)
    all_labels.extend(list(labels.numpy()))
# #print(range(len(set(train_labels))))

test_f1, test_P, test_R, test_acc, test_all_f1, test_all_P, test_all_R, test_all_acc = f1Score_multiLabel(all_preds, all_labels)
print(f"test F1 {test_f1},  test Precision {test_P}, test Recall {test_R}, test Accuracy {test_acc}")

with open("all_f1_test_v3.csv", 'ab') as abc:
    np.savetxt(abc, 
           test_all_f1,
           delimiter =", ", 
           fmt ='%s')

Evaluating test...


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for sents, labels in tqdm(zip(test_input_batches, test_label_batches), total=len(test_input_batches)):


  0%|          | 0/12 [00:00<?, ?it/s]

test F1 0.3457838201190367,  test Precision 0.30096408909264705, test Recall 0.44515184166404553, test Accuracy 0.7677954847277557


0

tensor(0)

tensor([[ 0.9030,  1.0262, -0.0431,  0.4149,  1.3941]], requires_grad=True) tensor([[0.5257, 0.1721, 0.2647, 0.0207, 0.0167]])


tensor(1.7929, grad_fn=<DivBackward1>)


tensor([[-0.1021,  0.0259,     nan, -0.8798,  0.3323]], grad_fn=<LogBackward0>)

tensor([[-0.0537,  0.0045,     nan, -0.0182,  0.0056]], grad_fn=<MulBackward0>)

In [None]:
import random
from tqdm import tqdm_notebook as tqdm
from torchmetrics.classification import MultilabelHammingDistance
def training_loop(
    num_epochs,
    train_features,
    train_labels,
    dev_sents,
    dev_labels,
    optimizer,
    scheduler,
    model,
):
    print("Training...")
    all_f1 = []
    all_P = []
    all_R = []
    all_L = []
    all_CELoss = []
    all_HMLoss = []
    all_acc = []
    loss_func = torch.nn.CrossEntropyLoss()
    hammingLoss = MultilabelHammingDistance(num_labels=20)
    batches = list(zip(train_features, train_labels))
    random.shuffle(batches)
    for i in range(num_epochs):
        losses = []
        #CELosses = []
        #HMLosses = []
        for features, labels in tqdm(batches):
            # Empty the dynamic computation graph
            features = features.to(device)
            labels = labels.float()
            labels = labels.to(device)
            optimizer.zero_grad()
            preds = model(features)
            
            loss0 = loss_func(preds[0].squeeze(1).squeeze(1), labels[:,0])
            #loss = loss_func(preds0.squeeze(1), labels)
            loss1 = loss_func(preds[1].squeeze(1).squeeze(1), labels[:,1]) 
            loss2 = loss_func(preds[2].squeeze(1).squeeze(1), labels[:,2]) 
            loss3 = loss_func(preds[3].squeeze(1).squeeze(1), labels[:,3]) 
            loss4 = loss_func(preds[4].squeeze(1).squeeze(1), labels[:,4]) 
            loss5 = loss_func(preds[5].squeeze(1).squeeze(1), labels[:,5]) 
            loss6 = loss_func(preds[6].squeeze(1).squeeze(1), labels[:,6]) 
            loss7 = loss_func(preds[7].squeeze(1).squeeze(1), labels[:,7]) 
            loss8 = loss_func(preds[8].squeeze(1).squeeze(1), labels[:,8]) 
            loss9 = loss_func(preds[9].squeeze(1).squeeze(1), labels[:,9]) 
            loss10 = loss_func(preds[10].squeeze(1).squeeze(1), labels[:,10])
            loss11 = loss_func(preds[11].squeeze(1).squeeze(1), labels[:,11]) 
            loss12 = loss_func(preds[12].squeeze(1).squeeze(1), labels[:,12]) 
            loss13 = loss_func(preds[13].squeeze(1).squeeze(1), labels[:,13]) 
            loss14 = loss_func(preds[14].squeeze(1).squeeze(1), labels[:,14]) 
            loss15 = loss_func(preds[15].squeeze(1).squeeze(1), labels[:,15]) 
            loss16 = loss_func(preds[16].squeeze(1).squeeze(1), labels[:,16]) 
            loss17 = loss_func(preds[17].squeeze(1).squeeze(1), labels[:,17]) 
            loss18 = loss_func(preds[18].squeeze(1).squeeze(1), labels[:,18]) 
            loss19 = loss_func(preds[19].squeeze(1).squeeze(1), labels[:,19])
            loss_group1 = loss_func(preds[20].squeeze(1).squeeze(1), labels[:,20])
            loss_group2 = loss_func(preds[21].squeeze(1).squeeze(1), labels[:,21])
            loss_group3 = loss_func(preds[22].squeeze(1).squeeze(1), labels[:,22])
            loss_group4 = loss_func(preds[23].squeeze(1).squeeze(1), labels[:,23])
            loss_group5 = loss_func(preds[24].squeeze(1).squeeze(1), labels[:,24])
            
            local_loss = local_loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6 + loss7 + loss8 + loss9 + loss10 + loss11 + loss12 + loss13 + loss14 + loss15 + loss16 + loss17 + loss18 + loss19
            group_loss = loss_group1 + loss_group2 + loss_group3 + loss_group4 + loss_group5
            loss = local_loss + 4*group_loss
            #print("Preds ",preds.shape)
            #print("Labels ", labels.shape)
            #loss = loss_func(preds, labels)
            #loss2 = 20*hammingLoss(preds, labels)
            
            #loss = loss1 + loss2
            # Backpropogate the loss through our model
            loss.backward()
            optimizer.step()
            #CELosses.append(loss1.item())
            #HMLosses.append(loss2.item())
            losses.append(loss.item())
        
        #print(f"epoch {i}, loss: {np.sum(losses)/len(losses)}, HM loss: {np.sum(HMLosses)/len(HMLosses)}, CE loss: {np.sum(CELosses)/len(CELosses)}")
        print(f"epoch {i}, loss: {np.sum(losses)/len(losses)}")
        # Estimate the f1 score for the development set
        print("Evaluating dev...")
        all_preds = []
        all_labels = []
        for sents, labels in tqdm(zip(dev_sents, dev_labels), total=len(dev_sents)):
            sents = sents.to(device)
            pred = predict(model, sents)
            all_preds.extend(pred)
            all_labels.extend(list(labels))
        # #print(range(len(set(train_labels))))

        dev_f1, dev_P, dev_R, dev_acc, dev_all_f1, dev_all_P, dev_all_R, dev_all_acc = f1Score_multiLabel(all_preds, all_labels)
        print(f"Dev F1 {dev_f1},  Dev Precision {dev_P}, Dev Recall {dev_R}, Dev Accuracy {dev_acc}")
        all_f1.append(dev_all_f1)
        all_P.append(dev_all_P)
        all_R.append(dev_all_R)
        all_L.append(losses)
        #all_CELoss.append(CELosses)
        #all_HMLoss.append(HMLosses)
        all_acc.append(dev_all_acc)
        scheduler.step()
        #print(optimizer)
    # Return the trained model
    with open("all_f1_base.csv", 'ab') as abc:
        np.savetxt(abc, 
               all_f1,
               delimiter =", ", 
               fmt ='%s')
    with open("all_P_base.csv", 'ab') as abc:
        np.savetxt(abc, 
               all_P,
               delimiter =", ", 
               fmt ='%s')
    with open("all_R_base.csv", 'ab') as abc:
        np.savetxt(abc, 
               all_R,
               delimiter =", ", 
               fmt ='%s')
    with open("all_acc_base.csv", 'ab') as abc:
        np.savetxt(abc, 
               all_acc,
               delimiter =", ", 
               fmt ='%s')
    with open("all_L_base.csv", 'ab') as abc:
        np.savetxt(abc, 
               all_L,
               delimiter =", ", 
               fmt ='%s')
    """
    with open("all_CELoss_base.csv", 'ab') as abc:
        np.savetxt(abc, 
               all_CELoss,
               delimiter =", ", 
               fmt ='%s')
    with open("all_HMLoss_base.csv", 'ab') as abc:
        np.savetxt(abc, 
               all_CELoss,
               delimiter =", ", 
               fmt ='%s')
    """
    return model

tensor(1.5930, grad_fn=<NllLossBackward0>)


epoch 1, loss: 10.0


torch.float32

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
[0, 9, 7, 1, 5, 3, 6, 2, 4, 8]
None


In [None]:
import random
from tqdm import tqdm_notebook as tqdm
from torchmetrics.classification import MultilabelHammingDistance
def training_loop(
    num_epochs,
    train_features,
    train_labels,
    dev_sents,
    dev_labels,
    optimizer,
    scheduler,
    model,
):
    print("Training...")
    all_f1 = []
    all_P = []
    all_R = []
    all_L = []
    all_CELoss = []
    all_HMLoss = []
    all_acc = []
    loss_func = torch.nn.CrossEntropyLoss()
    hammingLoss = MultilabelHammingDistance(num_labels=20)
    batches = list(zip(train_features, train_labels))
    random.shuffle(batches)
    for i in range(num_epochs):
        losses = []
        CELosses = []
        HMLosses = []
        for features, labels in tqdm(batches):
            # Empty the dynamic computation graph
            features = features.to(device)
            labels = labels.float()
            labels = labels.to(device)
            optimizer.zero_grad()
            preds = model(features).squeeze(1)
            #print("Preds ",preds.shape)
            #print("Labels ", labels.shape)
            loss1 = loss_func(preds, labels)
            loss2 = 20*hammingLoss(preds, labels)
            
            loss = loss1 + loss2
            # Backpropogate the loss through our model
            loss.backward()
            optimizer.step()
            CELosses.append(loss1.item())
            HMLosses.append(loss2.item())
            losses.append(loss.item())
        
        print(f"epoch {i}, loss: {np.sum(losses)/len(losses)}, HM loss: {np.sum(HMLosses)/len(HMLosses)}, CE loss: {np.sum(CELosses)/len(CELosses)}")
        # Estimate the f1 score for the development set
        print("Evaluating dev...")
        all_preds = []
        all_labels = []
        for sents, labels in tqdm(zip(dev_sents, dev_labels), total=len(dev_sents)):
            sents = sents.to(device)
            pred = predict(model, sents)
            all_preds.extend(pred)
            all_labels.extend(list(labels))
        # #print(range(len(set(train_labels))))

        dev_f1, dev_P, dev_R, dev_acc, dev_all_f1, dev_all_P, dev_all_R, dev_all_acc = f1Score_multiLabel(all_preds, all_labels)
        print(f"Dev F1 {dev_f1},  Dev Precision {dev_P}, Dev Recall {dev_R}, Dev Accuracy {dev_acc}")
        all_f1.append(dev_all_f1)
        all_P.append(dev_all_P)
        all_R.append(dev_all_R)
        all_L.append(losses)
        all_CELoss.append(CELosses)
        all_HMLoss.append(HMLosses)
        all_acc.append(dev_all_acc)
        scheduler.step()
        #print(optimizer)
    # Return the trained model
    with open("all_f1_base.csv", 'ab') as abc:
        np.savetxt(abc, 
               all_f1,
               delimiter =", ", 
               fmt ='%s')
    with open("all_P_base.csv", 'ab') as abc:
        np.savetxt(abc, 
               all_P,
               delimiter =", ", 
               fmt ='%s')
    with open("all_R_base.csv", 'ab') as abc:
        np.savetxt(abc, 
               all_R,
               delimiter =", ", 
               fmt ='%s')
    with open("all_acc_base.csv", 'ab') as abc:
        np.savetxt(abc, 
               all_acc,
               delimiter =", ", 
               fmt ='%s')
    with open("all_L_base.csv", 'ab') as abc:
        np.savetxt(abc, 
               all_L,
               delimiter =", ", 
               fmt ='%s')
    with open("all_CELoss_base.csv", 'ab') as abc:
        np.savetxt(abc, 
               all_CELoss,
               delimiter =", ", 
               fmt ='%s')
    with open("all_HMLoss_base.csv", 'ab') as abc:
        np.savetxt(abc, 
               all_CELoss,
               delimiter =", ", 
               fmt ='%s')
    return model