In [None]:
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler
from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedKFold
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import csv

from sklearn.preprocessing import MinMaxScaler

# protogain
from model import Network
from hypers import Params
from dataset import generate_hint
from output import Metrics

# DANN & GAIN hybrid

In [None]:
# Encoder
class Encoder(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, latent_dim: int):
        super(Encoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, latent_dim),
            nn.ReLU(),
            nn.BatchNorm1d(latent_dim)
        )

    
    def forward(self, x):
        return self.encoder(x)

# Decoder
class Decoder(nn.Module):
    def __init__(self, latent_dim: int, hidden_dim: int, target_dim: int):
        super(Decoder, self).__init__()
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, target_dim),
        )

    def forward(self, x):
        return self.decoder(x)

In [None]:
class DomainClassifier(nn.Module):
    """ Distinguish the domain of the input.
    """

    def __init__(self, input_dim: int, n_class: int):
        super(DomainClassifier, self).__init__()

        # in the end is a logistic regressor
        self.domain_classifier = nn.Sequential(
            nn.Linear(input_dim, input_dim),
            nn.ReLU(),
            nn.Linear(input_dim, n_class)
        )

    def forward(self, x):
        return self.domain_classifier(x)

In [None]:
class GradientReversalFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, lambd=1.0):
        ctx.lambd = lambd
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg() * ctx.lambd, None

class GradientReversalLayer(nn.Module):
    def __init__(self, lambd=1.0):
        super(GradientReversalLayer, self).__init__()
        self.lambd = lambd

    def forward(self, x):
        return GradientReversalFunction.apply(x, self.lambd)

In [None]:
class GAIN_DANN(nn.Module):
    def __init__(self, input_dim: int, latent_dim: int, n_class: int, params: Params, metrics: Metrics):
        super(GAIN_DANN, self).__init__()

        self.encoder = Encoder(input_dim=input_dim, hidden_dim=128, latent_dim=latent_dim)
        
        # gradient reversal layer
        self.grl = GradientReversalLayer()

        self.domain_classifier = DomainClassifier(latent_dim, n_class=n_class)
        
        # gain
        self.gain = Network(hypers=params, 
                            net_G= nn.Sequential(
                                nn.Linear(latent_dim* 2, latent_dim),
                                nn.ReLU(),
                                nn.Linear(latent_dim, latent_dim),
                                nn.ReLU(),
                                nn.Linear(latent_dim, latent_dim),
                                nn.Sigmoid(),
                            ), 
                            net_D= nn.Sequential(
                                nn.Linear(latent_dim * 2, latent_dim),
                                nn.ReLU(),
                                nn.Linear(latent_dim, latent_dim),
                                nn.ReLU(),
                                nn.Linear(latent_dim, latent_dim),
                                nn.Sigmoid(),
                            ),
                            metrics=metrics)
        
        self.decoder = Decoder(latent_dim=latent_dim, hidden_dim=128, target_dim=input_dim)



    def forward(self, x):
        """
            Forward pass for GAIN_DANN.
            Handles missing values (NaNs) by replacing them with noise and using a mask.
        """

        #todo x must be scaled

        x_filled = x.clone()
        x_filled[torch.isnan(x_filled)] = 0 # x filled with zeros in the place of missing values

        mask = (~torch.isnan(x)).float()

        # 1. Encode
        x_encoded = self.encoder(x_filled)
        x_grl = self.grl(x_encoded) # as a matter of fact, this is not needed, this layer is important for the training process

        # 2. Gain
        sample = self.gain.generate_sample(x_grl, mask)
        x_imputed = x_encoded * mask + sample * (1 - mask)

        # 2.1. Domain Classifier
        x_domain = self.domain_classifier(x_encoded)
        x_domain = torch.argmax(x_domain, dim=1)

        # 3. Decoder
        x_reconstructed = self.decoder(x_imputed)

        #todo voltar a transformar para a escala antes de ser scaled

        return x_reconstructed, x_domain

# Train

In [None]:
def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        nn.init.constant_(m.bias, 0)

