In [67]:
# Memory management utilities
import gc
import torch

def clear_memory():
    """Clear GPU and CPU memory caches"""
    torch.cuda.empty_cache()
    gc.collect()
    print("Memory cleared")

# Run this before each full execution
clear_memory()

Memory cleared


In [68]:
# Installations & Setup (Run once per session)
!pip install --quiet numpy pandas matplotlib scikit-learn torch torchvision torchaudio pytorch-lightning wandb rich ipywidgets tabulate tqdm
!git clone https://github.com/fschmid56/PretrainedSED.git
import sys
sys.path.append('/content/PretrainedSED')

fatal: destination path 'PretrainedSED' already exists and is not an empty directory.


In [69]:
# Imports & Config (Run once per session)
import os
import pickle
import zipfile
import shutil
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torchaudio
import torchvision
import torch.nn.functional as F
import pytorch_lightning as pl

from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBar
from pytorch_lightning.loggers import WandbLogger

from PretrainedSED.models.atstframe.audio_transformer import FrameASTModel

FP_costs = [1, 1, 2, 3, 3, 3, 3, 3, 3, 3]
FN_costs = [5, 5, 5, 10, 20, 15, 20, 15, 25, 15]

CONFIG = {
    # Seed for reproducibility
    "seed": 42,

    # Data
    "batch_size": 4,                   # Reduced to avoid OOM
    "accumulate_grad_batches": 2,     # Effective batch size = 8
    "num_workers": 4,
    "use_mel": True,                 # Switch between spectrogram vs. embedding mode
    "pca_components": 128,            # Required if use_mel=False

    # Model
    "model_type": "atst",
    "pretrained": True,
    "dropout": 0.3,
    "thresholds": [0.2, 0.2, 0.15, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
    "pos_weights": torch.tensor([fn/fp for fn, fp in zip(FN_costs, FP_costs)], dtype=torch.float32),

    # Optimization
    "lr": 1.5e-4,
    "weight_decay": 1e-5,
    "optimizer": "adamw",
    "gradient_clip_val": 0.,
    "precision": "bf16-mixed",
    "enable_flash_attention": True,

    # Scheduler
    "scheduler": "cosine",
    "warmup_epochs": 5,

    # Training
    "max_epochs": 75,
    "patience": 8,
    "limit_train_batches": 0.5,       # Faster iteration (optional)
}

In [70]:
# Data Download (Run once per session - with retry logic)
import time
from huggingface_hub import hf_hub_download, HfFileSystem

def download_with_retry(max_retries=3):
    for i in range(max_retries):
        try:
            # Download compute_cost.py
            if not os.path.exists("compute_cost.py"):
                print("Downloading compute_cost.py...")
                pyfile_path = hf_hub_download(
                    repo_id="fschmid56/mlpc2025_dataset",
                    filename="compute_cost.py",
                    repo_type="dataset"
                )
                shutil.copy(pyfile_path, "compute_cost.py")

            # Download dataset
            dataset_path = "/content/mlpc2025_dataset/data"
            if not os.path.exists(dataset_path):
                print("Downloading dataset...")
                os.makedirs("/content/mlpc2025_dataset", exist_ok=True)

                # Try direct download first
                try:
                    zip_path = hf_hub_download(
                        repo_id="fschmid56/mlpc2025_dataset",
                        filename="mlpc2025_dataset.zip",
                        repo_type="dataset"
                    )
                except Exception as e:
                    print(f"HF Hub download failed, trying alternative method: {str(e)}")
                    # Alternative download using wget
                    !wget https://huggingface.co/datasets/fschmid56/mlpc2025_dataset/resolve/main/mlpc2025_dataset.zip -O /content/mlpc2025_dataset.zip
                    zip_path = "/content/mlpc2025_dataset.zip"

                with zipfile.ZipFile(zip_path, 'r') as zip_ref:
                    zip_ref.extractall("/content/mlpc2025_dataset")

                # Clean up
                if os.path.exists(zip_path):
                    os.remove(zip_path)

            return dataset_path
        except Exception as e:
            print(f"Attempt {i+1} failed: {str(e)}")
            if i < max_retries - 1:
                print("Retrying in 10 seconds...")
                time.sleep(10)
    raise Exception("Failed after multiple retries. You can try manually downloading from:\n"
                   "https://huggingface.co/datasets/fschmid56/mlpc2025_dataset/tree/main\n"
                   "and uploading to your Colab session.")

# Run download
try:
    DATASET_PATH = download_with_retry()
    print("Successfully downloaded dataset to:", DATASET_PATH)

    # Import after setup
    from compute_cost import CLASSES as TARGET_CLASSES, get_ground_truth_df, get_segment_prediction_df, check_dataframe, total_cost
    print("TARGET_CLASSES:", TARGET_CLASSES)
    print("Functions available:", [get_ground_truth_df, get_segment_prediction_df, check_dataframe, total_cost])
except Exception as e:
    print(str(e))

Successfully downloaded dataset to: /content/mlpc2025_dataset/data
TARGET_CLASSES: ['Speech', 'Shout', 'Chainsaw', 'Jackhammer', 'Lawn Mower', 'Power Drill', 'Dog Bark', 'Rooster Crow', 'Horn Honk', 'Siren']
Functions available: [<function get_ground_truth_df at 0x7e110cbce340>, <function get_segment_prediction_df at 0x7e110cbce3e0>, <function check_dataframe at 0x7e110cbce160>, <function total_cost at 0x7e110cbce200>]


In [71]:
def load_data_with_cache(dataset_path):
    cache_file = os.path.join(dataset_path, "data_cache.pkl")

    if os.path.exists(cache_file):
        print("Loading cached data...")
        with open(cache_file, 'rb') as f:
            return pickle.load(f)

    print("Processing data from scratch...")
    metadata = pd.read_csv(os.path.join(dataset_path, 'metadata.csv'))
    all_files = metadata['filename'].unique()

    # Train/Val/Test split
    train_files, temp_files = train_test_split(
        all_files, test_size=0.4, random_state=CONFIG['seed'])
    val_files, test_files = train_test_split(
        temp_files, test_size=0.5, random_state=CONFIG['seed'])

    # Load features
    def load_features(files):
        X = []
        Y = {c: [] for c in TARGET_CLASSES}

        for fname in files:
            base = os.path.splitext(fname)[0]
            feat_path = os.path.join(dataset_path, 'audio_features', base + '.npz')
            features = np.load(feat_path)['embeddings']
            X.append(features)

            label_path = os.path.join(dataset_path, 'labels', base + '_labels.npz')
            labels = np.load(label_path)
            for c in TARGET_CLASSES:
                Y[c].append((np.max(labels[c], axis=1) > 0).astype(int))

        return X, Y

    X_train, Y_train = load_features(train_files)
    X_val, Y_val = load_features(val_files)
    X_test, Y_test = load_features(test_files)

    # Apply PCA if using embeddings
    if not CONFIG['use_mel']:
        # Determine safe number of components
        n_samples = X_train[0].shape[0]
        n_features = X_train[0].shape[1]
        safe_components = min(n_samples, n_features, CONFIG['pca_components'])
        print(f"Applying PCA with {safe_components} components (samples: {n_samples}, features: {n_features})")

        pca = PCA(safe_components)
        X_train = [pca.fit_transform(x) if i == 0 else pca.transform(x)
                  for i, x in enumerate(X_train)]
        X_val = [pca.transform(x) for x in X_val]
        X_test = [pca.transform(x) for x in X_test]

    # Cache results
    with open(cache_file, 'wb') as f:
        pickle.dump((X_train, Y_train, train_files,
                    X_val, Y_val, val_files,
                    X_test, Y_test, test_files), f)

    return X_train, Y_train, train_files, X_val, Y_val, val_files, X_test, Y_test, test_files

data = load_data_with_cache(DATASET_PATH)
X_train, Y_train, train_files, X_val, Y_val, val_files, X_test, Y_test, test_files = data

Loading cached data...


In [72]:
class ATSTFrameSED(nn.Module):
    def __init__(self, num_classes=10, pretrained=True, use_mel=True, input_dim=128):
        super().__init__()
        self.use_mel = use_mel
        self.target_frames = 512
        self.feature_adapter = nn.Sequential(
            nn.Linear(768, 1024),
            nn.GELU(),
            nn.LayerNorm(1024)
        )

        if self.use_mel:
            # FrameAST model for spectrogram input
            self.atst = FrameASTModel(
                patch_h=8,
                patch_w=4,
                atst_dropout=0.1,
                num_classes=num_classes,
                pos_type="cut",
                nprompt=0
            )
        else:
            # For embeddings: project input dim to transformer dim
            self.input_proj = nn.Linear(input_dim, 768)

        self.output_adapter = nn.Sequential(
            nn.Linear(768, 768),
            nn.ReLU(),
            nn.Linear(768, num_classes)
        )

        if pretrained and self.use_mel:
            self.load_pretrained_weights()

    def forward(self, x, lengths=None):
        if lengths is None:
            lengths = torch.full((x.size(0),), x.size(-1), device=x.device)

        if self.use_mel:
            # Input: [B, 1, 64, T]
            dummy_mask = torch.zeros((x.size(0), x.size(-1)),
                                     dtype=torch.bool, device=x.device)
            x_token, _, _, _, _, _ = self.atst.prepare_tokens(
                x, mask_index=dummy_mask, length=lengths, mask=False)

            for blk in self.atst.blocks:
                x_token = blk(x_token, lengths)

            frame_repr = self.atst.norm_frame(x_token)[:, :self.target_frames, :]
        else:
            # Input: [B, T, D] → project to match transformer output
            frame_repr = self.input_proj(x)

        return self.output_adapter(frame_repr)

    def load_pretrained_weights(self):
        pretrained_url = "https://github.com/fschmid56/PretrainedSED/releases/download/v0.0.1/ATST-F_strong_1.pt"
        try:
            state_dict = torch.hub.load_state_dict_from_url(
                pretrained_url, map_location='cpu'
            )
            if 'model' in state_dict:
                state_dict = state_dict['model']
            state_dict = {k: v for k, v in state_dict.items() if 'pos_embed' not in k}
            self.atst.load_state_dict(state_dict, strict=False)
            print("Loaded pretrained weights (excluding positional embeddings)")
        except Exception as e:
            print(f"Failed to load weights: {e}")

In [73]:
class SequenceDataset(Dataset):
    def __init__(self, Y, classes, filenames, audio_dir, is_training=False):
        self.Y = Y
        self.classes = classes
        self.filenames = filenames
        self.audio_dir = audio_dir
        self.is_training = is_training  # Training mode flag

        # Spectrogram transform
        self.mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=16000,
            n_fft=512,
            hop_length=160,
            n_mels=64,
            f_min=50,
            f_max=7500
        )
        self.amp_to_db = torchaudio.transforms.AmplitudeToDB()

        # Initialize augmentation transforms
        self.time_mask = torchaudio.transforms.TimeMasking(time_mask_param=100)  # ~1.0s
        self.freq_mask = torchaudio.transforms.FrequencyMasking(freq_mask_param=8)

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        # Load audio
        audio_path = os.path.join(self.audio_dir, self.filenames[idx])
        try:
            x, sr = torchaudio.load(audio_path)
            assert sr == 16000, f"Unexpected sample rate {sr}"
            x = x.squeeze(0)  # [1, T] -> [T]

            # Apply random gain augmentation during training (-5dB to +5dB)
            if self.is_training:
                gain = torch.empty(1).uniform_(-5, 5)  # More efficient than rand()
                x = x * (10 ** (gain / 20))

        except Exception as e:
            print(f"Error loading {audio_path}: {str(e)}")
            # Return silent audio as fallback
            x = torch.zeros(16000)  # 1s of silence

        # Compute spectrogram
        with torch.no_grad():
            spec = self.mel_transform(x)  # [64, T']

            # Apply SpecAugment during training
            if self.is_training:
                spec = self.time_mask(self.freq_mask(spec))

            # Pad/trim to fixed 512 frames (~5.12s)
            spec = spec[..., :512] if spec.size(-1) > 512 else F.pad(spec, (0, 512 - spec.size(-1)))

            # Normalize dB-scale to [0,1]
            spec = self.amp_to_db(spec)
            spec = (spec + 80) / 80  # Assuming -80dB floor
            spec = spec.unsqueeze(0)  # [1, 64, 512]

        # Process labels
        y = torch.stack([
            torch.tensor(self.Y[c][idx], dtype=torch.float32)
            for c in self.classes
        ], dim=1)  # [T, 10]

        # Align labels with spectrogram length
        y = y[:512] if y.size(0) >= 512 else F.pad(y, (0, 0, 0, 512 - y.size(0)))

        return spec, y, torch.tensor(spec.shape[-1]), self.filenames[idx]

