# Baseline for reference

In [None]:
import numpy as np
import pandas as pd
import torchaudio
import os
from tqdm.auto import tqdm

K = 16000
IN_DIR = "/kaggle/input/audio-demixing-aicc-round-2/"

def load_signal(path, K):
    wav, _ = torchaudio.load(path)
    arr = wav[0].numpy()
    if len(arr) >= K:
        return arr[:K].astype(np.float32)
    else:
        return np.pad(arr, (0, K - len(arr)), mode='constant').astype(np.float32)

# Load CSVs
train = pd.read_csv("/kaggle/input/audio-demixing-aicc-round-2/train.csv")
test = pd.read_csv("/kaggle/input/audio-demixing-aicc-round-2/test.csv")

# Load raw training signals
print("Loading training signals...")
X_train = np.vstack([load_signal(os.path.join(IN_DIR, p), K) for p in tqdm(train["file"], desc="X_train")])
Y1_train = np.vstack([load_signal(os.path.join(IN_DIR, p), K) for p in tqdm(train["sig_1"], desc="Y1_train")])
Y2_train = np.vstack([load_signal(os.path.join(IN_DIR, p), K) for p in tqdm(train["sig_2"], desc="Y2_train")])

# Load raw test signals
print("Loading test signals...")
if "sig_1" in test.columns:
    X_test = np.vstack([load_signal(os.path.join(IN_DIR, p), K) for p in tqdm(test["sig_1"], desc="X_test")])
else:
    X_test = np.tile(X_train[0], (len(test), 1))

test_IDs = test["ID"].astype(str).values

In [None]:
from sklearn.linear_model import Ridge

# Flatten for single-step regression
X_flat = X_train.reshape(-1, 1)
Y1_flat = Y1_train.reshape(-1)
Y2_flat = Y2_train.reshape(-1)

# Train Ridge regressors
print("Training Ridge models...")
m1 = Ridge(alpha=1.0)
m2 = Ridge(alpha=1.0)
m1.fit(X_flat, Y1_flat)
m2.fit(X_flat, Y2_flat)

# Predict on raw test signals
X_test_flat = X_test.reshape(-1, 1)
pred1_flat = m1.predict(X_test_flat)
pred2_flat = m2.predict(X_test_flat)

pred1 = pred1_flat.reshape(len(X_test), K).astype(np.float32)
pred2 = pred2_flat.reshape(len(X_test), K).astype(np.float32)

In [None]:
import base64
import pandas as pd
from tqdm.auto import tqdm
import csv

def encode_array_base85(arr: np.ndarray) -> str:
    arr = np.ascontiguousarray(arr.astype(np.float32).ravel())
    return base64.b85encode(arr.tobytes()).decode("ascii")

# Encode predictions
rows = []
for i in tqdm(range(len(test_IDs)), desc="Encoding rows"):
    b1 = encode_array_base85(pred1[i])
    b2 = encode_array_base85(pred2[i])
    rows.append((str(test_IDs[i]), b1, b2))

# Build submission DataFrame
submission = pd.DataFrame(rows, columns=["ID", "sig_1", "sig_2"])
submission.to_csv("/kaggle/working/submission.csv", index=False, quoting=csv.QUOTE_ALL)
print("âœ“ submission.csv written")

# Actual code

In [None]:
import os
import random
import base64
import csv
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm.auto import tqdm
import warnings
import copy

warnings.filterwarnings("ignore", category=UserWarning, module="torchaudio")

In [None]:
# --- Configuration ---
CONFIG = {
    "SR": 16000,
    "LEN": 16000,
    "BATCH_SIZE": 24,
    "EPOCHS": 50,          # Increased slightly for augmentation
    "LR": 1e-3,
    "DEVICE": "cuda" if torch.cuda.is_available() else "cpu",
    "ROOT_DIR": "/kaggle/input/audio-demixing-aicc-round-2/"
}

print(f"Running on device: {CONFIG['DEVICE']}")

In [None]:
class AudioDataset(Dataset):
    def __init__(self, df, root_dir, is_train=True):
        self.df = df
        self.root_dir = root_dir
        self.is_train = is_train

    def __len__(self):
        return len(self.df)

    def load_wav(self, path):
        full_path = os.path.join(self.root_dir, path)
        wav, sr = torchaudio.load(full_path)
        wav = wav[0].numpy()
        
        # Pad or Crop
        if len(wav) < CONFIG['LEN']:
            wav = np.pad(wav, (0, CONFIG['LEN'] - len(wav)), mode='constant')
        else:
            wav = wav[:CONFIG['LEN']]
            
        return torch.tensor(wav, dtype=torch.float32).unsqueeze(0)

    def augment(self, mixed, s1, s2):
        """Applies consistent augmentation to input and targets"""
        # 1. Random Gain (0.8x to 1.2x)
        gain = random.uniform(0.8, 1.2)
        mixed = mixed * gain
        s1 = s1 * gain
        s2 = s2 * gain

        # 2. Random Phase Inversion (Flip polarity)
        if random.random() > 0.5:
            mixed = -mixed
            s1 = -s1
            s2 = -s2
            
        return mixed, s1, s2

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        mixed = self.load_wav(row["file"])
        
        if self.is_train:
            s1 = self.load_wav(row["sig_1"])
            s2 = self.load_wav(row["sig_2"])
            
            # Apply Augmentation
            mixed, s1, s2 = self.augment(mixed, s1, s2)
            
            return mixed, s1, s2
        else:
            return mixed, row["ID"]

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_c, out_c, kernel=3, stride=1, padding=1):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv1d(in_c, out_c, kernel, stride, padding),
            nn.BatchNorm1d(out_c),
            nn.LeakyReLU(0.2)
        )
    def forward(self, x):
        return self.conv(x)

