## This is the training file for the local model which stands as a proof of concept currently

In [9]:
import os
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torch.optim import Adam

import sys
sys.path.append("~/Desktop/projects/audio-filler")
from models.model2.model2 import AudioEncoder1

ModuleNotFoundError: No module named 'models'

In [None]:
class MusicGenreDataset(Dataset):
    def __init__(self, data_dir, clip_duration=15, stride=5, sample_rate=16000):
        self.data_dir = Path(data_dir)
        self.clip_length = clip_duration * sample_rate
        self.stride = stride * sample_rate
        self.sample_rate = sample_rate
        
        # Get genre labels
        self.genres = sorted([d.name for d in self.data_dir.iterdir() if d.is_dir()])
        self.genre_to_idx = {genre: idx for idx, genre in enumerate(self.genres)}
        
        # Collect audio files and their genres
        self.audio_files = []
        for genre in self.genres:
            genre_dir = self.data_dir / genre
            for audio_file in genre_dir.glob("*.mp3"):
                self.audio_files.append((audio_file, genre))
        
        # Precompute clip segments
        self.clips = []
        for audio_file, genre in self.audio_files:
            info = torchaudio.info(audio_file)
            total_samples = info.num_frames
            start = 0
            while start < total_samples:
                end = start + self.clip_length
                if end > total_samples:
                    if total_samples - start >= 10 * sample_rate:  # Check if at least 10s
                        end = total_samples
                        padding = self.clip_length - (end - start)
                        self.clips.append((audio_file, start, end, padding, genre))
                    break
                else:
                    self.clips.append((audio_file, start, end, 0, genre))
                start += self.stride

    def __len__(self):
        return len(self.clips)

    def __getitem__(self, idx):
        audio_file, start, end, padding, genre = self.clips[idx]
        waveform, sr = torchaudio.load(audio_file, frame_offset=start, num_frames=end-start)
        waveform = torchaudio.functional.resample(waveform, sr, self.sample_rate)[0]  # Convert to mono
        
        if padding > 0:
            waveform = torch.nn.functional.pad(waveform, (0, padding))
        
        # Add channel dimension
        waveform = waveform.unsqueeze(0)
        label = self.genre_to_idx[genre]
        
        return waveform, label

In [None]:
# Initialize dataset and dataloader
dataset = MusicGenreDataset("path/to/your/data")
dataloader = DataLoader(dataset, batch_size=1000, shuffle=True, num_workers=4)

# Model and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AudioModel1().to(device)
optimizer = Adam(model.parameters(), lr=1e-4)

# Track metrics
metrics = {
    'total_loss': [], 'classification_loss': [], 
    'modulus_recon_loss': [], 'sign_recon_loss': [],
    'sign_accuracy': [], 'kl_loss': []
}

# Training loop
model.train()
for batch_idx, (data, targets) in enumerate(dataloader):
    data, targets = data.to(device), targets.to(device)
    
    optimizer.zero_grad()
    losses = model.loss_function(data, targets)
    losses['total_loss'].backward()
    optimizer.step()
    
    # Store metrics
    for k in metrics:
        metrics[k].append(losses[k].item())
    
    # Print batch statistics
    print(f"Batch {batch_idx+1}:")
    print(f"Total Loss: {losses['total_loss'].item():.4f}")
    print(f"Classification Loss: {losses['classification_loss'].item():.4f}")
    print(f"Modulus Recon Loss: {losses['modulus_recon_loss'].item():.4f}")
    print(f"Sign Recon Loss: {losses['sign_recon_loss'].item():.4f}")
    print(f"Sign Accuracy: {losses['sign_accuracy'].item():.4f}")
    print(f"KL Loss: {losses['kl_loss'].item():.4f}")
    
    # Plot every 5 batches
    if (batch_idx + 1) % 5 == 0:
        plt.figure(figsize=(12, 8))
        for i, (k, v) in enumerate(metrics.items()):
            plt.subplot(2, 3, i+1)
            plt.plot(v, label=k)
            plt.title(k)
            plt.xlabel('Batch')
        plt.tight_layout()
        plt.show()
        plt.close()