In [None]:
import pandas as pd
import numpy as np
import math
from transformers import pipeline
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import re
from scipy import stats
from sklearn.preprocessing import MinMaxScaler
import torch
import torch.nn as nn
import torch.optim as optim
import networkx as nx
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics.pairwise import rbf_kernel
from torch_geometric.utils import dense_to_sparse
from torch_geometric.nn import TAGConv
from torch_geometric.nn import GCNConv
from torch_geometric.nn import GATConv
from sklearn.metrics import classification_report, roc_auc_score, average_precision_score, f1_score

In [None]:


UNIT_CONVERSION = {

    'weight': {
        'kg': 1.0,
        'g': 0.001,
        'mg': 0.000001,
        'lb': 0.453592,
        'oz': 0.0283495
    },

    'length': {
        'm': 1.0,
        'cm': 0.01,
        'mm': 0.001,
        'km': 1000.0,
        'in': 0.0254,
        'ft': 0.3048,
        'mi': 1609.34
    },

    'time': {
        's': 1.0,
        'ms': 0.001,
        'min': 60.0,
        'h': 3600.0,
        'day': 86400.0
    },

    'volume': {
        'l': 1.0,
        'ml': 0.001,
        'm3': 1000.0,
        'gal': 3.78541,
        'qt': 0.946353,
        'pt': 0.473176
    }
}

UNIT_CATEGORIES = {

    'kg': 'weight', 'g': 'weight', 'mg': 'weight', 'lb': 'weight', 'oz': 'weight',
    'm': 'length', 'cm': 'length', 'mm': 'length', 'km': 'length', 'in': 'length',
    'ft': 'length', 'mi': 'length',
    's': 'time', 'ms': 'time', 'min': 'time', 'h': 'time', 'day': 'time',
    'l': 'volume', 'ml': 'volume', 'm3': 'volume', 'gal': 'volume', 'qt': 'volume', 'pt': 'volume'

}



def extract_unit_from_column(col_name):

    col_name = str(col_name).lower()
    underscore_match = re.search(r'_([a-z]{1,4})$', col_name)
    if underscore_match:
        return underscore_match.group(1)
    bracket_match = re.search(r'\(([a-z]{1,4})\)$', col_name)
    if bracket_match:
        return bracket_match.group(1)
    space_match = re.search(r'\s([a-z]{1,4})$', col_name)
    if space_match:
        return space_match.group(1)

    return None


def detect_unit_category(unit):

    return UNIT_CATEGORIES.get(unit.lower(), None)


def are_units_convertible(unit1, unit2):

    category1 = detect_unit_category(unit1)
    category2 = detect_unit_category(unit2)
    return category1 is not None and category1 == category2


def convert_units(value, from_unit, to_unit):

    category = detect_unit_category(from_unit)
    if category is None or not are_units_convertible(from_unit, to_unit):
        return value

    factor_from = UNIT_CONVERSION[category].get(from_unit.lower(), 1.0)
    factor_to = UNIT_CONVERSION[category].get(to_unit.lower(), 1.0)

    return value * (factor_from / factor_to)


def statistical_unit_check(series1, series2):

    s1 = series1.dropna()

    s2 = series2.dropna()

    if len(s1) < 10 or len(s2) < 10:
        return False, 1.0
    ratio = (s2.mean() / s1.mean()) if s1.mean() != 0 else 1.0
    min_length = min(len(s1), len(s2))
    s1 = s1.iloc[:min_length]
    s2 = s1.iloc[:min_length]
    slope, intercept, r_value, p_value, std_err = stats.linregress(s1, s2)
    if r_value > 0.9 and abs(intercept) < 0.1 * max(abs(s2.mean()), abs(s1.mean())):
        return True, slope
    return False, 1.0


def find_most_similar_time_col(tab_col, time_cols):

    processed_tab_col = preprocess_column_name(tab_col)
    processed_time_cols = [preprocess_column_name(col) for col in time_cols]
    tab_embedding = model.encode([processed_tab_col])
    time_embeddings = model.encode(processed_time_cols)
    similarities = cosine_similarity(tab_embedding, time_embeddings)[0]
    most_similar_idx = np.argmax(similarities)

    return time_cols[most_similar_idx], similarities[most_similar_idx]


