In [None]:
# === All Imports ===
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import accuracy_score
from google.colab import files
import io

# === Fusion Network ===
class FusionNetwork(nn.Module):
    def __init__(self, num_generators, feature_dim):
        super(FusionNetwork, self).__init__()
        # learnable weights gamma_i for each generator per feature
        self.weights = nn.Parameter(torch.ones(num_generators, feature_dim) / num_generators)

    def forward(self, outputs):
        # outputs: list of [B, D] tensors
        stacked = torch.stack(outputs, dim=0)  # [G, B, D]
        # normalize weights across generators
        gamma = torch.softmax(self.weights, dim=0)  # [G, D]
        # weighted sum: sum_g gamma[g,d] * outputs[g,b,d]
        fused = torch.einsum('gd,gbd->bd', gamma, stacked)
        return fused, gamma

# === Masked Generator ===
class MaskedGenerator(nn.Module):
    def __init__(self, input_dim):
        super(MaskedGenerator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, input_dim)
        )

    def forward(self, x, mask):
        # x: [B, D], mask: [B, D]
        x_masked = x * mask
        return self.net(x_masked)

# === Discriminator ===
class Discriminator(nn.Module):
    def __init__(self, feature_dim):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(feature_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.net(x)

# === MEG Adapter (Exact Paper Implementation) ===
class MEG(nn.Module):
    def __init__(self, input_dim, num_generators=5, alpha=1.0, beta=1.0):
        super(MEG, self).__init__()
        self.input_dim = input_dim
        self.num_generators = num_generators
        self.alpha = alpha  # weight of group similarity loss
        self.beta = beta    # weight of adversarial loss

        # Generators
        self.generators = nn.ModuleList([MaskedGenerator(input_dim) for _ in range(num_generators)])
        # Fusion network
        self.fusion = FusionNetwork(num_generators, input_dim)
        # Discriminator
        self.discriminator = Discriminator(input_dim)

        # Optimizers
        self.opt_gen = optim.Adam(list(self.generators.parameters()) + list(self.fusion.parameters()), lr=0.001)
        self.opt_disc = optim.Adam(self.discriminator.parameters(), lr=0.001)
        # Losses
        self.bce = nn.BCELoss()
        self.mse = nn.MSELoss()

    def forward(self, x):
        # single forward: mask, generate, fuse
        masks = [torch.bernoulli(torch.full(x.shape, 0.8)).to(x.device) for _ in self.generators]
        outputs = [g(x, m) for g, m in zip(self.generators, masks)]
        fused, gamma = self.fusion(outputs)
        return outputs, fused, gamma

    def train_meg(self, data, epochs=100, batch_size=64):
        # data: [N, D] torch tensor
        for epoch in range(epochs):
            perm = torch.randperm(data.size(0))
            for i in range(0, data.size(0), batch_size):
                idx = perm[i:i + batch_size]
                real_batch = data[idx]
                bs = real_batch.size(0)

                # === Discriminator training ===
                self.opt_disc.zero_grad()
                real_labels = torch.ones(bs,1).to(real_batch.device)
                fake_labels = torch.zeros(bs,1).to(real_batch.device)

                # Real accuracy
                real_out = self.discriminator(real_batch)
                loss_real = self.bce(real_out, real_labels)

                # Synthetic for disc (detach)
                _, fused, _ = self.forward(real_batch)
                fake_out = self.discriminator(fused.detach())
                loss_fake = self.bce(fake_out, fake_labels)

                loss_D = loss_real + loss_fake
                loss_D.backward()
                self.opt_disc.step()

                # === Generator + Fusion training ===
                self.opt_gen.zero_grad()
                outputs, fused, gamma = self.forward(real_batch)

                # Proxy supervised loss (reconstruction)
                loss_proxy = self.mse(fused, real_batch)
                # Group similarity loss: weighted MSE per generator
                loss_group = 0
                for g_out, g_w in zip(outputs, gamma):
                    # g_w: [D] weight for this generator, apply per feature
                    loss_group += torch.mean(g_w * (g_out - real_batch).pow(2))

                # Adversarial loss
                pred = self.discriminator(fused)
                loss_adv = self.bce(pred, real_labels)

                # Total gen loss
                loss_G = loss_proxy + self.alpha * loss_group + self.beta * loss_adv
                loss_G.backward()
                self.opt_gen.step()

            if epoch % 10 == 0:
                print(f"Epoch {epoch}: Loss_D={loss_D.item():.4f}, Loss_proxy={loss_proxy.item():.4f}, \
                      Loss_group={loss_group.item():.4f}, Loss_adv={loss_adv.item():.4f}")

    def generate(self, x):
        # generate synthetic data for x
        _, fused, _ = self.forward(x)
        return fused

# === Upload & Preprocess Data ===
uploaded = files.upload()
df = pd.read_csv(io.BytesIO(next(iter(uploaded.values()))))
df = df.dropna()
for col in df.select_dtypes(include='object').columns:
    df[col] = pd.factorize(df[col])[0]
X = df.iloc[:, :-1].values
y = df.iloc[:, -1].values
scaler = StandardScaler()
X = scaler.fit_transform(X)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
X_train_t = torch.tensor(X_train, dtype=torch.float32)
X_test_t = torch.tensor(X_test, dtype=torch.float32)

# === Train MEG ===
input_dim = X_train.shape[1]
meg = MEG(input_dim=input_dim, num_generators=5, alpha=1.0, beta=1.0)
meg.train_meg(X_train_t, epochs=100, batch_size=64)

# === Generate Synthetic Data ===
X_synth = meg.generate(X_test_t).detach().cpu().numpy()
y_synth = y_test

# === Evaluate TSTR ===
def evaluate_tstr(X_syn, y_syn, X_real, y_real):
    results = {}
    models = {
        'Random Forest': RandomForestClassifier(),
        'Logistic Regression': LogisticRegression(max_iter=500),
        'MLP': MLPClassifier(hidden_layer_sizes=(64, 32), max_iter=500)
    }
    for name, clf in models.items():
        clf.fit(X_syn, y_syn)
        pred = clf.predict(X_real)
        results[name] = accuracy_score(y_real, pred)
    return results

tstr_scores = evaluate_tstr(X_synth, y_synth, X_test, y_test)
print("\n🔍 TSTR Results:")
for m, a in tstr_scores.items():
    print(f"{m}: {a:.4f}")


Saving Bank_Personal_Loan.csv to Bank_Personal_Loan.csv
Epoch 0: Loss_D=1.2380, Loss_proxy=0.2945,                       Loss_group=0.3756, Loss_adv=0.5560
Epoch 10: Loss_D=1.3875, Loss_proxy=0.0557,                       Loss_group=0.2314, Loss_adv=0.6948
Epoch 20: Loss_D=1.3800, Loss_proxy=0.0678,                       Loss_group=0.2186, Loss_adv=0.7017
Epoch 30: Loss_D=1.3832, Loss_proxy=0.0370,                       Loss_group=0.1734, Loss_adv=0.6935
Epoch 40: Loss_D=1.3915, Loss_proxy=0.0559,                       Loss_group=0.2286, Loss_adv=0.6761
Epoch 50: Loss_D=1.3889, Loss_proxy=0.0622,                       Loss_group=0.2427, Loss_adv=0.7043
Epoch 60: Loss_D=1.3677, Loss_proxy=0.0518,                       Loss_group=0.2103, Loss_adv=0.7123
Epoch 70: Loss_D=1.3516, Loss_proxy=0.0383,                       Loss_group=0.1820, Loss_adv=0.7344
Epoch 80: Loss_D=1.3522, Loss_proxy=0.0624,                       Loss_group=0.2258, Loss_adv=0.7130
Epoch 90: Loss_D=1.3092, Loss_proxy=

