In [1]:
%load_ext autoreload
%autoreload 2

import pandas as pd
import torch
from coherenceModel import *
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

In [2]:
paragraph_df = pd.read_csv('moreAviationPerms.csv')
paragraph_df

Unnamed: 0,paragraph,is_coherent
0,The preflight inspection of the fuel tanks by ...,1
1,"The pilot reported that he was cleared to 4,00...",1
2,The instrument-rated private pilot lost contro...,1
3,The non-instrument rated private pilot was rec...,1
4,The commercial pilot reported a partial power ...,1
...,...,...
41995,The engine ran for approximately 30 minutes at...,0
41996,Examination of the wreckage revealed that ther...,0
41997,"The engine then lost all power, and the pilot ...",0
41998,"During the approach, the pilot was able to res...",0


In [3]:
X_train, X_test, y_train, y_test = train_test_split(
    paragraph_df.paragraph.values, 
    paragraph_df.is_coherent.values,
    stratify = paragraph_df.is_coherent.values,
    test_size = 0.1, 
    random_state = 487
)
X_train, X_val, y_train, y_val = train_test_split(
    X_train, 
    y_train,
    stratify = y_train,
    test_size = 0.2, 
    random_state = 487
)
print(X_train[0])
print(y_train[0])

It had been consumed by fire after a collision with trees. The coroner was an FAA DME (designated medical examiner). The airplane was found three days after it was declared missing. He said he did not call the authorities because he did not hear a crash or see smoke. Examination of wreckage revealed no mechanical anomaly that would have prevented normal flight. A toxicology test and autopsy were inconclusive due to inadequate or unsuitable specimens. He said he had denied the accident pilot a medical certificate due to heart disease, but the pilot appealed his decision and was given a special medical certificate. Also, he said the weather was clear and sunny. He said it was at an altitude equal to other airplanes departing and arriving at the airport. A witness said he observed the airplane flying level near the departure airport. The witness said the airplane suddenly pitched down, descended below trees, and did not come back up.
0


In [7]:
import gensim.downloader
embed = gensim.downloader.load("glove-wiki-gigaword-50")

In [5]:
train_data = WindowedParDataset(X_train, y_train, embed, 5)
dev_data = WindowedParDataset(X_val, y_val, embed, 5)
test_data = WindowedParDataset(X_test, y_test, embed, 5)

Number of coherent windows: 9912
Number of incoherent windows: 200742
Number of coherent windows: 2520
Number of incoherent windows: 50217
Number of coherent windows: 1399
Number of incoherent windows: 27895


In [6]:
train_loader = DataLoader(train_data, batch_size=25, collate_fn=basic_collate_fn, shuffle=True)
dev_loader = DataLoader(dev_data, batch_size=25, collate_fn=basic_collate_fn, shuffle=True)
test_loader = DataLoader(test_data, batch_size=25, collate_fn=basic_collate_fn, shuffle=False)

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


In [5]:
num_pos = len(paragraph_df[paragraph_df["is_coherent"] == 1])
num_neg = len(paragraph_df[paragraph_df["is_coherent"] == 0])
dampen = 1
pos_weight = torch.Tensor([num_neg / num_pos / dampen]).to(device)
pos_weight

tensor([20.], device='cuda:0')

In [12]:
# test on validation to see if overfit is possible
ffnn = FFNN(5, device)
ffnn.to(device)
optim = get_optimizer(ffnn, lr=1e-2, weight_decay=0)
best_model, stats = train_model(ffnn, dev_loader, dev_loader, optim, pos_weight=pos_weight,
                                num_epoch=15, collect_cycle=20, device=device, patience=None)
plot_loss(stats)

------------------------ Start Training ------------------------
Epoch No. 1--Iteration No. 2110-- batch loss = 1.8297
Validation UAR: 0.5000
Validation accuracy: 0.0478
Validation loss: 1.3249


KeyboardInterrupt: 

In [6]:
from torch.utils.data import WeightedRandomSampler
def get_sampler(X, y):
    counts = [0, 0]
    for l in y:
        counts[l] += 1
    sample_weights = [1 / counts[0], 1 / counts[1]]
    print(sample_weights)
    return WeightedRandomSampler(num_samples=len(X), weights=sample_weights, replacement=True)

