In [1]:
import torch
from argparse import  Namespace

# build MVCL model

In [2]:
from MVCL import MultiViewModel_lit, MultiViewModel

  from .autonotebook import tqdm as notebook_tqdm



## Configuration Parameters

| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `use_inner_CL` | int | 1 | Enable **inner contrastive learning** within the same modality for feature discrimination |
| `use_inter_CL` | int | 1 | Enable **inter-modal contrastive learning** between different modalities |
| `use_cls_loss_1_2` | int | 1 | Enable classification loss for modality 1 and modality 2 tasks |
| `use_fusion` | int | 1 | Enable **feature fusion mechanism** to combine multi-modal features |
| `use_fusion1D` | int | 1 | Enable **1D fusion** strategy for processing sequential feature fusion |
| `use_fusion2D` | int | 1 | Enable **2D fusion** strategy for processing spatial feature map fusion |
| `use_mse_loss` | int | 0 | Enable Mean Squared Error loss for regression tasks |
| `only_1D` | int | 0 | **Use only 1D modality**, ignoring other dimensional features |
| `only_2D` | int | 0 | **Use only 2D modality**, ignoring other dimensional features |
| `drop_layer` | float | 0.0 | Dropout rate for regularization to prevent overfitting |
| `w_con` | float | 1.0 | Weight coefficient for contrastive learning loss in total loss |
| `w_cls` | float | 1.0 | Weight coefficient for classification loss in total loss |

### Parameter Categories

#### 🎯 **Loss Function Control**
- `use_inner_CL`, `use_inter_CL`: Control different types of contrastive learning
- `use_cls_loss_1_2`: Control classification loss
- `use_mse_loss`: Control regression loss

#### 🔄 **Feature Fusion Strategy**
- `use_fusion`: Master switch for feature fusion
- `use_fusion1D`, `use_fusion2D`: Control fusion methods for different dimensions

#### 🎛️ **Modality Selection**
- `only_1D`, `only_2D`: Control whether to use only specific dimensional modalities

#### ⚖️ **Weight Balancing**
- `w_con`, `w_cls`: Balance the importance of different loss functions
- `drop_layer`: Regularization parameter

In [3]:

### default model configuration
mvcl_cfg = Namespace(
    use_inner_CL=1,
    use_inter_CL=1,
    use_cls_loss_1_2=1,
    use_fusion=1,
    use_fusion1D=1,
    use_fusion2D=1,
    use_mse_loss=0,
    only_1D=0,
    only_2D=0,
    drop_layer=0.0,
    w_con=1.0,
    w_cls=1.0,
)

In the first building, this will download the Wav2Clip model checkpoints and the WavLM model checkpoints.

In [4]:
mvcl = MultiViewModel(cfg=mvcl_cfg)

The input of our MVCL model is a torch Tensor with shape of (batch, 1, audio_length).

Take an random tensor as example. The batch size is 2, indicating this output is from processing 2 audio samples simultaneously.

In [5]:
x = torch.randn(2, 1, 48000)
res = mvcl(x)

In [6]:
for k, v in res.items():
    if isinstance(v, torch.Tensor):
        print(f"{k}: {v.shape}")
    else:
        print(f"{k}: {v}")

raw_spec: torch.Size([2, 1, 257, 257])
raw_wav_feat: torch.Size([2, 149, 768])
feature1D: torch.Size([2, 768])
feature2D: torch.Size([2, 512])
feature: torch.Size([2, 1280])
logit1D: torch.Size([2])
logit2D: torch.Size([2])
logit: torch.Size([2])


The output is a dict:

| Feature Name | Shape | Dimension | Description |
|--------------|-------|-----------|-------------|
| `raw_spec` | `[2, 1, 257, 257]` | 4D | **Raw spectrogram** - Original frequency-time representation of audio signal with 257 frequency bins and 257 time frames |
| `raw_wav_feat` | `[2, 149, 768]` | 3D | **Raw waveform features** - Sequential audio features extracted from backbone (e.g., WavLM), 149 time steps with 768-dimensional embeddings |
| `feature1D` | `[2, 768]` | 2D | **1D modality features** - the final classification feat of the 1D branch |
| `feature2D` | `[2, 512]` | 2D | **2D modality features** - the final classification feat of the 2D branch |
| `feature` | `[2, 1280]` | 2D | **Fused features** - Combined multi-modal features (1D + 2D), concatenated to 1280 dimensions (768 + 512) |
| `logit1D` | `[2]` | 1D | **1D modality logits** - Classification scores from 1D feature branch for binary classification |
| `logit2D` | `[2]` | 1D | **2D modality logits** - Classification scores from 2D feature branch for binary classification |
| `logit` | `[2]` | 1D | **Final logits** - Combined classification scores from fused features for final prediction |


