In [11]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [12]:
import os
from collections import Counter
from datetime import datetime

import pandas as pd
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader

from src import utils
from src.data_loader import MyDataset, load_from_csv
from src.model import CharacterLevelCNN
from train import train, evaluate

In [13]:
df = pd.read_csv('data/training.1600000.processed.noemoticon.csv', encoding='latin1')
df.shape

(1600000, 6)

In [11]:
sample_df = df.sample(n=1000)

In [13]:
sample_df[['target', 'text']].to_csv('data/sample.csv', index=False)

In [27]:
# Parameters

ALPHABET = "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\"/\\|_@#$%^&*~`+ =<>()[]{}"
MAX_LEN = 150
CLASS_WEIGHTS = False
training_params = {
    "batch_size":  128,
    "shuffle":     True,
    "num_workers": 1,
    "drop_last":   True,
}

validation_params = {
    "batch_size":  128,
    "shuffle":     False,
    "num_workers": 1,
    "drop_last":   True,
}
OPTIMIZER = 'sgd'
scheduler = 'clr'
learning_rate = 0.01
stepsize = 4.0
min_lr, max_lr = 1.7e-3, 1e-2
epochs = 10

now = datetime.now()
logdir = './log' + now.strftime("%Y%m%d-%H%M%S") + "/"
os.makedirs(logdir)
log_file = logdir + "log.txt"
writer = SummaryWriter(logdir)

best_f1 = 0
best_epoch = 0
checkpoint = 1
output = "./models/"
model_name = "cnn"
early_stopping = 0
patience = 3


In [28]:
sample_df = pd.read_csv('data/sample.csv')
sample_df.shape

(1000, 2)

In [29]:
texts, labels, number_of_classes, sample_weights = load_from_csv(sample_df)

data loaded successfully with 1000 rows and 2 labels
Distribution of the classes Counter({0: 505, 1: 495})


In [30]:
class_names = sorted(list(set(labels)))
class_names = [str(class_name) for class_name in class_names]
(
    train_texts,
    val_texts,
    train_labels,
    val_labels,
    train_sample_weights,
    _,
) = train_test_split(
    texts,
    labels,
    sample_weights,
    test_size=0.2,
    random_state=42,
    stratify=labels,
)
training_set = MyDataset(train_texts, train_labels, vocabulary=ALPHABET, max_length=MAX_LEN)
validation_set = MyDataset(val_texts, val_labels, vocabulary=ALPHABET)
training_generator = DataLoader(training_set, **training_params)
validation_generator = DataLoader(validation_set, **validation_params)

model = CharacterLevelCNN(number_of_classes, num_chars=len(ALPHABET), max_length=MAX_LEN)
if torch.cuda.is_available():
    model.cuda()

In [31]:
if bool(CLASS_WEIGHTS):
    class_counts = dict(Counter(train_labels))
    m = max(class_counts.values())
    for c in class_counts:
        class_counts[c] = m / class_counts[c]
    weights = []
    for k in sorted(class_counts.keys()):
        weights.append(class_counts[k])

    weights = torch.Tensor(weights)
    if torch.cuda.is_available():
        weights = weights.cuda()
        print(f"passing weights to CrossEntropyLoss : {weights}")
        criterion = nn.CrossEntropyLoss(weight=weights)
else:
    criterion = nn.CrossEntropyLoss()

if OPTIMIZER == "sgd":
    if scheduler == "clr":
        optimizer = torch.optim.SGD(
            model.parameters(), lr=1, momentum=0.9, weight_decay=0.00001
        )
    else:
        optimizer = torch.optim.SGD(
            model.parameters(), lr=learning_rate, momentum=0.9
        )
elif OPTIMIZER == "adam":
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

if scheduler == "clr":
    stepsize = int(stepsize * len(training_generator))
    clr = utils.cyclical_lr(stepsize, min_lr, max_lr)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, [clr])
else:
    scheduler = None

