In [1]:
import torch
import os
import numpy as np

In [2]:
one_train_dir = os.path.join('dataset', 'part_one_dataset', 'train_data')
one_eval_dir = os.path.join('dataset', 'part_one_dataset', 'eval_data')
two_train_dir = os.path.join('dataset', 'part_two_dataset', 'train_data')
two_eval_dir = os.path.join('dataset', 'part_two_dataset', 'eval_data')

In [3]:
domains = [{} for _ in range(20)]
eval_domains = [{} for _ in range(20)]

for i in range(10):
    domains[i] = torch.load(os.path.join(one_train_dir, f'{i+1}_train_data.tar.pth'), weights_only=False)
    domains[i+10] = torch.load(os.path.join(two_train_dir, f'{i+1}_train_data.tar.pth'), weights_only=False)
    eval_domains[i] = torch.load(os.path.join(one_eval_dir, f'{i+1}_eval_data.tar.pth'), weights_only=False)
    eval_domains[i+10] = torch.load(os.path.join(two_eval_dir, f'{i+1}_eval_data.tar.pth'), weights_only=False)
    
    domains[i]['data'] = domains[i]['data'].reshape(2500, -1)
    domains[i+10]['data'] = domains[i+10]['data'].reshape(2500, -1)
    eval_domains[i]['data'] = eval_domains[i]['data'].reshape(2500, -1)
    eval_domains[i+10]['data'] = eval_domains[i+10]['data'].reshape(2500, -1)

In [4]:
# train_paths = [os.path.join(one_train_dir, f'{i+1}_train_data.tar.pth') for i in range(10)] + [os.path.join(two_train_dir, f'{i+1}_train_data.tar.pth') for i in range(10)]
# eval_paths = [os.path.join(one_eval_dir, f'{i+1}_eval_data.tar.pth') for i in range(10)] + [os.path.join(two_eval_dir, f'{i+1}_eval_data.tar.pth') for i in range(10)]

# eval_domains = [torch.load(eval_paths[i], weights_only=False) for i in range(20)]

## Model

In [5]:
from lwp import LWP

## Training and Predictions

In [6]:
import numpy as np
from sklearn.mixture import GaussianMixture
from sklearn.preprocessing import StandardScaler

class GMMGenerativeClassifier:
    def __init__(self, n_components=3, n_classes=10, covariance_type='full', random_state=42):
        """
        Initialize GMM-based generative classifier
        
        Args:
            n_components (int): Number of Gaussian components per class
            n_classes (int): Number of classes
            covariance_type (str): Type of covariance parameters ('full', 'tied', 'diag', 'spherical')
            random_state (int): Random seed for reproducibility
        """
        self.n_components = n_components
        self.n_classes = n_classes
        self.covariance_type = covariance_type
        self.random_state = random_state
        
        # Initialize a GMM for each class
        self.gmms = [
            GaussianMixture(
                n_components=n_components,
                covariance_type=covariance_type,
                random_state=random_state
            ) for _ in range(n_classes)
        ]
        
        self.scaler = StandardScaler()
        self.class_priors = None
        
    def fit(self, embeddings, labels):
        """
        Fit the GMM classifier
        
        Args:
            embeddings: Array of shape (n_samples, n_features)
            labels: Array of shape (n_samples,)
        """
        # Scale the embeddings
        scaled_embeddings = self.scaler.fit_transform(embeddings)
        
        # Calculate class priors
        unique_labels, counts = np.unique(labels, return_counts=True)
        self.class_priors = counts / len(unique_labels)
        
        # Fit GMM for each class
        for class_idx in range(self.n_classes):
            class_mask = (labels == class_idx)
            class_embeddings = scaled_embeddings[class_mask]
            
            if len(class_embeddings) > 0:
                self.gmms[class_idx].fit(class_embeddings)
                
    def predict_proba(self, embeddings):
        """
        Predict class probabilities for embeddings
        
        Args:
            embeddings: Array of shape (n_samples, n_features)
        Returns:
            Array of shape (n_samples, n_classes) containing class probabilities
        """
        scaled_embeddings = self.scaler.transform(embeddings)
        
        # Calculate log likelihood for each class
        log_probs = np.zeros((len(embeddings), self.n_classes))
        
        for class_idx in range(self.n_classes):
            # Get log likelihood and add log prior
            log_probs[:, class_idx] = (
                self.gmms[class_idx].score_samples(scaled_embeddings) + 
                np.log(self.class_priors[class_idx])
            )
        
        # Convert log probabilities to probabilities
        log_prob_sum = logsumexp(log_probs, axis=1)
        probs = np.exp(log_probs - log_prob_sum[:, np.newaxis])
        
        return probs
    
    def predict(self, embeddings):
        """
        Predict classes for embeddings
        
        Args:
            embeddings: Array of shape (n_samples, n_features)
        Returns:
            Array of shape (n_samples,) containing predicted classes
        """
        probs = self.predict_proba(embeddings)
        return np.argmax(probs, axis=1)
    
    def generate_samples(self, n_samples_per_class):
        """
        Generate samples for each class
        
        Args:
            n_samples_per_class (int): Number of samples to generate per class
        Returns:
            tuple: (generated_samples, labels)
        """
        generated_samples = []
        labels = []
        
        for class_idx in range(self.n_classes):
            # Generate samples from the GMM
            samples, _ = self.gmms[class_idx].sample(n_samples_per_class)
            
            # Inverse transform to original space
            samples = self.scaler.inverse_transform(samples)
            
            generated_samples.append(samples)
            labels.extend([class_idx] * n_samples_per_class)
            
        return np.vstack(generated_samples), np.array(labels)