In [None]:
def train(model, train_loader, num_epochs: int, hint_rate: float):

    # Losses
    domain_criterion = nn.CrossEntropyLoss()

    # Optimizers
    optimizer_domain = torch.optim.Adam(model.domain_classifier.parameters())
    optimizer_encoder = torch.optim.Adam(model.encoder.parameters())
    optimizer_decoder = torch.optim.Adam(model.decoder.parameters())

    # torch.autograd.set_detect_anomaly(True) test purposes

    # initialize weights encoder
    model.encoder.apply(init_weights)
    model.decoder.apply(init_weights)

    rmse_per_epoch = []

    for epoch in range(num_epochs):
        print(f"\n--- Epoch {epoch+1}/{num_epochs} ---")

        domain_accuracies = [] # across batches
        reconstruction_errors = [] # reconstruction error
        domain_losses = []
        gain_losses = []
        encoder_losses = []

        for x, domain_target in train_loader:
            x_filled = x.clone()
            x_filled[torch.isnan(x_filled)] = 0 # x filled with zeros in the place of missing values

            mask = (~torch.isnan(x)).float()
            hint = generate_hint(mask, hint_rate)

            # =============================================
            #               Encoder training
            # =============================================

            p = float(epoch) / num_epochs # training progress 0 -> 1
            lambda_dann = 5 * (2. / (1 + np.exp(-10 * p)) - 1) # the paper sets y equal to 10 (just empirically)

            x_encoded = model.encoder(x_filled)
            x_grl = model.grl(x_encoded)

            domain_pred = model.domain_classifier(x_grl) # domain labels prediction
            domain_loss = domain_criterion(domain_pred, domain_target)
            
            encoder_adv_loss = lambda_dann * domain_loss # cross entropy loss
            encoder_losses.append(encoder_adv_loss.item())

            optimizer_encoder.zero_grad()
            encoder_adv_loss.backward(retain_graph=True)
            optimizer_encoder.step()

            # =============================================
            #          Domain Classifier training
            # =============================================

            x_encoded_detach = x_encoded.detach()

            domain_pred = model.domain_classifier(x_encoded_detach) # domain labels prediction
            domain_loss = domain_criterion(domain_pred, domain_target)
            domain_losses.append(domain_loss.item())


            # domain classification accuracy (just to have a better insight)
            domain_pred_labels = torch.argmax(domain_pred, dim=1)
            domain_accuracy = (domain_pred_labels == domain_target).float().mean().item()
            domain_accuracies.append(domain_accuracy)

            optimizer_domain.zero_grad()
            domain_loss.backward(retain_graph=True)
            optimizer_domain.step()

            # =============================================
            #               GAIN training
            # - as in function model.py/train
            # =============================================

            loss_gain = nn.BCELoss(reduction="none")
            loss_mse_gain = nn.MSELoss(reduction="none")

            dim = x_encoded.shape[1]
            n_samples = x_encoded.shape[0] # number of examples/samples

            Z = torch.rand((n_samples, dim)) * 0.01

            x_encoded_tensor = torch.from_numpy(x_encoded.detach().numpy())
            model.gain._update_D(x_encoded_tensor, mask, hint, Z, loss_gain)
            model.gain._update_G(x_encoded_tensor, mask, hint, Z, loss_gain)

            samples = model.gain.generate_sample(x_encoded_tensor, mask)

            loss_mse = loss_mse_gain(mask * x_encoded_tensor, mask * samples)
            loss_mse = loss_mse.detach().cpu().numpy()
            gain_losses.append(loss_mse.mean())

            x_imputed_aux = x_encoded_tensor * mask + samples * (1 - mask)

            # =============================================
            #               Decoder training
            # =============================================

            x_imputed = x_imputed_aux.clone().detach().requires_grad_(True)
            x_reconstructed = model.decoder(x_imputed)

            x_missing = x.clone().detach() # ground truth
            x_missing[torch.isnan(x_missing)] = 0

            squared_error = (x_reconstructed - x_missing) ** 2 # MSE error
            reconstruction_loss = torch.sqrt((squared_error * mask).sum() / mask.sum()) # RMSE error
            reconstruction_errors.append(reconstruction_loss.clone().detach().item())

            optimizer_decoder.zero_grad()
            reconstruction_loss.backward()
            optimizer_decoder.step()

        # rmse per epoch
        rmse_per_epoch.append(np.mean(reconstruction_errors))

        # === Evaluation ===
        print(f"GAIN Loss (MSE) {np.mean(gain_losses):.4f}")
        print(f"Domain Loss: {np.mean(domain_losses):.4f} | Encoder Adv Loss: {np.mean(encoder_losses):.4f}")
        print(f"Domain Accuracy: {np.mean(domain_accuracies):.4f}")
        print(f"Reconstruction loss RMSE: {np.mean(reconstruction_errors):.4f}")

    return rmse_per_epoch