**Feature Processing Pipeline**

```
Audio Input (batch, 1, 48000)
    ↓
┌─────────────────────────────────────────────────────────────┐
│                    Stage 1 (No Grad)                       │
├──────────────────────────┬──────────────────────────────────┤
│      1D Branch           │           2D Branch              │
│                          │                                  │
│ feature_model1D          │ feature_model2D                  │
│ .compute_stage1(x)       │ .compute_stage1(x, spec_aug)     │
│      ↓                   │      ↓                           │
│    wav1                  │   spec1, raw_spec               │
└──────────────────────────┴──────────────────────────────────┘
                           ↓
┌─────────────────────────────────────────────────────────────┐
│                 Cross-Modal Fusion                          │
├──────────────────────────┬──────────────────────────────────┤
│   squeeze_modules[0]     │     expand_modules[0]            │
│   (wav1, spec1)          │     (wav1, spec1)                │
│      ↓                   │      ↓                           │
│   fused_wav1             │   fused_spec1                    │
└──────────────────────────┴──────────────────────────────────┘
                           ↓
┌─────────────────────────────────────────────────────────────┐
│                    Stage 2                                  │
├──────────────────────────┬──────────────────────────────────┤
│ feature_model1D          │ feature_model2D                  │
│ .compute_stage2          │ .compute_stage2                  │
│ (fused_wav1)             │ (fused_spec1)                    │
│      ↓                   │      ↓                           │
│ wav2, position_bias      │   spec2                          │
└──────────────────────────┴──────────────────────────────────┘
                           ↓
┌─────────────────────────────────────────────────────────────┐
│              Cross-Modal Fusion + Stage 3                   │
├──────────────────────────┬──────────────────────────────────┤
│   squeeze_modules[1]     │     expand_modules[1]            │
│   (wav2, spec2)          │     (wav2, spec2)                │
│      ↓                   │      ↓                           │
│ feature_model1D          │ feature_model2D                  │
│ .compute_stage3          │ .compute_stage3                  │
│      ↓                   │      ↓                           │
│ wav3, position_bias      │   spec3                          │
└──────────────────────────┴──────────────────────────────────┘
                           ↓
┌─────────────────────────────────────────────────────────────┐
│              Cross-Modal Fusion + Stage 4                   │
├──────────────────────────┬──────────────────────────────────┤
│   squeeze_modules[2]     │     expand_modules[2]            │
│   (wav3, spec3)          │     (wav3, spec3)                │
│      ↓                   │      ↓                           │
│ feature_model1D          │ feature_model2D                  │
│ .compute_stage4          │ .compute_stage4                  │
│      ↓                   │      ↓                           │
│ wav4, position_bias      │   spec4                          │
└──────────────────────────┴──────────────────────────────────┘
                           ↓
┌─────────────────────────────────────────────────────────────┐
│           Final Cross-Modal Fusion + Latent Features        │
├──────────────────────────┬──────────────────────────────────┤
│   squeeze_modules[3]     │     expand_modules[3]            │
│   (wav4, spec4)          │     (wav4, spec4)                │
│      ↓                   │      ↓                           │
│ feature_model1D          │ feature_model2D                  │
│ .compute_latent_feature  │ .compute_latent_feature          │
│      ↓                   │      ↓                           │
│ wav5, raw_wav_feat       │   spec5                          │
└──────────────────────────┴──────────────────────────────────┘
                           ↓
┌─────────────────────────────────────────────────────────────┐
│                Feature Normalization                        │
├──────────────────────────┬──────────────────────────────────┤
│   norm_feat(wav5)        │   norm_feat(spec5)               │
│      ↓                   │      ↓                           │
│   feature1D [B, 768]     │   feature2D [B, 512]             │
│      ↓                   │      ↓                           │
│   cls1D(feature1D)       │   cls2D(feature2D)               │
│      ↓                   │      ↓                           │
│   logit1D [B]            │   logit2D [B]                    │
└──────────────────────────┴──────────────────────────────────┘
                           ↓
┌─────────────────────────────────────────────────────────────┐
│                   Multi-Modal Fusion                        │
│                                                             │
│        concat([wav5, spec5], dim=-1)                        │
│                     ↓                                       │
│              norm_feat(concat)                              │
│                     ↓                                       │
│               feature [B, 1280]                             │
│                     ↓                                       │
│              cls_final(feature)                             │
│                     ↓                                       │
│                logit [B]                                    │
└─────────────────────────────────────────────────────────────┘
```

