In [1]:
import numpy as np
from Tests_ODaDiL import test_dadil, test_odadil, test_forgetting_odadil

from tensorflow.python.ops.numpy_ops import np_config
np_config.enable_numpy_behavior()
  register_backend(TensorflowBackend())
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch  # noqa: E402
import random  # noqa: E402
import numpy as np  # noqa: E402
import matplotlib.pyplot as plt
import ot

from pydil.utils.igmm_modif import IGMM

from pydil.ipms.ot_ipms import (  # noqa: E402
    JointWassersteinDistance
)
from pydil.dadil.labeled_dictionary_GMM import LabeledDictionaryGMM
from pydil.torch_utils.measures import (  # noqa: E402
    UnsupervisedDatasetMeasure,
    SupervisedDatasetMeasure
)

from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score

import os
import pickle

In [2]:
list_of_datasets = []
for i in range(1, 11):
    dataset = np.load(f'data/toy_non_linear_100d_dataset_{i}.npy')
    list_of_datasets.append(dataset)

In [3]:
target = 'C'
with open(os.path.join('data', 'mlp_fts_256_target_{}.pkl'.format(target)), 'rb') as f:
        dataset = pickle.loads(f.read())

Xs, ys = [], []
d = None
keys = list(dataset.keys())
for i in range(len(keys)-1):
    features = dataset[keys[i]]['Features']
    labels = dataset[keys[i]]['Labels'].argmax(dim=1)
    domain = i*np.ones((features.shape[0], 1))
    Xs.append(features.float())
    ys.append(labels.float())
    if d is None:
        d = domain
    else:
        d = np.concatenate([d, domain], axis=0)

Xt = dataset[target]['fold 0']['Train']['Features'].float()
yt = dataset[target]['fold 0']['Train']['Labels'].float().argmax(dim=1)

Xt_test = dataset[target]['fold 0']['Test']['Features'].float()
yt_test = dataset[target]['fold 0']['Test']['Labels'].float().argmax(dim=1)
d = np.concatenate([d, 2*np.ones((Xt.shape[0], 1))], axis=0)

n_domains = int(np.max(d)) + 1
n_features = Xt.shape[1]
n_classes = int(np.max(yt.numpy())) + 1

In [4]:
n_samples = 1000
batch_size = 200
n_atoms = 3
n_classes = 10
n_iter = 100

