In [None]:
import torch
import numpy as np
from tqdm.notebook import tqdm

from datasets import load_dataset
from transformers import BertForSequenceClassification, Trainer, TrainingArguments

from bert_utils import TextSentiment, compute_centers_cov, mahalanobis_score, metrics_eval

imdb = load_dataset('imdb')
yelp = load_dataset("yelp_polarity")
tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'tokenizer', 'bert-base-cased')

trainset = TextSentiment(imdb['train'], tokenizer, finetune=True)
testset = TextSentiment(imdb['test'], tokenizer, finetune=True)

model = torch.hub.load('huggingface/pytorch-transformers',
                       'modelForSequenceClassification', 
                       'bert-base-cased',
                       output_attentions=False).cuda()
testloader = torch.utils.data.DataLoader(testset, batch_size=1,shuffle=True,  num_workers=2)

In [2]:
# This cell's code is adapted from HuggingFace library documentation

training_args = TrainingArguments(
    output_dir='./models/bert/results',          # output directory
    num_train_epochs=10,                         # total number of training epochs
    per_device_train_batch_size=16,  # batch size per device during training
    per_device_eval_batch_size=64,   # batch size for evaluation
    warmup_steps=500,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir='./models/bert/logs',            # directory for storing logs
    logging_steps=10,
)

trainer = Trainer(
    model=model,                         # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=trainset,              # training dataset
    eval_dataset=testset                 # evaluation dataset
)

trainer.train()

You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it.


