In [1]:
import numpy as np
from Tests_ODaDiL import test_dadil, test_odadil, test_forgetting_odadil

from tensorflow.python.ops.numpy_ops import np_config
np_config.enable_numpy_behavior()
  register_backend(TensorflowBackend())
  from .autonotebook import tqdm as notebook_tqdm


In [4]:
import torch  # noqa: E402
import random  # noqa: E402
import numpy as np  # noqa: E402
import matplotlib.pyplot as plt
import ot

from pydil.utils.igmm_modif import IGMM

from pydil.ipms.ot_ipms import (  # noqa: E402
    JointWassersteinDistance
)
from pydil.dadil.labeled_dictionary_GMM import LabeledDictionaryGMM
from pydil.torch_utils.measures import (  # noqa: E402
    UnsupervisedDatasetMeasure,
    SupervisedDatasetMeasure
)

from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score

import os
import pickle

In [2]:
list_of_datasets = []
for i in range(1, 11):
    dataset = np.load(f'data/toy_non_linear_100d_dataset_{i}.npy')
    list_of_datasets.append(dataset)

In [5]:
target = 'C'
with open(os.path.join('data', 'mlp_fts_256_target_{}.pkl'.format(target)), 'rb') as f:
        dataset = pickle.loads(f.read())

Xs, ys = [], []
d = None
keys = list(dataset.keys())
for i in range(len(keys)-1):
    features = dataset[keys[i]]['Features']
    labels = dataset[keys[i]]['Labels'].argmax(dim=1)
    domain = i*np.ones((features.shape[0], 1))
    Xs.append(features.float())
    ys.append(labels.float())
    if d is None:
        d = domain
    else:
        d = np.concatenate([d, domain], axis=0)

Xt = dataset[target]['fold 0']['Train']['Features'].float()
yt = dataset[target]['fold 0']['Train']['Labels'].float().argmax(dim=1)

Xt_test = dataset[target]['fold 0']['Test']['Features'].float()
yt_test = dataset[target]['fold 0']['Test']['Labels'].float().argmax(dim=1)
d = np.concatenate([d, 2*np.ones((Xt.shape[0], 1))], axis=0)

n_domains = int(np.max(d)) + 1
n_features = Xt.shape[1]
n_classes = int(np.max(yt.numpy())) + 1

In [6]:
n_samples = 1000
batch_size = 200
n_atoms = 3
n_classes = 10
n_iter = 100