In [32]:
for epoch in range(epochs):
    training_loss, training_accuracy, train_f1 = train(
        model,
        training_generator,
        optimizer,
        criterion,
        epoch,
        writer,
        log_file,
        scheduler,
        class_names,
    )

    validation_loss, validation_accuracy, validation_f1 = evaluate(
        model,
        validation_generator,
        criterion,
        epoch,
        writer,
        log_file,
    )

    print(
        "[Epoch: {} / {}]\ttrain_loss: {:.4f} \ttrain_acc: {:.4f} \tval_loss: {:.4f} \tval_acc: {:.4f}".format(
            epoch + 1,
            epochs,
            training_loss,
            training_accuracy,
            validation_loss,
            validation_accuracy,
        )
    )
    print("=" * 50)

    # learning rate scheduling

    if scheduler == "step":
        if optimizer == "sgd" and ((epoch + 1) % 3 == 0) and epoch > 0:
            current_lr = optimizer.state_dict()["param_groups"][0]["lr"]
            current_lr /= 2
            print("Decreasing learning rate to {0}".format(current_lr))
            for param_group in optimizer.param_groups:
                param_group["lr"] = current_lr

    # model checkpoint

    if validation_f1 > best_f1:
        best_f1 = validation_f1
        best_epoch = epoch
        if checkpoint == 1:
            torch.save(
                model.state_dict(),
                output
                + "model_{}_epoch_{}_maxlen_{}_lr_{}_loss_{}_acc_{}_f1_{}.pth".format(
                    model_name,
                    epoch,
                    MAX_LEN,
                    optimizer.state_dict()["param_groups"][0]["lr"],
                    round(validation_loss, 4),
                    round(validation_accuracy, 4),
                    round(validation_f1, 4),
                ),
            )

    if bool(early_stopping):
        if epoch - best_epoch > patience > 0:
            print(
                "Stop training at epoch {}. The lowest loss achieved is {} at epoch {}".format(
                    epoch, validation_loss, best_epoch
                )
            )
            break

100%|██████████| 6/6 [00:05<00:00,  1.04it/s]


              precision    recall  f1-score   support

           0       0.49      0.55      0.52       390
           1       0.47      0.41      0.44       378

    accuracy                           0.48       768
   macro avg       0.48      0.48      0.48       768
weighted avg       0.48      0.48      0.48       768



100%|██████████| 1/1 [00:00<00:00,  1.37it/s]


              precision    recall  f1-score   support

           0       0.41      0.30      0.35        66
           1       0.42      0.53      0.47        62

    accuracy                           0.41       128
   macro avg       0.41      0.42      0.41       128
weighted avg       0.41      0.41      0.41       128

[Epoch: 1 / 10]	train_loss: 0.8229 	train_acc: 0.4805 	val_loss: 0.7045 	val_acc: 0.4141


100%|██████████| 6/6 [00:05<00:00,  1.00it/s]


              precision    recall  f1-score   support

           0       0.51      0.51      0.51       388
           1       0.50      0.50      0.50       380

    accuracy                           0.51       768
   macro avg       0.51      0.51      0.51       768
weighted avg       0.51      0.51      0.51       768



100%|██████████| 1/1 [00:00<00:00,  1.32it/s]


              precision    recall  f1-score   support

           0       0.51      0.92      0.66        66
           1       0.38      0.05      0.09        62

    accuracy                           0.50       128
   macro avg       0.44      0.49      0.37       128
weighted avg       0.44      0.50      0.38       128

[Epoch: 2 / 10]	train_loss: 0.7590 	train_acc: 0.5052 	val_loss: 0.6908 	val_acc: 0.5000


100%|██████████| 6/6 [00:05<00:00,  1.18it/s]


              precision    recall  f1-score   support

           0       0.51      0.59      0.55       387
           1       0.50      0.42      0.46       381

    accuracy                           0.51       768
   macro avg       0.50      0.50      0.50       768
weighted avg       0.50      0.51      0.50       768



100%|██████████| 1/1 [00:00<00:00,  1.65it/s]


              precision    recall  f1-score   support

           0       0.53      0.92      0.67        66
           1       0.58      0.11      0.19        62

    accuracy                           0.53       128
   macro avg       0.55      0.52      0.43       128
weighted avg       0.55      0.53      0.44       128

[Epoch: 3 / 10]	train_loss: 0.7377 	train_acc: 0.5052 	val_loss: 0.6874 	val_acc: 0.5312


100%|██████████| 6/6 [00:04<00:00,  1.22it/s]


              precision    recall  f1-score   support

           0       0.51      0.54      0.53       386
           1       0.51      0.49      0.50       382

    accuracy                           0.51       768
   macro avg       0.51      0.51      0.51       768
weighted avg       0.51      0.51      0.51       768



100%|██████████| 1/1 [00:00<00:00,  1.29it/s]


              precision    recall  f1-score   support

           0       0.53      0.59      0.56        66
           1       0.51      0.45      0.48        62

    accuracy                           0.52       128
   macro avg       0.52      0.52      0.52       128
weighted avg       0.52      0.52      0.52       128

[Epoch: 4 / 10]	train_loss: 0.7079 	train_acc: 0.5130 	val_loss: 0.6881 	val_acc: 0.5234


100%|██████████| 6/6 [00:03<00:00,  1.54it/s]


              precision    recall  f1-score   support

           0       0.54      0.49      0.52       383
           1       0.54      0.59      0.56       385

    accuracy                           0.54       768
   macro avg       0.54      0.54      0.54       768