In [5]:
def test_dadil(Xs, ys, Xt, yt, Xt_test, yt_test, n_features, n_samples, n_classes, n_atoms, batch_size, n_iter):
    results = {'lin':{'wda': 0, 'e':0, 'e_ot':0, 'r':0, 'r_ot':0}, 'rbf':{'wda': 0, 'e':0, 'e_ot':0, 'r':0, 'r_ot':0}, 'RF':{'wda': 0, 'e':0, 'e_ot':0, 'r':0, 'r_ot':0}}

    Q = []
    for Xs_k, ys_k in zip(Xs, ys):
        Q.append(
            SupervisedDatasetMeasure(
                features=Xs_k.numpy(),
                labels=ys_k.numpy(),
                stratify=True,
                batch_size=batch_size,
                device='cpu'
            )
        )
    Q.append(
        UnsupervisedDatasetMeasure(
            features=Xt.numpy(),
            batch_size=batch_size,
            device='cpu'
        )
    )
    criterion = JointWassersteinDistance()
    dictionary = LabeledDictionaryGMM(XP=None,
                            YP=None,
                            A=None,
                            n_samples=n_samples,
                            n_dim=n_features,
                            n_classes=n_classes,
                            n_components=n_atoms,
                            weight_initialization='uniform',
                            n_distributions=len(Q),
                            loss_fn=criterion,
                            learning_rate_features=1e-1,
                            learning_rate_labels=1e-1,
                            learning_rate_weights=1e-1,
                            reg_e=0.0,
                            n_iter_barycenter=10,
                            n_iter_sinkhorn=20,
                            n_iter_emd=1000000,
                            domain_names=None,
                            grad_labels=True,
                            optimizer_name='Adam',
                            balanced_sampling=True,
                            sampling_with_replacement=True,
                            barycenter_tol=1e-9,
                            barycenter_beta=None,
                            tensor_dtype=torch.float32,
                            track_atoms=False,
                            schedule_lr=False)
    dictionary.fit(Q,
                n_iter_max=n_iter,
                batches_per_it=n_samples // batch_size,
                verbose=True)
    weights = dictionary.A[-1, :].detach()
    XP = [XPk.detach().clone() for XPk in dictionary.XP]
    YP = [YPk.detach().clone().softmax(dim=-1) for YPk in dictionary.YP]
    Xr, Yr = dictionary.reconstruct(weights=weights)

    classifiers_e = {'lin': SVC(kernel='linear', probability=True), 'rbf': SVC(kernel='rbf', probability=True), 'RF': RandomForestClassifier()}
    classifiers_r = {'lin': SVC(kernel='linear'), 'rbf': SVC(kernel='rbf',), 'RF': RandomForestClassifier()}

    
    for key in classifiers_e.keys():
        # Without DA
        clf_wda = classifiers_r[key]
        clf_wda.fit(torch.cat(Xs, dim=0),
                torch.cat(ys, dim=0))
        yp = clf_wda.predict(Xt_test)
        accuracy_wda = accuracy_score(yp, yt_test)
        results[key]['wda'] += accuracy_wda

        # DaDiL-E
        clf_e = classifiers_e[key]
        predictions = []
        for XP_k, YP_k in zip(XP, YP):
            # Get atom data
            XP_k, YP_k = XP_k.data.cpu(), YP_k.data.cpu()
            yp_k = YP_k.argmax(dim=1)
            clf_e.fit(XP_k, yp_k)
            P = clf_e.predict_proba(Xt_test)
            predictions.append(P)
        predictions = np.stack(predictions)
        # Weights atomic model predictions
        yp = np.einsum('i,inj->nj', weights, predictions).argmax(axis=1)
        # Compute statistics
        accuracy_e = accuracy_score(yt_test, yp)
        results[key]['e'] += accuracy_e

        # DaDiL-E with last optimal transport
        s = 0
        for _ in range(10):
            predictions = []
            for XP_k, YP_k in zip(XP, YP):
                # Get atom data
                XP_k, YP_k = XP_k.data.cpu(), YP_k.data.cpu()
                weights_k = torch.ones(XP_k.shape[0])/XP_k.shape[0]
                weights_t = torch.ones(Xt.shape[0])/Xt.shape[0]
                C = torch.cdist(XP_k, Xt, p=2) ** 2
                ot_plan = ot.emd(weights_k, weights_t, C, numItermax=1000000)
                Yt = ot_plan.T @ YP_k
                yt_k = Yt.argmax(dim=1)
                clf_e.fit(Xt, yt_k)
                P = clf_e.predict_proba(Xt_test)
                predictions.append(P)
            predictions = np.stack(predictions)
            # Weights atomic model predictions
            yp = np.einsum('i,inj->nj', weights, predictions).argmax(axis=1)
            # Compute statistics
            accuracy_e_ot = accuracy_score(yt_test, yp)
            s += accuracy_e_ot
        mean_accuracy_e_ot = s/10
        results[key]['e_ot'] += mean_accuracy_e_ot

        # DaDiL-R
        clf_r = classifiers_r[key]
        clf_r.fit(Xr, Yr.argmax(dim=1))
        yp = clf_r.predict(Xt_test)
        accuracy_r = accuracy_score(yp, yt_test)
        results[key]['r'] += accuracy_r

        # DaDiL-R with last optimal transport
        s = 0
        for _ in range(10):
            weights_r = torch.ones(Xr.shape[0])/Xr.shape[0]
            weights_t = torch.ones(Xt.shape[0])/Xt.shape[0]
            C = torch.cdist(Xr, Xt, p=2) ** 2
            ot_plan = ot.emd(weights_r, weights_t, C, numItermax=1000000)
            Yt = ot_plan.T @ Yr
            clf_r.fit(Xt, Yt.argmax(dim=1))
            yp = clf_r.predict(Xt_test)
            accuracy_r_ot = accuracy_score(yp, yt_test)
            s += accuracy_r_ot
        results[key]['r_ot'] += s/10

    
    return results