In [None]:
class WaveUNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Encoder
        self.enc1 = ConvBlock(1, 16)
        self.pool1 = nn.MaxPool1d(2) 
        
        self.enc2 = ConvBlock(16, 32)
        self.pool2 = nn.MaxPool1d(2)
        
        self.enc3 = ConvBlock(32, 64)
        self.pool3 = nn.MaxPool1d(2)
        
        self.enc4 = ConvBlock(64, 128)
        self.pool4 = nn.MaxPool1d(2)
        
        # Bottleneck
        self.bottleneck = ConvBlock(128, 256)
        
        # Decoder
        self.up4 = nn.ConvTranspose1d(256, 128, kernel_size=2, stride=2)
        self.dec4 = ConvBlock(256, 128)
        
        self.up3 = nn.ConvTranspose1d(128, 64, kernel_size=2, stride=2)
        self.dec3 = ConvBlock(128, 64)
        
        self.up2 = nn.ConvTranspose1d(64, 32, kernel_size=2, stride=2)
        self.dec2 = ConvBlock(64, 32)
        
        self.up1 = nn.ConvTranspose1d(32, 16, kernel_size=2, stride=2)
        self.dec1 = ConvBlock(32, 16)
        
        # Final output
        self.final = nn.Conv1d(16, 2, kernel_size=1)

    def forward(self, x):
        e1 = self.enc1(x)
        p1 = self.pool1(e1)
        e2 = self.enc2(p1)
        p2 = self.pool2(e2)
        e3 = self.enc3(p2)
        p3 = self.pool3(e3)
        e4 = self.enc4(p3)
        p4 = self.pool4(e4)
        
        b = self.bottleneck(p4)
        
        d4 = self.up4(b)
        d4 = torch.cat([d4, e4], dim=1)
        d4 = self.dec4(d4)
        d3 = self.up3(d4)
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.dec3(d3)
        d2 = self.up2(d3)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)
        d1 = self.up1(d2)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)
        
        return self.final(d1)

In [None]:
def pit_mse_loss(pred, target_s1, target_s2):
    p1 = pred[:, 0:1, :]
    p2 = pred[:, 1:2, :]
    
    # p1: p1->s1, p2->s2
    lossA = F.mse_loss(p1, target_s1, reduction='none').mean(dim=(1,2)) + F.mse_loss(p2, target_s2, reduction='none').mean(dim=(1,2))
            
    # p2: p1->s2, p2->s1
    lossB = F.mse_loss(p1, target_s2, reduction='none').mean(dim=(1,2)) + F.mse_loss(p2, target_s1, reduction='none').mean(dim=(1,2))
            
    loss, _ = torch.min(torch.stack([lossA, lossB], dim=1), dim=1)
    return loss.mean()

In [None]:
def train_model():
    df_train = pd.read_csv(os.path.join(CONFIG['ROOT_DIR'], "train.csv"))
    
    # Dataset & Loader
    train_ds = AudioDataset(df_train, CONFIG['ROOT_DIR'], is_train=True)
    train_loader = DataLoader(train_ds, batch_size=CONFIG['BATCH_SIZE'], shuffle=True, num_workers=2, persistent_workers=True)
    
    model = WaveUNet().to(CONFIG['DEVICE'])
    optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG['LR'])
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)
    
    print(f"Starting training for {CONFIG['EPOCHS']} epochs...")
    
    for epoch in range(CONFIG['EPOCHS']):
        model.train()
        running_loss = 0.0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{CONFIG['EPOCHS']}")
        
        for mixed, s1, s2 in pbar:
            mixed = mixed.to(CONFIG['DEVICE'])
            s1 = s1.to(CONFIG['DEVICE'])
            s2 = s2.to(CONFIG['DEVICE'])
            
            optimizer.zero_grad()
            preds = model(mixed)
            loss = pit_mse_loss(preds, s1, s2)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            pbar.set_postfix({'loss': running_loss / (pbar.n + 1)})
        
        # Step the scheduler based on average loss
        avg_loss = running_loss / len(train_loader)
        scheduler.step(avg_loss)
            
    return model

In [None]:
def encode_array_base85(arr: np.ndarray) -> str:
    arr = np.ascontiguousarray(arr.astype(np.float32).ravel())
    return base64.b85encode(arr.tobytes()).decode("ascii")

In [None]:
model = train_model()
    
df_test = pd.read_csv(os.path.join(CONFIG['ROOT_DIR'], "test.csv"))
test_ds = AudioDataset(df_test, CONFIG['ROOT_DIR'], is_train=False)
test_loader = DataLoader(test_ds, batch_size=CONFIG['BATCH_SIZE'], shuffle=False, num_workers=2)

model.eval()
submission_rows = []

print("Generating predictions...")
with torch.no_grad():
    for mixed_batch, ids in tqdm(test_loader, desc="Inference"):
        mixed_batch = mixed_batch.to(CONFIG['DEVICE'])
        
        preds = model(mixed_batch) # (B, 2, L)
        preds = preds.cpu().numpy()
        
        for i in range(len(ids)):
            sig1 = preds[i, 0, :]
            sig2 = preds[i, 1, :]
            
            b1 = encode_array_base85(sig1)
            b2 = encode_array_base85(sig2)
            
            submission_rows.append((str(ids[i]), b1, b2))

submission = pd.DataFrame(submission_rows, columns=["ID", "sig_1", "sig_2"])
submission.to_csv("submission.csv", index=False, quoting=csv.QUOTE_ALL)
print("submission.csv written")