In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

# ================= 1. CONFIG =================
CONFIG = {
    "feature_dir": Path("features"),
    "batch_size": 8,
    "lr": 5e-5,
    "epochs": 20,
    "d_model": 64,
    "nhead": 4,
    "num_layers": 2,
    "ssl_dim": 768,
    "sfm_dim": 7,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "checkpoint_path": "best_model.pth"
}

print(f"‚öôÔ∏è Running on {CONFIG['device']}")

# ================= 2. GLOBAL CLASS MAP =================
TRAIN_SSL_DIR = CONFIG["feature_dir"] / "ssl" / "train"
CLASSES = sorted([d.name for d in TRAIN_SSL_DIR.iterdir() if d.is_dir()])
CLASS_TO_IDX = {cls: i for i, cls in enumerate(CLASSES)}

print(f"üéØ Classes: {CLASS_TO_IDX}")

# ================= 3. DATASET =================
class DualStreamDataset(Dataset):
    def __init__(self, split="train"):
        self.ssl_dir = CONFIG["feature_dir"] / "ssl" / split
        self.sfm_dir = CONFIG["feature_dir"] / "sfm" / split
        self.files = []

        for cls in CLASSES:
            ssl_cls = self.ssl_dir / cls
            sfm_cls = self.sfm_dir / cls

            for f in ssl_cls.glob("*.npy"):
                sfm_path = sfm_cls / f.name
                if sfm_path.exists():
                    self.files.append({
                        "ssl": f,
                        "sfm": sfm_path,
                        "label": CLASS_TO_IDX[cls]
                    })

        print(f"‚úÖ Loaded {len(self.files)} samples for '{split}'")

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        item = self.files[idx]
        ssl = torch.from_numpy(np.load(item["ssl"])).float()
        sfm = torch.from_numpy(np.load(item["sfm"])).float()
        label = torch.tensor(item["label"]).long()
        return ssl, sfm, label

def collate_fn(batch):
    ssl, sfm, labels = zip(*batch)

    ssl_pad = pad_sequence(ssl, batch_first=True, padding_value=0.0)
    sfm_pad = pad_sequence(sfm, batch_first=True, padding_value=0.0)

    lengths = torch.tensor([x.size(0) for x in ssl])
    mask = torch.arange(ssl_pad.size(1))[None, :] < lengths[:, None]

    return ssl_pad, sfm_pad, mask.bool(), torch.stack(labels)


‚öôÔ∏è Running on cuda
üéØ Classes: {'Cysts_Structural': 0, 'Dysarthia': 1, 'Hyperfunctional': 2, 'Laryngitis': 3, 'Vox senilis': 4, 'parkinson': 5, 'spasmodische_dysphonie': 6}


In [6]:
# ================= 4. MODEL (CORRECTED FOR TORCH VERSION) =================
class DualStreamTransformer(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        d = CONFIG["d_model"]

        # Projections
        self.ssl_proj = nn.Linear(CONFIG["ssl_dim"], d)
        self.sfm_proj = nn.Linear(CONFIG["sfm_dim"], d)

        self.ssl_norm = nn.LayerNorm(d)
        self.sfm_norm = nn.LayerNorm(d)

        # Independent encoders (NO weight sharing)
        self.ssl_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=d,
                nhead=CONFIG["nhead"],
                dim_feedforward=256,
                dropout=0.1,
                batch_first=True
            ),
            num_layers=CONFIG["num_layers"]
        )

        self.sfm_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=d,
                nhead=CONFIG["nhead"],
                dim_feedforward=256,
                dropout=0.1,
                batch_first=True
            ),
            num_layers=CONFIG["num_layers"]
        )

        # Cross-attention (NO average_attn_weights here)
        self.cross_attn = nn.MultiheadAttention(
            embed_dim=d,
            num_heads=CONFIG["nhead"],
            batch_first=True
        )

        self.fusion_norm = nn.LayerNorm(d)

        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(d, 32),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(32, num_classes)
        )

    def forward(self, ssl, sfm, mask):
        """
        ssl : (B, T, 768)
        sfm : (B, T, 7)
        mask: (B, T)  True = valid, False = padding
        """

        # Project + normalize
        ssl = self.ssl_norm(self.ssl_proj(ssl))
        sfm = self.sfm_norm(self.sfm_proj(sfm))

        # Encode streams
        ssl = self.ssl_encoder(ssl, src_key_padding_mask=~mask)
        sfm = self.sfm_encoder(sfm, src_key_padding_mask=~mask)

        # Cross-attention
        fused, attn_weights = self.cross_attn(
            query=ssl,
            key=sfm,
            value=sfm,
            key_padding_mask=~mask,
            need_weights=True,
            average_attn_weights=False   # ‚úÖ correct place
        )
        # attn_weights shape: (B, nhead, T, T)

        # Residual + norm
        fused = self.fusion_norm(fused + ssl)

        # Masked pooling
        mask_exp = mask.unsqueeze(-1).float()
        pooled = (fused * mask_exp).sum(dim=1) / (mask_exp.sum(dim=1) + 1e-8)

        logits = self.classifier(pooled)
        return logits, attn_weights




In [7]:
# ================= 5. TRAINING (PURE TERMINAL OUTPUT) =================
from tqdm.std import tqdm