class SEDDataModule(pl.LightningDataModule):
    def __init__(self, dataset_path, batch_size=16, num_workers=4):
        super().__init__()
        self.dataset_path = dataset_path
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.audio_dir = os.path.join(dataset_path, 'audio')

    def setup(self, stage=None):
        data = load_data_with_cache(self.dataset_path)
        _, Y_train, train_files, _, Y_val, val_files, _, Y_test, test_files = data

        self.train_ds = SequenceDataset(Y_train, TARGET_CLASSES, train_files,
                                     self.audio_dir, is_training=True)  # Enable augmentations
        self.val_ds = SequenceDataset(Y_val, TARGET_CLASSES, val_files, self.audio_dir)
        self.test_ds = SequenceDataset(Y_test, TARGET_CLASSES, test_files, self.audio_dir)

    def collate_fn(self, batch):
        # Each spec is [1, 64, 624] → stack becomes [B, 1, 64, 624]
        specs = torch.stack([item[0] for item in batch])      # [B, 1, 64, 624]
        labels = torch.stack([item[1] for item in batch])     # [B, 624, 10]
        lengths = torch.stack([item[2] for item in batch])    # [B]
        filenames = [item[3] for item in batch]               # List[str]
        return specs, labels, lengths, filenames

    def train_dataloader(self):
        return DataLoader(
            self.train_ds,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            persistent_workers=True
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_ds,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            persistent_workers=True
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_ds,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
            persistent_workers=True
        )