In [8]:
################ SET THIS TO CHANGE WINDOW SIZE OF THINGS BELOW:
wsize = 5
################
train_data = WindowedParDataset(X_train, y_train, embed, wsize)
dev_data = WindowedParDataset(X_val, y_val, embed, wsize)
test_data = WindowedParDataset(X_test, y_test, embed, wsize)
train_loader = DataLoader(train_data, batch_size=25, collate_fn=basic_collate_fn, shuffle=True)
dev_loader = DataLoader(dev_data, batch_size=25, collate_fn=basic_collate_fn, shuffle=True)
test_loader = DataLoader(test_data, batch_size=25, collate_fn=basic_collate_fn, shuffle=False)

Number of coherent windows: 9938
Number of incoherent windows: 200790
Number of coherent windows: 2520
Number of incoherent windows: 50221
Number of coherent windows: 1399
Number of incoherent windows: 27896


In [9]:
import itertools
from tqdm.notebook import tqdm

torch.cuda.empty_cache()

def search_param_utterance(wsize):
    """Experiemnt on different hyper parameters."""
    learning_rate, weight_decay = get_hyper_parameters()
    window_sizes = [wsize]
    print("learning rate from: {}\nweight_decay from: {}\nwindow from: {}".format(
        learning_rate, weight_decay, window_sizes
    ))
    best_model, best_stats = None, None
    best_accuracy, best_lr, best_wd, best_window_size = 0, 0, 0, 0
    for lr, wd, window_size in tqdm(itertools.product(learning_rate, weight_decay, window_sizes),
                           total=len(learning_rate) * len(weight_decay) * len(window_sizes)):
        net = FFNN(window_size, device).to(device)
        optim = get_optimizer(net, lr=lr, weight_decay=wd)
        model, stats = train_model(net, train_loader, dev_loader, optim, pos_weight=pos_weight, 
                                   num_epoch=100, collect_cycle=500, device=device, 
                                   verbose=True, patience=5, stopping_criteria='accuracy')
        # print accuracy
        print(f"{(lr, wd, window_size)}: {stats['accuracy']}")
        # update best parameters if needed
        if stats['accuracy'] > best_accuracy:
            best_accuracy = stats['accuracy']
            best_model, best_stats = model, stats
            best_lr, best_wd, best_window_size = lr, wd, window_size
            torch.save(best_model.state_dict(), 'best_rnn.pt')
    print("\n\nBest learning rate: {}, best weight_decay: {}, best window: {}".format(
        best_lr, best_wd, best_window_size))
    print("Accuracy: {:.4f}".format(best_accuracy))
    plot_loss(best_stats)
    return best_model
basic_model = search_param_utterance(wsize)

learning rate from: [0.01]
weight_decay from: [0.0002, 0.002, 0.005, 0.01, 0.02, 0.025, 0.04, 0.05, 0.1]
window from: [5]


  0%|          | 0/9 [00:00<?, ?it/s]

------------------------ Start Training ------------------------
Epoch No. 1--Iteration No. 8430-- batch loss = 0.5011
Validation UAR: 0.6831
Validation accuracy: 0.6610
Validation loss: 1.1516
Epoch No. 2--Iteration No. 16860-- batch loss = 0.4586
Validation UAR: 0.6941
Validation accuracy: 0.7197
Validation loss: 1.1206
Epoch No. 3--Iteration No. 25290-- batch loss = 0.3654
Validation UAR: 0.7113
Validation accuracy: 0.7126
Validation loss: 1.0954
Epoch No. 4--Iteration No. 33720-- batch loss = 0.1224
Validation UAR: 0.7217
Validation accuracy: 0.7233
Validation loss: 1.0623
Epoch No. 5--Iteration No. 42150-- batch loss = 0.5738
Validation UAR: 0.7360
Validation accuracy: 0.7277
Validation loss: 1.0376
Epoch No. 6--Iteration No. 50580-- batch loss = 0.4922
Validation UAR: 0.7415
Validation accuracy: 0.7418
Validation loss: 1.0135
Epoch No. 7--Iteration No. 59010-- batch loss = 0.3734
Validation UAR: 0.7402
Validation accuracy: 0.7587
Validation loss: 1.0033
Epoch No. 8--Iteration No.

