In [None]:
# import warnings
# warnings.filterwarnings("ignore", message="Attempting to use hipBLASLt on an unsupported architecture!")


In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import os
from torch.utils.data import Dataset, DataLoader
import glob

In [9]:

# Model hyperparameters
LATENT_DIM = 64   # Size of latent space
BATCH_SIZE = 32
LEARNING_RATE = 1e-3
EPOCHS = 50
# Load processed mel spectrograms
PROCESSED_DIR = "/home/renzo/projects/stempalooza/processed_mel/"

os.environ["TORCH_USE_MIOPEN"] = "1"

In [11]:
# ===============================
# Dataset Class
class MelDataset(Dataset):
    def __init__(self, root_dir, device):
        self.files = glob.glob(os.path.join(root_dir, "**/*.npy"), recursive=True)
        self.device = device

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        mel_path = self.files[idx]
        mel = torch.from_numpy(np.load(mel_path)).float()
        # If mel is 2D (H, W), add channel dimension to get (1, H, W)
        if len(mel.shape) == 2:
            mel = mel.unsqueeze(0)
        # Our expected shape is (1, 64, 16).
        # Preprocessed spectrograms are (64, ~15) so we pad width dimension to 16.
        if mel.shape[-1] < 16:
            pad_amount = 16 - mel.shape[-1]
            mel = F.pad(mel, (0, pad_amount), mode='constant', value=0)
        if mel.shape[-1] > 16:
            mel = mel[..., :16]
        return mel.to(self.device)

# ===============================
# VAE Model
class VAE(nn.Module):
    def __init__(self, input_dim, latent_dim):
        """
        input_dim: tuple, e.g. (64, 16)
        latent_dim: size of latent space
        """
        super(VAE, self).__init__()
        # Encoder: using kernel_size=4, stride=2, padding=1
        # Input shape: (1, 64, 16)
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1),  # -> (32, 32, 8)
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),  # -> (64, 16, 4)
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), # -> (128, 8, 2)
            nn.ReLU()
        )
        # Calculate the flattened size (should be 128*8*2 = 2048)
        self.conv_output_size = self._get_conv_output_size(input_dim)
        # Latent space mapping
        self.fc_mu = nn.Linear(self.conv_output_size, latent_dim)
        self.fc_logvar = nn.Linear(self.conv_output_size, latent_dim)
        self.fc_decode = nn.Linear(latent_dim, self.conv_output_size)
        # Decoder: reshape back to (128, 8, 2)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=(1, 1)),
            nn.Sigmoid()
        )

    def _get_conv_output_size(self, input_dim):
        # Create a dummy input on CPU for now (we later move model to device)
        with torch.no_grad():
            x = torch.zeros(1, 1, *input_dim, device="cpu")
            x = self.encoder(x)
            print(f"Output size after encoder (for input_dim {input_dim}): {x.shape}")
            return int(np.prod(x.size()))

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        #print('encoding')
        # Ensure input is on the same device as model parameters.
        x = x.to(next(self.parameters()).device)
        x = self.encoder(x)
        # Calling contiguous() can help with ROCm/MIOpen issues.
        x = x.contiguous()
        x = torch.flatten(x, start_dim=1)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        z = self.reparameterize(mu, logvar)
        x = self.fc_decode(z)
        # Reshape back to (batch, 128, 8, 2)
        x = x.view(-1, 128, 8, 2)
        x = self.decoder(x)
        return x, mu, logvar

# ===============================
# Loss function
def loss_function(recon_x, x, mu, logvar):
    recon_loss = F.mse_loss(recon_x, x, reduction='sum')
    kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kl_div

# ===============================
# Training function
def train(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0
    for batch in dataloader:
        batch = batch.to(device, dtype=torch.float32, non_blocking=True)
        # Check if there are NaN or Inf values in the preprocessed data
        if torch.isnan(batch).any() or torch.isinf(batch).any():
            print("NaN or Inf detected in batch input!")

        optimizer.zero_grad()
        recon_batch, mu, logvar = model(batch)
        loss = loss_function(recon_batch, batch, mu, logvar)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        # if torch.isnan(batch).any():
        #     print("NaN detected in batch input!")
        # if torch.isnan(recon_batch).any():
        #     print("NaN detected in reconstruction!")
        # if torch.isnan(mu).any() or torch.isnan(logvar).any():
        #     print("NaN detected in mu or logvar!")



    return total_loss / len(dataloader.dataset)

# ===============================
# Main function
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    #device = torch.device("cpu")
    print(f"Using device: {device}")
    # Create dataset and dataloader
    dataset = MelDataset(PROCESSED_DIR, device)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
    # Input dimensions: (64, 16)
    input_dim = (64, 16)
    model = VAE(input_dim, LATENT_DIM).to(device)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    for epoch in range(EPOCHS):
        loss = train(model, dataloader, optimizer, device)
        print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {loss:.4f}")
    torch.save(model.state_dict(), "vae_model.pth")
    print("Model saved!")

In [36]:
!export MIOPEN_DEBUG_FIND_ONLY_SOLUTIONS=1
!export HSA_OVERRIDE_GFX_VERSION=10.3.0

In [4]:
import torch
print("CUDA Available:", torch.cuda.is_available())
print("ROCm Available:", torch.version.hip)
print("Number of GPUs:", torch.cuda.device_count())
print("GPU Name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU detected")


CUDA Available: True
ROCm Available: 6.2.41134-65d174c3e
Number of GPUs: 2
GPU Name: AMD Radeon RX 7800 XT


In [26]:
main()

Using device: cuda
Output size after encoder (for input_dim (64, 16)): torch.Size([1, 128, 8, 2])
NaN or Inf detected in batch input!
NaN or Inf detected in batch input!
NaN or Inf detected in batch input!
NaN or Inf detected in batch input!
NaN or Inf detected in batch input!
NaN or Inf detected in batch input!
NaN or Inf detected in batch input!
NaN or Inf detected in batch input!
NaN or Inf detected in batch input!
NaN or Inf detected in batch input!
NaN or Inf detected in batch input!
NaN or Inf detected in batch input!
NaN or Inf detected in batch input!
NaN or Inf detected in batch input!
NaN or Inf detected in batch input!
NaN or Inf detected in batch input!
NaN or Inf detected in batch input!
NaN or Inf detected in batch input!
NaN or Inf detected in batch input!
NaN or Inf detected in batch input!
NaN or Inf detected in batch input!
NaN or Inf detected in batch input!


KeyboardInterrupt: 