In [1]:
import warnings

import torch
from tqdm import tqdm

from config import get_config, get_weights_file_path
from dataset import get_ds
from model import get_tokenizer, get_model

warnings.filterwarnings('ignore')
device= 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
def test(config, model):
    test_dataloader = get_ds(config, tokenizer, 'test')
    epoch_loss = 0
    batch_iterator = tqdm(test_dataloader, desc='Test epoch')
    model.eval()
    
    for batch in batch_iterator:
        input_tokens = batch['input_tokens'].to(device)
        mask = batch['mask'].to(device)
    
        with torch.no_grad():
            output = model(input_tokens, attention_mask=mask, labels=input_tokens)
            loss = output.loss
        epoch_loss+= loss.item()
        batch_iterator.set_postfix({'Test_loss': f'{loss.item():6.3f}'})
    
    mean_loss = epoch_loss / len(batch_iterator)
    perplexity = torch.exp(torch.tensor(mean_loss))
    
    batch_iterator.write(f'TEST LOSS: {mean_loss}')
    batch_iterator.write(f'PERPLEXITY: {perplexity}')
    return perplexity
    

In [3]:
models = {'baseline': None,
            'full': 'gpt_no_lora_07.pt',
            'lora_1': 'gpt_lora_no_proj_07.pt',
            'lora_2': 'gpt_lora_no_ff_07.pt',
            'lora_3': 'gpt_lora_no_output_06.pt',
            'lora_4': 'gpt_lora_04.pt',
         }

In [4]:
ppl_wikitext = []
ppl_lambada = []
ppl_imdb = []

for m in list(models.keys()):  
    print(f'Making predictions for {m} model')
    
    if m in ['baseline', 'full']:
        config = get_config()
    else:
        config = get_config(True)
        if m == 'lora_1':
            config["target_modules"] = ['attention']
        elif m == 'lora_2':
            config["target_modules"] = ['attention', 'projection']
        elif m == 'lora_3':
            config["target_modules"] = ['attention', 'projection', 'feed_forward']
        elif m == 'lora_4':
            config["target_modules"] = ['attention', 'projection', 'feed_forward', 'output']
    
    model = get_model(config).to(device)

    if m != 'baseline':
        model_filename = f'{config["model_folder"]}/{models[m]}'
        state = torch.load(model_filename, map_location=device)
        model.load_state_dict(state['model_state_dict'])

    tokenizer = get_tokenizer(config)


    ppl_wikitext.append(test(config, model))

    config["dataset"] = "EleutherAI/lambada_openai"
    config['subset'] = "en"
    
    ppl_lambada.append(test(config, model))

    config["dataset"] = "stanfordnlp/imdb"

    ppl_imdb.append(test(config, model))
    print('=========================================================')
    

Making predictions for baseline model


Test epoch: 100%|████████████████████████████████████████████████████| 207/207 [00:52<00:00,  3.97it/s, Test_loss=6.797]


TEST LOSS: 5.327636755606979
PERPLEXITY: 205.95068359375


Test epoch: 100%|████████████████████████████████████████████████████| 369/369 [01:32<00:00,  3.99it/s, Test_loss=7.213]


TEST LOSS: 7.026108217110155
PERPLEXITY: 1125.641357421875


Test epoch:   0%|▏                                                    | 8/1786 [00:02<07:43,  3.84it/s, Test_loss=5.622]Token indices sequence length is longer than the specified maximum sequence length for this model (1300 > 1024). Running this sequence through the model will result in indexing errors
Test epoch: 100%|██████████████████████████████████████████████████| 1786/1786 [07:54<00:00,  3.76it/s, Test_loss=5.931]


TEST LOSS: 5.714285751740049
PERPLEXITY: 303.1676330566406
Making predictions for full model


Test epoch: 100%|████████████████████████████████████████████████████| 207/207 [00:52<00:00,  3.94it/s, Test_loss=0.859]


TEST LOSS: 0.9056965089830511
PERPLEXITY: 2.473654270172119


Test epoch: 100%|████████████████████████████████████████████████████| 369/369 [01:32<00:00,  3.99it/s, Test_loss=1.247]


TEST LOSS: 1.1742488814563286
PERPLEXITY: 3.2357118129730225


Test epoch:   0%|▏                                                    | 8/1786 [00:02<07:29,  3.95it/s, Test_loss=2.753]Token indices sequence length is longer than the specified maximum sequence length for this model (1300 > 1024). Running this sequence through the model will result in indexing errors
Test epoch: 100%|██████████████████████████████████████████████████| 1786/1786 [07:29<00:00,  3.97it/s, Test_loss=3.315]


TEST LOSS: 2.9820067673621224
PERPLEXITY: 19.727365493774414
Making predictions for lora_1 model


Test epoch: 100%|████████████████████████████████████████████████████| 207/207 [00:53<00:00,  3.86it/s, Test_loss=0.864]


TEST LOSS: 0.9249808668442394
PERPLEXITY: 2.521820068359375