In [8]:
def test_dadil(Xs, ys, Xt, yt, n_features, n_samples, n_classes, n_atoms, batch_size, n_iter):
    results = {'lin':{'wda': 0, 'e':0, 'e_ot':0, 'r':0, 'r_ot':0}, 'rbf':{'wda': 0, 'e':0, 'e_ot':0, 'r':0, 'r_ot':0}, 'RF':{'wda': 0, 'e':0, 'e_ot':0, 'r':0, 'r_ot':0}}

    Q = []
    for Xs_k, ys_k in zip(Xs, ys):
        Q.append(
            SupervisedDatasetMeasure(
                features=Xs_k.numpy(),
                labels=ys_k.numpy(),
                stratify=True,
                batch_size=batch_size,
                device='cpu'
            )
        )
    Q.append(
        UnsupervisedDatasetMeasure(
            features=Xt.numpy(),
            batch_size=batch_size,
            device='cpu'
        )
    )
    criterion = JointWassersteinDistance()
    dictionary = LabeledDictionaryGMM(XP=None,
                            YP=None,
                            A=None,
                            n_samples=n_samples,
                            n_dim=n_features,
                            n_classes=n_classes,
                            n_components=n_atoms,
                            weight_initialization='uniform',
                            n_distributions=len(Q),
                            loss_fn=criterion,
                            learning_rate_features=1e-1,
                            learning_rate_labels=1e-1,
                            learning_rate_weights=1e-1,
                            reg_e=0.0,
                            n_iter_barycenter=10,
                            n_iter_sinkhorn=20,
                            n_iter_emd=1000000,
                            domain_names=None,
                            grad_labels=True,
                            optimizer_name='Adam',
                            balanced_sampling=True,
                            sampling_with_replacement=True,
                            barycenter_tol=1e-9,
                            barycenter_beta=None,
                            tensor_dtype=torch.float32,
                            track_atoms=False,
                            schedule_lr=False)
    dictionary.fit(Q,
                n_iter_max=n_iter,
                batches_per_it=n_samples // batch_size,
                verbose=True)
    weights = dictionary.A[-1, :].detach()
    XP = [XPk.detach().clone() for XPk in dictionary.XP]
    YP = [YPk.detach().clone().softmax(dim=-1) for YPk in dictionary.YP]
    Xr, Yr = dictionary.reconstruct(weights=weights)

    classifiers_e = {'lin': SVC(kernel='linear', probability=True), 'rbf': SVC(kernel='rbf', probability=True), 'RF': RandomForestClassifier()}
    classifiers_r = {'lin': SVC(kernel='linear'), 'rbf': SVC(kernel='rbf',), 'RF': RandomForestClassifier()}

    
    for key in classifiers_e.keys():
        # Without DA
        clf_wda = classifiers_r[key]
        clf_wda.fit(torch.cat(Xs, dim=0),
                torch.cat(ys, dim=0))
        yp = clf_wda.predict(Xt)
        accuracy_wda = accuracy_score(yp, yt)
        results[key]['wda'] += accuracy_wda

        # DaDiL-E
        clf_e = classifiers_e[key]
        predictions = []
        for XP_k, YP_k in zip(XP, YP):
            # Get atom data
            XP_k, YP_k = XP_k.data.cpu(), YP_k.data.cpu()
            yp_k = YP_k.argmax(dim=1)
            clf_e.fit(XP_k, yp_k)
            P = clf_e.predict_proba(Xt)
            predictions.append(P)
        predictions = np.stack(predictions)
        # Weights atomic model predictions
        yp = np.einsum('i,inj->nj', weights, predictions).argmax(axis=1)
        # Compute statistics
        accuracy_e = accuracy_score(yt, yp)
        results[key]['e'] += accuracy_e

        # DaDiL-E with last optimal transport
        s = 0
        for _ in range(10):
            predictions = []
            for XP_k, YP_k in zip(XP, YP):
                # Get atom data
                XP_k, YP_k = XP_k.data.cpu(), YP_k.data.cpu()
                weights_k = torch.ones(XP_k.shape[0])/XP_k.shape[0]
                weights_t = torch.ones(Xt.shape[0])/Xt.shape[0]
                C = torch.cdist(XP_k, Xt, p=2) ** 2
                ot_plan = ot.emd(weights_k, weights_t, C, numItermax=1000000)
                Yt = ot_plan.T @ YP_k
                yt_k = Yt.argmax(dim=1)
                clf_e.fit(Xt, yt_k)
                P = clf_e.predict_proba(Xt)
                predictions.append(P)
            predictions = np.stack(predictions)
            # Weights atomic model predictions
            yp = np.einsum('i,inj->nj', weights, predictions).argmax(axis=1)
            # Compute statistics
            accuracy_e_ot = accuracy_score(yt, yp)
            s += accuracy_e_ot
        mean_accuracy_e_ot = s/10
        results[key]['e_ot'] += mean_accuracy_e_ot

        # DaDiL-R
        clf_r = classifiers_r[key]
        clf_r.fit(Xr, Yr.argmax(dim=1))
        yp = clf_r.predict(Xt)
        accuracy_r = accuracy_score(yp, yt)
        results[key]['r'] += accuracy_r

        # DaDiL-R with last optimal transport
        s = 0
        for _ in range(10):
            weights_r = torch.ones(Xr.shape[0])/Xr.shape[0]
            weights_t = torch.ones(Xt.shape[0])/Xt.shape[0]
            C = torch.cdist(Xr, Xt, p=2) ** 2
            ot_plan = ot.emd(weights_r, weights_t, C, numItermax=1000000)
            Yt = ot_plan.T @ Yr
            clf_r.fit(Xt, Yt.argmax(dim=1))
            yp = clf_r.predict(Xt)
            accuracy_r_ot = accuracy_score(yp, yt)
            s += accuracy_r_ot
        results[key]['r_ot'] += s/10

    
    return results

