In [1]:
import os
import torch
import torch.nn as nn
import numpy as np
import librosa
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from tqdm import tqdm

# ========== CONFIG ==========
N_MFCC = 13
MAX_AUDIO_LEN = 100
IMG_SIZE = (128, 128)
BATCH_SIZE = 8
EPOCHS = 10
FRAMES_PER_SAMPLE = 10  # we simulate 10 grayscale frames per sample

# ========== AUDIO FEATURE ==========

def extract_mfcc(wav_path, n_mfcc=N_MFCC, max_len=MAX_AUDIO_LEN):
    y, sr = librosa.load(wav_path, sr=None)
    mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=n_mfcc)
    if mfcc.shape[1] < max_len:
        mfcc = np.pad(mfcc, ((0, 0), (0, max_len - mfcc.shape[1])), mode='constant')
    else:
        mfcc = mfcc[:, :max_len]
    return torch.tensor(mfcc, dtype=torch.float32)

# ========== DATASET ==========

class MultiModalSpeechDataset(Dataset):
    def __init__(self, root_dir, label, transform=None):
        self.samples = []
        self.label = label
        self.transform = transform
        for root, _, files in os.walk(root_dir):
            for file in files:
                if file.endswith(".ult"):
                    base = os.path.splitext(file)[0]
                    wav_path = os.path.join(root, base + ".wav")
                    if os.path.exists(wav_path):
                        self.samples.append((os.path.join(root, file), wav_path))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        _, wav_path = self.samples[idx]
        # Simulate 10 frames of ultrasound as grayscale images
        imgs = []
        for _ in range(FRAMES_PER_SAMPLE):
            frame = Image.fromarray(np.random.randint(0, 255, IMG_SIZE, dtype=np.uint8))
            img_tensor = self.transform(frame) if self.transform else transforms.ToTensor()(frame)
            imgs.append(img_tensor)
        imgs = torch.stack(imgs)  # shape: [10, 1, 128, 128]

        mfcc = extract_mfcc(wav_path)
        return imgs, mfcc, self.label

# ========== MODEL ==========

class MultiModalNet(nn.Module):
    def __init__(self, audio_feat_dim, num_classes=3):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 16, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten()
        )
        self.audio_net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(audio_feat_dim, 64), nn.ReLU(),
            nn.Linear(64, 32)
        )
        self.classifier = nn.Sequential(
            nn.Linear(32 + 32, 64), nn.ReLU(),
            nn.Linear(64, num_classes)
        )

    def forward(self, ult_imgs, audio_feat):
        batch_size, seq_len, c, h, w = ult_imgs.size()
        ult_imgs = ult_imgs.view(batch_size * seq_len, c, h, w)
        x1 = self.cnn(ult_imgs)  # [batch_size * 10, 32]
        x1 = x1.view(batch_size, seq_len, -1).mean(dim=1)  # [batch_size, 32]

        x2 = self.audio_net(audio_feat)
        x = torch.cat((x1, x2), dim=1)
        return self.classifier(x)

# ========== TRAINING SCRIPT ==========