def preprocess_data(*dfs):

    scaler = MinMaxScaler()
    scaled_dfs = []
    for df in dfs:
        scaled_data = scaler.fit_transform(df.values)
        scaled_dfs.append(pd.DataFrame(scaled_data, columns=df.columns))
    return scaled_dfs

class TimeSeriesTransformer(nn.Module):

    def __init__(self, input_dim, hidden_dim, latent_dim):

        super().__init__()
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=input_dim, nhead=4),
            num_layers=4
        )
        self.mu_layer = nn.Linear(input_dim, latent_dim)

        self.sigma_layer = nn.Linear(input_dim, latent_dim)

    def forward(self, x_seq):

        h = self.transformer(x_seq)
        mu = self.mu_layer(h.mean(dim=1))
        sigma = self.sigma_layer(h.mean(dim=1)).exp()
        return mu, sigma


# class TimeSeriesTransformer(nn.Module):

#     def __init__(self, input_dim, hidden_dim=64, num_layers=3):

#         super().__init__()
#         self.encoder = nn.Sequential(
#             nn.Linear(input_dim, hidden_dim),
#             nn.ReLU(),
#             *[nn.Sequential(
#                 nn.Linear(hidden_dim, hidden_dim),
#                 nn.ReLU()
#             ) for _ in range(num_layers-1)]
#         )
#         self.mu = nn.Linear(hidden_dim, hidden_dim)
#         self.logvar = nn.Linear(hidden_dim, hidden_dim)

#     def forward(self, x):

#         h = self.encoder(x)
#         return self.mu(h), self.logvar(h)


class Generator(nn.Module):

    def __init__(self, noise_dim, cond_dim, output_dim):

        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(noise_dim + cond_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, output_dim),
            nn.Sigmoid()
        )

    def forward(self, z, c):

        zc = torch.cat([z, c], dim=1)
        return self.net(zc)


class Discriminator(nn.Module):

    def __init__(self, input_dim, cond_dim):

        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim + cond_dim, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x, c):

        xc = torch.cat([x, c], dim=1)
        return self.net(xc)


def compute_mmd(x, y, sigma=1.0):

    x_kernel = rbf_kernel(x, x, gamma=1.0/(2*sigma**2))
    y_kernel = rbf_kernel(y, y, gamma=1.0/(2*sigma**2))
    xy_kernel = rbf_kernel(x, y, gamma=1.0/(2*sigma**2))

    mmd = x_kernel.mean() + y_kernel.mean() - 2*xy_kernel.mean()
    return mmd


def train_cgan(tab_dfs, time_df, num_epochs=1000, batch_size=32):

    noise_dim = 64
    cond_dim = time_df.shape[1]
    feature_dim = tab_dfs[0].shape[1]

    transformer = TimeSeriesTransformer(time_df.shape[1])
    generator = Generator(noise_dim, cond_dim, feature_dim)
    discriminator = Discriminator(feature_dim, cond_dim)

    optim_g = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optim_d = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optim_t = optim.Adam(transformer.parameters(), lr=0.0001)

    time_tensor = torch.FloatTensor(time_df.values)
    tab_tensors = [torch.FloatTensor(df.values) for df in tab_dfs]

    all_tab_data = torch.cat(tab_tensors, dim=0)
    tab_loader = DataLoader(TensorDataset(all_tab_data), batch_size=batch_size, shuffle=True)
    time_loader = DataLoader(TensorDataset(time_tensor), batch_size=batch_size, shuffle=True)

    feature_mmd_losses = {i: [] for i in range(feature_dim)}

    for epoch in range(num_epochs):

        for (tab_batch,), (time_batch,) in zip(tab_loader, time_loader):

            current_batch_size = tab_batch.size(0)
            real_labels = torch.ones(current_batch_size, 1)
            fake_labels = torch.zeros(current_batch_size, 1)

            mu, logvar = transformer(time_batch)
            cond = mu + torch.exp(0.5*logvar) * torch.randn_like(logvar)
            noise = torch.randn(current_batch_size, noise_dim)
            fake_data = generator(noise, cond)
            optim_d.zero_grad()
            real_loss = nn.BCELoss()(discriminator(tab_batch, cond), real_labels)
            fake_loss = nn.BCELoss()(discriminator(fake_data.detach(), cond), fake_labels)
            d_loss = (real_loss + fake_loss) / 2

            d_loss.backward()
            optim_d.step()
            optim_g.zero_grad()
            optim_t.zero_grad()

            noise = torch.randn(current_batch_size, noise_dim)
            fake_data = generator(noise, cond)
            g_loss_adv = nn.BCELoss()(discriminator(fake_data, cond), real_labels)

            mmd_loss = 0
            for i in range(feature_dim):

                tab_feature = tab_batch[:, i].view(-1, 1)
                fake_feature = fake_data[:, i].view(-1, 1)
                feature_mmd = compute_mmd(tab_feature.detach().numpy(), fake_feature.detach().numpy())
                mmd_loss += feature_mmd
                feature_mmd_losses[i].append(feature_mmd)

            sampled_time = time_tensor[torch.randint(0, len(time_tensor), (current_batch_size,))]
            l2_loss = nn.MSELoss()(fake_data, sampled_time[:, :feature_dim])
            g_loss = g_loss_adv + mmd_loss + l2_loss

            g_loss.backward()
            optim_g.step()
            optim_t.step()

        print(f"Epoch {epoch+1}/{num_epochs}, D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}")

    return generator, transformer, feature_mmd_losses



