In [17]:
from google.colab import drive
drive.mount('/content/drive')

import os
path = '/content/drive/My Drive/NLPProject'
os.chdir(path)

import sys
sys.path.append('/content/drive/My Drive/NLPProject')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!pip install datasets
!pip install transformers

Collecting datasets
  Downloading datasets-2.15.0-py3-none-any.whl (521 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m521.2/521.2 kB[0m [31m6.7 MB/s[0m eta [36m0:00:00[0m
Collecting pyarrow-hotfix (from datasets)
  Downloading pyarrow_hotfix-0.6-py3-none-any.whl (7.9 kB)
Collecting dill<0.3.8,>=0.3.0 (from datasets)
  Downloading dill-0.3.7-py3-none-any.whl (115 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m8.8 MB/s[0m eta [36m0:00:00[0m
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m7.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: pyarrow-hotfix, dill, multiprocess, datasets
Successfully installed datasets-2.15.0 dill-0.3.7 multiprocess-0.70.15 pyarrow-hotfix-0.6


In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, Dataset
import collections
from transformers import ElectraTokenizer, ElectraForTokenClassification
from torch.utils.data.sampler import SubsetRandomSampler
from datasets import load_dataset
from data_loader import EventSentenceLoader
from tqdm import tqdm
import re
import numpy as np
import argparse

tokenizer_config.json:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]

In [4]:
use_freeze_model = False
validation_run = True
learning_rate = 0.00001

In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
tokenizer = ElectraTokenizer.from_pretrained("google/electra-base-discriminator")

def tokenize(batched_text):
    return tokenizer(batched_text['sentence'], padding=True, truncation=True, max_length=256)


tokenizer_config.json:   0%|          | 0.00/27.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/666 [00:00<?, ?B/s]

In [6]:
class TextClassificationDataset(Dataset):
    def __init__(self, data):
            self.data_points = data

    def __len__(self):
        return len(self.data_points)

    def __getitem__(self, idx):
        data = self.data_points[idx]
        return {'tokens': data['tokens'].flatten(), 'attention': data['attention'].flatten(), 'labels': torch.tensor(data['labels'])}

In [7]:
def load_electra_data():
    filepath = "events.txt"
    tokenizer_name = "google/electra-base-discriminator"  # Use Electra tokenizer
    loader = EventSentenceLoader(filepath, tokenizer_name)
    train_data = TextClassificationDataset(loader.load_data())

    train_size = int(0.9 * len(train_data))
    test_size = len(train_data) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(train_data, [train_size, test_size])
    train_data = list(map(lambda d: (torch.tensor(d['tokens']), d['labels'], d['attention']), train_dataset))
    test_data = list(map(lambda d: (torch.tensor(d['tokens']), d['labels'], d['attention']), test_dataset))

    return train_data, test_data

In [8]:
def test(model, data):
    confusion_matrix_size = 2
    confusion_matrix = []

    total = 0
    correct = 0
    total_real_words = 0

    for i in range(confusion_matrix_size):
        row = []
        for j in range(confusion_matrix_size):
            row.append(0)
        confusion_matrix.append(row)

    with torch.no_grad():
        for i, batch in tqdm(enumerate(data, 0), total=len(data), leave=False):
            inputs, labels, attention = batch
            inputs, labels, attention = inputs.to(device), labels.to(device), attention.to(device)

            outputs = model(inputs).logits.to(device)

            predicted = torch.argmax(outputs, 2).flatten()
            labels = labels.flatten()
            attention = attention.flatten()

            for j in range(len(predicted)):
                if attention[j]:
                    confusion_matrix[labels[j].item()][predicted[j].item()] += 1

            total += len(predicted)
            correct += predicted.eq(labels).sum().item()
    print(f"Total Real Words: {total_real_words}")
    print("Test Accuracy: {:.3f}".format(correct/total))
    print(confusion_matrix)
    return confusion_matrix

In [9]:
def train(model, data, epochs):
    n = len(data)

    # Define metrics to monitor change in performance during execution
    accuracy_history_epoch = []
    accuracy_history_step = []

    optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
    loss_func = nn.CrossEntropyLoss(weight=torch.tensor([0.1, 0.9])).to(device)

    for epoch in range(1, epochs + 1):
        correct = 0
        total = 0
        for i, batch in tqdm(enumerate(data, 0), total=len(data), leave=False):
            inputs, labels, attention = batch
            inputs, labels, attention = inputs.to(device), labels.to(device), attention.to(device)

            outputs = model(inputs).logits.to(device)
            predicted = torch.argmax(outputs, 2)

            optimizer.zero_grad()

            loss = loss_func(outputs.flatten(start_dim=0, end_dim=1), F.one_hot(labels).float().flatten(start_dim=0, end_dim=1))
            loss.backward()
            optimizer.step()

            total += len(predicted)
            correct += predicted.eq(labels).sum().item()
            accuracy_history_step.append((i+1, correct/total))

        accuracy_history_epoch.append(correct / total)
        print("Epoch: {:>3d} Accuracy: {:.3f}".format(epoch, accuracy_history_epoch[-1]))

    return accuracy_history_epoch, accuracy_history_step

In [10]:
filepath = "events.txt"
tokenizer_name = "google/electra-base-discriminator"
loader = EventSentenceLoader(filepath, tokenizer_name)
train_data = TextClassificationDataset(loader.load_data())

train_data, test_data = load_electra_data()

  return {'tokens': data['tokens'].flatten(), 'attention': data['attention'].flatten(), 'labels': torch.tensor(data['labels'])}
  test_data = list(map(lambda d: (torch.tensor(d['tokens']), d['labels'], d['attention']), test_dataset))


In [11]:
model = ElectraForTokenClassification.from_pretrained("google/electra-base-discriminator").to(device)

pytorch_model.bin:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of ElectraForTokenClassification were not initialized from the model checkpoint at google/electra-base-discriminator and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
# file_path = "model_state.pth"
# model.load_state_dict(torch.load(file_path))

In [12]:
indices = list(range(len(train_data)))

train_dataloader = DataLoader(train_data, batch_size=32, num_workers=0, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=32, num_workers=0, shuffle=0)

In [13]:
train(model, train_dataloader, 50)

  0%|          | 0/94 [00:00<?, ?it/s]We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


Epoch:   1 Accuracy: 237.897




Epoch:   2 Accuracy: 238.711




Epoch:   3 Accuracy: 239.035




Epoch:   4 Accuracy: 239.246




Epoch:   5 Accuracy: 239.139




Epoch:   6 Accuracy: 239.411




Epoch:   7 Accuracy: 240.217




Epoch:   8 Accuracy: 240.408




Epoch:   9 Accuracy: 240.722




Epoch:  10 Accuracy: 241.416




Epoch:  11 Accuracy: 242.273




Epoch:  12 Accuracy: 243.140




Epoch:  13 Accuracy: 244.136




Epoch:  14 Accuracy: 244.784




Epoch:  15 Accuracy: 245.361




Epoch:  16 Accuracy: 246.131




Epoch:  17 Accuracy: 246.457




Epoch:  18 Accuracy: 246.926




Epoch:  19 Accuracy: 247.408




Epoch:  20 Accuracy: 247.636




Epoch:  21 Accuracy: 248.122




Epoch:  22 Accuracy: 248.301




Epoch:  23 Accuracy: 248.548




Epoch:  24 Accuracy: 248.749




Epoch:  25 Accuracy: 248.992




Epoch:  26 Accuracy: 249.348




Epoch:  27 Accuracy: 249.449




Epoch:  28 Accuracy: 249.686




Epoch:  29 Accuracy: 250.018




Epoch:  30 Accuracy: 250.244




Epoch:  31 Accuracy: 250.328




Epoch:  32 Accuracy: 250.498




Epoch:  33 Accuracy: 250.650




Epoch:  34 Accuracy: 250.962




Epoch:  35 Accuracy: 251.088




Epoch:  36 Accuracy: 251.261




Epoch:  37 Accuracy: 251.459




Epoch:  38 Accuracy: 251.636




Epoch:  39 Accuracy: 251.809




Epoch:  40 Accuracy: 251.911




Epoch:  41 Accuracy: 251.844




Epoch:  42 Accuracy: 252.120




Epoch:  43 Accuracy: 252.200




Epoch:  44 Accuracy: 252.379




Epoch:  45 Accuracy: 252.423




Epoch:  46 Accuracy: 252.607




Epoch:  47 Accuracy: 252.694




Epoch:  48 Accuracy: 252.687




Epoch:  49 Accuracy: 252.790


                                               

Epoch:  50 Accuracy: 253.002




([237.89687292082502,
  238.71057884231536,
  239.03526280771788,
  239.24584165003327,
  239.13872255489022,
  239.4105123087159,
  240.21723220226215,
  240.40785096473718,
  240.72155688622755,
  241.41550232867598,
  242.27345309381238,
  243.13972055888223,
  244.13639387890885,
  244.78443113772454,
  245.36094477711245,
  246.13140385894877,
  246.4570858283433,
  246.92614770459082,
  247.40818363273453,
  247.63572854291417,
  248.12208915502327,
  248.30073186959416,
  248.54757152361944,
  248.74850299401197,
  248.99234863606122,
  249.34830339321357,
  249.44876912840985,
  249.68629407850966,
  250.01763140385896,
  250.24417831004658,
  250.3280106453759,
  250.49767132401863,
  250.65003326679974,
  250.96240851630074,
  251.0878243512974,
  251.26114437791085,
  251.45908183632736,
  251.63639387890885,
  251.80871590153026,
  251.9111776447106,
  251.84397870924818,
  252.11976047904193,
  252.19993346640052,
  252.37924151696606,
  252.42348636061212,
  252.607119095

In [14]:
test(model, test_dataloader)

                                               

Total Real Words: 0
Test Accuracy: 0.958
[[844, 821], [361, 2449]]




[[844, 821], [361, 2449]]

In [21]:
torch.save(model.state_dict(), '/content/drive/My Drive/NLPProject/model_state_electra.pth')