if __name__ == "__main__":
    transform = transforms.Compose([
        transforms.Resize(IMG_SIZE),
        transforms.ToTensor()
    ])

    # 🔍 Load datasets
    dataset_uxtd = MultiModalSpeechDataset("D:/UltraSuite/core-uxtd/core", label=0, transform=transform)
    dataset_uxssd = MultiModalSpeechDataset("D:/UltraSuite/core-uxssd/core", label=1, transform=transform)
    dataset_upx   = MultiModalSpeechDataset("D:/UltraSuite/core-upx/core", label=2, transform=transform)

    full_dataset = dataset_uxtd + dataset_uxssd + dataset_upx
    print(f"📦 Total samples: {len(full_dataset)}")

    # ✂ Split
    train_size = int(0.8 * len(full_dataset))
    test_size = len(full_dataset) - train_size
    train_data, test_data = random_split(full_dataset, [train_size, test_size])
    train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=BATCH_SIZE)

    # 🚀 Model Setup
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = MultiModalNet(audio_feat_dim=N_MFCC * MAX_AUDIO_LEN).to(device)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    print(f"🚀 Training on {device}...\n")
    for epoch in range(EPOCHS):
        model.train()
        total_loss, correct, total = 0, 0, 0
        for imgs, mfccs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
            imgs, mfccs, labels = imgs.to(device), mfccs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(imgs, mfccs)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            correct += (outputs.argmax(1) == labels).sum().item()
            total += labels.size(0)

        train_acc = 100 * correct / total

        # Test set
        model.eval()
        test_correct, test_total = 0, 0
        with torch.no_grad():
            for imgs, mfccs, labels in test_loader:
                imgs, mfccs, labels = imgs.to(device), mfccs.to(device), labels.to(device)
                outputs = model(imgs, mfccs)
                test_correct += (outputs.argmax(1) == labels).sum().item()
                test_total += labels.size(0)
        test_acc = 100 * test_correct / test_total

        print(f"📊 Epoch {epoch+1}: Loss={total_loss:.4f} | Train Acc={train_acc:.2f}% | Test Acc={test_acc:.2f}%\n")

    # 📈 Final Evaluation
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for imgs, mfccs, labels in test_loader:
            imgs, mfccs = imgs.to(device), mfccs.to(device)
            outputs = model(imgs, mfccs)
            preds = outputs.argmax(dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.numpy())

    print(f"✅ Final Test Accuracy: {accuracy_score(all_labels, all_preds) * 100:.2f}%")
    print("\n📋 Classification Report:")
    print(classification_report(all_labels, all_preds, digits=3))
    print("\n📊 Confusion Matrix:")
    print(confusion_matrix(all_labels, all_preds))


📦 Total samples: 4167
🚀 Training on cpu...



Epoch 1/10: 100%|██████████| 417/417 [04:49<00:00,  1.44it/s]


📊 Epoch 1: Loss=316.3685 | Train Acc=75.40% | Test Acc=80.82%



Epoch 2/10: 100%|██████████| 417/417 [03:11<00:00,  2.18it/s]


📊 Epoch 2: Loss=212.0157 | Train Acc=81.55% | Test Acc=75.78%



Epoch 3/10: 100%|██████████| 417/417 [03:13<00:00,  2.15it/s]


📊 Epoch 3: Loss=190.6772 | Train Acc=83.11% | Test Acc=82.61%



Epoch 4/10: 100%|██████████| 417/417 [02:41<00:00,  2.58it/s]


📊 Epoch 4: Loss=181.0587 | Train Acc=83.86% | Test Acc=84.41%



Epoch 5/10: 100%|██████████| 417/417 [03:01<00:00,  2.30it/s]


📊 Epoch 5: Loss=185.0133 | Train Acc=83.98% | Test Acc=82.73%



Epoch 6/10: 100%|██████████| 417/417 [03:14<00:00,  2.14it/s]


📊 Epoch 6: Loss=175.8726 | Train Acc=84.40% | Test Acc=81.29%



Epoch 7/10: 100%|██████████| 417/417 [03:15<00:00,  2.14it/s]


📊 Epoch 7: Loss=169.7753 | Train Acc=84.88% | Test Acc=83.69%



Epoch 8/10: 100%|██████████| 417/417 [03:05<00:00,  2.24it/s]


📊 Epoch 8: Loss=160.8600 | Train Acc=85.54% | Test Acc=80.70%



Epoch 9/10: 100%|██████████| 417/417 [03:15<00:00,  2.13it/s]


📊 Epoch 9: Loss=158.5417 | Train Acc=86.59% | Test Acc=84.29%



Epoch 10/10: 100%|██████████| 417/417 [03:02<00:00,  2.29it/s]


📊 Epoch 10: Loss=156.5718 | Train Acc=86.32% | Test Acc=82.25%

✅ Final Test Accuracy: 82.25%

📋 Classification Report:
              precision    recall  f1-score   support

           0      0.622     0.220     0.326       127
           1      0.810     0.965     0.880       564
           2      0.974     0.797     0.877       143

    accuracy                          0.823       834
   macro avg      0.802     0.661     0.694       834
weighted avg      0.809     0.823     0.795       834


📊 Confusion Matrix:
[[ 28  99   0]
 [ 17 544   3]
 [  0  29 114]]