Test epoch: 100%|████████████████████████████████████████████████████| 369/369 [01:34<00:00,  3.90it/s, Test_loss=1.612]


TEST LOSS: 1.731662882052786
PERPLEXITY: 5.650041580200195


Test epoch:   0%|▏                                                    | 8/1786 [00:02<07:50,  3.78it/s, Test_loss=4.641]Token indices sequence length is longer than the specified maximum sequence length for this model (1300 > 1024). Running this sequence through the model will result in indexing errors
Test epoch: 100%|██████████████████████████████████████████████████| 1786/1786 [07:54<00:00,  3.77it/s, Test_loss=4.407]


TEST LOSS: 4.558662382111971
PERPLEXITY: 95.4557113647461
Making predictions for lora_2 model


Test epoch: 100%|████████████████████████████████████████████████████| 207/207 [00:56<00:00,  3.66it/s, Test_loss=0.853]


TEST LOSS: 0.9058684850300568
PERPLEXITY: 2.4740796089172363


Test epoch: 100%|████████████████████████████████████████████████████| 369/369 [01:41<00:00,  3.62it/s, Test_loss=1.299]


TEST LOSS: 1.1962203530438225
PERPLEXITY: 3.3075919151306152


Test epoch:   0%|▏                                                    | 8/1786 [00:02<08:37,  3.43it/s, Test_loss=2.706]Token indices sequence length is longer than the specified maximum sequence length for this model (1300 > 1024). Running this sequence through the model will result in indexing errors
Test epoch: 100%|██████████████████████████████████████████████████| 1786/1786 [08:15<00:00,  3.60it/s, Test_loss=3.222]


TEST LOSS: 2.9284559842720546
PERPLEXITY: 18.69873809814453
Making predictions for lora_3 model


Test epoch: 100%|████████████████████████████████████████████████████| 207/207 [00:59<00:00,  3.47it/s, Test_loss=0.859]


TEST LOSS: 0.9048694104568106
PERPLEXITY: 2.471609115600586


Test epoch: 100%|████████████████████████████████████████████████████| 369/369 [01:43<00:00,  3.57it/s, Test_loss=1.240]


TEST LOSS: 1.1720497753561996
PERPLEXITY: 3.2286036014556885


Test epoch:   0%|▏                                                    | 8/1786 [00:02<08:44,  3.39it/s, Test_loss=2.718]Token indices sequence length is longer than the specified maximum sequence length for this model (1300 > 1024). Running this sequence through the model will result in indexing errors
Test epoch: 100%|██████████████████████████████████████████████████| 1786/1786 [08:29<00:00,  3.50it/s, Test_loss=3.244]


TEST LOSS: 2.9440948541994727
PERPLEXITY: 18.99346351623535
Making predictions for lora_4 model


Test epoch: 100%|████████████████████████████████████████████████████| 207/207 [01:02<00:00,  3.31it/s, Test_loss=0.860]


TEST LOSS: 0.9122637098178196
PERPLEXITY: 2.489952564239502


Test epoch: 100%|████████████████████████████████████████████████████| 369/369 [01:50<00:00,  3.33it/s, Test_loss=1.291]


TEST LOSS: 1.1957918024321559
PERPLEXITY: 3.3061747550964355


Test epoch:   0%|▏                                                    | 8/1786 [00:02<09:08,  3.24it/s, Test_loss=2.782]Token indices sequence length is longer than the specified maximum sequence length for this model (1300 > 1024). Running this sequence through the model will result in indexing errors
Test epoch: 100%|██████████████████████████████████████████████████| 1786/1786 [09:08<00:00,  3.26it/s, Test_loss=3.290]

TEST LOSS: 2.982417516251828
PERPLEXITY: 19.735471725463867





In [46]:
wikitext_results = {}
for idx, name in enumerate(list(models.keys())):
    wikitext_results[name] = round(ppl_wikitext[idx].item(), ndigits=3)

lambada_results = {}
for idx, name in enumerate(list(models.keys())):
    lambada_results[name] = round(ppl_lambada[idx].item(), ndigits=3)

imdb_results = {}
for idx, name in enumerate(list(models.keys())):
    imdb_results[name] = round(ppl_imdb[idx].item(), ndigits=3)

In [47]:
wikitext_results, min(list(wikitext_results.values()))

({'baseline': 205.951,
  'full': 2.474,
  'lora_1': 2.522,
  'lora_2': 2.474,
  'lora_3': 2.472,
  'lora_4': 2.49},
 2.472)

In [48]:
lambada_results, min(list(lambada_results.values()))

({'baseline': 1125.641,
  'full': 3.236,
  'lora_1': 5.65,
  'lora_2': 3.308,
  'lora_3': 3.229,
  'lora_4': 3.306},
 3.229)

In [49]:
imdb_results, min(list(imdb_results.values()))

({'baseline': 303.168,
  'full': 19.727,
  'lora_1': 95.456,
  'lora_2': 18.699,
  'lora_3': 18.993,
  'lora_4': 19.735},
 18.699)