In [74]:
class SEDLightningModule(pl.LightningModule):
    def __init__(self, classes, lr=1e-4, thresholds=None, pos_weight=None):
        super().__init__()
        self.save_hyperparameters()
        self.model = ATSTFrameSED(
            num_classes=len(classes),
            pretrained=True,
            use_mel=CONFIG['use_mel'],
            input_dim=CONFIG.get('pca_components', 128)
        )

        # Convert thresholds to tensor if provided
        if thresholds is not None:
            self.register_buffer('thresholds', torch.tensor(thresholds, dtype=torch.float32))
        else:
            self.thresholds = None

        self.criterion = nn.BCEWithLogitsLoss(
            pos_weight=pos_weight,
            reduction='none'
        )
        self.criterion = nn.BCEWithLogitsLoss(
            pos_weight=pos_weight,
            reduction='none'
        )
        self.val_preds = {c: [] for c in classes}
        self.val_targets = {c: [] for c in classes}
        self.val_filenames = []

    def forward(self, x, lengths=None):
        return self.model(x, lengths)

    def training_step(self, batch, batch_idx):
        if not isinstance(batch, (list, tuple)):
            raise ValueError(f"Unexpected batch type: {type(batch)}")

        if batch_idx == 0:
            print("[DEBUG] Entered first training_step of epoch")
            print(f"[DEBUG] Batch type: {type(batch)}, length: {len(batch)}")
            for i, item in enumerate(batch):
                print(f"[DEBUG] Item {i}: type={type(item)}, shape={getattr(item, 'shape', 'list' if isinstance(item, list) else 'unknown')}")

        # Safely unpack or fail with more info
        try:
            X, Y, lengths, _ = batch
        except ValueError as e:
            raise ValueError(f"[ERROR] Batch unpacking failed (expected 4 items). Batch content: {batch}") from e

        logits = self(X, lengths)
        mask = (torch.arange(logits.size(1), device=logits.device)[None, :] < lengths[:, None]).unsqueeze(-1).float()
        loss = (self.criterion(logits, Y.float()) * mask).sum() / mask.sum()
        self.log("train/loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        X, Y, lengths, filenames = batch  # Matches dataset format
        logits = self(X, lengths)
        probs = torch.sigmoid(logits)
        if self.thresholds is not None:
            # Expand thresholds to match logits shape [B,T,C]
            thresholds = self.thresholds.view(1, 1, -1).expand_as(probs)
            preds = (probs > thresholds).long()
        else:
            # Fallback to default 0.5 threshold
            preds = (probs > 0.5).long()

        for i, cls in enumerate(self.hparams.classes):
            self.val_preds[cls].append(preds[..., i].cpu())
            self.val_targets[cls].append(Y[..., i].long().cpu())

        self.val_filenames.extend(filenames)
        mask = (torch.arange(logits.size(1), device=logits.device))[None, :] < lengths[:, None]
        mask = mask.unsqueeze(-1).float()
        loss = (self.criterion(logits, Y.float()) * mask).sum() / mask.sum()
        self.log("val/loss", loss, prog_bar=True, sync_dist=True)

    def on_validation_epoch_end(self):
        preds_numpy = {
            cls: torch.cat(self.val_preds[cls]).numpy()
            for cls in self.hparams.classes
        }
        targets_numpy = {
            cls: torch.cat(self.val_targets[cls]).numpy()
            for cls in self.hparams.classes
        }

        total_cost, _ = self.evaluate_cost(
            preds_numpy=preds_numpy,
            targets_numpy=targets_numpy,
            filenames=list(set(self.val_filenames))
        )
        self.log("val/cost", total_cost, prog_bar=True, sync_dist=True)

        self.val_preds = {c: [] for c in self.hparams.classes}
        self.val_targets = {c: [] for c in self.hparams.classes}
        self.val_filenames = []

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.hparams.lr,
            weight_decay=CONFIG['weight_decay'],
            betas=(0.9, 0.999)
        )

        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer,
            T_0=CONFIG['warmup_epochs'] * len(self.trainer.datamodule.train_dataloader()),
            T_mult=1,
            eta_min=1e-6
        )

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step"
            }
        }

    def evaluate_cost(self, preds_numpy, targets_numpy, filenames):
        predictions_dict = {
            filename: {cls: preds_numpy[cls][i] for cls in self.hparams.classes}
            for i, filename in enumerate(filenames)
        }

        try:
            gt_df = get_ground_truth_df(filenames, DATASET_PATH)
            pred_df = get_segment_prediction_df(predictions_dict, self.hparams.classes)

            merged = pred_df.merge(
                gt_df,
                on=["filename", "onset"],
                how="inner",
                suffixes=("_pred", "_true")
            )
            if len(merged) == 0:
                raise ValueError("No matching rows between predictions and ground truth")

            cost_value, breakdown = total_cost(
                merged[[col for col in merged.columns if col.endswith("_pred") or col in ["filename", "onset"]]].rename(
                    columns={col: col.replace("_pred", "") for col in merged.columns if col.endswith("_pred")}
                ),
                merged[[col for col in merged.columns if col.endswith("_true") or col in ["filename", "onset"]]].rename(
                    columns={col: col.replace("_true", "") for col in merged.columns if col.endswith("_true")}
                )
            )
            return cost_value, breakdown
        except Exception as e:
            print(f"Error during cost calculation: {str(e)}")
            return float('inf'), {}

    test_step = validation_step
    on_test_epoch_end = on_validation_epoch_end

