# Step 2: Retraining with OOF (Out-of-Fold) Predictions

This notebook trains a specific architecture on the **expanded dataset** (original + pseudo-labels) and saves OOF probabilities for the next stacking step.

**Note**: You should run this notebook 3 times, changing the `ARCH` variable to `'resnet18'`, `'resnet34'`, and `'efficientnet_b0'` respectively.

In [None]:
import os, json, random, numpy as np, pandas as pd, torch, torch.nn as nn, torch.nn.functional as F, torchaudio, torchvision.models as models
from tqdm.auto import tqdm
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score

# ================= CONFIGURATION =================
ARCH = 'resnet18' # Options: 'resnet18', 'resnet34', 'efficientnet_b0'
DATA_DIR = '/kaggle/input/the-last-frequency'
PSEUDO_DATA_DIR = '/kaggle/working' 

class CFG:
    sample_rate, n_fft, hop_length, n_mels, target_frames = 16000, 1024, 256, 128, 64
    n_splits, batch_size, epochs, lr, weight_decay, label_smoothing, mixup_alpha = 5, 64, 35, 1e-3, 1e-2, 0.1, 0.2
    num_classes = 35

def seed_everything(seed=42):
    random.seed(seed); os.environ['PYTHONHASHSEED'] = str(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True; torch.backends.cudnn.benchmark = False

seed_everything(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# =================================================

print(f"Loading expanded data for {ARCH}...")
train_waveforms = np.load(f'{PSEUDO_DATA_DIR}/expanded_train_waveforms.npy')
train_labels = np.load(f'{PSEUDO_DATA_DIR}/expanded_train_labels.npy')
with open(f'{DATA_DIR}/label_map.json') as f: 
    label_map = {int(k): v for k, v in json.load(f).items()}

class SpecTransform(nn.Module):
    def __init__(self):
        super().__init__()
        self.mel_spec = torchaudio.transforms.MelSpectrogram(sample_rate=CFG.sample_rate, n_fft=CFG.n_fft, hop_length=CFG.hop_length, n_mels=CFG.n_mels)
        self.amp_to_db = torchaudio.transforms.AmplitudeToDB()
        self.f_mask, self.t_mask = torchaudio.transforms.FrequencyMasking(20), torchaudio.transforms.TimeMasking(15)
    def forward(self, x, augment=False):
        x = self.amp_to_db(self.mel_spec(x))
        if x.shape[-1] > CFG.target_frames: x = x[..., :CFG.target_frames]
        elif x.shape[-1] < CFG.target_frames: x = F.pad(x, (0, CFG.target_frames - x.shape[-1]))
        if augment: x = self.t_mask(self.f_mask(x))
        return x

def get_model(arch):
    if arch == 'resnet18': model = models.resnet18(weights=None)
    elif arch == 'resnet34': model = models.resnet34(weights=None)
    elif arch == 'efficientnet_b0': model = models.efficientnet_b0(weights=None)
    if 'resnet' in arch:
        model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        model.fc = nn.Sequential(nn.Dropout(0.3), nn.Linear(model.fc.in_features, CFG.num_classes))
    else: 
        old_conv = model.features[0][0]
        model.features[0][0] = nn.Conv2d(1, old_conv.out_channels, old_conv.kernel_size, old_conv.stride, old_conv.padding, bias=False)
        model.classifier[1] = nn.Sequential(nn.Dropout(0.3), nn.Linear(model.classifier[1].in_features, CFG.num_classes))
    class Wrapper(nn.Module):
        def __init__(self, backbone):
            super().__init__()
            self.backbone, self.spec_layer = backbone, SpecTransform()
        def forward(self, x, augment=False):
            return self.backbone(self.spec_layer(x, augment=augment).unsqueeze(1))
    return Wrapper(model).to(device)

class SpeechDataset(Dataset):
    def __init__(self, waveforms, labels=None, augment=False):
        self.waveforms, self.labels, self.augment = waveforms, labels, augment
    def __len__(self): return len(self.waveforms)
    def __getitem__(self, idx):
        wav = self.waveforms[idx].copy()
        if self.augment: wav = np.roll(wav, int(random.uniform(-0.1, 0.1) * wav.shape[0]))
        wav = torch.from_numpy(wav).float()
        return (wav, self.labels[idx]) if self.labels is not None else wav

skf = StratifiedKFold(n_splits=CFG.n_splits, shuffle=True, random_state=42)
oof_probs = np.zeros((len(train_labels), CFG.num_classes))
pub, priv = np.load(f'{DATA_DIR}/public_test_waveforms.npy'), np.load(f'{DATA_DIR}/private_test_waveforms.npy')
test_wavs = np.concatenate([pub, priv])
all_test_probs = []
for fold, (train_idx, val_idx) in enumerate(skf.split(train_waveforms, train_labels)):
    print(f"Fold {fold+1}/{CFG.n_splits}")
    train_loader = DataLoader(SpeechDataset(train_waveforms[train_idx], train_labels[train_idx], augment=True), batch_size=CFG.batch_size, shuffle=True)
    val_loader = DataLoader(SpeechDataset(train_waveforms[val_idx], train_labels[val_idx], augment=False), batch_size=CFG.batch_size, shuffle=False)
    model = get_model(ARCH)
    opt = torch.optim.AdamW(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)
    sched = torch.optim.lr_scheduler.OneCycleLR(opt, max_lr=CFG.lr*2, steps_per_epoch=len(train_loader), epochs=CFG.epochs)
    crit = nn.CrossEntropyLoss(label_smoothing=CFG.label_smoothing)
    best_acc = 0
    for epoch in range(1, CFG.epochs + 1):
        model.train()
        for x, y in train_loader: 
            p = model(x.to(device), augment=True); l = crit(p, y.to(device))
            opt.zero_grad(); l.backward(); opt.step(); sched.step()
        model.eval(); vp, vt = [], []
        with torch.no_grad():
            for x, y in val_loader:
                out = model(x.to(device)); vp.append(F.softmax(out, dim=1).cpu().numpy()); vt.extend(y.numpy())
        vp = np.concatenate(vp); acc = accuracy_score(vt, vp.argmax(1))
        if acc > best_acc: 
            best_acc = acc; torch.save(model.state_dict(), f'{ARCH}_fold_{fold}.pth'); oof_probs[val_idx] = vp
        if epoch % 10 == 0: print(f"Epoch {epoch} Val Acc: {acc:.4f}")
    model.load_state_dict(torch.load(f'{ARCH}_fold_{fold}.pth')); model.eval(); tp = []
    test_loader = DataLoader(SpeechDataset(test_wavs, augment=False), batch_size=CFG.batch_size, shuffle=False)
    with torch.no_grad():
        for x in tqdm(test_loader): tp.append(F.softmax(model(x.to(device)), dim=1).cpu().numpy())
    all_test_probs.append(np.concatenate(tp))
np.save(f'{ARCH}_oof_probs.npy', oof_probs)
np.save(f'{ARCH}_test_probs.npy', np.mean(all_test_probs, axis=0))
print(f"Saved results for {ARCH}")