In [7]:
results = test_dadil(Xs, ys, Xt, yt, Xt_test, yt_test, n_features, n_samples, n_classes, n_atoms, batch_size, n_iter)

It 1/100, Loss: 7250.020605468751
It 2/100, Loss: 5412.389453125
It 3/100, Loss: 3971.7236328125
It 4/100, Loss: 2991.436328125
It 5/100, Loss: 2328.8544921875
It 6/100, Loss: 1843.2712158203126
It 7/100, Loss: 1453.3168212890625
It 8/100, Loss: 1166.7173095703124
It 9/100, Loss: 945.2510131835938
It 10/100, Loss: 764.6506469726563
It 11/100, Loss: 632.8729858398438
It 12/100, Loss: 540.3836059570312
It 13/100, Loss: 466.92755126953125
It 14/100, Loss: 401.4289916992187
It 15/100, Loss: 359.0767761230469
It 16/100, Loss: 323.2957702636719
It 17/100, Loss: 291.80101318359374
It 18/100, Loss: 274.8337951660156
It 19/100, Loss: 265.99306640625
It 20/100, Loss: 248.4516387939453
It 21/100, Loss: 236.10530395507809
It 22/100, Loss: 216.57101745605468
It 23/100, Loss: 215.47197875976565
It 24/100, Loss: 218.15243530273438
It 25/100, Loss: 204.7766082763672
It 26/100, Loss: 196.22800292968748
It 27/100, Loss: 176.01340332031248
It 28/100, Loss: 177.65384521484373
It 29/100, Loss: 154.66368103

In [8]:
results

{'lin': {'wda': 0.7316666666666667,
  'e': 0.55125,
  'e_ot': 0.9645000000000001,
  'r': 0.9854166666666667,
  'r_ot': 0.9695833333333332},
 'rbf': {'wda': 0.72125,
  'e': 0.8616666666666667,
  'e_ot': 1.0,
  'r': 1.0,
  'r_ot': 1.0},
 'RF': {'wda': 0.7129166666666666,
  'e': 0.6879166666666666,
  'e_ot': 0.9668749999999999,
  'r': 0.9966666666666667,
  'r_ot': 0.9345833333333331}}

