In [None]:
# Defines.

# Define Custom Model.
import torch
import torch.nn as nn
from transformers import AutoModel

class CustomModel(nn.Module):
    def __init__(self, checkpoint, num_labels, additional_feature_dim):
        super(CustomModel, self).__init__()
        
        # Load pretrained transformer.
        self.transformer = AutoModel.from_pretrained(checkpoint)

        # Expose the transformer's config.
        self.config = self.transformer.config
        
        # Combine transformer outputs with additional features
        transformer_hidden_size = self.transformer.config.hidden_size
        self.fc1 = nn.Linear(transformer_hidden_size + additional_feature_dim, num_labels)
        
#       self.fc2 = nn.Linear(256, num_labels)   # For complex Head.
#       self.dropout = nn.Dropout(0.1)          # For dropout.
        
    def forward(self, input_ids, attention_mask, additional_features):
        # Transformer output.
        transformer_output = self.transformer(
            input_ids      = input_ids,
            attention_mask = attention_mask
        )
        
        # Use [CLS] token for concatenation.
        cls_output     = transformer_output.last_hidden_state[:, 0, :]
        combined_input = torch.cat([cls_output, additional_features], dim=1)
        
        # Pass through fully connected layers
#       x = self.dropout(torch.relu(self.fc1(combined_input)))
        output = self.fc1(combined_input)
        
        return output

# Define HF Wrapper for Custom Model.
class HuggingFaceModelWrapper(nn.Module):
    def __init__(self, base_model):
        super().__init__()
        self.base_model = base_model         # Custom model.
        self.config     = base_model.config  # Expose the base model's config.

    def forward(self, input_ids, attention_mask, additional_features, labels=None):
        # Forward pass through the base model.
        output = self.base_model(input_ids           = input_ids, 
                                 attention_mask      = attention_mask, 
                                 additional_features = additional_features)
        
        # If labels are provided, calculate loss.
        logits = output
        loss   = None
        if labels is not None:
            loss_fn = nn.BCEWithLogitsLoss()
            loss    = loss_fn(logits, labels)
        
        return {"loss": loss, "logits": logits}

    def prepare_inputs_for_generation(self, *args, **kwargs):
        # Delegate to the base model.
        return self.base_model.prepare_inputs_for_generation(*args, **kwargs)

# Spearman's Corr.
from scipy.stats import spearmanr

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions    = np.argmax(logits, axis=1) if logits.ndim == 3 else logits

    # Calculate Spearman's correlation for each label.
    spearman_corrs = []
    for i in range(labels.shape[1]):
        corr, _ = spearmanr(predictions[:, i], labels[:, i])
        spearman_corrs.append(corr)

    # Return the mean of Spearman's correlation.
    mean_spearman = np.nanmean(spearman_corrs)  # Handle NaNs if any.
    return {"spearman": mean_spearman}

# Convert datasets, `df` -> `ds`.
def preprocess_data(df, labels):
    return {
        "input_ids"             : list(df['input_ids']),
        "attention_mask"        : list(df['attention_mask']),
        "additional_features"   : df.iloc[:, :-2].values.tolist(),  
        "labels"                : labels.to_numpy().tolist()        
    }

# Method to Check Rank.
def check_rank(score):
    leaderboard = pd.read_csv('./leaderboard.csv')
    num_team    = len(leaderboard)
    mean        = leaderboard['Score'].mean()
    median      = leaderboard['Score'].median()

    my_rank = (leaderboard['Score'] >= my_score).sum()

    print(f'My Rank = {my_rank} / {num_team}')
    print(f'My Score = {score:.4f}')
    print(f'Mean = {mean:.4f}')
    print(f'Median = {median:.4f}')

# Plot Learning Curve.
def plot_learning_curve(history):
    eval_loss  = [log["eval_loss"] for log in history if "eval_loss" in log]
    train_loss = [log["loss"] for log in history if "loss" in log]
    spear_loss = [log["eval_spearman"] for log in history if "loss" in log]
    
    plt.figure(figsize=(10, 6))
    
    # Training Loss curve
    plt.plot(
        [log["step"] for log in history if "loss" in log],
        train_loss,
        label="Training Loss",
        marker="o",
    )
    
    # Validation Loss curve
    plt.plot(
        [log["step"] for log in history if "eval_loss" in log],
        eval_loss,
        label="Validation Loss",
        marker="x",
    )

    # Spearman Loss curve
    plt.plot(
        [log["step"] for log in history if "eval_loss" in log],
        spear_loss,
        label="Spearman Loss",
        marker="x",
    )
    
    # Labels and legends
    plt.title("Training and Validation Loss Curve")
    plt.xlabel("Training Steps")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid(True)
    
    # Show plot
    plt.show()