In [1]:
from domains import load_domains

domains, eval_domains = load_domains()

In [2]:
from lwp import LWP

In [None]:
import numpy as np

def sample_from_gmms(gmms, n_samples, class_counts, num_classes = 10):
    pseudo_features = []
    pseudo_labels = []
    sampling_probabilities = class_counts / np.sum(class_counts)
    print('class counts are' , class_counts)
    
    for i in range(num_classes):
        # Determine the number of samples for this class based on its probability
        num_class_samples = int(n_samples * sampling_probabilities[i])
        
        # Sample from the ith GMM
        class_samples, _ = gmms[i].sample(num_class_samples)
        
        # Append the samples and corresponding class labels
        pseudo_features.append(class_samples)
        pseudo_labels.extend([i] * num_class_samples)
    
    # Concatenate the features and labels
    pseudo_features = np.concatenate(pseudo_features, axis=0)
    pseudo_labels = np.array(pseudo_labels)
    
    return pseudo_features, pseudo_labels

In [None]:
from sklearn.mixture import GaussianMixture

num_classes = 10
models = []

source_dataset = domains[0]
gmms = [None] * num_classes

model = LWP(distance_metric='euclidean')
model.fit(source_dataset['features'], source_dataset['labels'])
models.append(model)

class_frequencies = [np.sum(source_dataset['labels'] == i) for i in range(num_classes)]
total_samples = np.sum(class_frequencies)
sampling_probabilities = np.array(class_frequencies) / total_samples

# Update GMM Models
for i in range(num_classes):
    gmms[i] = GaussianMixture(n_components=2, covariance_type='full', random_state=42)
    gmms[i].fit(source_dataset['features'][source_dataset['labels'] == i])

In [None]:
pseudo_size = 2500
num_iters = 10

for i in range(1, 20) :
    curr_dataset = domains[i]['features']
    curr_dataset_labels = model.predict(curr_dataset)
    
    pseudo_dataset = {}
    pseudo_dataset['features'], pseudo_dataset['labels'] = sample_from_gmms(gmms, pseudo_size, list(model.class_counts.values()), num_classes = 10)
    
    new_dataset = {}
    new_dataset['features'] = np.concatenate([pseudo_dataset['features'], curr_dataset])
    new_dataset['labels'] = np.concatenate([pseudo_dataset['labels'], curr_dataset_labels])
    model.fit(new_dataset['features'], new_dataset['labels'])
    
    models.append(model)
    
    # Update GMM Models
    for i in range(num_classes):
        gmms[i] = GaussianMixture(n_components=2, covariance_type='full', random_state=42)
        gmms[i].fit(new_dataset['features'][new_dataset['labels'] == i])

class counts are [253, 243, 255, 244, 262, 236, 250, 253, 254, 250]
class counts are [748, 742, 737, 755, 776, 733, 748, 760, 750, 750]
class counts are [1264, 1231, 1197, 1260, 1303, 1270, 1248, 1230, 1255, 1237]
class counts are [1765, 1737, 1644, 1783, 1831, 1789, 1736, 1700, 1750, 1756]
class counts are [2268, 2253, 2094, 2288, 2360, 2327, 2215, 2168, 2238, 2274]
class counts are [2765, 2724, 2560, 2797, 2909, 2889, 2695, 2637, 2718, 2786]
class counts are [3264, 3207, 3031, 3296, 3455, 3430, 3167, 3093, 3239, 3292]
class counts are [3767, 3700, 3516, 3807, 3987, 3988, 3632, 3541, 3726, 3805]
class counts are [4260, 4206, 3961, 4322, 4549, 4510, 4096, 4002, 4231, 4328]
class counts are [4755, 4688, 4406, 4841, 5093, 5096, 4562, 4453, 4727, 4839]
class counts are [5265, 5175, 4870, 5454, 5632, 5668, 5026, 4843, 5209, 5313]
class counts are [5993, 5499, 5216, 5988, 6151, 6292, 5575, 5202, 5775, 5759]
class counts are [6508, 5968, 5623, 6538, 6701, 6834, 6089, 5619, 6297, 6267]
class 

In [None]:
from sklearn.metrics import accuracy_score
import pandas as pd

df = pd.DataFrame()

for idx,model in enumerate(models) :
    
    scores = []
    for eval_domain in eval_domains[:idx+1]:
        
        features = eval_domain['features']
        labels = eval_domain['labels']
        
        preds = model.predict(features)
        acc = accuracy_score(labels, preds)
        
        scores.append(acc)
    
    df[f'Domain {idx+1}'] = scores + [np.nan] * (len(eval_domains) - len(scores))

In [None]:
df

Unnamed: 0,Domain 1,Domain 2,Domain 3,Domain 4,Domain 5,Domain 6,Domain 7,Domain 8,Domain 9,Domain 10,Domain 11,Domain 12,Domain 13,Domain 14,Domain 15,Domain 16,Domain 17,Domain 18,Domain 19,Domain 20
0,0.8816,0.8816,0.8816,0.8816,0.8816,0.8816,0.8816,0.8816,0.8816,0.8816,0.8816,0.8816,0.8816,0.8816,0.8816,0.8816,0.8816,0.8816,0.8816,0.8816
1,,0.8876,0.8876,0.8876,0.8876,0.8876,0.8876,0.8876,0.8876,0.8876,0.8876,0.8876,0.8876,0.8876,0.8876,0.8876,0.8876,0.8876,0.8876,0.8876
2,,,0.8944,0.8944,0.8944,0.8944,0.8944,0.8944,0.8944,0.8944,0.8944,0.8944,0.8944,0.8944,0.8944,0.8944,0.8944,0.8944,0.8944,0.8944
3,,,,0.9092,0.9092,0.9092,0.9092,0.9092,0.9092,0.9092,0.9092,0.9092,0.9092,0.9092,0.9092,0.9092,0.9092,0.9092,0.9092,0.9092
4,,,,,0.8984,0.8984,0.8984,0.8984,0.8984,0.8984,0.8984,0.8984,0.8984,0.8984,0.8984,0.8984,0.8984,0.8984,0.8984,0.8984
5,,,,,,0.9056,0.9056,0.9056,0.9056,0.9056,0.9056,0.9056,0.9056,0.9056,0.9056,0.9056,0.9056,0.9056,0.9056,0.9056
6,,,,,,,0.8972,0.8972,0.8972,0.8972,0.8972,0.8972,0.8972,0.8972,0.8972,0.8972,0.8972,0.8972,0.8972,0.8972
7,,,,,,,,0.8932,0.8932,0.8932,0.8932,0.8932,0.8932,0.8932,0.8932,0.8932,0.8932,0.8932,0.8932,0.8932
8,,,,,,,,,0.8964,0.8964,0.8964,0.8964,0.8964,0.8964,0.8964,0.8964,0.8964,0.8964,0.8964,0.8964
9,,,,,,,,,,0.902,0.902,0.902,0.902,0.902,0.902,0.902,0.902,0.902,0.902,0.902
