In [None]:
# === All Imports ===
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import RepeatedKFold
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import accuracy_score
from google.colab import files
import io

# === Fusion Network ===
class FusionNetwork(nn.Module):
    def __init__(self, num_generators, feature_dim):
        super(FusionNetwork, self).__init__()
        self.weights = nn.Parameter(torch.ones(num_generators, feature_dim) / num_generators)

    def forward(self, outputs):
        # outputs: list of [G, B, D] tensors
        stacked = torch.stack(outputs, dim=0)
        gamma = torch.softmax(self.weights, dim=0)
        fused = torch.einsum('gd,gbd->bd', gamma, stacked)
        return fused, gamma

# === Masked Generator ===
class MaskedGenerator(nn.Module):
    def __init__(self, input_dim):
        super(MaskedGenerator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, input_dim)
        )

    def forward(self, x, mask):
        return self.net(x * mask)

# === Discriminator ===
class Discriminator(nn.Module):
    def __init__(self, feature_dim):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(feature_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.net(x)

# === MEG Adapter (Exact Paper Implementation) ===
class MEG(nn.Module):
    def __init__(self, input_dim, num_generators=None, alpha=0.1, beta=1.0):
        super(MEG, self).__init__()
        self.input_dim = input_dim
        # default one generator per original feature
        self.num_generators = input_dim if num_generators is None else num_generators
        self.alpha = alpha
        self.beta = beta
        # instantiate generators, fusion, discriminator
        self.generators = nn.ModuleList([
            MaskedGenerator(input_dim) for _ in range(self.num_generators)
        ])
        self.fusion = FusionNetwork(self.num_generators, input_dim)
        self.discriminator = Discriminator(input_dim)
        # optimizers
        self.opt_gen = optim.Adam(
            list(self.generators.parameters()) + list(self.fusion.parameters()), lr=0.001
        )
        self.opt_disc = optim.Adam(self.discriminator.parameters(), lr=0.001)
        # losses
        self.bce = nn.BCELoss()
        self.mse = nn.MSELoss()

    def forward(self, x):
        # draw masks dynamically
        masks = [(torch.rand_like(x) < 0.8).float() for _ in range(self.num_generators)]
        outputs = [g(x, m) for g, m in zip(self.generators, masks)]
        fused, gamma = self.fusion(outputs)
        return outputs, fused, gamma

    def train_meg(self, data, epochs=50, batch_size=64):
        # 50 epochs for large datasets per paper
        for epoch in range(epochs):
            perm = torch.randperm(data.size(0))
            for i in range(0, data.size(0), batch_size):
                idx = perm[i:i + batch_size]
                real_batch = data[idx]
                bs = real_batch.size(0)
                real_labels = torch.ones(bs, 1).to(real_batch.device)
                fake_labels = torch.zeros(bs, 1).to(real_batch.device)

                # Discriminator update
                self.opt_disc.zero_grad()
                loss_real = self.bce(self.discriminator(real_batch), real_labels)
                _, fused, _ = self.forward(real_batch)
                loss_fake = self.bce(self.discriminator(fused.detach()), fake_labels)
                loss_D = loss_real + loss_fake
                loss_D.backward()
                self.opt_disc.step()

                # Generator + Fusion update
                self.opt_gen.zero_grad()
                outputs, fused, gamma = self.forward(real_batch)
                loss_proxy = self.mse(fused, real_batch)
                loss_group = sum(
                    torch.mean(w * (out - real_batch).pow(2))
                    for out, w in zip(outputs, gamma)
                )
                loss_adv = self.bce(self.discriminator(fused), real_labels)
                loss_G = loss_proxy + self.alpha * loss_group + self.beta * loss_adv
                loss_G.backward()
                self.opt_gen.step()

            if epoch % 10 == 0:
                print(
                    f"Epoch {epoch}: "
                    f"Loss_D={loss_D.item():.4f}, "
                    f"Loss_proxy={loss_proxy.item():.4f}, "
                    f"Loss_group={loss_group.item():.4f}, "
                    f"Loss_adv={loss_adv.item():.4f}"
                )

    def generate(self, x):
        _, fused, _ = self.forward(x)
        return fused

# === Load Adult Train/Test Files Directly ===
# Column names for UCI Adult dataset
columns = [
    'age','workclass','fnlwgt','education','education-num',
    'marital-status','occupation','relationship','race','sex',
    'capital-gain','capital-loss','hours-per-week','native-country','income'
]
# URLs for train and test
t_train = 'https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data'
t_test  = 'https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test'
# Load data
df_train = pd.read_csv(
    t_train, header=None, names=columns, sep=', ', engine='python', na_values='?'
)
df_test = pd.read_csv(
    t_test, header=0, names=columns, sep=', ', engine='python', na_values='?'
)
# Remove trailing dot in test labels
if df_test['income'].dtype == object:
    df_test['income'] = df_test['income'].str.rstrip('.')
# Drop missing values
df_train.dropna(inplace=True)
df_test.dropna(inplace=True)

# Original feature count (excluding target)
orig_feature_count = df_train.shape[1] - 1
# One-hot encode features only (exclude target) per paper
feature_cols = df_train.columns.drop('income')
# Dummies on training features
df_train_enc = pd.get_dummies(df_train[feature_cols], drop_first=False)
# Dummies on test features and align columns
df_test_enc = pd.get_dummies(df_test[feature_cols], drop_first=False)
df_test_enc = df_test_enc.reindex(columns=df_train_enc.columns, fill_value=0)

# Separate features and label
X_train = df_train_enc.values
y_train = LabelEncoder().fit_transform(df_train['income'].values)
X_test  = df_test_enc.values
y_test  = LabelEncoder().fit_transform(df_test['income'].values)

# Standardize numeric
scaler = StandardScaler().fit(X_train)
X_train_t = torch.tensor(scaler.transform(X_train), dtype=torch.float32)
X_test_t  = torch.tensor(scaler.transform(X_test), dtype=torch.float32)

# Train MEG on full training set
meg = MEG(
    input_dim=X_train_t.shape[1],
    num_generators=orig_feature_count,
    alpha=0.1,
    beta=1.0
)
meg.train_meg(X_train_t, epochs=50, batch_size=64)

# Generate synthetic training data
X_syn = meg.generate(X_train_t).detach().cpu().numpy()
y_syn = y_train

# Evaluate classifiers on real test set
clfs = {
    'Random Forest': RandomForestClassifier(n_estimators=100, random_state=42),
    'MLP': MLPClassifier(hidden_layer_sizes=(100,50), max_iter=500, random_state=42)
}
results = {}
for name, clf in clfs.items():
    clf.fit(X_syn, y_syn)
    preds = clf.predict(X_test)
    results[name] = accuracy_score(y_test, preds)

print(" TSTR Results on Adult (train/test files):")
for name, acc in results.items():
    print(f"{name}: {acc:.4f}")


Epoch 0: Loss_D=1.3917, Loss_proxy=1.8908, Loss_group=2.2425, Loss_adv=1.7723
Epoch 10: Loss_D=1.3975, Loss_proxy=0.0743, Loss_group=0.5728, Loss_adv=0.6381
Epoch 20: Loss_D=1.3751, Loss_proxy=0.0309, Loss_group=0.2257, Loss_adv=0.7101
Epoch 30: Loss_D=1.3579, Loss_proxy=0.0919, Loss_group=0.4769, Loss_adv=0.7514
Epoch 40: Loss_D=1.3293, Loss_proxy=0.0639, Loss_group=0.3966, Loss_adv=0.8977
 TSTR Results on Adult (train/test files):
Random Forest: 0.7780
MLP: 0.7834