HBox(children=(FloatProgress(value=0.0, description='Epoch', max=10.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1563.0, style=ProgressStyle(description_w…

{'loss': 0.7313716888427735, 'learning_rate': 1.0000000000000002e-06, 'epoch': 0.006397952655150352, 'total_flos': 53237420851200, 'step': 10}
{'loss': 0.7247463226318359, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.012795905310300703, 'total_flos': 106474841702400, 'step': 20}
{'loss': 0.7043634414672851, 'learning_rate': 3e-06, 'epoch': 0.019193857965451054, 'total_flos': 159712262553600, 'step': 30}
{'loss': 0.7048545837402344, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.025591810620601407, 'total_flos': 212949683404800, 'step': 40}
{'loss': 0.710145378112793, 'learning_rate': 5e-06, 'epoch': 0.03198976327575176, 'total_flos': 266187104256000, 'step': 50}
{'loss': 0.6929107666015625, 'learning_rate': 6e-06, 'epoch': 0.03838771593090211, 'total_flos': 319424525107200, 'step': 60}
{'loss': 0.6815090179443359, 'learning_rate': 7.000000000000001e-06, 'epoch': 0.044785668586052464, 'total_flos': 372661945958400, 'step': 70}
{'loss': 0.6785926818847656, 'learning_rate': 8.0

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1563.0, style=ProgressStyle(description_w…

{'loss': 0.1586822509765625, 'learning_rate': 4.6463978849966955e-05, 'epoch': 1.0044785668586051, 'total_flos': 8355613202595840, 'step': 1570}
{'loss': 0.119842529296875, 'learning_rate': 4.6430931923331135e-05, 'epoch': 1.0108765195137557, 'total_flos': 8408850623447040, 'step': 1580}
{'loss': 0.119952392578125, 'learning_rate': 4.6397884996695315e-05, 'epoch': 1.017274472168906, 'total_flos': 8462088044298240, 'step': 1590}
{'loss': 0.1265838623046875, 'learning_rate': 4.636483807005949e-05, 'epoch': 1.0236724248240563, 'total_flos': 8515325465149440, 'step': 1600}
{'loss': 0.237249755859375, 'learning_rate': 4.633179114342366e-05, 'epoch': 1.0300703774792066, 'total_flos': 8568562886000640, 'step': 1610}
{'loss': 0.439129638671875, 'learning_rate': 4.629874421678784e-05, 'epoch': 1.036468330134357, 'total_flos': 8621800306851840, 'step': 1620}
{'loss': 0.242950439453125, 'learning_rate': 4.6265697290152015e-05, 'epoch': 1.0428662827895074, 'total_flos': 8675037727703040, 'step': 1

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1563.0, style=ProgressStyle(description_w…

{'loss': 0.179095458984375, 'learning_rate': 4.1308658294778584e-05, 'epoch': 2.00255918106206, 'total_flos': 16657988984340480, 'step': 3130}
{'loss': 0.102685546875, 'learning_rate': 4.1275611368142764e-05, 'epoch': 2.0089571337172103, 'total_flos': 16711226405191680, 'step': 3140}
{'loss': 0.057818603515625, 'learning_rate': 4.1242564441506944e-05, 'epoch': 2.015355086372361, 'total_flos': 16764463826042880, 'step': 3150}
{'loss': 0.010772705078125, 'learning_rate': 4.1209517514871124e-05, 'epoch': 2.0217530390275114, 'total_flos': 16817701246894080, 'step': 3160}
{'loss': 0.179803466796875, 'learning_rate': 4.11764705882353e-05, 'epoch': 2.0281509916826614, 'total_flos': 16870938667745280, 'step': 3170}
{'loss': 0.086737060546875, 'learning_rate': 4.114342366159947e-05, 'epoch': 2.034548944337812, 'total_flos': 16924176088596480, 'step': 3180}
{'loss': 0.24334716796875, 'learning_rate': 4.111037673496365e-05, 'epoch': 2.040946896992962, 'total_flos': 16977413509447680, 'step': 3190

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1563.0, style=ProgressStyle(description_w…

{'loss': 0.229656982421875, 'learning_rate': 3.615333773959022e-05, 'epoch': 3.000639795265515, 'total_flos': 24960364766085120, 'step': 4690}
{'loss': 0.101324462890625, 'learning_rate': 3.61202908129544e-05, 'epoch': 3.0070377479206654, 'total_flos': 25013602186936320, 'step': 4700}
{'loss': 0.043963623046875, 'learning_rate': 3.608724388631857e-05, 'epoch': 3.013435700575816, 'total_flos': 25066839607787520, 'step': 4710}
{'loss': 0.02857666015625, 'learning_rate': 3.605419695968275e-05, 'epoch': 3.019833653230966, 'total_flos': 25120077028638720, 'step': 4720}
{'loss': 0.040313720703125, 'learning_rate': 3.602115003304693e-05, 'epoch': 3.0262316058861165, 'total_flos': 25173314449489920, 'step': 4730}
{'loss': 0.07674560546875, 'learning_rate': 3.5988103106411106e-05, 'epoch': 3.0326295585412666, 'total_flos': 25226551870341120, 'step': 4740}
{'loss': 0.260821533203125, 'learning_rate': 3.595505617977528e-05, 'epoch': 3.039027511196417, 'total_flos': 25279789291192320, 'step': 4750

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1563.0, style=ProgressStyle(description_w…

{'loss': 0.0196044921875, 'learning_rate': 3.096497025776603e-05, 'epoch': 4.00511836212412, 'total_flos': 33315977968680960, 'step': 6260}
{'loss': 0.03768310546875, 'learning_rate': 3.093192333113021e-05, 'epoch': 4.0115163147792705, 'total_flos': 33369215389532160, 'step': 6270}
{'loss': 0.0553955078125, 'learning_rate': 3.089887640449438e-05, 'epoch': 4.017914267434421, 'total_flos': 33422452810383360, 'step': 6280}
{'loss': 0.0177001953125, 'learning_rate': 3.086582947785856e-05, 'epoch': 4.024312220089572, 'total_flos': 33475690231234560, 'step': 6290}
{'loss': 0.0013671875, 'learning_rate': 3.083278255122274e-05, 'epoch': 4.030710172744722, 'total_flos': 33528927652085760, 'step': 6300}
{'loss': 0.001025390625, 'learning_rate': 3.0799735624586915e-05, 'epoch': 4.037108125399872, 'total_flos': 33582165072936960, 'step': 6310}
{'loss': 0.09556884765625, 'learning_rate': 3.076668869795109e-05, 'epoch': 4.043506078055023, 'total_flos': 33635402493788160, 'step': 6320}
{'loss': 0.047

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1563.0, style=ProgressStyle(description_w…

{'loss': 0.07691650390625, 'learning_rate': 2.580964970257766e-05, 'epoch': 5.0031989763275755, 'total_flos': 41618353750425600, 'step': 7820}
{'loss': 0.00125732421875, 'learning_rate': 2.577660277594184e-05, 'epoch': 5.009596928982726, 'total_flos': 41671591171276800, 'step': 7830}
{'loss': 0.00045166015625, 'learning_rate': 2.5743555849306017e-05, 'epoch': 5.015994881637876, 'total_flos': 41724828592128000, 'step': 7840}
{'loss': 0.0013916015625, 'learning_rate': 2.5710508922670197e-05, 'epoch': 5.022392834293027, 'total_flos': 41778066012979200, 'step': 7850}
{'loss': 0.00035400390625, 'learning_rate': 2.567746199603437e-05, 'epoch': 5.028790786948177, 'total_flos': 41831303433830400, 'step': 7860}
{'loss': 0.00096435546875, 'learning_rate': 2.5644415069398547e-05, 'epoch': 5.035188739603327, 'total_flos': 41884540854681600, 'step': 7870}
{'loss': 0.05316162109375, 'learning_rate': 2.5611368142762727e-05, 'epoch': 5.041586692258477, 'total_flos': 41937778275532800, 'step': 7880}
{'

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1563.0, style=ProgressStyle(description_w…

{'loss': 0.02330322265625, 'learning_rate': 2.0654329147389292e-05, 'epoch': 6.00127959053103, 'total_flos': 49920729532170240, 'step': 9380}
{'loss': 0.0651123046875, 'learning_rate': 2.0621282220753472e-05, 'epoch': 6.007677543186181, 'total_flos': 49973966953021440, 'step': 9390}
{'loss': 0.04395751953125, 'learning_rate': 2.058823529411765e-05, 'epoch': 6.014075495841331, 'total_flos': 50027204373872640, 'step': 9400}
{'loss': 0.0038330078125, 'learning_rate': 2.0555188367481825e-05, 'epoch': 6.020473448496481, 'total_flos': 50080441794723840, 'step': 9410}
{'loss': 0.00087890625, 'learning_rate': 2.0522141440846002e-05, 'epoch': 6.026871401151632, 'total_flos': 50133679215575040, 'step': 9420}
{'loss': 0.00146484375, 'learning_rate': 2.048909451421018e-05, 'epoch': 6.033269353806782, 'total_flos': 50186916636426240, 'step': 9430}
{'loss': 0.05294189453125, 'learning_rate': 2.0456047587574356e-05, 'epoch': 6.039667306461932, 'total_flos': 50240154057277440, 'step': 9440}
{'loss': 0

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1563.0, style=ProgressStyle(description_w…

{'loss': 0.04398193359375, 'learning_rate': 1.5465961665565104e-05, 'epoch': 7.005758157389636, 'total_flos': 58276342734766080, 'step': 10950}
{'loss': 0.0007568359375, 'learning_rate': 1.543291473892928e-05, 'epoch': 7.012156110044786, 'total_flos': 58329580155617280, 'step': 10960}
{'loss': 0.01597900390625, 'learning_rate': 1.5399867812293457e-05, 'epoch': 7.018554062699936, 'total_flos': 58382817576468480, 'step': 10970}
{'loss': 0.0195556640625, 'learning_rate': 1.5366820885657634e-05, 'epoch': 7.024952015355086, 'total_flos': 58436054997319680, 'step': 10980}
{'loss': 0.00074462890625, 'learning_rate': 1.533377395902181e-05, 'epoch': 7.031349968010237, 'total_flos': 58489292418170880, 'step': 10990}
{'loss': 0.03416748046875, 'learning_rate': 1.5300727032385988e-05, 'epoch': 7.037747920665387, 'total_flos': 58542529839022080, 'step': 11000}
{'loss': 0.0007080078125, 'learning_rate': 1.5267680105750164e-05, 'epoch': 7.044145873320537, 'total_flos': 58595767259873280, 'step': 1101

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1563.0, style=ProgressStyle(description_w…

{'loss': 0.0463134765625, 'learning_rate': 1.0310641110376736e-05, 'epoch': 8.00383877159309, 'total_flos': 66578718516510720, 'step': 12510}
{'loss': 0.000537109375, 'learning_rate': 1.0277594183740913e-05, 'epoch': 8.01023672424824, 'total_flos': 66631955937361920, 'step': 12520}
{'loss': 0.00054931640625, 'learning_rate': 1.024454725710509e-05, 'epoch': 8.01663467690339, 'total_flos': 66685193358213120, 'step': 12530}
{'loss': 0.00048828125, 'learning_rate': 1.0211500330469268e-05, 'epoch': 8.023032629558541, 'total_flos': 66738430779064320, 'step': 12540}
{'loss': 0.00050048828125, 'learning_rate': 1.0178453403833443e-05, 'epoch': 8.029430582213692, 'total_flos': 66791668199915520, 'step': 12550}
{'loss': 0.00047607421875, 'learning_rate': 1.0145406477197621e-05, 'epoch': 8.035828534868841, 'total_flos': 66844905620766720, 'step': 12560}
{'loss': 0.00047607421875, 'learning_rate': 1.0112359550561798e-05, 'epoch': 8.042226487523992, 'total_flos': 66898143041617920, 'step': 12570}
{'

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=1563.0, style=ProgressStyle(description_w…

{'loss': 0.00037841796875, 'learning_rate': 5.155320555188368e-06, 'epoch': 9.001919385796546, 'total_flos': 74881094298255360, 'step': 14070}
{'loss': 0.00037841796875, 'learning_rate': 5.122273628552545e-06, 'epoch': 9.008317338451695, 'total_flos': 74934331719106560, 'step': 14080}
{'loss': 0.0003662109375, 'learning_rate': 5.089226701916721e-06, 'epoch': 9.014715291106846, 'total_flos': 74987569139957760, 'step': 14090}
{'loss': 0.0003662109375, 'learning_rate': 5.056179775280899e-06, 'epoch': 9.021113243761997, 'total_flos': 75040806560808960, 'step': 14100}
{'loss': 0.0003662109375, 'learning_rate': 5.0231328486450765e-06, 'epoch': 9.027511196417146, 'total_flos': 75094043981660160, 'step': 14110}
{'loss': 0.0003662109375, 'learning_rate': 4.990085922009253e-06, 'epoch': 9.033909149072297, 'total_flos': 75147281402511360, 'step': 14120}
{'loss': 0.00035400390625, 'learning_rate': 4.95703899537343e-06, 'epoch': 9.040307101727446, 'total_flos': 75200518823362560, 'step': 14130}
{'l

TrainOutput(global_step=15630, training_loss=0.08469140656240003)

In [3]:
torch.save(model, './models/bert/IMDB_finetuned10')

In [20]:
testloader = torch.utils.data.DataLoader(testset, batch_size=16,shuffle=True,  num_workers=2)

def eval_acc(model, quantity_target=1000):
    correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for data in tqdm(testloader, total=quantity_target//testloader.batch_size):
            images, labels = data['input_ids'].cuda(), data['labels'].cuda()
            outputs = model(images)[0]
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted.view(-1,1) == labels).sum().item()
            if total>=quantity_target:
                break
    model.train()
    print(f'Validation accuracy: {100 * correct / total:.2f}')
    return(correct / total)
eval_acc(model)

HBox(children=(FloatProgress(value=0.0, max=62.0), HTML(value='')))


Validation accuracy: 80.75


0.8075396825396826