weighted avg       0.54      0.54      0.54       768



100%|██████████| 1/1 [00:00<00:00,  1.50it/s]


              precision    recall  f1-score   support

           0       0.53      0.67      0.59        66
           1       0.51      0.37      0.43        62

    accuracy                           0.52       128
   macro avg       0.52      0.52      0.51       128
weighted avg       0.52      0.52      0.51       128

[Epoch: 5 / 10]	train_loss: 0.6970 	train_acc: 0.5404 	val_loss: 0.6886 	val_acc: 0.5234


100%|██████████| 6/6 [00:06<00:00,  1.13s/it]


              precision    recall  f1-score   support

           0       0.50      0.48      0.49       391
           1       0.49      0.51      0.50       377

    accuracy                           0.49       768
   macro avg       0.50      0.50      0.49       768
weighted avg       0.50      0.49      0.49       768



100%|██████████| 1/1 [00:00<00:00,  1.93it/s]


              precision    recall  f1-score   support

           0       0.53      0.91      0.67        66
           1       0.57      0.13      0.21        62

    accuracy                           0.53       128
   macro avg       0.55      0.52      0.44       128
weighted avg       0.55      0.53      0.45       128

[Epoch: 6 / 10]	train_loss: 0.7196 	train_acc: 0.4948 	val_loss: 0.6854 	val_acc: 0.5312


100%|██████████| 6/6 [00:06<00:00,  1.04s/it]


              precision    recall  f1-score   support

           0       0.50      0.59      0.54       394
           1       0.47      0.38      0.42       374

    accuracy                           0.49       768
   macro avg       0.48      0.48      0.48       768
weighted avg       0.48      0.49      0.48       768



100%|██████████| 1/1 [00:00<00:00,  2.02it/s]


              precision    recall  f1-score   support

           0       0.52      0.97      0.68        66
           1       0.67      0.06      0.12        62

    accuracy                           0.53       128
   macro avg       0.60      0.52      0.40       128
weighted avg       0.59      0.53      0.41       128

[Epoch: 7 / 10]	train_loss: 0.7161 	train_acc: 0.4870 	val_loss: 0.6858 	val_acc: 0.5312


100%|██████████| 6/6 [00:04<00:00,  1.29it/s]


              precision    recall  f1-score   support

           0       0.53      0.63      0.57       393
           1       0.51      0.41      0.45       375

    accuracy                           0.52       768
   macro avg       0.52      0.52      0.51       768
weighted avg       0.52      0.52      0.51       768



100%|██████████| 1/1 [00:01<00:00,  1.34s/it]


              precision    recall  f1-score   support

           0       0.56      0.82      0.66        66
           1       0.61      0.31      0.41        62

    accuracy                           0.57       128
   macro avg       0.58      0.56      0.54       128
weighted avg       0.58      0.57      0.54       128

[Epoch: 8 / 10]	train_loss: 0.7051 	train_acc: 0.5208 	val_loss: 0.6858 	val_acc: 0.5703


100%|██████████| 6/6 [00:07<00:00,  1.25s/it]


              precision    recall  f1-score   support

           0       0.55      0.53      0.54       391
           1       0.53      0.55      0.54       377

    accuracy                           0.54       768
   macro avg       0.54      0.54      0.54       768
weighted avg       0.54      0.54      0.54       768



100%|██████████| 1/1 [00:00<00:00,  1.81it/s]


              precision    recall  f1-score   support

           0       0.60      0.56      0.58        66
           1       0.56      0.60      0.58        62

    accuracy                           0.58       128
   macro avg       0.58      0.58      0.58       128
weighted avg       0.58      0.58      0.58       128

[Epoch: 9 / 10]	train_loss: 0.7028 	train_acc: 0.5417 	val_loss: 0.6864 	val_acc: 0.5781


100%|██████████| 6/6 [00:03<00:00,  1.57it/s]


              precision    recall  f1-score   support

           0       0.51      0.53      0.52       388
           1       0.50      0.49      0.49       380

    accuracy                           0.51       768
   macro avg       0.51      0.51      0.51       768
weighted avg       0.51      0.51      0.51       768



100%|██████████| 1/1 [00:00<00:00,  1.98it/s]

              precision    recall  f1-score   support

           0       0.54      0.77      0.63        66
           1       0.55      0.29      0.38        62

    accuracy                           0.54       128
   macro avg       0.54      0.53      0.51       128
weighted avg       0.54      0.54      0.51       128

[Epoch: 10 / 10]	train_loss: 0.7041 	train_acc: 0.5078 	val_loss: 0.6858 	val_acc: 0.5391