def generate_and_build_graph(generator, transformer, time_df, tab_dfs, num_samples=1000):

    noise = torch.randn(num_samples, 64)
    mu, logvar = transformer(torch.FloatTensor(time_df.values))

    cond = mu + torch.exp(0.5*logvar) * torch.randn_like(logvar)
    cond = cond[torch.randint(0, len(cond), (num_samples,))]
    new_data = generator(noise, cond).detach().numpy()
    graphs = []
    for df in tab_dfs:

        corr_matrix = df.corr().abs()
        threshold = 0.5
        adj_matrix = (corr_matrix > threshold).astype(int)
        G = nx.from_pandas_adjacency(adj_matrix)
        graphs.append(G)
    merged_graph = nx.Graph()
    for G in graphs:

        merged_graph = nx.compose(merged_graph, G)
    avg_mmd_losses = {i: np.mean(losses) for i, losses in feature_mmd_losses.items()}
    min_mmd = min(avg_mmd_losses.values())
    max_mmd = max(avg_mmd_losses.values())
    for node in merged_graph.nodes():

        feature_idx = int(node.split('_')[-1])
        mmd = avg_mmd_losses.get(feature_idx, max_mmd)
        weight = 1 - (mmd - min_mmd) / (max_mmd - min_mmd + 1e-8)
        merged_graph.nodes[node]['weight'] = weight

    return new_data, merged_graph


def target_minmax_scale(source_series, target_series):

    target_min = target_series.min()
    target_max = target_series.max()
    target_range = target_max - target_min

    source_min = source_series.min()
    source_max = source_series.max()
    source_range = source_max - source_min
    if source_range == 0 or target_range == 0:

        return source_series

    scaled_series = (source_series - source_min) / source_range
    scaled_series = scaled_series * target_range + target_min

    return scaled_series


def auto_encode_features(df, skip_columns=None, max_unique_for_label=20):

    df = df.copy()
    skip_columns = skip_columns or []

    for col in df.columns:
        if col in skip_columns:

            continue

        dtype = df[col].dtype
        if dtype == 'bool':

            df[col] = df[col].astype(int)
        elif dtype == 'object' or isinstance(df[col].iloc[0], str):

            num_unique = df[col].nunique()
            if 1 < num_unique <= max_unique_for_label:

                df[col] = df[col].astype('category').cat.codes
            elif num_unique > max_unique_for_label:

                dummies = pd.get_dummies(df[col], prefix=col)
                df = pd.concat([df, dummies], axis=1)
                df.drop(columns=[col], inplace=True)

    return df


class LearnableGraph(nn.Module):

    def __init__(self, num_nodes, hidden_dim):

        super().__init__()
        self.num_nodes = num_nodes
        self.learnable_adj = nn.Parameter(torch.randn(num_nodes, num_nodes))
    def forward(self, x):

        adj = torch.nn.functional.relu(self.learnable_adj)
        adj = torch.nn.functional.normalize(adj, p=1, dim=1)
        return adj

class SelfAttention(nn.Module):

    def __init__(self, hidden_dim):

        super().__init__()
        self.att = nn.Linear(hidden_dim, 1)

    def forward(self, x):

        weights = F.softmax(self.att(x).squeeze(-1), dim=1)
        weighted = torch.bmm(weights.unsqueeze(1), x).squeeze(1)
        return weighted


