In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import KFold
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import mean_squared_error, r2_score

# -------------------------
# Dataset (reuse same as above)
# -------------------------
target_columns = [
    "Average-Total-ISD-Cells", "ACE-Xi", "ACE-km", "ACE-Ks", "H2-Xi", "H2-km", "H2-Ks",
    "Digester_BD", "Digester_BF", "Digester_CB", "Digester_CP", "Digester_FD", "Digester_GB",
    "Digester_GP", "Digester_JB", "Digester_LP", "Digester_MA", "Digester_NB", "Digester_NS",
    "Digester_PC", "Digester_PO", "Digester_SF", "Digester_SS", "Digester_SW", "Digester_WA",
    "Digester_WP", "Digester_WR", "Source_I", "Source_M", "Source_P", "Type_CSTR", "Type_EFB",
    "Type_EGSB", "Type_Lagoon", "Type_UASB", "Waste_BW", "Waste_Dairy", "Waste_FW", "Waste_HSI",
    "Waste_MPW", "Waste_MS", "Waste_MS+Dairy", "Waste_MS+HSI", "Waste_PP", "Waste_PR",
    "Waste_SDW", "Biomass_F", "Biomass_G"
]

class TabularDataset(Dataset):
    def __init__(self, df, cat_cols, num_cols, target_col, scaler=None, cat_encoders=None):
        self.df = df.copy().dropna().reset_index(drop=True)
        self.df = self.df.apply(pd.to_numeric, errors='coerce').fillna(0)
        self.cat_cols = cat_cols
        self.num_cols = num_cols
        self.target_col = target_col
        
        if scaler is None:
            self.scaler = StandardScaler()
            self.num_data = self.scaler.fit_transform(self.df[num_cols])
        else:
            self.scaler = scaler
            self.num_data = self.scaler.transform(self.df[num_cols])
        
        self.cat_encoders = {}
        cat_data_list = []
        for col in cat_cols:
            if cat_encoders is None or col not in cat_encoders:
                le = LabelEncoder()
                encoded = le.fit_transform(self.df[col])
                self.cat_encoders[col] = le
            else:
                le = cat_encoders[col]
                self.cat_encoders[col] = le
                encoded = le.transform(self.df[col])
            cat_data_list.append(encoded.astype(np.int64))
        if len(cat_cols) > 0:
            self.cat_data = np.stack(cat_data_list, axis=1)
        else:
            self.cat_data = None
        self.target = self.df[target_col].values.astype(np.float32)
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        num = self.num_data[idx]
        if self.cat_data is not None:
            cat = self.cat_data[idx]
            return (torch.tensor(num, dtype=torch.float32),
                    torch.tensor(cat, dtype=torch.long),
                    torch.tensor(self.target[idx], dtype=torch.float32))
        else:
            return (torch.tensor(num, dtype=torch.float32),
                    torch.tensor(self.target[idx], dtype=torch.float32))