In [9]:
def test_odadil(Xs, ys, Xt, yt, Xt_test, yt_test, n_features, n_samples, n_classes, n_atoms, batch_size, n_iter):
    results = {'lin':{'wda': 0, 'e':0, 'e_ot':0, 'r':0, 'r_ot':0}, 'rbf':{'wda': 0, 'e':0, 'e_ot':0, 'r':0, 'r_ot':0}, 'RF':{'wda': 0, 'e':0, 'e_ot':0, 'r':0, 'r_ot':0}}
    
    Q_sources = []
    for Xs_k, ys_k in zip(Xs, ys):
        Q_sources.append(
            SupervisedDatasetMeasure(
                features=Xs_k.numpy(),
                labels=ys_k.numpy(),
                stratify=True,
                batch_size=batch_size,
                device='cpu'
            )
        )

    criterion = JointWassersteinDistance()

    dictionary_sources = LabeledDictionaryGMM(XP=None,
                            YP=None,
                            A=None,
                            n_samples=n_samples,
                            n_dim=n_features,
                            n_classes=n_classes,
                            n_components=n_atoms,
                            weight_initialization='uniform',
                            n_distributions=len(Q_sources),
                            loss_fn=criterion,
                            learning_rate_features=1e-1,
                            learning_rate_labels=1e-1,
                            learning_rate_weights=1e-1,
                            reg_e=0.0,
                            n_iter_barycenter=10,
                            n_iter_sinkhorn=20,
                            n_iter_emd=1000000,
                            domain_names=None,
                            grad_labels=True,
                            optimizer_name='Adam',
                            balanced_sampling=True,
                            sampling_with_replacement=True,
                            barycenter_tol=1e-9,
                            barycenter_beta=None,
                            tensor_dtype=torch.float32,
                            track_atoms=False,
                            schedule_lr=False)

    dictionary_sources.fit(Q_sources,
                n_iter_max=n_iter,
                batches_per_it=n_samples // batch_size,
                verbose=True)

    XP_sources = dictionary_sources.XP
    YP_sources = dictionary_sources.YP

    dictionary_target = LabeledDictionaryGMM(XP=XP_sources,
                                    YP=YP_sources,
                                    A=None,
                                    n_samples=n_samples,
                                    n_dim=n_features,
                                    n_classes=n_classes,
                                    n_components=n_atoms,
                                    weight_initialization='uniform',
                                    n_distributions=1,
                                    loss_fn=criterion,
                                    learning_rate_features=0,
                                    learning_rate_labels=0,
                                    learning_rate_weights=1e-1,
                                    reg_e=0.0,
                                    n_iter_barycenter=10,
                                    n_iter_sinkhorn=20,
                                    n_iter_emd=1000000,
                                    domain_names=None,
                                    grad_labels=True,
                                    optimizer_name='Adam',
                                    balanced_sampling=True,
                                    sampling_with_replacement=True,
                                    barycenter_tol=1e-9,
                                    barycenter_beta=None,
                                    tensor_dtype=torch.float32,
                                    track_atoms=False,
                                    schedule_lr=False,
                                    min_components=10,
                                    max_step_components=10,
                                    max_components=20)
    
    n_batch = 20
    i = 0
    while i < Xt.shape[0]-n_batch:
        dictionary_target.fit_target_sample(Xt[i:i+n_batch, :],
                                            batches_per_it=n_samples // batch_size,
                                            batch_size=batch_size,
                                            verbose=True,
                                            regularization=False,)
        i += n_batch

    weights = dictionary_target.A[-1, :].detach()
    XP = [XPk.detach().clone() for XPk in dictionary_target.XP]
    YP = [YPk.detach().clone().softmax(dim=-1) for YPk in dictionary_target.YP]

    Xr, Yr = dictionary_target.reconstruct(weights=weights)

    classifiers_e = {'lin': SVC(kernel='linear', probability=True), 'rbf': SVC(kernel='rbf', probability=True), 'RF': RandomForestClassifier()}
    classifiers_r = {'lin': SVC(kernel='linear'), 'rbf': SVC(kernel='rbf',), 'RF': RandomForestClassifier()}

    for key in classifiers_e.keys():
        # Without DA
        clf_wda = classifiers_r[key]
        clf_wda.fit(torch.cat(Xs, dim=0),
                torch.cat(ys, dim=0))
        yp = clf_wda.predict(Xt_test)
        accuracy_wda = accuracy_score(yp, yt_test)
        results[key]['wda'] += accuracy_wda

        #DaDiL-E
        clf_e = classifiers_e[key]
        predictions = []
        for XP_k, YP_k in zip(XP, YP):
            # Get atom data
            XP_k, YP_k = XP_k.data.cpu(), YP_k.data.cpu()
            yp_k = YP_k.argmax(dim=1)
            clf_e.fit(XP_k, yp_k)
            P = clf_e.predict_proba(Xt_test)
            predictions.append(P)
        predictions = np.stack(predictions)
        # Weights atomic model predictions
        yp = np.einsum('i,inj->nj', weights, predictions).argmax(axis=1)
        # Compute statistics
        accuracy_e = accuracy_score(yt_test, yp)
        results[key]['e'] += accuracy_e

        #DaDiL-E with last optimal transport
        s = 0
        for _ in range(10):
            predictions = []
            for XP_k, YP_k in zip(XP, YP):
                # Get atom data
                XP_k, YP_k = XP_k.data.cpu(), YP_k.data.cpu()
                weights_k = torch.ones(XP_k.shape[0])/XP_k.shape[0]
                weights_t = torch.ones(Xt.shape[0])/Xt.shape[0]
                C = torch.cdist(XP_k, Xt, p=2) ** 2
                ot_plan = ot.emd(weights_k, weights_t, C, numItermax=1000000)
                Yt = ot_plan.T @ YP_k
                yt_k = Yt.argmax(dim=1)
                clf_e.fit(Xt, yt_k)
                P = clf_e.predict_proba(Xt_test)
                predictions.append(P)
            predictions = np.stack(predictions)
            # Weights atomic model predictions
            yp = np.einsum('i,inj->nj', weights, predictions).argmax(axis=1)
            # Compute statistics
            accuracy_e_ot = accuracy_score(yt_test, yp)
            s += accuracy_e_ot
        results[key]['e_ot'] += s/10

        #DaDiL-R
        clf_r = classifiers_r[key]
        clf_r.fit(Xr, Yr.argmax(dim=1))
        yp = clf_r.predict(Xt_test)
        accuracy_r = accuracy_score(yp, yt_test)
        results[key]['r'] += accuracy_r

        #DaDiL-R with last optimal transport
        s = 0
        for _ in range(10):
            weights_r = torch.ones(Xr.shape[0])/Xr.shape[0]
            weights_t = torch.ones(Xt.shape[0])/Xt.shape[0]
            C = torch.cdist(Xr, Xt, p=2) ** 2
            ot_plan = ot.emd(weights_r, weights_t, C, numItermax=1000000)
            Yt = ot_plan.T @ Yr
            clf_r.fit(Xt, Yt.argmax(dim=1))
            yp = clf_r.predict(Xt_test)
            accuracy_r_ot = accuracy_score(yp, yt_test)
            s += accuracy_r_ot
        results[key]['r_ot'] += s/10

    return results

