In [1]:
!pip install datasets
!pip install transformers

Defaulting to user installation because normal site-packages is not writeable
You should consider upgrading via the '/share/apps/python/3.8.6/intel/bin/python -m pip install --upgrade pip' command.[0m
Defaulting to user installation because normal site-packages is not writeable
You should consider upgrading via the '/share/apps/python/3.8.6/intel/bin/python -m pip install --upgrade pip' command.[0m


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, Dataset
import collections
from transformers import RobertaTokenizer, RobertaForMaskedLM
from torch.utils.data.sampler import SubsetRandomSampler
from datasets import load_dataset
from data_loader_masked import EventSentenceLoader
from tqdm import tqdm
import re
import numpy as np
import argparse

In [3]:
use_freeze_model = False
validation_run = True
learning_rate = 0.00001

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')


def tokenize(batched_text):
    return tokenizer(batched_text['sentence'], padding=True, truncation=True, max_length=256)

In [5]:
class TextClassificationDataset(Dataset):
    def __init__(self, data):
            self.data_points = data

    def __len__(self):
        return len(self.data_points)

    def __getitem__(self, idx):
        data = self.data_points[idx]
        mask_id = tokenizer.mask_token_id
        new_tokens = torch.clone(data['tokens'])
        new_tokens = (new_tokens - (torch.ones(new_tokens.shape) * mask_id)) * (torch.ones(data['maskable'].shape) - data['maskable']) + (torch.ones(new_tokens.shape) * mask_id)
        new_tokens = new_tokens.long()
        old_tokens = data['tokens'].long()
        return {'labels': old_tokens.flatten(), 'attention': data['attention'].flatten(), 'tokens': new_tokens.flatten(), 'masked': data['maskable']}

In [6]:
def load_roberta_data():
    filepath = "new_sentences.txt"
    tokenizer_name = "roberta-base"
    loader = EventSentenceLoader(filepath, tokenizer_name)
    train_data = TextClassificationDataset(loader.load_data())

    train_size = int(0.9 * len(train_data))
    test_size = len(train_data) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(train_data, [train_size, test_size])
    train_data = list(map(lambda d: (torch.tensor(d['tokens']), d['labels'], d['attention'], d['masked']), train_dataset))
    test_data = list(map(lambda d: (torch.tensor(d['tokens']), d['labels'], d['attention'], d['masked']), test_dataset))

    return train_data, test_data

In [7]:
model = RobertaForMaskedLM.from_pretrained("roberta-base").to(device)

In [23]:
def test(model, data):
    confusion_matrix_size = 2
    confusion_matrix = []

    total = 0
    correct = 0
    total_real_words = 0

    for i in range(confusion_matrix_size):
        row = []
        for j in range(confusion_matrix_size):
            row.append(0)
        confusion_matrix.append(row)

    with torch.no_grad():
        total_loss = 0
        total_masked = 0
        wrong = 0
        for i, batch in tqdm(enumerate(data, 0), total=len(data), leave=False):
            inputs, labels, attention, mask = batch
            inputs, labels, attention, mask = inputs.to(device), labels.to(device), attention.to(device), mask.long().to(device)
            outputs = model(input_ids = inputs, attention_mask = attention, labels = labels)
            total_masked += mask.sum()
            predicted = outputs.logits.argmax(dim=-1)
            correct = labels.eq(predicted).sum()
            incorrect = labels.ne(predicted).sum()
            wrong += incorrect
            print(correct)
            
            loss = outputs.loss
            total_loss += loss.item()
        print((total_masked - wrong)/total_masked)
        
    return confusion_matrix

In [9]:
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
loss_func = nn.CrossEntropyLoss().to(device)

def train(model, data, epochs):
    n = len(data)

    # Define metrics to monitor change in performance during execution
    accuracy_history_epoch = []
    accuracy_history_step = []

    optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
    loss_func = nn.CrossEntropyLoss(weight=torch.tensor([0.1, 0.9])).to(device)

    for epoch in range(1, epochs + 1):
        correct = 0
        total = 0
        total_loss = 0
        for i, batch in tqdm(enumerate(data, 0), total=len(data), leave=False):
            inputs, labels, attention, mask = batch
            inputs, labels, attention, mask = inputs.to(device), labels.to(device), attention.to(device), mask.long().to(device)

            outputs = model(input_ids = inputs, attention_mask = attention, labels = labels)
            loss = outputs.loss
            total_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(total_loss)

    return accuracy_history_epoch, accuracy_history_step

In [10]:
filepath = "new_sentences.txt"
tokenizer_name = "roberta-base"
loader = EventSentenceLoader(filepath, tokenizer_name)
train_data = TextClassificationDataset(loader.load_data())

train_data, test_data = load_roberta_data()

  train_data = list(map(lambda d: (torch.tensor(d['tokens']), d['labels'], d['attention'], d['masked']), train_dataset))
  test_data = list(map(lambda d: (torch.tensor(d['tokens']), d['labels'], d['attention'], d['masked']), test_dataset))


In [11]:
indices = list(range(len(train_data)))

train_dataloader = DataLoader(train_data, batch_size=32, num_workers=0, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=32, num_workers=0, shuffle=0)

In [12]:
test(model, test_dataloader)

                                             

tensor(906, device='cuda:0')
tensor(37, device='cuda:0')
54.10102653503418




[[0, 0], [0, 0]]

In [15]:
train(model, train_dataloader, 250)

                                             

29.711336612701416


                                             

26.135920763015747


                                             

24.40906524658203


                                             

23.138280153274536


                                             

22.077983140945435


                                             

21.070541620254517


                                             

20.117915630340576


                                             

19.190329551696777


                                             

18.28682279586792


                                             

17.39822781085968


                                             

16.526934385299683


                                             

15.674032807350159


                                             

14.839802742004395


                                             

14.029088020324707


                                             

13.241135001182556


                                             

12.482263445854187


                                             

11.749879717826843


                                             

11.04584264755249


                                             

10.37120771408081


                                             

9.732224941253662


                                             

9.118555128574371


                                             

8.541335344314575


                                             

7.99760764837265


                                             

7.48728084564209


                                             

7.004127204418182


                                             

6.554677605628967


                                             

6.137194991111755


                                             

5.7463348507881165


                                             

5.381930947303772


                                             

5.044079899787903


                                             

4.732141315937042


                                             

4.443610429763794


                                             

4.176486700773239


                                             

3.9305151998996735


                                             

3.7026806473731995


                                             

3.490837872028351


                                             

3.2962223291397095


                                             

3.1154843866825104


                                             

2.9482697546482086


                                             

2.7986108660697937


                                             

2.6550019681453705


                                             

2.5181517601013184


                                             

2.3946366608142853


                                             

2.2802925556898117


                                             

2.173291325569153


                                             

2.0773893296718597


                                             

1.9900087863206863


                                             

1.89817613363266


                                             

1.8118488490581512


                                             

1.7339410334825516


                                             

1.6614978164434433


                                             

1.5942596197128296


                                             

1.5304213464260101


                                             

1.4710438549518585


                                             

1.4150051921606064


                                             

1.3623643964529037


                                             

1.3124489933252335


                                             

1.265441432595253


                                             

1.2211984544992447


                                             

1.1791123449802399


                                             

1.1392382606863976


                                             

1.1013981476426125


                                             

1.0655389577150345


                                             

1.0315297693014145


                                             

0.9991165995597839


                                             

0.968226969242096


                                             

0.9388623833656311


                                             

0.9108464866876602


                                             

0.8841251134872437


                                             

0.8586027398705482


                                             

0.8342795595526695


                                             

0.8109105080366135


                                             

0.7885412946343422


                                             

0.7672731950879097


                                             

0.746734119951725


                                             

0.7270984798669815


                                             

0.7083316668868065


                                             

0.6901979222893715


                                             

0.6727723404765129


                                             

0.6561450883746147


                                             

0.6400412768125534


                                             

0.6246257424354553


                                             

0.609674908220768


                                             

0.5954856425523758


                                             

0.5815936289727688


                                             

0.5682200118899345


                                             

0.5553411468863487


                                             

0.5429240092635155


                                             

0.5309958606958389


                                             

0.5194193944334984


                                             

0.5081970989704132


                                             

0.4975043572485447


                                             

0.4869340881705284


                                             

0.47680703923106194


                                             

0.4669746197760105


                                             

0.4574982337653637


                                             

0.44829118624329567


                                             

0.4394230581820011


                                             

0.4308379366993904


                                             

0.42236923053860664


                                             

0.41430773213505745


                                             

0.406345222145319


                                             

0.39869193732738495


                                             

0.3912259377539158


                                             

0.3840673640370369


                                             

0.37695325911045074


                                             

0.3702802248299122


                                             

0.3635435700416565


                                             

0.3570100776851177


                                             

0.35075176134705544


                                             

0.34466907009482384


                                             

0.33871160447597504


                                             

0.33288028091192245


                                             

0.3272661156952381


                                             

0.321723286062479


                                             

0.3164125420153141


                                             

0.31111928075551987


                                             

0.30606045573949814


                                             

0.3010888881981373


                                             

0.29619552567601204


                                             

0.2915178593248129


                                             

0.2869356442242861


                                             

0.28240259177982807


                                             

0.27797610871493816


                                             

0.2736971750855446


                                             

0.26945071294903755


                                             

0.26538744010031223


                                             

0.2613530643284321


                                             

0.25752423144876957


                                             

0.2536883242428303


                                             

0.24989788606762886


                                             

0.24629191495478153


                                             

0.24274104088544846


                                             

0.2392042651772499


                                             

0.23570933565497398


                                             

0.23237435333430767


                                             

0.22913488745689392


                                             

0.22591346874833107


                                             

0.22272220253944397


                                             

0.21969072706997395


                                             

0.21662299521267414


                                             

0.21369541250169277


                                             

0.21083426102995872


                                             

0.20794836431741714


                                             

0.20512838661670685


                                             

0.20247760973870754


                                             

0.19977758824825287


                                             

0.19716637581586838


                                             

0.1945686861872673


                                             

0.19204970635473728


                                             

0.1895897164940834


                                             

0.18713048100471497


                                             

0.1848881132900715


                                             

0.18245123513042927


                                             

0.180338766425848


                                             

0.17792687192559242


                                             

0.17564851231873035


                                             

0.1736651137471199


                                             

0.17150544188916683


                                             

0.16932220570743084


                                             

0.16721546463668346


                                             

0.16521713510155678


                                             

0.16327484138309956


                                             

0.16136290691792965


                                             

0.159376947209239


                                             

0.1574669498950243


                                             

0.1556359026581049


                                             

0.1537852194160223


                                             

0.15201094187796116


                                             

0.15024039708077908


                                             

0.1485196277499199


                                             

0.14680772088468075


                                             

0.14519287645816803


                                             

0.1435074284672737


                                             

0.14191519562155008


                                             

0.14028755575418472


                                             

0.13876054342836142


                                             

0.13720765803009272


                                             

0.13567781541496515


                                             

0.13419813103973866


                                             

0.13272858876734972


                                             

0.1313318870961666


                                             

0.1298193372786045


                                             

0.12844784185290337


                                             

0.12703385576605797


                                             

0.12571386620402336


                                             

0.12447452545166016


                                             

0.12315678875893354


                                             

0.12177153583616018


                                             

0.1205069050192833


                                             

0.11921308562159538


                                             

0.11801767069846392


                                             

0.1167097082361579


                                             

0.11560201272368431


                                             

0.11431475915014744


                                             

0.11316460464149714


                                             

0.11201034486293793


                                             

0.11090518813580275


                                             

0.10973872430622578


                                             

0.10860410239547491


                                             

0.10752711445093155


                                             

0.10643173288553953


                                             

0.10537065844982862


                                             

0.10434712935239077


                                             

0.10329312831163406


                                             

0.10231399070471525


                                             

0.10128540731966496


                                             

0.10030064266175032


                                             

0.09935665223747492


                                             

0.09838381130248308


                                             

0.0974264433607459


                                             

0.09649827051907778


                                             

0.09560628980398178


                                             

0.09468593448400497


                                             

0.09381653647869825


                                             

0.092932497151196


                                             

0.09209473058581352


                                             

0.09123003017157316


                                             

0.09039056114852428


                                             

0.08956308756023645


                                             

0.08870424143970013


                                             

0.08788255788385868


                                             

0.08711112011224031


                                             

0.0862960834056139


                                             

0.0855494188144803


                                             

0.08481200784444809


                                             

0.08403990045189857


                                             

0.0832316791638732


                                             

0.08253340423107147


                                             

0.08179142605513334


                                             

0.081051766872406


                                             

0.08031316474080086


                                             

0.07966537680476904


                                             

0.07899958454072475


                                             

0.0782327326014638


                                             

0.07757341861724854


                                             

0.07690133061259985


                                             

0.07618743926286697


                                             

0.0756043279543519


                                             

0.07491237483918667


                                             

0.07426047883927822


                                             

0.07361730933189392


                                             

0.07305014599114656


                                             

0.07251165620982647


                                             

0.07179631851613522


                                             

0.07123096566647291


                                             

0.07065234938636422


                                             

0.06999180931597948


                                             

0.06943801511079073


                                             

0.06884032860398293




([], [])

In [24]:
test(model, test_dataloader)

                                             

tensor(8188, device='cuda:0')
tensor(256, device='cuda:0')
tensor(0.9259, device='cuda:0')




[[0, 0], [0, 0]]