In [1]:
from pathlib import Path
from libribrain_experiments.grouped_dataset import MyGroupedDatasetV3
from typing import Literal
from pnpl.datasets import LibriBrainCompetitionHoldout, LibriBrainPhoneme

raw_source_dataset = LibriBrainPhoneme(
    data_path="./data/",
    tmin=0.0,
    tmax=0.5,
    standardize=True,
)
source_dataset = MyGroupedDatasetV3(
    raw_source_dataset,
    grouped_samples=100,
    drop_remaining=False,
    average_grouped_samples=True,
    state_cache_path=Path(f"./data_preprocessed/groupedv3/all_grouped_100.pt"),
    # balance=True,
    shuffle=True,
)

holdout_dataset = LibriBrainCompetitionHoldout(
    data_path="./data/",
    task="phoneme",
    tmin=0.0,
    tmax=0.5,
    standardize=False, # already standardized
)

In [2]:
import torch
from torch.utils.data import ConcatDataset, DataLoader

full_dataset = ConcatDataset([source_dataset, holdout_dataset])

train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])

def collate_fn(batch):
    ys = torch.empty((len(batch),), dtype=torch.long)
    for i, sample in enumerate(batch):
        # print(f"{i}: {type(sample)}")
        if type(sample) is tuple:
            ys[i] = 0  # from source_dataset
            batch[i] = sample[0]
        else:
            ys[i] = 1  # from holdout_dataset
    xs = torch.stack(batch)
    return xs, ys

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

In [3]:
import torch
import lightning as L
from torch import nn
from torchmetrics import F1Score


class SourceClassificationModel(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv1d(306, 128, 1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(16000, 2)
        )
        self.criterion = nn.CrossEntropyLoss()
        self.f1_macro = F1Score(
            num_classes=2, average='macro', task="multiclass")

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        f1_macro = self.f1_macro(y_hat, y)
        self.log('train_loss', loss, prog_bar=True)
        self.log('train_f1_macro', f1_macro)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        f1_macro = self.f1_macro(y_hat, y)
        self.log('val_loss', loss)
        self.log('val_f1_macro', f1_macro, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=0.0005)

In [4]:
import os
from torch.utils.data import DataLoader
import lightning as L
from lightning.pytorch.loggers import TensorBoardLogger, CSVLogger

# Setup paths for logs and checkpoints
LOG_DIR = f"lightning_logs"
CHECKPOINT_PATH = f"models/phoneme_model.ckpt"

logger = CSVLogger(
    save_dir=LOG_DIR,
    name="",
    version=None,
)

# Set a fixed seed for reproducibility
L.seed_everything(42)

# Conditionally set num_workers to avoid multiprocessing issues (try increasing if performance is problematic)
num_workers = 4

# Initialize the SourceClassificationModel model
model = SourceClassificationModel()

# Log Hyperparameters (these will be empty be default!)
logger.log_hyperparams(model.hparams)

# Initialize trainer
trainer = L.Trainer(
    devices="auto",
    max_epochs=15,
    logger=logger,
    enable_checkpointing=True,
)

# Actually train the model
trainer.fit(model, train_loader, val_loader)

# Save the trained model
trainer.save_checkpoint(CHECKPOINT_PATH)

Seed set to 42
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
  return _C._get_float32_matmul_precision()
You are using a CUDA device ('NVIDIA GeForce RTX 5090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type              | Params | Mode 
--------------------------------------------------------
0 | model     | Sequential        | 71.3 K | train
1 | criterion | CrossEntropyLoss  | 0      | train
2 |

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/dogeon/libribrain/phoneme/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
/home/dogeon/libribrain/phoneme/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...


SystemExit: 1

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