In [10]:
results = test_dadil(Xs, ys, Xt, yt, n_features, n_samples, n_classes, n_atoms, batch_size, n_iter)

It 1/100, Loss: 7264.110546874999
It 2/100, Loss: 5393.187207031249
It 3/100, Loss: 3963.7130859374997
It 4/100, Loss: 3006.333056640625
It 5/100, Loss: 2353.006787109375
It 6/100, Loss: 1858.0480224609375
It 7/100, Loss: 1474.4659423828125
It 8/100, Loss: 1184.0206298828125
It 9/100, Loss: 959.5632934570312
It 10/100, Loss: 786.7889770507812
It 11/100, Loss: 645.5423217773439
It 12/100, Loss: 542.1340576171875
It 13/100, Loss: 462.6333801269531
It 14/100, Loss: 399.08770141601565
It 15/100, Loss: 360.6198303222656
It 16/100, Loss: 332.4620666503906
It 17/100, Loss: 303.83948364257816
It 18/100, Loss: 277.38637084960936
It 19/100, Loss: 262.4359436035156
It 20/100, Loss: 246.7506072998047
It 21/100, Loss: 230.53716430664062
It 22/100, Loss: 223.20847778320314
It 23/100, Loss: 238.5696044921875
It 24/100, Loss: 211.16941528320314
It 25/100, Loss: 204.50079040527345
It 26/100, Loss: 196.11647949218752
It 27/100, Loss: 198.8087341308594
It 28/100, Loss: 185.7472595214844
It 29/100, Loss: 

In [12]:
results

{'lin': {'wda': 0.7307142857142858,
  'e': 0.4180357142857143,
  'e_ot': 0.8917142857142858,
  'r': 0.9048214285714286,
  'r_ot': 0.9305357142857142},
 'rbf': {'wda': 0.7217857142857143,
  'e': 0.7391071428571429,
  'e_ot': 1.0,
  'r': 1.0,
  'r_ot': 1.0},
 'RF': {'wda': 0.6928571428571428,
  'e': 0.6258928571428571,
  'e_ot': 0.7177499999999999,
  'r': 0.9541071428571428,
  'r_ot': 0.7744642857142858}}

