In [10]:
import torch
from domains import load_domains

In [11]:
domains, eval_domains = load_domains()

## Model

In [12]:
from lwp import LWP

## Training and Predictions

In [13]:
from sklearn.metrics import accuracy_score
import pandas as pd

model = LWP(distance_metric='euclidean')

df = pd.DataFrame()

for idx,domain in enumerate(domains):
    
    x_test = domain['features']
    y_pred = model.predict(x_test) if domain['labels'] is None else domain['labels']
    
    model.fit(x_test, y_pred)
    print(model.class_counts)
    
    scores = []
    for eval_domain in eval_domains[:idx+1]:
        
        features = eval_domain['features']
        labels = eval_domain['labels']
        
        preds = model.predict(features)
        acc = accuracy_score(labels, preds)
        
        scores.append(acc)
    
    df[f'Domain {idx+1}'] = scores + [np.nan] * (len(eval_domains) - len(scores))

{0: 253, 1: 243, 2: 255, 3: 244, 4: 262, 5: 236, 6: 250, 7: 253, 8: 254, 9: 250}
{0: 495, 1: 499, 2: 483, 3: 511, 4: 514, 5: 497, 6: 498, 7: 507, 8: 496, 9: 500}
{0: 768, 1: 740, 2: 696, 3: 762, 4: 786, 5: 793, 6: 748, 7: 724, 8: 748, 9: 735}
{0: 1024, 1: 998, 2: 903, 3: 1028, 4: 1055, 5: 1061, 6: 990, 7: 944, 8: 988, 9: 1009}
{0: 1277, 1: 1266, 2: 1116, 3: 1276, 4: 1322, 5: 1351, 6: 1221, 7: 1170, 8: 1226, 9: 1275}
{0: 1526, 1: 1486, 2: 1346, 3: 1527, 4: 1609, 5: 1660, 6: 1456, 7: 1398, 8: 1456, 9: 1536}
{0: 1781, 1: 1722, 2: 1579, 3: 1772, 4: 1889, 5: 1945, 6: 1685, 7: 1610, 8: 1730, 9: 1787}
{0: 2039, 1: 1968, 2: 1828, 3: 2026, 4: 2154, 5: 2245, 6: 1910, 7: 1818, 8: 1968, 9: 2044}
{0: 2287, 1: 2229, 2: 2036, 3: 2284, 4: 2446, 5: 2504, 6: 2137, 7: 2041, 8: 2223, 9: 2313}
{0: 2535, 1: 2464, 2: 2246, 3: 2545, 4: 2721, 5: 2832, 6: 2364, 7: 2254, 8: 2468, 9: 2571}
{0: 2812, 1: 2704, 2: 2466, 3: 2886, 4: 2991, 5: 3156, 6: 2595, 7: 2402, 8: 2697, 9: 2791}
{0: 3327, 1: 2781, 2: 2564, 3: 314

## Evaluation

In [14]:
df

Unnamed: 0,Domain 1,Domain 2,Domain 3,Domain 4,Domain 5,Domain 6,Domain 7,Domain 8,Domain 9,Domain 10,Domain 11,Domain 12,Domain 13,Domain 14,Domain 15,Domain 16,Domain 17,Domain 18,Domain 19,Domain 20
0,0.902,0.8936,0.8908,0.8904,0.89,0.89,0.8892,0.8884,0.8872,0.8868,0.8852,0.8844,0.8832,0.8832,0.8844,0.882,0.8824,0.8828,0.8816,0.882
1,,0.904,0.8996,0.8988,0.8976,0.8968,0.8952,0.8956,0.8956,0.8952,0.8948,0.892,0.8924,0.8912,0.8904,0.8896,0.89,0.8884,0.8872,0.8876
2,,,0.9096,0.9076,0.9068,0.9072,0.9064,0.9056,0.9056,0.9048,0.9024,0.9012,0.9004,0.8992,0.8988,0.896,0.8944,0.8936,0.892,0.8916
3,,,,0.9208,0.9204,0.9188,0.918,0.9176,0.9172,0.9168,0.916,0.9152,0.9144,0.9148,0.9152,0.9128,0.9128,0.9124,0.9096,0.9088
4,,,,,0.9064,0.9056,0.9052,0.9044,0.904,0.9036,0.9036,0.9008,0.8992,0.8992,0.8988,0.8976,0.8968,0.896,0.8948,0.8952
5,,,,,,0.9128,0.9132,0.914,0.9148,0.914,0.9116,0.9124,0.9096,0.91,0.9104,0.9104,0.9088,0.9068,0.9068,0.9044
6,,,,,,,0.9064,0.9052,0.9044,0.9048,0.9048,0.9036,0.9024,0.9024,0.9016,0.9012,0.9008,0.9008,0.8992,0.8992
7,,,,,,,,0.8984,0.8988,0.8992,0.8976,0.896,0.8948,0.8948,0.894,0.8944,0.894,0.8932,0.8928,0.892
8,,,,,,,,,0.9072,0.9064,0.904,0.9024,0.9004,0.9,0.8992,0.8972,0.8968,0.8976,0.8956,0.8964
9,,,,,,,,,,0.9088,0.908,0.9064,0.9048,0.9056,0.9052,0.9044,0.9036,0.9036,0.9016,0.9016
