### Import

In [None]:
import torch
from Dataset.bank_dataset import BankTxnDataset, pad_collate_fn
from Models.transformer import TransformerClassifier
from Config.config import load_config
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import tqdm

### Loading Config

In [None]:
cfg = load_config()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Current device: \033[92m{device}\033[0m")

### Loading Dataset

In [None]:
train_ds = BankTxnDataset(cfg, split="train", val_ratio=cfg.dataset['validateSplit'])
val_ds = BankTxnDataset(cfg, split="val", val_ratio=cfg.dataset['validateSplit'])
# test_ds = BankTxnDataset(cfg, split="test")
print(f"Total number of training data: \033[92m{len(train_ds.data)}\033[0m")

train_loader = DataLoader(
	train_ds,
	batch_size=cfg.parameter['batchSize'],
	shuffle=True,
	num_workers=4,
	pin_memory=True,                # speeds host→GPU copies
	collate_fn=pad_collate_fn
)

val_loader = DataLoader(
    val_ds,
    batch_size=cfg.parameter['batchSize'],
    shuffle=False,  # No need to shuffle validation data
    num_workers=4,
    pin_memory=True,
    collate_fn=pad_collate_fn
)

# test_loader = DataLoader(
#     test_ds,
#     batch_size=cfg.parameter['batchSize'],
#     shuffle=False,
#     num_workers=4,
#     pin_memory=True,
#     collate_fn=pad_collate_fn
# )


In [None]:
sample_batch = next(iter(train_loader))
x_sample = sample_batch[0]
print(f"Input tensor shape: {x_sample.shape}")
actual_feat_dim = x_sample.shape[2]
print(f"Feature dimension from data: {actual_feat_dim}")

### Loading Model & Optimizer

In [None]:
model = TransformerClassifier(
	feat_dim=actual_feat_dim,
	d_model=cfg.parameter['d_model'],
	nhead=cfg.parameter['attention_head'],
	num_layers=cfg.parameter['num_layers'],
	num_classes=1
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=cfg.parameter['learningRate'])
scaler    = torch.amp.GradScaler()  # optional mixed‑precision

### Training

In [None]:
# Add imports for metrics
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

# Define evaluation function
def evaluate(model, data_loader, device):
    model.eval()
    val_loss = 0.0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for x, lengths, y in data_loader:
            x, lengths, y = x.to(device), lengths.to(device), y.to(device)
            
            with torch.amp.autocast(device_type=device.type):
                logits = model(
                    x,
                    src_key_padding_mask=(torch.arange(x.size(1), device=device)
                                         .unsqueeze(0)
                                         .ge(lengths.unsqueeze(1)))
                )
                loss = torch.nn.functional.binary_cross_entropy_with_logits(
                    logits.squeeze(), y
                )
            
            val_loss += loss.item()
            
            # Get predictions (0 or 1)
            preds = torch.sigmoid(logits.squeeze()) >= 0.5
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(y.cpu().numpy())
    
    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, zero_division=0)
    recall = recall_score(all_labels, all_preds, zero_division=0)
    f1 = f1_score(all_labels, all_preds, zero_division=0)
    
    # Return average loss and metrics
    return val_loss / len(data_loader), accuracy, precision, recall, f1


In [None]:
epochs = cfg.parameter['epochs']
print(f"Starting training for {epochs} epochs on {device}")
train_losses = []
val_losses = []
val_metrics = []
best_val_loss = float('inf')
best_model_state = None
model.train()

for epoch in range(1, epochs+1):
    # Training phase
    model.train()
    epoch_loss = 0.0
    pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}", ncols=80)
    
    for x, lengths, y in pbar:
        x, lengths, y = x.to(device), lengths.to(device), y.to(device)
        optimizer.zero_grad()
        with torch.amp.autocast(device_type=device.type):
            logits = model(
                x,
                src_key_padding_mask=(torch.arange(x.size(1), device=device)
                                      .unsqueeze(0)
                                      .ge(lengths.unsqueeze(1)))
            )
            loss = torch.nn.functional.binary_cross_entropy_with_logits(
                logits.squeeze(), y
            )
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        batch_loss = loss.item()
        epoch_loss += batch_loss
        pbar.set_postfix(loss=f"{batch_loss:.4f}")

    avg_train_loss = epoch_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    
    # Validation phase
    val_loss, accuracy, precision, recall, f1 = evaluate(model, val_loader, device)
    val_losses.append(val_loss)
    val_metrics.append({
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1
    })
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model_state = model.state_dict().copy()
        print(f"New best model saved! (val_loss: {val_loss:.4f})")
    
    print(f"Epoch {epoch} - Train loss: {avg_train_loss:.4f}, Val loss: {val_loss:.4f}")
    print(f"Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")

# Load best model for final evaluation
if best_model_state:
    model.load_state_dict(best_model_state)
    print("Loaded best model based on validation performance")

# Plot training and validation loss
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(range(1, epochs+1), train_losses, marker='o', linestyle='-', color='b', label='Training Loss')
plt.plot(range(1, epochs+1), val_losses, marker='o', linestyle='-', color='r', label='Validation Loss')
plt.title('Loss Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True, linestyle='--', alpha=0.7)

# Plot validation metrics
plt.subplot(1, 2, 2)
plt.plot(range(1, epochs+1), [m['accuracy'] for m in val_metrics], marker='o', label='Accuracy')
plt.plot(range(1, epochs+1), [m['precision'] for m in val_metrics], marker='s', label='Precision')
plt.plot(range(1, epochs+1), [m['recall'] for m in val_metrics], marker='^', label='Recall')
plt.plot(range(1, epochs+1), [m['f1'] for m in val_metrics], marker='d', label='F1')
plt.title('Validation Metrics')
plt.xlabel('Epoch')
plt.ylabel('Score')
plt.legend()
plt.grid(True, linestyle='--', alpha=0.7)

plt.tight_layout()
plt.show()

### Testing

In [None]:
# — inference on test split —
import pandas as pd
from torch.utils.data import DataLoader

# 1) build test dataset & loader
test_ds     = BankTxnDataset(cfg, split="test")
test_loader = DataLoader(
    test_ds,
    batch_size=cfg.parameter['batchSize'],
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    collate_fn=pad_collate_fn
)

# 2) run model in eval mode and collect probabilities
model.eval()
preds = []
with torch.no_grad():
    for x, lengths, _ in test_loader:
        x, lengths = x.to(device), lengths.to(device)
        # generate logits then sigmoid to get probability
        mask   = (torch.arange(x.size(1), device=device)
                  .unsqueeze(0)
                  .ge(lengths.unsqueeze(1)))
        logits = model(x, src_key_padding_mask=mask).squeeze(1)
        probs  = torch.sigmoid(logits)
        preds.extend(probs.cpu().numpy())

# 3) map back to account numbers (sorted by ACCT_NBR)
df_out = test_ds.get_label()          # DataFrame with sorted ACCT_NBR
df_out['prediction'] = preds          # same order as test_ds sequences
# 4) write to CSV
df_out.to_csv('predictions.csv', index=False)
print("Wrote", len(df_out), "rows to predictions.csv")