# -------------------------
# FT-Transformer Model
# -------------------------
class FTTransformer(nn.Module):
    def __init__(self, num_num_features, cat_cardinalities, d_token=64, nhead=8,
                 num_layers=4, dropout=0.1, mlp_hidden=64):
        """
        num_num_features: Number of numerical features.
        cat_cardinalities: List with the number of unique values for each categorical feature.
        """
        super(FTTransformer, self).__init__()
        # Project each numerical feature (scalar) to a token
        self.num_tokens = nn.ModuleList([nn.Linear(1, d_token) for _ in range(num_num_features)])
        # Embedding for each categorical feature
        self.cat_tokens = nn.ModuleList([nn.Embedding(card, d_token) for card in cat_cardinalities])
        # Total tokens = numerical tokens + categorical tokens
        self.total_tokens = len(self.num_tokens) + len(self.cat_tokens)
        # Learnable [CLS] token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, d_token))
        # Positional embeddings for entire sequence (CLS + features)
        self.pos_embedding = nn.Parameter(torch.zeros(1, self.total_tokens + 1, d_token))
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_token, nhead=nhead, dropout=dropout)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        # MLP head for regression (using CLS token)
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(d_token),
            nn.Linear(d_token, mlp_hidden),
            nn.ReLU(),
            nn.Linear(mlp_hidden, 1)
        )
        self._init_weights()
    
    def _init_weights(self):
        nn.init.trunc_normal_(self.pos_embedding, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        for layer in self.num_tokens:
            nn.init.xavier_uniform_(layer.weight)
            if layer.bias is not None:
                nn.init.zeros_(layer.bias)
        for emb in self.cat_tokens:
            nn.init.uniform_(emb.weight, -0.1, 0.1)
    
    def forward(self, num_data, cat_data):
        batch_size = num_data.size(0)
        tokens = []
        # Process numerical features individually
        for i, proj in enumerate(self.num_tokens):
            token = proj(num_data[:, i:i+1])  # [batch, d_token]
            tokens.append(token.unsqueeze(1))
        # Process categorical features (if provided)
        if cat_data is not None:
            for i, emb in enumerate(self.cat_tokens):
                token = emb(cat_data[:, i])  # [batch, d_token]
                tokens.append(token.unsqueeze(1))
        tokens = torch.cat(tokens, dim=1)  # [batch, total_tokens, d_token]
        # Prepend the [CLS] token
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)  # [batch, 1, d_token]
        tokens = torch.cat([cls_tokens, tokens], dim=1)  # [batch, total_tokens+1, d_token]
        tokens = tokens + self.pos_embedding  # Add positional embeddings
        # Transformer expects (seq_len, batch, d_token)
        tokens = tokens.transpose(0, 1)
        encoded = self.transformer(tokens)
        encoded = encoded.transpose(0, 1)
        cls_out = encoded[:, 0, :]  # Use [CLS] token output
        out = self.mlp_head(cls_out)
        return out.squeeze(1)

# -------------------------
# Cross-Validation and Training for FT-Transformer
# -------------------------
def cross_validate_ft(df, target_column, cat_cols, num_cols, k_folds=5,
                      epochs=50, batch_size=32, learning_rate=0.001):
    kf = KFold(n_splits=k_folds, shuffle=True, random_state=42)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    mse_scores, r2_scores = [], []
    all_actuals, all_predictions = [], []
    
    for fold, (train_idx, val_idx) in enumerate(kf.split(df)):
        print(f"\n[FT-Transformer] Fold {fold+1}/{k_folds}")
        df_train = df.iloc[train_idx].reset_index(drop=True)
        df_val   = df.iloc[val_idx].reset_index(drop=True)
        train_ds = TabularDataset(df_train, cat_cols, num_cols, target_column)
        val_ds = TabularDataset(df_val, cat_cols, num_cols, target_column,
                                scaler=train_ds.scaler, cat_encoders=train_ds.cat_encoders)
        train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
        
        num_num_features = len(num_cols)
        cat_cardinalities = [len(train_ds.cat_encoders[col].classes_) for col in cat_cols]
        
        model = FTTransformer(num_num_features=num_num_features, cat_cardinalities=cat_cardinalities,
                              d_token=64, nhead=8, num_layers=4, dropout=0.1, mlp_hidden=64).to(device)
        criterion = nn.MSELoss()
        optimizer = optim.Adam(model.parameters(), lr=learning_rate)
        
        model.train()
        for epoch in range(epochs):
            epoch_loss = 0
            for batch in train_loader:
                if len(cat_cols) > 0:
                    X_num, X_cat, y = batch
                    X_num, X_cat, y = X_num.to(device), X_cat.to(device), y.to(device)
                    optimizer.zero_grad()
                    preds = model(X_num, X_cat)
                else:
                    X_num, y = batch
                    X_num, y = X_num.to(device), y.to(device)
                    optimizer.zero_grad()
                    preds = model(X_num, None)
                loss = criterion(preds, y)
                loss.backward()
                optimizer.step()
                epoch_loss += loss.item()
            print(f"Epoch [{epoch+1}/{epochs}] Loss: {epoch_loss/len(train_loader):.4f}")
        
        model.eval()
        fold_preds, fold_actuals = [], []
        with torch.no_grad():
            for batch in val_loader:
                if len(cat_cols) > 0:
                    X_num, X_cat, y = batch
                    X_num, X_cat, y = X_num.to(device), X_cat.to(device), y.to(device)
                    preds = model(X_num, X_cat)
                else:
                    X_num, y = batch
                    X_num, y = X_num.to(device), y.to(device)
                    preds = model(X_num, None)
                fold_preds.extend(preds.cpu().numpy())
                fold_actuals.extend(y.cpu().numpy())
        mse = mean_squared_error(fold_actuals, fold_preds)
        r2 = r2_score(fold_actuals, fold_preds)
        mse_scores.append(mse)
        r2_scores.append(r2)
        all_actuals.extend(fold_actuals)
        all_predictions.extend(fold_preds)
        print(f"Fold {fold+1} Results: MSE = {mse:.4f}, R² = {r2:.4f}")
    
    print(f"\n[FT-Transformer] Final CV: Mean MSE = {np.mean(mse_scores):.4f}, Mean R² = {np.mean(r2_scores):.4f}")
    plt.figure(figsize=(8,6))
    sns.regplot(x=all_actuals, y=all_predictions, scatter_kws={'alpha':0.5}, line_kws={'color':'red'})
    plt.xlabel("Actual Values")
    plt.ylabel("Predicted Values")
    plt.title(f"FT-Transformer: Predicted vs Actual ({target_column})")
    plt.show()

