In [12]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np

# ---------------- Dataset ----------------
class AudioEditingDataset(Dataset):
    def __init__(self, csv_file):
        self.data = pd.read_csv(csv_file)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]

        # Load .npy files
        prompt = np.load(row['prompt_embedding'])         # shape (num_tokens, prompt_dim)
        input_audio = np.load(row['input_embedding'])    # shape (audio_dim,)
        output_audio = np.load(row['output_embedding'])  # shape (audio_dim,)

        # Pool prompt if needed
        if prompt.ndim > 1:
            prompt = prompt.mean(axis=0)

        # Flatten audio if multi-frame
        if input_audio.ndim > 1:
            input_audio = input_audio.mean(axis=0)
        if output_audio.ndim > 1:
            output_audio = output_audio.mean(axis=0)

        # Convert to tensors
        prompt = torch.tensor(prompt, dtype=torch.float32)
        input_audio = torch.tensor(input_audio, dtype=torch.float32)
        output_audio = torch.tensor(output_audio, dtype=torch.float32)

        return prompt, input_audio, output_audio

# ---------------- Transformer Model ----------------
class AudioEditingTransformer(nn.Module):
    def __init__(self, prompt_dim=1024, audio_dim=128, embedding_dim=128, nhead=4, num_layers=2):
        super().__init__()

        # Project prompt and audio embeddings to same dimension
        self.proj_prompt = nn.Linear(prompt_dim, embedding_dim)
        self.proj_audio = nn.Linear(audio_dim, embedding_dim)

        # Positional encoding for sequence length 2
        self.positional_encoding = nn.Parameter(torch.zeros(1, 2, embedding_dim))

        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=nhead, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Output MLP to predict edited audio embedding
        self.mlp = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim),
            nn.ReLU(),
            nn.Linear(embedding_dim, audio_dim)
        )

    def forward(self, prompt_embedding, input_audio_embedding):
        # Project to same embedding_dim
        prompt_embedding = self.proj_prompt(prompt_embedding)
        input_audio_embedding = self.proj_audio(input_audio_embedding)

        # Stack into sequence length 2
        x = torch.stack([prompt_embedding, input_audio_embedding], dim=1)  # (batch_size, 2, embedding_dim)
        x = x + self.positional_encoding

        # Transformer
        x = self.transformer(x)

        # Use last token (audio) to predict output
        out = self.mlp(x[:, -1, :])
        return out

# ---------------- Training Loop ----------------
def train_model(csv_file, embedding_dim=128, batch_size=64, epochs=20, lr=1e-4, device='cuda'):
    dataset = AudioEditingDataset(csv_file)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    model = AudioEditingTransformer(embedding_dim=embedding_dim).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()

    for epoch in range(epochs):
        for prompt, input_audio, output_audio in dataloader:
            prompt = prompt.to(device)
            input_audio = input_audio.to(device)
            output_audio = output_audio.to(device)

            optimizer.zero_grad()
            pred = model(prompt, input_audio)
            loss = criterion(pred, output_audio)
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch+1}/{epochs} - Loss: {loss.item():.4f}")

    return model


# -----------------------------
# Example inference
# -----------------------------
def infer(model, prompt_embedding, input_audio_embedding, device='cuda'):
    model.eval()
    with torch.no_grad():
        prompt_embedding = prompt_embedding.to(device).float().unsqueeze(0)
        input_audio_embedding = input_audio_embedding.to(device).float().unsqueeze(0)
        pred_embedding = model(prompt_embedding, input_audio_embedding)
    return pred_embedding.squeeze(0)  # Return embedding vector


In [13]:
import numpy as np

emb = np.load('../embeddings/input_data_embeddings/1.npy', allow_pickle=True)

print(type(emb))
print("Shape:", emb.shape)
print("First 10 values:", emb[:10])

<class 'numpy.ndarray'>
Shape: (128,)
First 10 values: [175.   8. 147.  97. 215.  76.  77. 130. 148. 175.]


In [14]:
import numpy as np
x = np.load("../embeddings/prompts_embeddings/embeddings_snowfall.npy")
print(x.shape)

(10, 1024)


In [15]:
import torch

# 1. Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# 2. Train the model
csv_file = '../50000_datapoint.csv'
embedding_dim = 512  # Set this to match your embedding size
model = train_model(csv_file, embedding_dim=embedding_dim, batch_size=64, epochs=20, lr=1e-4, device=device)

# 3. Save the trained model
torch.save(model.state_dict(), 'audio_editing_transformer.pth')

# # 4. Load the trained model later
# model_loaded = AudioEditingTransformer(embedding_dim=embedding_dim).to(device)
# model_loaded.load_state_dict(torch.load('audio_editing_transformer.pth'))

# # 5. Example inference with new embeddings
# # Suppose you have torch tensors: prompt_emb, input_audio_emb
# # They should have shape [embedding_dim]
# predicted_output_emb = infer(model_loaded, prompt_emb, input_audio_emb, device=device)

# # 6. Use predicted_output_emb to reconstruct audio via your decoder / vocoder


Epoch 1/20 - Loss: 96.7417
Epoch 2/20 - Loss: 2.1898
Epoch 3/20 - Loss: 86.3750
Epoch 4/20 - Loss: 39.4520
Epoch 5/20 - Loss: 19.0071
Epoch 6/20 - Loss: 26.9399
Epoch 7/20 - Loss: 54.3680
Epoch 8/20 - Loss: 1.1060
Epoch 9/20 - Loss: 14.2896
Epoch 10/20 - Loss: 27.4440
Epoch 11/20 - Loss: 14.5823
Epoch 12/20 - Loss: 1.0573
Epoch 13/20 - Loss: 7.1401
Epoch 14/20 - Loss: 4.3344
Epoch 15/20 - Loss: 32.1758
Epoch 16/20 - Loss: 25.2243
Epoch 17/20 - Loss: 40.1536
Epoch 18/20 - Loss: 55.1921
Epoch 19/20 - Loss: 1.1630
Epoch 20/20 - Loss: 29.1605