# class GraphTimeModel(nn.Module):
#     def __init__(self, num_features, num_classes, num_nodes=20,
#                  hidden_dim=64, heads=4, static_edge_index=None):
#         super().__init__()
#         self.static_edge_index = static_edge_index

#         self.graph_learner = LearnableGraph(num_nodes, hidden_dim)

#         self.gat1 = GATConv(num_features, hidden_dim, heads=heads)
#         self.gat2 = GATConv(hidden_dim * heads, hidden_dim, heads=1)

#         self.gru = nn.GRU(hidden_dim, hidden_dim, batch_first=True)

#         self.classifier = nn.Linear(hidden_dim, num_classes)

#     def forward(self, x_seq, edge_index_static):
#         B, T, N, F = x_seq.shape
#         device = x_seq.device

#         outputs = []

#         for t in range(T):
#             x_t = x_seq[:, t, :, :]
#             adj = self.graph_learner.learnable_adj
#             edge_index, edge_weight = dense_to_sparse(adj)

#             edge_index_combined = torch.cat([edge_index_static.to(device), edge_index.to(device)], dim=1)
#             edge_weight_combined = None

#             h = F.relu(self.gat1(x_t, edge_index_combined, edge_weight=edge_weight_combined))
#             h = self.gat2(h, edge_index_combined)

#             outputs.append(h.mean(dim=1))
#         h_seq = torch.stack(outputs, dim=1)  # shape: (B, T, H)
#         out, _ = self.gru(h_seq)
#         logits = self.classifier(out[:, -1, :])

#         return logits


class GraphTimeModel(nn.Module):

    def __init__(self, num_features, num_classes, num_nodes=13,
                 hidden_dim=64, heads=4, static_edge_index=None):

        super().__init__()

        self.static_edge_index = static_edge_index

        self.graph_learner = LearnableGraph(num_nodes, hidden_dim)

        self.gcn1 = GATConv(num_features, hidden_dim)
        self.gcn2 = GATConv(hidden_dim, hidden_dim)
        self.gcn3 = GATConv(hidden_dim, hidden_dim)
        self.gcn4 = GATConv(hidden_dim, hidden_dim)

        # self.gru = nn.GRU(hidden_dim, hidden_dim, batch_first=True)

        self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=4),
          num_layers=2)

        self.classifier = nn.Linear(hidden_dim, num_classes)

        # self.attention = SelfAttention(hidden_dim)

    def forward(self, x_seq, edge_index_static):

        B, T, N, F = x_seq.shape
        device = x_seq.device

        outputs = []

        for t in range(T):

            x_t = x_seq[:, t, :, :]
            adj = self.graph_learner.learnable_adj
            edge_index, edge_weight = dense_to_sparse(adj)
            assert edge_index.max() < x_t.size(1), \
                "Edge index contains invalid node indices"
            # edge_index_combined = torch.cat([edge_index_static.to(device), edge_index.to(device)], dim=1)
            # edge_weight_combined = None

            alpha = 0.5
            edge_index_combined = torch.cat([edge_index, edge_index_static], dim=1)
            edge_weight_combined = torch.cat([
                alpha * edge_weight_dynamic,
                (1 - alpha) * edge_weight_static
            ])
            # h = F.relu(self.gat1(x_t, edge_index_combined, edge_attr=edge_weight_combined))
            h = torch.nn.functional.relu(self.gcn1(x_t, edge_index_combined, edge_weight=edge_weight_combined))
            h = self.gcn2(h, edge_index_combined, edge_weight=edge_weight_combined)

            h = self.gcn3(h, edge_index_combined, edge_weight=edge_weight_combined)
            h = self.gcn4(h, edge_index_combined, edge_weight=edge_weight_combined)

            outputs.append(h.mean(dim=1))
        h_seq = torch.stack(outputs, dim=1)
        out, _ = self.transformer(h_seq)
        logits = self.classifier(out[:, -1, :])
        return logits