def train():
    train_ds = DualStreamDataset("train")
    val_ds   = DualStreamDataset("val")

    train_loader = DataLoader(
        train_ds,
        batch_size=CONFIG["batch_size"],
        shuffle=True,
        collate_fn=collate_fn
    )

    val_loader = DataLoader(
        val_ds,
        batch_size=CONFIG["batch_size"],
        shuffle=False,
        collate_fn=collate_fn
    )

    model = DualStreamTransformer(len(CLASSES)).to(CONFIG["device"])
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=CONFIG["lr"], weight_decay=1e-4)

    best_acc = 0.0
    history = {"train_loss": [], "val_loss": [], "val_acc": []}

    print("\nüöÄ Training Started")
    print("Epoch | Train Loss | Val Loss | Val Acc")
    print("-" * 45)

    for epoch in range(CONFIG["epochs"]):

        # ---------- TRAIN ----------
        model.train()
        train_loss = 0.0

        for ssl, sfm, mask, labels in tqdm(
            train_loader,
            desc=f"Epoch {epoch+1}/{CONFIG['epochs']}",
            ncols=80,
            ascii=True
        ):
            ssl, sfm, mask, labels = (
                ssl.to(CONFIG["device"]),
                sfm.to(CONFIG["device"]),
                mask.to(CONFIG["device"]),
                labels.to(CONFIG["device"])
            )

            optimizer.zero_grad()
            outputs, _ = model(ssl, sfm, mask)
            loss = criterion(outputs, labels)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            train_loss += loss.item()

        train_loss /= len(train_loader)

        # ---------- VALIDATION ----------
        model.eval()
        val_loss = 0.0
        correct, total = 0, 0

        with torch.no_grad():
            for ssl, sfm, mask, labels in val_loader:
                ssl, sfm, mask, labels = (
                    ssl.to(CONFIG["device"]),
                    sfm.to(CONFIG["device"]),
                    mask.to(CONFIG["device"]),
                    labels.to(CONFIG["device"])
                )

                outputs, _ = model(ssl, sfm, mask)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

                preds = outputs.argmax(dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)

        val_loss /= len(val_loader)
        val_acc = 100 * correct / total

        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_loss)
        history["val_acc"].append(val_acc)

        print(f"{epoch+1:>5} | {train_loss:.4f} | {val_loss:.4f} | {val_acc:.2f}%")

        # ---------- CHECKPOINT ----------
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), CONFIG["checkpoint_path"])
            print("   üíæ Best model saved")

    return model, history


In [None]:

# ================= 6. RUN =================
model, history = train()


‚úÖ Loaded 14740 samples for 'train'
‚úÖ Loaded 920 samples for 'val'

üöÄ Training Started
Epoch | Train Loss | Val Loss | Val Acc
---------------------------------------------


Epoch 1/20: 100%|###########################| 1843/1843 [01:30<00:00, 20.40it/s]


    1 | 1.0033 | 0.5160 | 83.91%
   üíæ Best model saved


Epoch 2/20: 100%|###########################| 1843/1843 [00:40<00:00, 45.94it/s]


    2 | 0.4972 | 0.4466 | 85.98%
   üíæ Best model saved


Epoch 3/20: 100%|###########################| 1843/1843 [00:39<00:00, 46.89it/s]


    3 | 0.4351 | 0.4349 | 85.87%


Epoch 4/20: 100%|###########################| 1843/1843 [00:38<00:00, 47.38it/s]


    4 | 0.3997 | 0.4242 | 86.52%
   üíæ Best model saved


Epoch 5/20: 100%|###########################| 1843/1843 [00:41<00:00, 44.15it/s]


    5 | 0.3835 | 0.4442 | 86.41%


Epoch 6/20: 100%|###########################| 1843/1843 [00:42<00:00, 43.06it/s]


    6 | 0.3625 | 0.4549 | 87.07%
   üíæ Best model saved


Epoch 7/20: 100%|###########################| 1843/1843 [01:17<00:00, 23.91it/s]


    7 | 0.3449 | 0.4253 | 88.04%
   üíæ Best model saved


Epoch 8/20: 100%|###########################| 1843/1843 [01:06<00:00, 27.71it/s]


    8 | 0.3318 | 0.4935 | 85.98%


Epoch 9/20: 100%|###########################| 1843/1843 [01:08<00:00, 27.08it/s]


    9 | 0.3179 | 0.4209 | 87.07%


Epoch 10/20: 100%|##########################| 1843/1843 [00:59<00:00, 30.81it/s]


   10 | 0.3061 | 0.4418 | 87.72%


Epoch 11/20: 100%|##########################| 1843/1843 [00:59<00:00, 30.74it/s]


   11 | 0.2936 | 0.4086 | 88.80%
   üíæ Best model saved


Epoch 12/20: 100%|##########################| 1843/1843 [01:01<00:00, 29.93it/s]


   12 | 0.2833 | 0.4790 | 87.50%


Epoch 13/20: 100%|##########################| 1843/1843 [01:00<00:00, 30.57it/s]


   13 | 0.2724 | 0.4732 | 88.70%


Epoch 14/20: 100%|##########################| 1843/1843 [00:47<00:00, 38.59it/s]


   14 | 0.2601 | 0.4616 | 87.50%


Epoch 15/20: 100%|##########################| 1843/1843 [00:46<00:00, 40.06it/s]


   15 | 0.2510 | 0.4434 | 88.80%


Epoch 16/20: 100%|##########################| 1843/1843 [01:08<00:00, 26.99it/s]


In [None]:

# ================= 7. PLOT =================
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(history["train_loss"], label="Train")
plt.plot(history["val_loss"], label="Val")
plt.title("Loss")
plt.legend(); plt.grid()

plt.subplot(1, 2, 2)
plt.plot(history["val_acc"], label="Val Acc", color="green")
plt.title("Validation Accuracy")
plt.legend(); plt.grid()

plt.show()