# Lit model

> PyTorch Lightning is the deep learning framework for professional AI researchers and machine learning engineers who need maximal flexibility without sacrificing performance at scale. Lightning evolves with you as your projects go from idea to paper/production

We use [pytorch_lightning](https://lightning.ai/docs/pytorch/stable/) to train, validate, and test our model. Besides, it can also easily control the logging, model saving and callbacks.

In [7]:
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.loggers import  CSVLogger

We use the pytorch Lightning module to train the model, where we define the train step, validation/predict step, loss function and optimizer.

In [8]:
mvcl_lit = MultiViewModel_lit(cfg=mvcl_cfg)

BCE loss with label smoothing:  0.1


## Test forwarding 

In the lit model, we use the `_shared_pred` method to predict the logits of the input batch. If the stage is train, we also the the audio_transform to augment the spectrogram.

Generate a random batch:

In [9]:
batch = {
    "label": torch.randint(0, 2, (3,)),
    "audio": torch.randn(3, 1, 48000),
    "sample_rate": [16000, 16000, 16000],
}

Note, **your batch must be a dict with above keys**.

As can be seen, the `_shared_pred` output is also a dict. We use it to compute the loss
function, AUC, and ERR scores.

In [10]:
batch_res = mvcl_lit._shared_pred(batch=batch, batch_idx=0)
for key, value in batch_res.items():
    print(key, value.shape)

raw_spec torch.Size([3, 1, 257, 257])
raw_wav_feat torch.Size([3, 149, 768])
feature1D torch.Size([3, 768])
feature2D torch.Size([3, 512])
feature torch.Size([3, 1280])
logit1D torch.Size([3])
logit2D torch.Size([3])
logit torch.Size([3])
pred torch.Size([3])


# Demo training

We first build a simple dataloaders for training, where all the samples are randomly generated.

In [11]:
from callbacks import EER_Callback, BinaryAUC_Callback, BinaryACC_Callback

In [12]:
import torch
from torch.utils.data import Dataset, DataLoader

Generate a dataloader with random values for demo training.

In [13]:
class SimpleTestDataset(Dataset):
    def __init__(self, num_samples=10):
        # Generate synthetic data similar to your example
        self.samples = []
        for _ in range(num_samples):
            self.samples.append({
                "audio": torch.randn(1, 48000),
                "label": torch.randint(0, 2, (1,)).item(),
                "sample_rate": 16000,
            })
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        return self.samples[idx]



# Create the dataset and dataloader
train_dataloader = DataLoader(
    SimpleTestDataset(num_samples=100),
    batch_size=3,
    shuffle=True,
)
val_dataloader = DataLoader(
    SimpleTestDataset(num_samples=50),
    batch_size=3,
    shuffle=False,
)
test_dataloader = DataLoader(
    SimpleTestDataset(num_samples=20),
    batch_size=3,
    shuffle=False,
)

We build a simple trainer to train and test our model, which uses:
- 

In [14]:
trainer = Trainer(
    logger=CSVLogger(save_dir="./logs", version=0),
    max_epochs=4,
    callbacks=[
        BinaryACC_Callback(batch_key="label", output_key="logit"),
        BinaryAUC_Callback(batch_key="label", output_key="logit"),
        EER_Callback(batch_key="label", output_key="logit"),
    ],
    devices=[0], # use cuda:0 device
    accelerator="gpu", # use GPU acceleration
)

💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Start training

In [15]:
trainer.fit(mvcl_lit, train_dataloader, val_dataloaders=val_dataloader)

You are using a CUDA device ('NVIDIA GeForce RTX 4090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
/home/ay/anaconda3/envs/mvcl/lib/python3.9/site-packages/lightning_fabric/loggers/csv_logs.py:268: Experiment logs directory ./logs/lightning_logs/version_0 exists and is not empty. Previous log files in this directory will be deleted when the new ones are saved!
/home/ay/anaconda3/envs/mvcl/lib/python3.9/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:658: Checkpoint directory ./logs/lightning_logs/version_0/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]



  | Name                | Type                            | Params | Mode 
--------------------------------------------------------------------------------
0 | model               | MultiViewModel                  | 128 M  | train
1 | clip_heads          | ModuleList                      | 1.6 M  | train
2 | bce_loss            | LabelSmoothingBCE               | 0      | train
3 | contrast_loss2      | BinaryTokenContrastLoss         | 0      | train
4 | triplet_loss        | TripletMarginLoss               | 0      | train
5 | clip_loss           | CLIPLoss1D                      | 1      | train
6 | reconstruction_loss | TimeFrequencyReconstructionLoss | 379 K  | train
--------------------------------------------------------------------------------
130 M     Trainable params
0         Non-trainable params
130 M     Total params
521.061   Total estimated model params size (MB)
210       Modules in train mode
233       Modules in eval mode


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

/home/ay/anaconda3/envs/mvcl/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.


                                                                           

/home/ay/anaconda3/envs/mvcl/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.
/home/ay/anaconda3/envs/mvcl/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py:310: The number of training batches (34) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 3: 100%|██████████| 34/34 [00:03<00:00,  8.89it/s, v_num=0, val-clip_loss=2.200, val-mse_loss=1.000, val-cls_loss1D=0.688, val-cls_loss2D=0.692, val-cls_loss=0.689, val-contrast_loss=0.310, val-contrast_loss1D=0.330, val-contrast_loss2D=0.284, val-loss=4.880, val-acc=0.440, val-auc=0.458, val-eer=0.545, train-clip_loss=2.250, train-mse_loss=1.000, train-cls_loss1D=0.690, train-cls_loss2D=0.688, train-cls_loss=0.686, train-contrast_loss=0.296, train-contrast_loss1D=0.336, train-contrast_loss2D=0.276, train-loss=4.930, train-acc=0.440, train-auc=0.524, train-eer=0.536]

`Trainer.fit` stopped: `max_epochs=4` reached.


Epoch 3: 100%|██████████| 34/34 [00:09<00:00,  3.73it/s, v_num=0, val-clip_loss=2.200, val-mse_loss=1.000, val-cls_loss1D=0.688, val-cls_loss2D=0.692, val-cls_loss=0.689, val-contrast_loss=0.310, val-contrast_loss1D=0.330, val-contrast_loss2D=0.284, val-loss=4.880, val-acc=0.440, val-auc=0.458, val-eer=0.545, train-clip_loss=2.250, train-mse_loss=1.000, train-cls_loss1D=0.690, train-cls_loss2D=0.688, train-cls_loss=0.686, train-contrast_loss=0.296, train-contrast_loss1D=0.336, train-contrast_loss2D=0.276, train-loss=4.930, train-acc=0.440, train-auc=0.524, train-eer=0.536]


After training, you can view the logging loss in the logger file, for example `logs/lightning_logs/version_0/metrics.csv`.
![](imgs/loss.png)

## Demo Testing

After testing, the results will also saved in logger file.

In [16]:
trainer.test(mvcl_lit, test_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
/home/ay/anaconda3/envs/mvcl/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 7/7 [00:00<00:00, 29.50it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test-acc            0.44999998807907104
        test-auc            0.4545454978942871
     test-clip_loss         2.1567835807800293
      test-cls_loss         0.6900644302368164
     test-cls_loss1D        0.6895066499710083
     test-cls_loss2D        0.6917417049407959
   test-contrast_loss       0.31804102659225464
  test-contrast_loss1D      0.34499993920326233
  test-contrast_loss2D      0.28468912839889526
        test-eer            0.5454545617103577
        test-loss            4.857785701751709
      test-mse_loss         1.0041100978851318
──────────────────────────────────────────────────────────────────────────────

[{'test-clip_loss': 2.1567835807800293,
  'test-mse_loss': 1.0041100978851318,
  'test-cls_loss1D': 0.6895066499710083,
  'test-cls_loss2D': 0.6917417049407959,
  'test-cls_loss': 0.6900644302368164,
  'test-contrast_loss': 0.31804102659225464,
  'test-contrast_loss1D': 0.34499993920326233,
  'test-contrast_loss2D': 0.28468912839889526,
  'test-loss': 4.857785701751709,
  'test-acc': 0.44999998807907104,
  'test-auc': 0.4545454978942871,
  'test-eer': 0.5454545617103577}]

<div class="alert alert-success">
Note, train, val, and test process will logging in the same file: `metrics.csv`.
</div>