Evaluate model

In [None]:
def evaluate(model, test_loader):

    print("\nEvaluation...")

    total_correct = 0
    total_samples = 0
    total_squared_error = 0
    total_mask_elements = 0

    for x, x_domain in test_loader:
        x_missing = x.clone()
        mask = (~torch.isnan(x_missing))

        x_pred, x_domain_pred = model(x_missing)

        total_correct += (x_domain_pred == x_domain).sum().item()
        total_samples += x_domain.size(0)

        squared_error = (x_pred - x_missing) ** 2 # MSE error
        squared_error[~mask] = 0
        total_squared_error += squared_error.sum().item()
        total_mask_elements += mask.sum().item()

    domain_accuracy = total_correct / total_samples
    rmse = (total_squared_error / total_mask_elements) ** 0.5

    print(f"Domain Accuracy: {domain_accuracy:.4f}")
    print(f"Reconstruction error RMSE: {rmse:.4f}")

    return rmse

# Prepare dataset

In [None]:
def generate_missingness(data: pd.DataFrame, miss_rate: float=0.2):
    size, dim = data.shape

    # do not alter the last column, since it corresponds to the projects
    mask = np.random.rand(size, dim - 1) > miss_rate

    data_np = data.to_numpy()
    missing_data_np = np.where(mask, data_np[:, :-1], np.nan)
    last_column = data_np[:, -1:].copy()

    full_missing_np = np.hstack((missing_data_np, last_column))

    missing_data = pd.DataFrame(full_missing_np, columns=data.columns, index=data.index)

    return missing_data

In [None]:
hela = pd.read_csv('./hela_dann.csv', index_col=0)

hela = hela.iloc[8000:, :] # if ran locally

hela = hela.T
project_data = hela["Project"]

Generate artificial missingness

In [None]:
hela_missing = generate_missingness(hela, miss_rate=0.2)

In [None]:
# if run with only a train and test split
def create_dataloaders(data: pd.DataFrame, batch_size: int=64, test_size: float=0.2):
    domain = data["Project"]
    projects = domain.unique()
    project_to_number = {name: idx for idx, name in enumerate(projects)} # map each project with a number
    domain_labels = torch.tensor(domain.map(project_to_number).to_numpy(), dtype=torch.long)
    data = data.drop(columns="Project")

    data_values = data.values.astype(np.float32)
    labels = domain_labels

    range_scaler = (0, 1)
    scaler = MinMaxScaler(feature_range=range_scaler)
    scaled_values = scaler.fit_transform(data_values)
    X = torch.tensor(scaled_values, dtype=torch.float32)
    y = labels

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, stratify=labels, random_state=42)

    train_dataset = TensorDataset(X_train, y_train)
    test_dataset = TensorDataset(X_test, y_test)

    # balance classes
    train_labels = torch.tensor([y for _, y in train_dataset]) 
    class_samples_count = torch.bincount(train_labels)
    weights = 1. / class_samples_count
    sample_weights = weights[train_labels]

    sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    
    return train_loader, test_loader

# Main