In [11]:
results_odadil = test_odadil(Xs, ys, Xt, yt, Xt_test, yt_test, n_features, n_samples, n_classes, n_atoms, batch_size, n_iter)

It 1/100, Loss: 6155.6748046875
It 2/100, Loss: 4598.67529296875
It 3/100, Loss: 3395.6925781249997
It 4/100, Loss: 2567.266064453125
It 5/100, Loss: 2030.51181640625
It 6/100, Loss: 1634.2913574218749
It 7/100, Loss: 1315.9779296875
It 8/100, Loss: 1076.7628173828125
It 9/100, Loss: 885.8607177734375
It 10/100, Loss: 720.4839965820313
It 11/100, Loss: 579.9966430664062
It 12/100, Loss: 482.60099487304683
It 13/100, Loss: 398.63926391601564
It 14/100, Loss: 329.35023193359376
It 15/100, Loss: 274.9776092529297
It 16/100, Loss: 231.20325622558593
It 17/100, Loss: 196.91969299316406
It 18/100, Loss: 168.05216064453126
It 19/100, Loss: 144.07355346679688
It 20/100, Loss: 124.32376251220703
It 21/100, Loss: 108.23246765136719
It 22/100, Loss: 93.77239532470702
It 23/100, Loss: 82.05999145507812
It 24/100, Loss: 73.50467834472657
It 25/100, Loss: 65.47496948242187
It 26/100, Loss: 59.25986175537109
It 27/100, Loss: 54.32483749389648
It 28/100, Loss: 49.90577697753907
It 29/100, Loss: 45.150

In [13]:
results_odadil

{'lin': {'wda': 0.7316666666666667,
  'e': 0.7529166666666667,
  'e_ot': 0.9854166666666668,
  'r': 0.7520833333333333,
  'r_ot': 0.985},
 'rbf': {'wda': 0.72125,
  'e': 0.75625,
  'e_ot': 1.0,
  'r': 0.6995833333333333,
  'r_ot': 1.0},
 'RF': {'wda': 0.6883333333333334,
  'e': 0.6829166666666666,
  'e_ot': 0.990375,
  'r': 0.5975,
  'r_ot': 0.9896666666666667}}

