In [29]:
import torch
import time
import numpy as np
import pathlib
import tqdm
import datetime
from language_detection.data import DataLoader, load_wili_2018_dataset, batch_collate_function, get_mask_from_lengths
from language_detection.model import TrainingConfig, TransformerClassifier, create_datasets, evaluate_model

In [37]:
checkpoint_filepath = "./experiments/wili2018/wili2018-checkpoint-000020.pt"

In [38]:
if not pathlib.Path(checkpoint_filepath).is_file():
    raise ValueError(f"checkpoint file '{checkpoint_filepath}' does not exist!")
checkpoint = torch.load(checkpoint_filepath)
config = TrainingConfig(**checkpoint["config"])

In [24]:
print(f"\nloading model from checkpoint '{checkpoint}'")
model = TransformerClassifier(num_classes=checkpoint["num_classes"])
model.load_state_dict(checkpoint["model_state_dict"])



loading model with num_classes 235


<All keys matched successfully>

In [25]:
if torch.cuda.is_available():
    print(f"CUDA detected, using gpu")
    device_string = "cuda"
else:
    print(f"warning! no CUDA detected, using cpu")
    device_string = "cpu"
_ = model.to(device_string)

CUDA detected, using gpu


In [39]:
raw_data = load_wili_2018_dataset(config.data_path)
if config.debug:
    print(f"debug mode is true, so truncate test to few elements only")
    raw_data.x_test = raw_data.x_test[:1024]
    raw_data.y_test = raw_data.y_test[:1024]

_, _, test_dataset = create_datasets(
    raw_data, max_seq_len=config.max_length, dev_pct=config.dev_pct
)
test_dataloader = DataLoader(
    test_dataset, batch_size=config.batch_size, shuffle=False, num_workers=0, collate_fn=batch_collate_function
)
num_classes = len(raw_data.idx2lang)
if checkpoint["num_classes"] != num_classes:
    raise ValueError(f"model's {checkpoint['num_classes']} output classes != data's {num_classes} classes, is this the correct dataset?")

[32m2023-11-28 15:40:22.924[0m | [1mINFO    [0m | [36mlanguage_detection.data.loaders[0m:[36mload_wili_2018_dataset[0m:[36m67[0m - [1m'drop_duplicates' is true, dropping duplicates from *training* set...[0m
[32m2023-11-28 15:40:22.985[0m | [1mINFO    [0m | [36mlanguage_detection.data.loaders[0m:[36mload_wili_2018_dataset[0m:[36m74[0m - [1mdropped 3117 samples from training data that also appeared in the test data[0m


In [30]:
# eval on test
clf_criterion = torch.nn.CrossEntropyLoss(reduction="sum")
print(f"[{datetime.datetime.now().isoformat()}] starting evaluation on test set")
model.eval()
epoch_test_loss = []
test_epoch_targets = []
test_epoch_predictions = []
with torch.no_grad():
    test_iterator = iter(test_dataloader)
    for batch_idx, minibatch in enumerate(pbar := tqdm.tqdm(test_iterator, total=len(test_iterator))):
        # format data and move to gpu
        x, y, seq_lens, mask_indices, targets = minibatch
        x = x.to(device_string)
        y = y.to(device_string)
        targets = targets.to(device_string)
        pad_mask = get_mask_from_lengths(seq_lens, config.max_length, x.device)
        clf_logits, mlm_logits = model.forward(x, pad_mask)
        clf_loss = clf_criterion(clf_logits, targets)
        epoch_test_loss.append(clf_loss.item())
        test_epoch_targets += targets.detach().cpu().numpy().tolist()
        test_epoch_predictions += clf_logits.max(1).indices.detach().cpu().numpy().tolist()
    time.sleep(0.1)
    print(f"[{datetime.datetime.now().isoformat()}] test clf loss : {np.mean(epoch_test_loss):.5f}")
    test_results = evaluate_model(set_name="test", targets=test_epoch_targets, predictions=test_epoch_predictions)
    time.sleep(0.1)

[2023-11-28T15:21:55.600032] starting evaluation on test set


100%|██████████| 3672/3672 [09:47<00:00,  6.25it/s]


[2023-11-28T15:31:42.915205] test clf loss : 10.93545
[2023-11-28T15:31:43.672673] test micro prc: 0.90686,	macro 0.91144
[2023-11-28T15:31:43.672738] test micro rcl: 0.90686,	macro 0.90686
[2023-11-28T15:31:43.672745] test micro f1b: 0.90686,	macro 0.90768


In [41]:
mapping = checkpoint["output_mapping"]
fullnames = checkpoint["extended_labels"]

In [57]:
true_codes = [mapping[p] for p in test_epoch_targets]
pred_codes = [mapping[p] for p in test_epoch_predictions]
true_names = [fullnames[p] for p in true_codes]
pred_names = [fullnames[p] for p in pred_codes]
true_texts = raw_data.x_test
if true_codes != raw_data.y_test:
    raise ValueError(f"raw data test labels do not match true codes, is your checkpoint mapping correct?")

In [58]:
df_data = {
    "true_label": true_codes,
    "true_name": true_names,
    "pred_label": pred_codes,
    "pred_name": pred_names,
    "text": true_texts
}

In [60]:
import pandas as pd

In [65]:
df = pd.DataFrame.from_dict(df_data, orient='index').transpose()
ord = ["true_label", "true_name", "pred_label", "pred_name", "text"]
df = df[ord]

In [66]:
df

Unnamed: 0,true_label,true_name,pred_label,pred_name,text
0,mwl,Mirandese,mwl,Mirandese,Ne l fin de l seclo XIX l Japon era inda çconh...
1,nld,Dutch,nld,Dutch,Schiedam is gelegen tussen Rotterdam en Vlaard...
2,ava,Avar,ava,Avar,"ГIурусаз батальонал, гьоркьор гIарадабиги лъун..."
3,tcy,Tulu,kan,Kannada,ರಾಜ್ಯಶಾಸ್ತ್ರದ ಪಿತಾಮಹೆ ಅರಿಸ್ಟಾಟಲ್. ರಾಜ್ಯಶಾಸ್ತ್ರ...
4,bjn,Banjar,bjn,Banjar,Halukum adalah kelenjar tiroid nang menonjol d...
...,...,...,...,...,...
117495,swa,Swahili (macrolanguage),swa,Swahili (macrolanguage),"Wakati wa mimba,homa ya Q ni vigumu kutibu kwa..."
117496,glk,Gilaki,fas,Persian,گیلون یک ته تاریخی منطقه‌ سفیدرود دلتای طرف ای...
117497,khm,Central Khmer,khm,Central Khmer,តាម​រយៈ​ការ​ចិញ្ចឹម​មនោសញ្ចេតនា​ជាតិនិយម​បែប​ន...
117498,pnb,Western Panjabi,pnb,Western Panjabi,روس اک وفاق اے تے 1 مارچ 2008ء توں اسدیاں 83 و...


In [63]:
df.to_csv("test.tsv", sep="\t")

In [34]:
# # hotfix: edit existing checkpoints to include num_classes and label mapping
# import glob, os
# checkpoint_paths = glob.glob(os.path.join("./experiments/wili2018", "*.pt"))
# for fpath in checkpoint_paths:
#     checkpoint = torch.load(checkpoint_filepath)
#     checkpoint["num_classes"] = num_classes
#     torch.save(checkpoint, fpath)