def logsumexp(x, axis=None):
    """Compute log(sum(exp(x))) in a numerically stable way"""
    x_max = np.max(x, axis=axis, keepdims=True)
    return x_max + np.log(np.sum(np.exp(x - x_max), axis=axis, keepdims=True))


In [None]:
# Initialize the classifier
model = GMMGenerativeClassifier(n_components=3, n_classes=10)

from sklearn.metrics import accuracy_score
import pandas as pd

model = LWP(distance_metric='manhattan')

df = pd.DataFrame()

for idx,domain in enumerate(domains):
    
    x_test = domain['data']
    y_pred = model.predict(x_test) if 'targets' not in domain else domain['targets']
    
    model.fit(x_test, y_pred)
    print(model.class_counts)
    del domain
    
    scores = []
    for eval_domain in eval_domains[:idx+1]:
        
        features = eval_domain['data']
        labels = eval_domain['targets']
        
        preds = model.predict(features)
        acc = accuracy_score(labels, preds)
        
        scores.append(acc)
        
    df[f'Domain {idx+1}'] = scores + [np.nan] * (20 - len(scores))

{0: 253, 1: 243, 2: 255, 3: 244, 4: 262, 5: 236, 6: 250, 7: 253, 8: 254, 9: 250}


{0: 710, 1: 370, 2: 373, 3: 287, 4: 412, 5: 496, 6: 928, 7: 360, 8: 447, 9: 617}
{0: 1155, 1: 516, 2: 543, 3: 343, 4: 622, 5: 711, 6: 1576, 7: 483, 8: 657, 9: 894}


## Evaluation

In [None]:
df

Unnamed: 0,Domain 1,Domain 2,Domain 3,Domain 4,Domain 5,Domain 6,Domain 7,Domain 8,Domain 9,Domain 10,Domain 11,Domain 12,Domain 13,Domain 14,Domain 15,Domain 16,Domain 17,Domain 18,Domain 19,Domain 20
0,0.902,0.8936,0.8908,0.8904,0.89,0.89,0.8892,0.8884,0.8872,0.8868,0.8852,0.8844,0.8832,0.8832,0.8844,0.882,0.8824,0.8828,0.8816,0.882
1,,0.904,0.8996,0.8988,0.8976,0.8968,0.8952,0.8956,0.8956,0.8952,0.8948,0.892,0.8924,0.8912,0.8904,0.8896,0.89,0.8884,0.8872,0.8876
2,,,0.9096,0.9076,0.9068,0.9072,0.9064,0.9056,0.9056,0.9048,0.9024,0.9012,0.9004,0.8992,0.8988,0.896,0.8944,0.8936,0.892,0.8916
3,,,,0.9208,0.9204,0.9188,0.918,0.9176,0.9172,0.9168,0.916,0.9152,0.9144,0.9148,0.9152,0.9128,0.9128,0.9124,0.9096,0.9088
4,,,,,0.9064,0.9056,0.9052,0.9044,0.904,0.9036,0.9036,0.9008,0.8992,0.8992,0.8988,0.8976,0.8968,0.896,0.8948,0.8952
5,,,,,,0.9128,0.9132,0.914,0.9148,0.914,0.9116,0.9124,0.9096,0.91,0.9104,0.9104,0.9088,0.9068,0.9068,0.9044
6,,,,,,,0.9064,0.9052,0.9044,0.9048,0.9048,0.9036,0.9024,0.9024,0.9016,0.9012,0.9008,0.9008,0.8992,0.8992
7,,,,,,,,0.8984,0.8988,0.8992,0.8976,0.896,0.8948,0.8948,0.894,0.8944,0.894,0.8932,0.8928,0.892
8,,,,,,,,,0.9072,0.9064,0.904,0.9024,0.9004,0.9,0.8992,0.8972,0.8968,0.8976,0.8956,0.8964
9,,,,,,,,,,0.9088,0.908,0.9064,0.9048,0.9056,0.9052,0.9044,0.9036,0.9036,0.9016,0.9016