In [14]:
def test_forgetting_odadil(Xs, ys, Xt, yt, Xt_test, yt_test, n_features, n_samples, n_classes, n_atoms, batch_size, n_iter):
    before_online_results = {'lin':{'r':[], 'r_ot':[]}, 
               'rbf':{'r':[], 'r_ot':[]}, 
               'RF':{'r':[], 'r_ot':[]}}
    after_online_results = {'lin':{'r':[], 'r_ot':[]}, 
               'rbf':{'r':[], 'r_ot':[]}, 
               'RF':{'r':[], 'r_ot':[]}}
    classifiers_r = {'lin': SVC(kernel='linear'), 'rbf': SVC(kernel='rbf',), 'RF': RandomForestClassifier()}

    Q_sources = []
    for Xs_k, ys_k in zip(Xs, ys):
        Q_sources.append(
            SupervisedDatasetMeasure(
                features=Xs_k.numpy(),
                labels=ys_k.numpy(),
                stratify=True,
                batch_size=batch_size,
                device='cpu'
            )
        )

    criterion = JointWassersteinDistance()

    dictionary_sources = LabeledDictionaryGMM(XP=None,
                            YP=None,
                            A=None,
                            n_samples=n_samples,
                            n_dim=n_features,
                            n_classes=n_classes,
                            n_components=n_atoms,
                            weight_initialization='uniform',
                            n_distributions=len(Q_sources),
                            loss_fn=criterion,
                            learning_rate_features=1e-1,
                            learning_rate_labels=1e-1,
                            learning_rate_weights=1e-1,
                            reg_e=0.0,
                            n_iter_barycenter=10,
                            n_iter_sinkhorn=20,
                            n_iter_emd=1000000,
                            domain_names=None,
                            grad_labels=True,
                            optimizer_name='Adam',
                            balanced_sampling=True,
                            sampling_with_replacement=True,
                            barycenter_tol=1e-9,
                            barycenter_beta=None,
                            tensor_dtype=torch.float32,
                            track_atoms=False,
                            schedule_lr=False)

    dictionary_sources.fit(Q_sources,
                n_iter_max=n_iter,
                batches_per_it=n_samples // batch_size,
                verbose=True)

    XP_sources = dictionary_sources.XP
    YP_sources = dictionary_sources.YP

    weights_list = dictionary_sources.A.detach()

    # Test classif sources avant OGMM
    for i in range(len(weights_list)):
        Xr, Yr = dictionary_sources.reconstruct(weights=weights_list[i])

        for key in classifiers_r.keys():
            #DaDiL-R
            clf_r = classifiers_r[key]
            clf_r.fit(Xr, Yr.argmax(dim=1))
            yp = clf_r.predict(Xt_test)
            accuracy_r = accuracy_score(yp, yt_test)
            before_online_results[key]['r'].append(accuracy_r)

            #DaDiL-R with last optimal transport
            s = 0
            for _ in range(10):
                weights_r = torch.ones(Xr.shape[0])/Xr.shape[0]
                weights_t = torch.ones(Xt.shape[0])/Xt.shape[0]
                C = torch.cdist(Xr, Xt, p=2) ** 2
                ot_plan = ot.emd(weights_r, weights_t, C, numItermax=1000000)
                Yt = ot_plan.T @ Yr
                clf_r.fit(Xt, Yt.argmax(dim=1))
                yp = clf_r.predict(Xt_test)
                accuracy_r_ot = accuracy_score(yp, yt_test)
                s += accuracy_r_ot
            before_online_results[key]['r_ot'].append(s/10)

    # Online Learning
    dictionary_target = LabeledDictionaryGMM(XP=XP_sources,
                                    YP=YP_sources,
                                    A=None,
                                    n_samples=n_samples,
                                    n_dim=n_features,
                                    n_classes=n_classes,
                                    n_components=n_atoms,
                                    weight_initialization='uniform',
                                    n_distributions=1,
                                    loss_fn=criterion,
                                    learning_rate_features=0,
                                    learning_rate_labels=0,
                                    learning_rate_weights=1e-1,
                                    reg_e=0.0,
                                    n_iter_barycenter=10,
                                    n_iter_sinkhorn=20,
                                    n_iter_emd=1000000,
                                    domain_names=None,
                                    grad_labels=True,
                                    optimizer_name='Adam',
                                    balanced_sampling=True,
                                    sampling_with_replacement=True,
                                    barycenter_tol=1e-9,
                                    barycenter_beta=None,
                                    tensor_dtype=torch.float32,
                                    track_atoms=False,
                                    schedule_lr=False,
                                    min_components=10,
                                    max_step_components=10,
                                    max_components=20)
    
    n_batch = 20
    c = 0
    while c < Xt.shape[0]-n_batch:
        dictionary_target.fit_target_sample(Xt[c:c+n_batch, :],
                                            batches_per_it=n_samples // batch_size,
                                            batch_size=batch_size,
                                            verbose=True,
                                            regularization=False,)
        c += n_batch


    # Test classif sources après online learning
    for i in range(len(weights_list)):
        Xr, Yr = dictionary_target.reconstruct(weights=weights_list[i])

        for key in classifiers_r.keys():
            #DaDiL-R
            clf_r = classifiers_r[key]
            clf_r.fit(Xr, Yr.argmax(dim=1))
            yp = clf_r.predict(Xt_test)
            accuracy_r = accuracy_score(yp, yt_test)
            after_online_results[key]['r'].append(accuracy_r)

            #DaDiL-R with last optimal transport
            s = 0
            for _ in range(10):
                weights_r = torch.ones(Xr.shape[0])/Xr.shape[0]
                weights_t = torch.ones(Xt.shape[0])/Xt.shape[0]
                C = torch.cdist(Xr, Xt, p=2) ** 2
                ot_plan = ot.emd(weights_r, weights_t, C, numItermax=1000000)
                Yt = ot_plan.T @ Yr
                clf_r.fit(Xt, Yt.argmax(dim=1))
                yp = clf_r.predict(Xt_test)
                accuracy_r_ot = accuracy_score(yp, yt_test)
                s += accuracy_r_ot
            after_online_results[key]['r_ot'].append(s/10)
    
    return before_online_results, after_online_results

