In [None]:
!git clone https://github.com/SolarFlareZ/amt-project.git
%cd amt-piano

In [None]:
!pip install -r requirements.txt

In [None]:
from src.data.transforms import AudioTransform, MIDILabels
from src.data.dataset import MAESTRODataset
from src.data.datamodule import MAESTRODataModule
from src.models.cnn import PianoCNN, ConvBlock
from src.lightning_module import AMTLightningModule
from src.evaluation import optimize_threshold, evaluate_frame_metrics

print("All imports successful!")

In [None]:
from omegaconf import OmegaConf

cfg = OmegaConf.create({
    "name": "mel",
    "sample_rate": 16000,
    "hop_length": 512,
    "n_fft": 2048,
    "n_mels": 229,
    "fmin": 30.0,
    "fmax": 8000.0
})

audio_transform = AudioTransform(cfg)
midi_labels = MIDILabels(cfg)

print(f"AudioTransform initialized: {audio_transform.spec_transform}")
print(f"MIDILabels FPS: {midi_labels.fps:.2f}")

In [None]:
# Cell 5: Test CNN model
import torch
from src.models.cnn import PianoCNN

model = PianoCNN(n_mels=229)
x = torch.randn(4, 229, 160)  # (batch, mels, time)
out = model(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {out.shape}")
assert out.shape == (4, 160, 88), "Shape mismatch!"
print("CNN model OK!")

In [None]:
# Cell 6: Test LightningModule
from src.lightning_module import AMTLightningModule

module = AMTLightningModule(n_mels=229)

batch = {
    "spec": torch.randn(4, 229, 160),
    "frame_labels": torch.zeros(4, 160, 88),
    "onset_labels": torch.zeros(4, 160, 88)
}

loss = module.training_step(batch, 0)
print(f"Loss: {loss.item():.4f}")
print("LightningModule OK!")

In [None]:
# Cell 7: Test DataModule (without actual data)
from src.data.datamodule import MAESTRODataModule

# Just test instantiation
dm = MAESTRODataModule(
    cache_dir="./cache/mel",
    sequence_length=160,
    batch_size=4,
    num_workers=0
)
print("DataModule instantiation OK!")

In [None]:
# Cell 8: Count parameters
total_params = sum(p.numel() for p in module.parameters())
print(f"Total parameters: {total_params:,}")
print("\n All tests passed!")