In [None]:
import torch

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!git clone https://github.com/anayap0/strategyqa_v2.git

Cloning into 'strategyqa_v2'...
remote: Enumerating objects: 70, done.[K
remote: Counting objects: 100% (70/70), done.[K
remote: Compressing objects: 100% (50/50), done.[K
remote: Total 70 (delta 33), reused 52 (delta 18), pack-reused 0[K
Receiving objects: 100% (70/70), 25.82 MiB | 15.84 MiB/s, done.
Resolving deltas: 100% (33/33), done.


In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Device:', DEVICE)

Device: cuda


In [None]:
from strategyqa_v2.src.SQP1Dataset import initialize_datasets, SQP1Dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, T5Tokenizer, T5Model, T5ForConditionalGeneration
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from torch.optim import Optimizer, AdamW
from tqdm.notebook import tqdm

# Load tokenizer
tokenizer = T5Tokenizer.from_pretrained("t5-base")
model = T5ForConditionalGeneration.from_pretrained('t5-base').to(DEVICE)

datasets = initialize_datasets('strategyqa_v2/data/train.json', 'strategyqa_v2/data/dev.json', tokenizer)
print(datasets['train'][0])
train_dataloader = DataLoader(datasets['train'],
                                   batch_size=32,
                                   shuffle=False,
                                   collate_fn=SQP1Dataset.collate_fn)

validation_dataloader = DataLoader(datasets['dev'],
                                   batch_size=32,
                                   shuffle=False,
                                   collate_fn=SQP1Dataset.collate_fn)

# print(validation_dataloader)
# for data in validation_dataloader:
#   print(data)
batch = next(iter(validation_dataloader))

print(f"{len(datasets['dev'])}")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


SQP1Example(question='Are more people today related to Genghis Khan than Julius Caesar?', decompositions=['How many kids did Julius Caesar have?', 'How many kids did Genghis Khan have?', 'Is #2 greater than #1?'])
229


In [None]:
#### THIS WORKS WITHOUT ERRORS
input_question = "Are more people today related to Genghis Khan than Julius Caesar?"
decompositions = [
            "How many kids did Julius Caesar have?",
            "How many kids did Genghis Khan have?",
            "Is #2 greater than #1?"
        ]

inputs = tokenizer(input_question, return_tensors="pt", padding=True, truncation=True).input_ids.to(DEVICE)
outputs = tokenizer("<SEP>".join(decompositions), return_tensors="pt", padding=True, truncation=True).input_ids.to(DEVICE)
print(inputs)
print(outputs)

model.train()
# Fine-tuning
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
for epoch in range(5):
    optimizer.zero_grad()
    oputs = model(input_ids=inputs, labels=outputs)
    loss = oputs.loss
    loss.backward()
    optimizer.step()

# evaluate model after training on one example 5 times
model.eval()
predictions = model.generate(input_ids=inputs, max_length=512)
print(tokenizer.decode(predictions[0], skip_special_tokens=True))


tensor([[ 1521,    72,   151,   469,  1341,    12,  5945,  5649,     7, 14420,
           145,  9983,   302, 26218,    58,     1]], device='cuda:0')
tensor([[  571,   186,  1082,   410,  9983,   302, 26218,    43,    58,     2,
           134,  8569,  3155,  7825,   186,  1082,   410,  5945,  5649,     7,
         14420,    43,    58,     2,   134,  8569,  3155,   196,     7, 15493,
          2123,   145,  7172,    58,     1]], device='cuda:0')
Are more people related to Genghis Khan than Julius Caesar?


In [None]:
def train_one_epoch(model: nn.Module, train_dataloader: DataLoader, optimizer: Optimizer, epoch: int):

    model.train()
    with tqdm(train_dataloader, desc=f"Train Ep {epoch}", total=len(train_dataloader)) as tq:
        for batch in tq:
            inputs = batch['input_ids'].input_ids.to(model.device)
            outputs = batch['target_ids'].input_ids.to(model.device)

            loss = model(input_ids=inputs, labels=outputs).loss
            print(f"loss at epoch {epoch}: {loss}")
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()


In [None]:
def evaluate(model: nn.Module, dataloader: DataLoader):
  model.eval()
  all_predictions = []
  with torch.no_grad():
    with tqdm(dataloader, desc=f"Train Ep {epoch}", total=len(train_dataloader)) as tq:
      for batch in tq:
        inputs = batch['input_ids'].input_ids.to(DEVICE)
        predictions = model.generate(input_ids=inputs, max_length=512)
        all_predictions += predictions
        # text predictions

        # print(tokenizer.decode(pred, skip_special_tokens=True) for pred in predictions)

  return all_predictions


In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
for i in range(1, 11):
  train_one_epoch(model, train_dataloader, optimizer, i)

Train Ep 1:   0%|          | 0/65 [00:00<?, ?it/s]

loss at epoch 1: 9.07058334350586
loss at epoch 1: 7.698611259460449
loss at epoch 1: 5.025505542755127
loss at epoch 1: 4.127199649810791
loss at epoch 1: 3.6759033203125
loss at epoch 1: 4.319543838500977
loss at epoch 1: 3.559028387069702
loss at epoch 1: 2.4993107318878174
loss at epoch 1: 2.776299476623535
loss at epoch 1: 2.445887565612793
loss at epoch 1: 2.60064959526062
loss at epoch 1: 2.539801836013794
loss at epoch 1: 2.5269269943237305
loss at epoch 1: 2.420578718185425
loss at epoch 1: 2.1620712280273438
loss at epoch 1: 2.3524487018585205
loss at epoch 1: 2.583688974380493
loss at epoch 1: 2.518007755279541
loss at epoch 1: 2.2333505153656006
loss at epoch 1: 1.8560748100280762
loss at epoch 1: 1.88676917552948
loss at epoch 1: 2.6253092288970947
loss at epoch 1: 2.1314871311187744
loss at epoch 1: 2.212197780609131
loss at epoch 1: 2.075395107269287
loss at epoch 1: 1.967165231704712
loss at epoch 1: 2.0579047203063965
loss at epoch 1: 1.681607723236084
loss at epoch 1:

Train Ep 2:   0%|          | 0/65 [00:00<?, ?it/s]

loss at epoch 2: 1.1811189651489258
loss at epoch 2: 0.9946280717849731
loss at epoch 2: 1.0160892009735107
loss at epoch 2: 1.3274177312850952
loss at epoch 2: 1.2187143564224243
loss at epoch 2: 1.2638212442398071
loss at epoch 2: 1.2163491249084473
loss at epoch 2: 0.887388288974762
loss at epoch 2: 1.1568522453308105
loss at epoch 2: 1.0500355958938599
loss at epoch 2: 1.3594316244125366
loss at epoch 2: 1.0485914945602417
loss at epoch 2: 1.123901128768921
loss at epoch 2: 1.2122315168380737
loss at epoch 2: 0.9155861735343933
loss at epoch 2: 1.2840981483459473
loss at epoch 2: 1.3373384475708008
loss at epoch 2: 1.3405290842056274
loss at epoch 2: 1.1055160760879517
loss at epoch 2: 1.0900847911834717
loss at epoch 2: 0.9669348001480103
loss at epoch 2: 1.6096649169921875
loss at epoch 2: 1.270062804222107
loss at epoch 2: 1.3169846534729004
loss at epoch 2: 1.2320020198822021
loss at epoch 2: 1.099848747253418
loss at epoch 2: 1.2948907613754272
loss at epoch 2: 1.0319170951843

Train Ep 3:   0%|          | 0/65 [00:00<?, ?it/s]

loss at epoch 3: 0.9167695045471191
loss at epoch 3: 0.806212842464447
loss at epoch 3: 0.8240734338760376
loss at epoch 3: 1.0481407642364502
loss at epoch 3: 1.0264559984207153
loss at epoch 3: 0.973621666431427
loss at epoch 3: 1.0026417970657349
loss at epoch 3: 0.7062773704528809
loss at epoch 3: 0.9526234865188599
loss at epoch 3: 0.8777744174003601
loss at epoch 3: 1.177148461341858
loss at epoch 3: 0.8559752702713013
loss at epoch 3: 0.9506554007530212
loss at epoch 3: 0.9975485801696777
loss at epoch 3: 0.7309376001358032
loss at epoch 3: 1.0623164176940918
loss at epoch 3: 1.082418441772461
loss at epoch 3: 1.1511629819869995
loss at epoch 3: 0.9192901253700256
loss at epoch 3: 0.9424111843109131
loss at epoch 3: 0.8138907551765442
loss at epoch 3: 1.3693103790283203
loss at epoch 3: 1.075346827507019
loss at epoch 3: 1.1303597688674927
loss at epoch 3: 1.0957472324371338
loss at epoch 3: 0.9375500082969666
loss at epoch 3: 1.1421232223510742
loss at epoch 3: 0.90096759796142

Train Ep 4:   0%|          | 0/65 [00:00<?, ?it/s]

loss at epoch 4: 0.8602725863456726
loss at epoch 4: 0.749823808670044
loss at epoch 4: 0.7522008419036865
loss at epoch 4: 0.9753217101097107
loss at epoch 4: 0.943522036075592
loss at epoch 4: 0.8815158605575562
loss at epoch 4: 0.9291713237762451
loss at epoch 4: 0.6572243571281433
loss at epoch 4: 0.889306902885437
loss at epoch 4: 0.7967936992645264
loss at epoch 4: 1.0900992155075073
loss at epoch 4: 0.8083282709121704
loss at epoch 4: 0.8629299402236938
loss at epoch 4: 0.9144310355186462
loss at epoch 4: 0.687751054763794
loss at epoch 4: 0.9913691282272339
loss at epoch 4: 1.0116337537765503
loss at epoch 4: 1.081495761871338
loss at epoch 4: 0.8804141283035278
loss at epoch 4: 0.8564864993095398
loss at epoch 4: 0.7062185406684875
loss at epoch 4: 1.3309599161148071
loss at epoch 4: 1.0019235610961914
loss at epoch 4: 1.063246250152588
loss at epoch 4: 1.0229448080062866
loss at epoch 4: 0.8946078419685364
loss at epoch 4: 1.0358504056930542
loss at epoch 4: 0.841669738292694

Train Ep 5:   0%|          | 0/65 [00:00<?, ?it/s]

loss at epoch 5: 0.802742063999176
loss at epoch 5: 0.7145888805389404
loss at epoch 5: 0.7059168815612793
loss at epoch 5: 0.9125088453292847
loss at epoch 5: 0.9176064729690552
loss at epoch 5: 0.8513725996017456
loss at epoch 5: 0.8916241526603699
loss at epoch 5: 0.6324765086174011
loss at epoch 5: 0.8436988592147827
loss at epoch 5: 0.7784612774848938
loss at epoch 5: 1.0209935903549194
loss at epoch 5: 0.7619491219520569
loss at epoch 5: 0.8413702249526978
loss at epoch 5: 0.8747541904449463
loss at epoch 5: 0.6238012313842773
loss at epoch 5: 0.9572700262069702
loss at epoch 5: 0.9876272678375244
loss at epoch 5: 1.0133616924285889
loss at epoch 5: 0.8040834069252014
loss at epoch 5: 0.8521555662155151
loss at epoch 5: 0.6714790463447571
loss at epoch 5: 1.2681745290756226
loss at epoch 5: 0.9690830111503601
loss at epoch 5: 0.9688498377799988
loss at epoch 5: 0.9614940285682678
loss at epoch 5: 0.838154673576355
loss at epoch 5: 1.0129765272140503
loss at epoch 5: 0.80402809381

Train Ep 6:   0%|          | 0/65 [00:00<?, ?it/s]

loss at epoch 6: 0.761059582233429
loss at epoch 6: 0.6949347853660583
loss at epoch 6: 0.6819258332252502
loss at epoch 6: 0.8937819004058838
loss at epoch 6: 0.8796984553337097
loss at epoch 6: 0.8111498951911926
loss at epoch 6: 0.8559000492095947
loss at epoch 6: 0.6157733798027039
loss at epoch 6: 0.8299384713172913
loss at epoch 6: 0.7354724407196045
loss at epoch 6: 0.9972361326217651
loss at epoch 6: 0.7368214726448059
loss at epoch 6: 0.8115547895431519
loss at epoch 6: 0.8254120349884033
loss at epoch 6: 0.6192417144775391
loss at epoch 6: 0.9296342730522156
loss at epoch 6: 0.9572039246559143
loss at epoch 6: 0.9985159635543823
loss at epoch 6: 0.7914084792137146
loss at epoch 6: 0.8160430788993835
loss at epoch 6: 0.6202391982078552
loss at epoch 6: 1.1625112295150757
loss at epoch 6: 0.9413320422172546
loss at epoch 6: 0.9677362442016602
loss at epoch 6: 0.9194352030754089
loss at epoch 6: 0.7939737439155579
loss at epoch 6: 0.9614111185073853
loss at epoch 6: 0.7763103842

Train Ep 7:   0%|          | 0/65 [00:00<?, ?it/s]

loss at epoch 7: 0.7436460256576538
loss at epoch 7: 0.6755452156066895
loss at epoch 7: 0.6676321625709534
loss at epoch 7: 0.8546637892723083
loss at epoch 7: 0.8555888533592224
loss at epoch 7: 0.7876704931259155
loss at epoch 7: 0.83155357837677
loss at epoch 7: 0.5959253907203674
loss at epoch 7: 0.7860724329948425
loss at epoch 7: 0.7451503872871399
loss at epoch 7: 0.9559929370880127
loss at epoch 7: 0.7128891348838806
loss at epoch 7: 0.7903802990913391
loss at epoch 7: 0.8119418025016785
loss at epoch 7: 0.577892541885376
loss at epoch 7: 0.8695035576820374
loss at epoch 7: 0.9271377921104431
loss at epoch 7: 0.973103404045105
loss at epoch 7: 0.7565430998802185
loss at epoch 7: 0.8007345795631409
loss at epoch 7: 0.6169314384460449
loss at epoch 7: 1.1796824932098389
loss at epoch 7: 0.9060149788856506
loss at epoch 7: 0.9451143741607666
loss at epoch 7: 0.8913444876670837
loss at epoch 7: 0.7756515741348267
loss at epoch 7: 0.9593495726585388
loss at epoch 7: 0.7499563097953

Train Ep 8:   0%|          | 0/65 [00:00<?, ?it/s]

loss at epoch 8: 0.7336783409118652
loss at epoch 8: 0.662996232509613
loss at epoch 8: 0.6432072520256042
loss at epoch 8: 0.839701235294342
loss at epoch 8: 0.8355404734611511
loss at epoch 8: 0.7619725465774536
loss at epoch 8: 0.8161048293113708
loss at epoch 8: 0.5889711380004883
loss at epoch 8: 0.7636610865592957
loss at epoch 8: 0.7205092906951904
loss at epoch 8: 0.9526306986808777
loss at epoch 8: 0.68963623046875
loss at epoch 8: 0.7595347762107849
loss at epoch 8: 0.7771540284156799
loss at epoch 8: 0.584252119064331
loss at epoch 8: 0.8465445041656494
loss at epoch 8: 0.8928321003913879
loss at epoch 8: 0.9335652589797974
loss at epoch 8: 0.7350932955741882
loss at epoch 8: 0.7676220536231995
loss at epoch 8: 0.5728126764297485
loss at epoch 8: 1.1413342952728271
loss at epoch 8: 0.8691628575325012
loss at epoch 8: 0.9313428401947021
loss at epoch 8: 0.841550350189209
loss at epoch 8: 0.7437366247177124
loss at epoch 8: 0.9212163090705872
loss at epoch 8: 0.733605086803436

Train Ep 9:   0%|          | 0/65 [00:00<?, ?it/s]

loss at epoch 9: 0.7054017782211304
loss at epoch 9: 0.6508780717849731
loss at epoch 9: 0.6160632967948914
loss at epoch 9: 0.811367928981781
loss at epoch 9: 0.8215184211730957
loss at epoch 9: 0.7300683259963989
loss at epoch 9: 0.7657039165496826
loss at epoch 9: 0.5751426815986633
loss at epoch 9: 0.7398453950881958
loss at epoch 9: 0.6963361501693726
loss at epoch 9: 0.9514893889427185
loss at epoch 9: 0.6708401441574097
loss at epoch 9: 0.7256298661231995
loss at epoch 9: 0.7666141986846924
loss at epoch 9: 0.5679004192352295
loss at epoch 9: 0.841062605381012
loss at epoch 9: 0.865411102771759
loss at epoch 9: 0.8882833123207092
loss at epoch 9: 0.6944157481193542
loss at epoch 9: 0.7517615556716919
loss at epoch 9: 0.556676983833313
loss at epoch 9: 1.0792620182037354
loss at epoch 9: 0.8405061364173889
loss at epoch 9: 0.8876953721046448
loss at epoch 9: 0.8203960657119751
loss at epoch 9: 0.7030701041221619
loss at epoch 9: 0.894467830657959
loss at epoch 9: 0.71985369920730

Train Ep 10:   0%|          | 0/65 [00:00<?, ?it/s]

loss at epoch 10: 0.6874386072158813
loss at epoch 10: 0.6322190761566162
loss at epoch 10: 0.6074788570404053
loss at epoch 10: 0.7930088043212891
loss at epoch 10: 0.8006817698478699
loss at epoch 10: 0.7167069315910339
loss at epoch 10: 0.762771487236023
loss at epoch 10: 0.5514660477638245
loss at epoch 10: 0.733320415019989
loss at epoch 10: 0.6972560286521912
loss at epoch 10: 0.8938800692558289
loss at epoch 10: 0.6533885598182678
loss at epoch 10: 0.6953824758529663
loss at epoch 10: 0.735884428024292
loss at epoch 10: 0.54905104637146
loss at epoch 10: 0.7815907001495361
loss at epoch 10: 0.85650235414505
loss at epoch 10: 0.8765488266944885
loss at epoch 10: 0.6889231204986572
loss at epoch 10: 0.7269969582557678
loss at epoch 10: 0.5359954237937927
loss at epoch 10: 1.0764018297195435
loss at epoch 10: 0.8114355206489563
loss at epoch 10: 0.8387885689735413
loss at epoch 10: 0.7813063859939575
loss at epoch 10: 0.7053109407424927
loss at epoch 10: 0.8679310083389282
loss at 

In [None]:
preds = evaluate(model, validation_dataloader)

for pred in preds:
  print(tokenizer.decode(pred, skip_special_tokens=True))

Train Ep 4:   0%|          | 0/65 [00:00<?, ?it/s]

What is the population of the Albany in Georgia?SEP>How many people are in New York?SEP>Is #1 greater than or equal to #2?
What language is used in Saint Vincent and the Grenadines?SEP>What language is English spoken in?SEP>Is #1 the same as #2?
What is greed?SEP>What are the Seven Deadly Sins?SEP>Is #1 the most prevalent?
What is Mount Fuji's topography?SEP>What is the topography of the Sea of Japan?SEP>Is #2 greater than or equal to #1?
What is Lil Jon's top ranked song on Billboard?SEP>What is The Lox's biggest hit?SEP>Is #1 the same as #2?
What is the population of Miami?SEP>What is the population of the American West Coast?SEP>Is #2 greater than or equal to #1?
What is the number of members of the Virginia General Assembly?SEP>What is the number of members of the Swiss Guard?SEP>Is #1 greater than or equal to #2?
What was the role of Switzerland in WWII?SEP>What was the role of Portugal in the Portuguese Colonial War?SEP>Is #2 the same as #2?
What language is Old English?SEP>What 

In [22]:
!git add -A

NotImplementedError: A UTF-8 locale is required. Got ANSI_X3.4-1968