def align_and_standardize_units(df_time, df_tab_aligned, similarity_threshold=0.6):

    report = []
    processed_df = df_tab_aligned.copy()
    time_cols = df_time.columns.tolist()

    common_cols = set(processed_df.columns) & set(time_cols)

    for tab_col in df_tab_aligned.columns:

        if tab_col in common_cols:

            time_col = tab_col
            similarity = 1.0
        else:
            time_col, similarity = find_most_similar_time_col(tab_col, time_cols)
            if similarity < similarity_threshold:

                report.append({
                    'original_column': tab_col,
                    'new_column': tab_col,
                    'similarity': similarity,
                    'action': 'No sufficiently similar time series column found',
                    'conversion': None
                })
                continue

            processed_df = processed_df.rename(columns={tab_col: time_col})
            report.append({
                'original_column': tab_col,
                'new_column': time_col,
                'similarity': similarity,
                'action': 'Column renamed based on semantic similarity',
                'conversion': None
            })
        if not np.issubdtype(processed_df[time_col].dtype, np.number):

            report.append({
                'column': time_col,
                'action': 'Non-numeric column - no unit conversion',
                'conversion': None
            })
            continue

        time_unit = extract_unit_from_column(time_col)
        tab_unit = extract_unit_from_column(tab_col)

        if time_unit and tab_unit:

            if are_units_convertible(time_unit, tab_unit):

                processed_df[time_col] = processed_df[time_col].apply(
                    lambda x: convert_units(x, tab_unit, time_unit)
                )
                report.append({
                    'column': time_col,
                    'action': 'Converted using explicit units',
                    'conversion': f'{tab_unit} → {time_unit}',
                    'factor': UNIT_CONVERSION[detect_unit_category(tab_unit)][tab_unit] /
                             UNIT_CONVERSION[detect_unit_category(time_unit)][time_unit]
                })
            else:

                processed_df[time_col] = target_minmax_scale(
                    processed_df[time_col],
                    df_time[time_col]
                )
                report.append({
                    'column': time_col,
                    'action': 'Incompatible units - scaled to time series range',
                    'conversion': None
                })

        elif time_unit and not tab_unit:

            is_proportional, factor = statistical_unit_check(
                processed_df[time_col], df_time[time_col]
            )

            if is_proportional:

                processed_df[time_col] = processed_df[time_col] * factor
                report.append({
                    'column': time_col,
                    'action': 'Converted using statistical scaling',
                    'conversion': f'factor: {factor:.4f}'
                })
            else:

                processed_df[time_col] = target_minmax_scale(
                    processed_df[time_col],
                    df_time[time_col]
                )
                report.append({
                    'column': time_col,
                    'action': 'Scaled to time series range',
                    'conversion': None
                })
        elif not time_unit and tab_unit:

            report.append({
                'column': time_col,
                'action': 'Kept time series unitless format',
                'conversion': None
            })

        else:

            is_proportional, factor = statistical_unit_check(
                processed_df[time_col], df_time[time_col]
            )

            if is_proportional and abs(factor - 1.0) > 0.01:

                processed_df[time_col] = processed_df[time_col] * factor
                report.append({
                    'column': time_col,
                    'action': 'Adjusted by statistical scaling',
                    'conversion': f'factor: {factor:.4f}'
                })
            else:

                processed_df[time_col] = target_minmax_scale(
                    processed_df[time_col],
                    df_time[time_col]
                )
                report.append({
                    'column': time_col,
                    'action': 'Scaled to time series range',
                    'conversion': None
                })

    return processed_df, report


def preprocess_column_name(col_name):

    return str(col_name).lower().replace('_', ' ').strip()

def detect_antonym_relation(col1, col2):

    col1 = preprocess_column_name(col1)
    col2 = preprocess_column_name(col2)
    result = classifier({
        'text': f"The feature is {col1}",
        'text_pair': f"The feature is {col2}",
    }, top_k=3)
    for item in result:

        if item['label'] == 'contradiction':
            return (item['score'] > 0.7), item['score']

    return False, 0.0


def is_boolean_column(series):

    unique_vals = set(series.dropna().unique())
    return unique_vals.issubset({0, 1}) if unique_vals else False

