In [1]:
from spanemo.learner import Trainer
from spanemo.model import SpanEmo
from spanemo.data_loader import DataClass
from torch.utils.data import DataLoader
import torch
import datetime
import numpy as np
import json

seed = 12345678

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

if str(device) == 'cuda:0':
    print("Currently using GPU: {}".format(device))
    np.random.seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
else:
    print("WARNING: USING CPU")

Currently using GPU: cuda:0


In [3]:
def make_loaders(args):
    train_dataset = DataClass(args, args['train_path'])


    train_data_loader = DataLoader(train_dataset,
                                  batch_size=int(args['train_batch_size']),
                                  shuffle=True
                                  )
    print('The number of training batches: ', len(train_data_loader))

    val_dataset = DataClass(args, args['val_path'])
    val_data_loader = DataLoader(val_dataset,
                                batch_size=int(args['val_batch_size']),
                                shuffle=False
                                )
    print('The number of validation batches: ', len(val_data_loader))

    return train_data_loader, val_data_loader

In [4]:
def make_model(args):
    model = SpanEmo(output_dropout=args['output_dropout'],
                    backbone=args['backbone'],
                    joint_loss=args['loss_type'],
                    alpha=args['alpha_loss'])
    
    return model


In [5]:
def pipeline(args, loaders=None):


    now = datetime.datetime.now()
    filename = now.strftime("%Y-%m-%d-%H:%M:%S")
    fw = open('configs/' + filename + '.json', 'a')
    json.dump(args, fw, sort_keys=True, indent=2)

    if loaders is None:
        train_data_loader, val_data_loader = make_loaders(args)
    else:
        train_data_loader, val_data_loader = loaders
    model = make_model(args)

    learn = Trainer(model, train_data_loader, val_data_loader, filename=filename)
    learn.fit(
        num_epochs=int(args['max_epoch']),
        args=args,
        device=device
    )

In [6]:
hyperparams = {
    'train_path':'data/train.csv', 
    'val_path':'data/val.csv',
    'backbone':'bert-base-uncased',
    'train_batch_size': 128,
    'val_batch_size': 128,
    'output_dropout': 0.1,
    'loss_type': 'joint',
    'alpha_loss': 0.2,
    'max_epoch': 20,
    'max_length': 128,
    'ffn_lr': 0.001,
    'bert_lr': 2e-5
}

In [7]:
loaders = make_loaders(hyperparams)

  self.tok = re.compile(r"({})".format("|".join(pipeline)))


Reading twitter_2018 - 1grams ...
Reading twitter_2018 - 2grams ...


  regexes = {k.lower(): re.compile(self.expressions[k]) for k, v in


Reading twitter_2018 - 1grams ...


PreProcessing dataset ...: 100%|██████████| 43410/43410 [00:39<00:00, 1088.52it/s]


The number of training batches:  340
Reading twitter_2018 - 1grams ...
Reading twitter_2018 - 2grams ...
Reading twitter_2018 - 1grams ...


PreProcessing dataset ...: 100%|██████████| 5426/5426 [00:05<00:00, 1041.13it/s]

The number of validation batches:  43





In [8]:
#wooooooooooooooooooooo
pipeline(hyperparams, loaders=loaders)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Train_Loss,Val_Loss,F1-Macro,F1-Micro,JS,Time
0.3004,0.2076,0.4563,0.5601,0.5038,03:01
0.2026,0.1931,0.5118,0.6024,0.5571,03:02
0.1901,0.1903,0.5403,0.6152,0.5721,03:00
0.1823,0.1884,0.537,0.6232,0.587,03:01
0.1757,0.192,0.5343,0.6113,0.5739,03:01
0.1683,0.1917,0.5435,0.6106,0.5762,03:02
0.1613,0.1961,0.5335,0.6001,0.5622,03:02
0.1539,0.2019,0.5299,0.5905,0.5563,03:02
0.146,0.2061,0.5286,0.5927,0.5603,03:02
0.1391,0.2145,0.5169,0.5718,0.5386,03:03


epoch#:  1
Validation loss decreased (inf --> 0.207630).  Saving model ...
epoch#:  2
Validation loss decreased (0.207630 --> 0.193123).  Saving model ...
epoch#:  3
Validation loss decreased (0.193123 --> 0.190322).  Saving model ...
epoch#:  4
Validation loss decreased (0.190322 --> 0.188408).  Saving model ...
epoch#:  5
EarlyStopping counter: 1 out of 10
epoch#:  6
EarlyStopping counter: 2 out of 10
epoch#:  7
EarlyStopping counter: 3 out of 10
epoch#:  8
EarlyStopping counter: 4 out of 10
epoch#:  9
EarlyStopping counter: 5 out of 10
epoch#:  10
EarlyStopping counter: 6 out of 10
epoch#:  11
EarlyStopping counter: 7 out of 10
epoch#:  12
EarlyStopping counter: 8 out of 10
epoch#:  13
EarlyStopping counter: 9 out of 10
epoch#:  14
EarlyStopping counter: 10 out of 10
Early stopping