# -------------------------
# Main Execution for FT-Transformer
# -------------------------
if __name__ == "__main__":
    df = pd.read_csv("Data/New_data.csv")  # Update as needed
    target_column = "ACE-km"
    cat_cols = []  # (Populate if needed)
    num_cols = [col for col in df.columns if col not in target_columns]
    cross_validate_ft(df, target_column, cat_cols, num_cols,
                      k_folds=5, epochs=50, batch_size=32, learning_rate=0.001)



[FT-Transformer] Fold 1/5




Epoch [1/50] Loss: 267.6487
Epoch [2/50] Loss: 248.6782
Epoch [3/50] Loss: 240.2321
Epoch [4/50] Loss: 246.3775
Epoch [5/50] Loss: 239.7164
Epoch [6/50] Loss: 216.2830
Epoch [7/50] Loss: 256.2126
Epoch [8/50] Loss: 184.9729
Epoch [9/50] Loss: 206.0536
Epoch [10/50] Loss: 250.0005
Epoch [11/50] Loss: 174.8703
Epoch [12/50] Loss: 179.0890
Epoch [13/50] Loss: 230.0662
Epoch [14/50] Loss: 210.4991
Epoch [15/50] Loss: 152.1523
Epoch [16/50] Loss: 172.1486
Epoch [17/50] Loss: 208.0616
Epoch [18/50] Loss: 143.0674
Epoch [19/50] Loss: 135.8951
Epoch [20/50] Loss: 160.6101
Epoch [21/50] Loss: 183.8126
Epoch [22/50] Loss: 150.4085
Epoch [23/50] Loss: 121.6874
Epoch [24/50] Loss: 118.6678
Epoch [25/50] Loss: 114.3801
Epoch [26/50] Loss: 122.3542
Epoch [27/50] Loss: 106.6539
Epoch [28/50] Loss: 134.2132
Epoch [29/50] Loss: 107.7782
Epoch [30/50] Loss: 119.7869
Epoch [31/50] Loss: 112.5130
Epoch [32/50] Loss: 128.2462
Epoch [33/50] Loss: 122.0447
Epoch [34/50] Loss: 108.3985
Epoch [35/50] Loss: 105



Epoch [1/50] Loss: 233.0992
Epoch [2/50] Loss: 222.8120
Epoch [3/50] Loss: 265.5637
Epoch [4/50] Loss: 200.8480
Epoch [5/50] Loss: 205.1207
Epoch [6/50] Loss: 184.2469
Epoch [7/50] Loss: 183.5707
Epoch [8/50] Loss: 177.9490
Epoch [9/50] Loss: 183.9512
Epoch [10/50] Loss: 222.7711
Epoch [11/50] Loss: 175.5731
Epoch [12/50] Loss: 216.4648
Epoch [13/50] Loss: 156.3698
Epoch [14/50] Loss: 184.6034
Epoch [15/50] Loss: 180.0985
Epoch [16/50] Loss: 155.9326
Epoch [17/50] Loss: 172.7530
Epoch [18/50] Loss: 134.9559
Epoch [19/50] Loss: 120.9826
Epoch [20/50] Loss: 117.1790
Epoch [21/50] Loss: 148.5008
Epoch [22/50] Loss: 141.3353
Epoch [23/50] Loss: 112.7411