def align_column_names(df_time, df_tab_relevant, threshold=0.5):

    time_cols = df_time.columns.tolist()
    tab_cols = df_tab_relevant.columns.tolist()

    processed_time_cols = [preprocess_column_name(col) for col in time_cols]
    processed_tab_cols = [preprocess_column_name(col) for col in tab_cols]

    time_embeddings = model.encode(processed_time_cols)
    tab_embeddings = model.encode(processed_tab_cols)
    similarity_matrix = cosine_similarity(tab_embeddings, time_embeddings)

    report = {
        'column_mappings': [],
        'antonym_processed': []
    }

    aligned_df = pd.DataFrame(index=df_tab_relevant.index)

    for tab_idx, tab_col in enumerate(tab_cols):

        time_idx = np.argmax(similarity_matrix[tab_idx])
        best_match_col = time_cols[time_idx]
        similarity_score = similarity_matrix[tab_idx][time_idx]
        is_antonym, antonym_confidence = detect_antonym_relation(tab_col, best_match_col)

        if is_antonym:

            if is_boolean_column(df_tab_relevant[tab_col]):

                inverted_series = df_tab_relevant[tab_col].apply(lambda x: 1 if x == 0 else 0)
                aligned_df[best_match_col] = inverted_series
                report['antonym_processed'].append({
                    'original_column': tab_col,
                    'new_column': best_match_col,
                    'similarity': similarity_score,
                    'antonym_confidence': antonym_confidence,
                    'action': 'Boolean values inverted'
                })
            else:

                aligned_df[tab_col] = df_tab_relevant[tab_col]
                report['antonym_processed'].append({
                    'original_column': tab_col,
                    'new_column': tab_col,
                    'similarity': similarity_score,
                    'antonym_confidence': antonym_confidence,
                    'action': 'Non-boolean column, only renamed'
                })
        else:

            aligned_df[tab_col] = df_tab_relevant[tab_col]
            report['column_mappings'].append({
                'original_column': tab_col,
                'new_column': tab_col,
                'similarity': similarity_score,
                'action': 'Aligned by semantic similarity'
            })

    return aligned_df, report


def prepare_data(df_time):

    samples = []
    labels = []

    for pid, group in df_time.groupby('Date'):

        time_steps = sorted(group['timestep'].unique())
        seq = []
        for t in time_steps:
            data_t = group[group['timestep'] == t].drop(columns=['Date', 'timestep', 'Target'])
            seq.append(data_t.values)

        sample = np.concatenate(seq, axis=0)
        label = group['Target'].iloc[-1]
        samples.append(sample)
        labels.append(label)
    X = np.stack(samples)
    y = np.array(labels)

    return X, y

def preprocess_column_names(columns):

    processed = []
    for col in columns:

        col = str(col).lower().replace('_', ' ').replace('-', ' ')
        col = ' '.join(col.split())
        processed.append(col)
    return processed

def find_most_similar_columns(time_series_cols, table_cols, threshold=0.5):

    processed_time_cols = preprocess_column_names(time_series_cols)
    processed_table_cols = preprocess_column_names(table_cols)
    time_embeddings = model.encode(processed_time_cols)

    table_embeddings = model.encode(processed_table_cols)
    similarity_matrix = cosine_similarity(time_embeddings, table_embeddings)

    results = {}
    for i, time_col in enumerate(time_series_cols):

        similarities = similarity_matrix[i]
        max_idx = np.argmax(similarities)
        max_similarity = similarities[max_idx]
        if max_similarity > threshold:
            results[time_col] = (table_cols[max_idx], max_similarity)

    return results


def get_most_relevant_features(df_time, df_tables, threshold=0.5):

    time_cols = df_time.columns.tolist()
    if isinstance(df_tables, dict):

        table_list = list(df_tables.values())
        table_names = list(df_tables.keys())
    else:

        table_list = df_tables

        table_names = [f'table_{i+1}' for i in range(len(table_list))]
    all_results = {}
    for df, name in zip(table_list, table_names):

        table_cols = df.columns.tolist()
        matches = find_most_similar_columns(time_cols, table_cols, threshold)
        all_results[name] = matches

    return all_results


def convert_to_timesteps(df, time_col='Time', patient_id_col='Date'):

    df = df.copy()
    df[time_col] = pd.to_datetime(df[time_col])
    df['timestep'] = df.groupby(patient_id_col)[time_col].rank(method='first').astype(int) - 1
    df.drop(columns=[time_col], inplace=True)

    return df


def select_relevant_columns(df_tab, matches):

    relevant_cols = [match[0] for match in matches.values()]
    return df_tab[relevant_cols]