In [75]:
def verify_data_pipeline(dm):
    print("\n=== VERIFYING DATA PIPELINE ===")

    # Check dataset
    sample = dm.train_ds[0]
    print(f"Single sample contains {len(sample)} items:")
    for i, item in enumerate(sample):
        print(f"  Item {i}: Type={type(item)}, Shape={item.shape if hasattr(item, 'shape') else 'str'}")

    # Check batch
    batch = next(iter(dm.train_dataloader()))
    print(f"\nBatch contains {len(batch)} items:")
    for i, item in enumerate(batch):
        if isinstance(item, (list, tuple)):
            print(f"  Item {i}: Type={type(item)}, Length={len(item)}")
        else:
            print(f"  Item {i}: Type={type(item)}, Shape={item.shape if hasattr(item, 'shape') else 'N/A'}")

    # Verify model can process the batch
    model = SEDLightningModule(TARGET_CLASSES, lr=CONFIG['lr'])
    try:
        X, Y, lengths, _ = batch
        outputs = model(X.to('cuda'), lengths.to('cuda'))
        print(f"\nModel output shape: {outputs.shape} (should be [16, 624, 10])")
        return True
    except Exception as e:
        print(f"\nModel forward pass failed: {str(e)}")
        return False

In [76]:
torch.cuda.empty_cache()
gc.collect()