In [15]:
before_online_results, after_online_results = test_forgetting_odadil(Xs, ys, Xt, yt, Xt_test, yt_test, n_features, n_samples, n_classes, n_atoms, batch_size, n_iter)

It 1/100, Loss: 6184.32822265625
It 2/100, Loss: 4652.443798828125
It 3/100, Loss: 3393.0194335937495
It 4/100, Loss: 2564.2015625
It 5/100, Loss: 2066.425610351562
It 6/100, Loss: 1667.5993896484376
It 7/100, Loss: 1336.59296875
It 8/100, Loss: 1108.1639404296875
It 9/100, Loss: 877.7685546875
It 10/100, Loss: 731.2276367187501
It 11/100, Loss: 601.156787109375
It 12/100, Loss: 491.2104309082031
It 13/100, Loss: 402.89794921875
It 14/100, Loss: 338.68431396484374
It 15/100, Loss: 283.22911376953124
It 16/100, Loss: 236.00957946777345
It 17/100, Loss: 198.70504760742188
It 18/100, Loss: 171.68735961914064
It 19/100, Loss: 145.5631591796875
It 20/100, Loss: 125.6956069946289
It 21/100, Loss: 108.38717498779296
It 22/100, Loss: 95.46341094970703
It 23/100, Loss: 84.44511718749999
It 24/100, Loss: 73.37073211669922
It 25/100, Loss: 66.61158294677735
It 26/100, Loss: 60.346449279785155
It 27/100, Loss: 54.38810729980469
It 28/100, Loss: 49.41547775268555
It 29/100, Loss: 45.52946395874024


In [18]:
before_online_results

{'lin': {'r': [0.7820833333333334, 0.72125],
  'r_ot': [0.98125, 0.9908333333333333]},
 'rbf': {'r': [0.8270833333333333, 0.8691666666666666], 'r_ot': [1.0, 1.0]},
 'RF': {'r': [0.7520833333333333, 0.7433333333333333],
  'r_ot': [0.9861666666666669, 0.9935833333333333]}}

In [19]:
after_online_results

{'lin': {'r': [0.6770833333333334, 0.7508333333333334],
  'r_ot': [0.9825000000000002, 0.9870833333333333]},
 'rbf': {'r': [0.74125, 0.8533333333333334], 'r_ot': [1.0, 1.0]},
 'RF': {'r': [0.8454166666666667, 0.7875], 'r_ot': [0.9850416666666666, 0.99]}}