Cross-validation, namely [Stratified K-Fold](https://github.com/xbeat/Machine-Learning/blob/main/Stratified%20K-Fold%20Cross-Validation%20in%20Python.md)

In [None]:
def mainkfold(k: int=5, num_epochs: int=10):

    input_dim = hela_missing.shape[1] - 1 # number of columns - 1 (due to the project column)
    n_class = len(project_data.unique())

    print("Number of samples:", hela_missing.shape[0])
    print("Number of proteins:", hela_missing.shape[1] - 1)
    print("Number of unique projects:", n_class)

    params = Params()
    metrics = Metrics(params)

    skf = StratifiedKFold(n_splits=k, shuffle=True, random_state=42)

    data = hela_missing.copy()

    domain = data["Project"]
    projects = domain.unique()
    project_to_number = {name: idx for idx, name in enumerate(projects)} # map each project with a number
    domain_labels = torch.tensor(domain.map(project_to_number).to_numpy(), dtype=torch.long)
    data = data.drop(columns="Project")

    data_values = data.values.astype(np.float32)
    labels = domain_labels

    range_scaler = (0, 1)
    scaler = MinMaxScaler(feature_range=range_scaler)
    scaled_values = scaler.fit_transform(data_values)
    X = torch.tensor(scaled_values, dtype=torch.float32)
    y = labels

    rmse_per_epoch = [[] for _ in range(num_epochs)] # rmse per epoch for every fold

    for fold, (train_index, test_index) in enumerate(skf.split(X, y), 1):
        print(f"\n====== Fold {fold} ======")

        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index]

        train_dataset = TensorDataset(X_train, y_train)
        test_dataset = TensorDataset(X_test, y_test)

        # balance class distribution
        train_labels = torch.tensor([y for _, y in train_dataset]) 
        class_samples_count = torch.bincount(train_labels)
        weights = 1. / class_samples_count
        sample_weights = weights[train_labels]

        sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)

        train_loader = DataLoader(train_dataset, batch_size=64, sampler=sampler)
        test_loader = DataLoader(test_dataset, batch_size=64)

        model = GAIN_DANN(input_dim, latent_dim=input_dim, n_class=n_class, params=params, metrics=metrics)
        
        rmse = train(model, train_loader, num_epochs=num_epochs, hint_rate=0.9)

        for i, val in enumerate(rmse):
            rmse_per_epoch[i].append(val)

        evaluate(model, test_loader)

    return rmse_per_epoch

rmse_per_epoch = mainkfold(k=5, num_epochs=5)

# Plots

In [None]:
with open('./rmse.csv', mode='w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(['Epoch', 'Reconstruction Error'])
    for epoch in range(len(rmse_per_epoch)):
        for rmse in rmse_per_epoch[epoch]:
            writer.writerow([epoch+1, rmse])

Plot rmse from cross validation

In [None]:
rmse_plot = []

for epoch in range(len(rmse_per_epoch)):
    rmse_plot.append(np.mean(rmse_per_epoch[epoch]))

epochs = list(range(1, len(rmse_plot) + 1))

plt.figure(figsize=(10,6))
plt.plot(epochs, rmse_plot, marker="o", color="y")
plt.title("RMSE over 5 epochs (5-Fold Cross-Validation)")
plt.xlabel("Epochs")
plt.ylabel("RMSE")
plt.xticks(epochs)

plt.grid(True, linestyle='--', alpha=0.5)
plt.ylim(min(rmse_plot) - 0.02, max(rmse_plot) + 0.02)
plt.style.use("ggplot")
plt.rcParams.update({'font.size': 12})

for x, y in zip(epochs, rmse_plot):
    if x == 1:
        plt.text(x, y + 0.003, f"{y:.3f}", ha='center')
    elif x == 2:
        plt.text(x, y + 0.003, f"{y:.3f}", ha='left')
    else:
        plt.text(x, y + 0.005, f"{y:.3f}", ha='center')

plt.savefig("./imgs/rmse_dann_gain.png", dpi=300, bbox_inches='tight')
plt.show()