def train_model():
    # Memory Management
    torch.cuda.empty_cache()
    gc.collect()
    print(f"GPU Memory cleared: {torch.cuda.memory_allocated()/1e9:.1f}GB free")

    # Data Setup
    dm = SEDDataModule(
        dataset_path=DATASET_PATH,
        batch_size=CONFIG['batch_size'],
        num_workers=CONFIG['num_workers']
    )
    dm.setup()
    print("Data module ready")

    # Model Initialization
    model = SEDLightningModule(
        classes=TARGET_CLASSES,
        lr=CONFIG['lr'],
        thresholds=CONFIG['thresholds'],
        pos_weight=CONFIG['pos_weights']  # Now properly a tensor
    )
    print(f"Model initialized with lr={CONFIG['lr']}")

    # Setup Weights & Biases logger
    run_type = "mel" if CONFIG['use_mel'] else "embed"
    wandb_logger = WandbLogger(project="atst_sed", name=f"run_{run_type}", log_model=True)
    wandb_logger.log_hyperparams(CONFIG)

    # Trainer Configuration
    trainer = pl.Trainer(
        logger=wandb_logger,
        accelerator="auto",
        devices="auto",
        max_epochs=CONFIG['max_epochs'],
        precision=CONFIG['precision'],
        accumulate_grad_batches=CONFIG['accumulate_grad_batches'],
        gradient_clip_val=CONFIG['gradient_clip_val'],
        callbacks=[
            ModelCheckpoint(
                monitor="val/cost",
                mode="min",
                save_top_k=1,
                filename="best-{epoch:02d}-{val_cost:.2f}",
                auto_insert_metric_name=False
            ),
            EarlyStopping(
                monitor="val/cost",
                patience=CONFIG['patience'],
                mode="min",
                min_delta=0.1,
                verbose=True
            ),
            LearningRateMonitor(),
            RichProgressBar(refresh_rate=5)
        ],
        log_every_n_steps=1,
        check_val_every_n_epoch=3,         # <-- validate every 3 epochs (faster)
        limit_train_batches=0.5,           # <-- use only 50% of data per epoch (for tuning)
        enable_model_summary=True
    )

    # Training Execution
    print("\nStarting Training...")
    print(f"Batch Size: {CONFIG['batch_size']} (Effective: {CONFIG['batch_size']*CONFIG['accumulate_grad_batches']})")
    print(f"Steps per Epoch: {len(dm.train_dataloader())}")

    # === RUN without suppressing errors ===
    trainer.fit(model, dm)
    test_results = trainer.test(model, dm, ckpt_path="best")
    print(f"\nFinal Validation Cost: {test_results[0]['val/cost']:.2f}")

    return model


In [None]:
# Execution Block (Run to train)
if __name__ == "__main__":
    import torch.backends.cuda

    print("Flash Attention Enabled (before):", torch.backends.cuda.flash_sdp_enabled())

    # Force-enable Flash Attention modes
    torch.backends.cuda.enable_flash_sdp(True)
    torch.backends.cuda.enable_mem_efficient_sdp(True)
    torch.backends.cuda.enable_math_sdp(False)

    print("Flash Attention Enabled (after):", torch.backends.cuda.flash_sdp_enabled())

    # Optional: set matmul precision for speed
    torch.set_float32_matmul_precision('medium')

    # Launch training
    best_model = train_model()


Flash Attention Enabled (before): True
Flash Attention Enabled (after): True
GPU Memory cleared: 0.0GB free
Loading cached data...
Data module ready


INFO:pytorch_lightning.utilities.rank_zero:Using bfloat16 Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


Loaded pretrained weights (excluding positional embeddings)
Model initialized with lr=0.00015

Starting Training...
Batch Size: 4 (Effective: 8)
Steps per Epoch: 1235
Loading cached data...


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()