In [13]:
def test_odadil(Xs, ys, Xt, yt, n_features, n_samples, n_classes, n_atoms, batch_size, n_iter):
    results = {'lin':{'wda': 0, 'e':0, 'e_ot':0, 'r':0, 'r_ot':0}, 'rbf':{'wda': 0, 'e':0, 'e_ot':0, 'r':0, 'r_ot':0}, 'RF':{'wda': 0, 'e':0, 'e_ot':0, 'r':0, 'r_ot':0}}
    
    Q_sources = []
    for Xs_k, ys_k in zip(Xs, ys):
        Q_sources.append(
            SupervisedDatasetMeasure(
                features=Xs_k.numpy(),
                labels=ys_k.numpy(),
                stratify=True,
                batch_size=batch_size,
                device='cpu'
            )
        )

    criterion = JointWassersteinDistance()

    dictionary_sources = LabeledDictionaryGMM(XP=None,
                            YP=None,
                            A=None,
                            n_samples=n_samples,
                            n_dim=n_features,
                            n_classes=n_classes,
                            n_components=n_atoms,
                            weight_initialization='uniform',
                            n_distributions=len(Q_sources),
                            loss_fn=criterion,
                            learning_rate_features=1e-1,
                            learning_rate_labels=1e-1,
                            learning_rate_weights=1e-1,
                            reg_e=0.0,
                            n_iter_barycenter=10,
                            n_iter_sinkhorn=20,
                            n_iter_emd=1000000,
                            domain_names=None,
                            grad_labels=True,
                            optimizer_name='Adam',
                            balanced_sampling=True,
                            sampling_with_replacement=True,
                            barycenter_tol=1e-9,
                            barycenter_beta=None,
                            tensor_dtype=torch.float32,
                            track_atoms=False,
                            schedule_lr=False)

    dictionary_sources.fit(Q_sources,
                n_iter_max=n_iter,
                batches_per_it=n_samples // batch_size,
                verbose=True)

    XP_sources = dictionary_sources.XP
    YP_sources = dictionary_sources.YP

    dictionary_target = LabeledDictionaryGMM(XP=XP_sources,
                                    YP=YP_sources,
                                    A=None,
                                    n_samples=n_samples,
                                    n_dim=n_features,
                                    n_classes=n_classes,
                                    n_components=n_atoms,
                                    weight_initialization='uniform',
                                    n_distributions=1,
                                    loss_fn=criterion,
                                    learning_rate_features=0,
                                    learning_rate_labels=0,
                                    learning_rate_weights=1e-1,
                                    reg_e=0.0,
                                    n_iter_barycenter=10,
                                    n_iter_sinkhorn=20,
                                    n_iter_emd=1000000,
                                    domain_names=None,
                                    grad_labels=True,
                                    optimizer_name='Adam',
                                    balanced_sampling=True,
                                    sampling_with_replacement=True,
                                    barycenter_tol=1e-9,
                                    barycenter_beta=None,
                                    tensor_dtype=torch.float32,
                                    track_atoms=False,
                                    schedule_lr=False,
                                    min_components=10,
                                    max_step_components=10,
                                    max_components=20)
    
    n_batch = 20
    i = 0
    while i < Xt.shape[0]-n_batch:
        dictionary_target.fit_target_sample(Xt[i:i+n_batch, :],
                                            batches_per_it=n_samples // batch_size,
                                            batch_size=batch_size,
                                            verbose=True,
                                            regularization=False,)
        i += n_batch

    weights = dictionary_target.A[-1, :].detach()
    XP = [XPk.detach().clone() for XPk in dictionary_target.XP]
    YP = [YPk.detach().clone().softmax(dim=-1) for YPk in dictionary_target.YP]

    Xr, Yr = dictionary_target.reconstruct(weights=weights)

    classifiers_e = {'lin': SVC(kernel='linear', probability=True), 'rbf': SVC(kernel='rbf', probability=True), 'RF': RandomForestClassifier()}
    classifiers_r = {'lin': SVC(kernel='linear'), 'rbf': SVC(kernel='rbf',), 'RF': RandomForestClassifier()}

    for key in classifiers_e.keys():
        # Without DA
        clf_wda = classifiers_r[key]
        clf_wda.fit(torch.cat(Xs, dim=0),
                torch.cat(ys, dim=0))
        yp = clf_wda.predict(Xt)
        accuracy_wda = accuracy_score(yp, yt)
        results[key]['wda'] += accuracy_wda

        #DaDiL-E
        clf_e = classifiers_e[key]
        predictions = []
        for XP_k, YP_k in zip(XP, YP):
            # Get atom data
            XP_k, YP_k = XP_k.data.cpu(), YP_k.data.cpu()
            yp_k = YP_k.argmax(dim=1)
            clf_e.fit(XP_k, yp_k)
            P = clf_e.predict_proba(Xt)
            predictions.append(P)
        predictions = np.stack(predictions)
        # Weights atomic model predictions
        yp = np.einsum('i,inj->nj', weights, predictions).argmax(axis=1)
        # Compute statistics
        accuracy_e = accuracy_score(yt, yp)
        results[key]['e'] += accuracy_e

        #DaDiL-E with last optimal transport
        s = 0
        for _ in range(10):
            predictions = []
            for XP_k, YP_k in zip(XP, YP):
                # Get atom data
                XP_k, YP_k = XP_k.data.cpu(), YP_k.data.cpu()
                weights_k = torch.ones(XP_k.shape[0])/XP_k.shape[0]
                weights_t = torch.ones(Xt.shape[0])/Xt.shape[0]
                C = torch.cdist(XP_k, Xt, p=2) ** 2
                ot_plan = ot.emd(weights_k, weights_t, C, numItermax=1000000)
                Yt = ot_plan.T @ YP_k
                yt_k = Yt.argmax(dim=1)
                clf_e.fit(Xt, yt_k)
                P = clf_e.predict_proba(Xt)
                predictions.append(P)
            predictions = np.stack(predictions)
            # Weights atomic model predictions
            yp = np.einsum('i,inj->nj', weights, predictions).argmax(axis=1)
            # Compute statistics
            accuracy_e_ot = accuracy_score(yt, yp)
            s += accuracy_e_ot
        results[key]['e_ot'] += s/10

        #DaDiL-R
        clf_r = classifiers_r[key]
        clf_r.fit(Xr, Yr.argmax(dim=1))
        yp = clf_r.predict(Xt)
        accuracy_r = accuracy_score(yp, yt)
        results[key]['r'] += accuracy_r

        #DaDiL-R with last optimal transport
        s = 0
        for _ in range(10):
            weights_r = torch.ones(Xr.shape[0])/Xr.shape[0]
            weights_t = torch.ones(Xt.shape[0])/Xt.shape[0]
            C = torch.cdist(Xr, Xt, p=2) ** 2
            ot_plan = ot.emd(weights_r, weights_t, C, numItermax=1000000)
            Yt = ot_plan.T @ Yr
            clf_r.fit(Xt, Yt.argmax(dim=1))
            yp = clf_r.predict(Xt)
            accuracy_r_ot = accuracy_score(yp, yt)
            s += accuracy_r_ot
        results[key]['r_ot'] += s/10

    return results