if __name__ == "__main__":


  df_time = pd.read_csv("df_time.csv")

  df_tab1 = pd.read_csv("df_tab1.csv")
  df_tab2 = pd.read_csv("df_tab2.csv")
  df_tab3 = pd.read_csv("df_tab3.csv")

  df_time = df_time.dropna()
  df_time = convert_to_timesteps(df_time)

  model = SentenceTransformer('all-mpnet-base-v2')

  classifier = pipeline("text-classification", model="roberta-large-mnli")

  tables = {
      'table_1': df_tab1,
      'table_2': df_tab2,
      'table_3': df_tab3
  }

  [df_tab1, df_tab2, df_tab3], [df_time_scaled] = preprocess_data(df_tab1_final, df_tab2_final, df_tab3_final, df_time)

  generator, transformer, feature_mmd_losses = train_cgan([df_tab1, df_tab2, df_tab3], df_time_scaled)

  relevant_features = get_most_relevant_features(df_time, tables)

  df_tab1_relevant = select_relevant_columns(df_tab1, relevant_features['table_1'])
  df_tab2_relevant = select_relevant_columns(df_tab2, relevant_features['table_2'])
  df_tab3_relevant = select_relevant_columns(df_tab3, relevant_features['table_3'])

  df_tab1_aligned, report1 = align_column_names(df_time, df_tab1_relevant)
  df_tab2_aligned, report2 = align_column_names(df_time, df_tab2_relevant)
  df_tab3_aligned, report3 = align_column_names(df_time, df_tab3_relevant)

  df_tab1_final, report1 = align_and_standardize_units(df_time, df_tab1_aligned)
  df_tab2_final, report2 = align_and_standardize_units(df_time, df_tab2_aligned)
  df_tab3_final, report3 = align_and_standardize_units(df_time, df_tab3_aligned)

  new_data, G_unified = generate_and_build_graph(generator, transformer, df_time_scaled, [df_tab1, df_tab2, df_tab3])

  df_time = auto_encode_features(df_time, skip_columns=['Date'])

  min_timesteps = df_time.groupby('Date')['timestep'].max().min()

  df_time = df_time[df_time['timestep'] <= min_timesteps]

  X, y = prepare_data(df_time)

  X_train, X_test, y_train, y_test, len_train, len_test = train_test_split(X, y, lengths, test_size=0.2,random_state=42, stratify =y)

  X_train_tensor = torch.tensor(X_train, dtype=torch.float32)

  y_train_tensor = torch.tensor(y_train, dtype=torch.long)

  X_test_tensor = torch.tensor(X_test, dtype=torch.float32)

  y_test_tensor = torch.tensor(y_test, dtype=torch.long)

  _, y_train_tensor = torch.unique(y_train_tensor, return_inverse=True)

  _, y_test_tensor = torch.unique(y_test_tensor, return_inverse=True)

  data_static = from_networkx(G_unified)

  edge_index_static = data_static.edge_index

  model = GraphTimeModel(
          num_features=1,
          num_classes=len(np.unique(y)),
          num_nodes=13,
          hidden_dim=32,
          heads=2,
          static_edge_index=edge_index_static)

  criterion = nn.CrossEntropyLoss()
  optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

  for epoch in range(2000):

      model.train()
      optimizer.zero_grad()

      x_input = X_train_tensor.unsqueeze(-1)
      output = model(x_input, edge_index_static)
      loss = criterion(output, y_train_tensor)

      loss.backward()
      optimizer.step()
      if epoch % 10 == 0:
          print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

  model.eval()

  with torch.no_grad():

      x_input = X_test_tensor.unsqueeze(-1)
      logits = model(x_input, edge_index_static)

      probs = F.softmax(logits, dim=1)
      pred = logits.argmax(dim=1)
      y_true = y_test_tensor.numpy()
      y_pred = pred.numpy()

      y_probs = probs.numpy()
      f1 = f1_score(y_true, y_pred, average='weighted')

      print(f"F1 Score: {f1:.4f}")
      if len(np.unique(y_true)) == 2:

          auroc = roc_auc_score(y_true, y_probs[:, 1], average='micro')
          auprc = average_precision_score(y_true, y_probs[:, 1], average='micro')
          print(f"AUROC: {auroc:.4f}")
          print(f"AUPRC: {auprc:.4f}")
      else:

          auroc = roc_auc_score(y_true, y_probs, multi_class='ovr', average='micro')
          auprc = average_precision_score(y_true, y_probs, average='micro')
          print(f"AUROC: {auroc:.4f}")
          print(f"AUPRC: {auprc:.4f}")