Validation UAR: 0.6712
Validation accuracy: 0.6530
Validation loss: 1.1692
Epoch No. 11--Iteration No. 92730-- batch loss = 0.7569
Validation UAR: 0.6840
Validation accuracy: 0.6739
Validation loss: 1.1445
Epoch No. 12--Iteration No. 101160-- batch loss = 0.5702
Validation UAR: 0.6895
Validation accuracy: 0.6710
Validation loss: 1.1357
Epoch No. 13--Iteration No. 109590-- batch loss = 3.9095
Validation UAR: 0.6936
Validation accuracy: 0.6928
Validation loss: 1.1260
Epoch No. 14--Iteration No. 118020-- batch loss = 0.6763
Validation UAR: 0.6879
Validation accuracy: 0.6966
Validation loss: 1.1316
Epoch No. 15--Iteration No. 126450-- batch loss = 0.5887
Validation UAR: 0.6973
Validation accuracy: 0.7021
Validation loss: 1.1128
Epoch No. 16--Iteration No. 134880-- batch loss = 0.8481
Validation UAR: 0.7010
Validation accuracy: 0.6739
Validation loss: 1.1045
Epoch No. 17--Iteration No. 143310-- batch loss = 0.6534
Validation UAR: 0.6980
Validation accuracy: 0.6133
Validation loss: 1.1064
Ep

Epoch No. 23--Iteration No. 193890-- batch loss = 0.6883
Validation UAR: 0.5000
Validation accuracy: 0.9522
Validation loss: 1.3226
Epoch No. 24--Iteration No. 202320-- batch loss = 0.6882
Validation UAR: 0.5000
Validation accuracy: 0.9522
Validation loss: 1.3223
Epoch No. 25--Iteration No. 210750-- batch loss = 0.6882
Validation UAR: 0.5000
Validation accuracy: 0.9522
Validation loss: 1.3223
Epoch No. 26--Iteration No. 219180-- batch loss = 0.6883
Validation UAR: 0.5000
Validation accuracy: 0.9522
Validation loss: 1.3225
Epoch No. 27--Iteration No. 227610-- batch loss = 0.6883
Validation UAR: 0.5000
Validation accuracy: 0.9522
Validation loss: 1.3223
Epoch No. 28--Iteration No. 236040-- batch loss = 0.6883
Validation UAR: 0.5000
Validation accuracy: 0.9522
Validation loss: 1.3225
Epoch No. 29--Iteration No. 244470-- batch loss = 0.6882
Validation UAR: 0.5000
Validation accuracy: 0.9522
Validation loss: 1.3228
Epoch No. 30--Iteration No. 252900-- batch loss = 0.6883
Validation UAR: 0.5

Epoch No. 33--Iteration No. 278190-- batch loss = 0.6883
Validation UAR: 0.5000
Validation accuracy: 0.9522
Validation loss: 1.3223
Epoch No. 34--Iteration No. 286620-- batch loss = 0.6883
Validation UAR: 0.5000
Validation accuracy: 0.9522
Validation loss: 1.3223
Epoch No. 35--Iteration No. 295050-- batch loss = 0.6882
Validation UAR: 0.5000
Validation accuracy: 0.9522
Validation loss: 1.3225
Epoch No. 36--Iteration No. 303480-- batch loss = 0.6882
Validation UAR: 0.5000
Validation accuracy: 0.9522
Validation loss: 1.3223
Epoch No. 37--Iteration No. 311910-- batch loss = 5.1128
Validation UAR: 0.5000
Validation accuracy: 0.9522
Validation loss: 1.3225
Epoch No. 38--Iteration No. 320340-- batch loss = 0.6883
Validation UAR: 0.5000
Validation accuracy: 0.9522
Validation loss: 1.3225
Epoch No. 39--Iteration No. 328770-- batch loss = 0.6883
Validation UAR: 0.5000
Validation accuracy: 0.9522
Validation loss: 1.3225
Epoch No. 40--Iteration No. 337200-- batch loss = 5.1121
Validation UAR: 0.5

KeyboardInterrupt: 

In [11]:
basic_model = FFNN(5, device)
basic_model.load_state_dict(torch.load('best_rnn.pt'))
basic_model.eval()
basic_model.to(device)
uar, accuracy, total_loss = get_validation_performance(
    basic_model, 
    nn.BCEWithLogitsLoss(pos_weight=pos_weight), 
    test_loader, 
    device
)
print("Final selection: window size 5 with Q = 0.0002")
print("Test UAR: {:.4f}".format(uar))
print("Test accuracy: {:.4f}".format(accuracy))
print("Test loss: {:.4f}".format(total_loss))

Final selection: window size 5 with Q = 0.0002
Test UAR: 0.7660
Test accuracy: 0.7768
Test loss: 0.9321