In [14]:
results_odadil = test_odadil(Xs, ys, Xt, yt, n_features, n_samples, n_classes, n_atoms, batch_size, n_iter)

It 1/100, Loss: 6234.732324218749
It 2/100, Loss: 4638.29609375
It 3/100, Loss: 3427.78125
It 4/100, Loss: 2587.117431640625
It 5/100, Loss: 2061.343920898438
It 6/100, Loss: 1649.3478027343751
It 7/100, Loss: 1352.8155761718751
It 8/100, Loss: 1094.2416259765625
It 9/100, Loss: 894.1794189453126
It 10/100, Loss: 710.6870239257813
It 11/100, Loss: 590.6761108398437
It 12/100, Loss: 484.77144165039067
It 13/100, Loss: 404.81666870117186
It 14/100, Loss: 338.74931640625005
It 15/100, Loss: 278.3031311035156
It 16/100, Loss: 237.87523498535157
It 17/100, Loss: 197.33466796874998
It 18/100, Loss: 168.2675109863281
It 19/100, Loss: 142.91923217773436
It 20/100, Loss: 124.55365753173828
It 21/100, Loss: 107.83352355957032
It 22/100, Loss: 93.39157104492188
It 23/100, Loss: 82.75978393554688
It 24/100, Loss: 73.69974060058594
It 25/100, Loss: 65.78096313476563
It 26/100, Loss: 59.21064147949219
It 27/100, Loss: 53.48043365478515
It 28/100, Loss: 48.749976348876956
It 29/100, Loss: 44.82263488

In [17]:
results_odadil

{'lin': {'wda': 0.7307142857142858,
  'e': 0.7516071428571428,
  'e_ot': 0.9882142857142855,
  'r': 0.8239285714285715,
  'r_ot': 0.9876785714285712},
 'rbf': {'wda': 0.7217857142857143,
  'e': 0.7598214285714285,
  'e_ot': 1.0,
  'r': 0.9175,
  'r_ot': 1.0},
 'RF': {'wda': 0.6610714285714285,
  'e': 0.7153571428571428,
  'e_ot': 0.9875714285714287,
  'r': 0.7158928571428571,
  'r_ot': 0.9846428571428572}}