In [1]:
import os
import torch
import librosa
import librosa.display
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
import torch.optim as optim

# Check PyTorch version
print(torch.__version__)

  from .autonotebook import tqdm as notebook_tqdm


2.6.0+cpu


# pretraining:

In [2]:
# Add src to the system path
import sys
import pickle
import os
sys.path.append(os.path.abspath('src'))

# Import SSAST-specific modules
from ssast.src.models.ast_models import ASTModel  # The main model class
from ssast.src.dataloader import AudioDataset  # For loading audio data
import importlib
import ssast.src.traintest_mask

# Force reload the traintest_mask module
importlib.reload(ssast.src.traintest_mask)
from ssast.src.traintest_mask import trainmask  # Pre-training function

# Parameters from the bash script
tr_data = r"dummy_data.json".replace("\\", "/")
te_data = r"dummy_data.json".replace("\\", "/")
dataset = 'custom'
dataset_mean = -4.2677393
dataset_std = 4.5689974
target_length = 1024
num_mel_bins = 128

model_size = 'base'
fshape = 16
tshape = 16
fstride = 16
tstride = 16
mask_patch = 400

batch_size = 8  # Will adjust below if needed
lr = 1e-4
lr_patience = 2
n_epochs = 1
freqm = 0
timem = 0
mixup = 0
bal = 'none'
n_print_steps = 100
epoch_iter = 4000
task = 'pretrain_joint'

exp_dir = r"exp/mask01-{}-f{}-t{}-b{}-lr{}-m{}-{}-custom".format(
    model_size, fshape, tshape, batch_size, lr, mask_patch, task
).replace("\\", "/")

# Audio configurations
audio_conf = {
    'num_mel_bins': num_mel_bins,
    'target_length': target_length,
    'freqm': freqm,
    'timem': timem,
    'mixup': mixup,
    'dataset': dataset,
    'mode': 'train',
    'mean': dataset_mean,
    'std': dataset_std,
    'noise': False
}

val_audio_conf = {
    'num_mel_bins': num_mel_bins,
    'target_length': target_length,
    'freqm': 0,
    'timem': 0,
    'mixup': 0,
    'dataset': dataset,
    'mode': 'evaluation',
    'mean': dataset_mean,
    'std': dataset_std,
    'noise': False
}

# Create DataLoaders
import torch

# Temporary dataset instance to check length
tmp_dataset = AudioDataset(tr_data, label_csv=r"dummy_labels.csv".replace("\\", "/"), audio_conf=audio_conf)
if len(tmp_dataset) < batch_size:
    batch_size = len(tmp_dataset)

train_loader = torch.utils.data.DataLoader(
    tmp_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0,
    pin_memory=torch.cuda.is_available(),
    drop_last=False  # ✅ FIXED: prevent dropping small batches
)

val_loader = torch.utils.data.DataLoader(
    AudioDataset(te_data, label_csv=r"dummy_labels.csv".replace("\\", "/"), audio_conf=val_audio_conf),
    batch_size=max(1, batch_size * 2),
    shuffle=False,
    num_workers=0,
    pin_memory=torch.cuda.is_available()
)

# Initialize the model
ast_mdl = ASTModel(
    fshape=fshape,
    tshape=tshape,
    fstride=fshape,
    tstride=tshape,
    input_fdim=num_mel_bins,
    input_tdim=target_length,
    model_size=model_size,
    pretrain_stage=True
)

if not isinstance(ast_mdl, torch.nn.DataParallel):
    ast_mdl = torch.nn.DataParallel(ast_mdl)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ast_mdl = ast_mdl.to(device)

# Set up arguments
class Args:
    def __init__(self):
        self.data_train = tr_data
        self.data_val = te_data
        self.data_eval = None
        self.label_csv = r"E:/bird/dummy_labels.csv".replace("\\", "/")
        self.n_class = 1
        self.dataset = dataset
        self.dataset_mean = dataset_mean
        self.dataset_std = dataset_std
        self.target_length = target_length
        self.num_mel_bins = num_mel_bins
        self.exp_dir = exp_dir
        self.lr = lr
        self.warmup = True
        self.optim = "adam"
        self.batch_size = batch_size
        self.num_workers = 0
        self.n_epochs = n_epochs
        self.lr_patience = lr_patience
        self.adaptschedule = False
        self.n_print_steps = n_print_steps
        self.save_model = False
        self.freqm = freqm
        self.timem = timem
        self.mixup = mixup
        self.bal = bal
        self.fstride = fstride
        self.tstride = tstride
        self.fshape = fshape
        self.tshape = tshape
        self.model_size = model_size
        self.task = task
        self.mask_patch = mask_patch
        self.cluster_factor = 3
        self.epoch_iter = epoch_iter
        self.pretrained_mdl_path = None
        self.head_lr = 1
        self.noise = False
        self.metrics = "mAP"
        self.lrscheduler_start = 10
        self.lrscheduler_step = 5
        self.lrscheduler_decay = 0.5
        self.wa = False
        self.wa_start = 16
        self.wa_end = 30
        self.loss = "BCE"

args = Args()
os.makedirs(args.exp_dir, exist_ok=True)
os.makedirs(os.path.join(args.exp_dir, "models"), exist_ok=True)
with open(os.path.join(args.exp_dir, "args.pkl"), "wb") as f:
    pickle.dump(args, f)

# Start pre-training
print(f"\n✅ Now starting self-supervised pretraining for {args.n_epochs} epochs")
trainmask(ast_mdl, train_loader, val_loader, args)
print("Files in exp_dir:", os.listdir(args.exp_dir))
print("Files in models dir:", os.listdir(os.path.join(args.exp_dir, "models")))



Successfully imported AverageMeter in traintest_mask.py
Successfully imported AverageMeter in traintest_mask.py
pretraining patch split stride: frequency=16, time=16
pretraining patch shape: frequency=16, time=16
pretraining patch array dimension: frequency=8, time=64
pretraining number of patches=512

✅ Now starting self-supervised pretraining for 1 epochs
Now running on : cpu
Total parameter number is : 88.763346000 million
Total trainable parameter number is : 88.763344000 million
current #steps=0, #epochs=1
start training...


Epoch 1: 100%|██████████| 6/6 [03:10<00:00, 31.75s/it, Loss=8.98, Acc=0.00234, LR=0]

Files in exp_dir: ['args.pkl', 'models']
Files in models dir: ['audio_model.1.pth']





# Few-Shot Inference:

In [None]:
import torch
import torch.nn.functional as F
from ssast.src.models.ast_models import ASTModel
from ssast.src.dataloader import AudioDataset
import pandas as pd
import json

# Model parameters
fshape = 16
tshape = 16
fstride = 16
tstride = 16
num_mel_bins = 128
target_length = 1024
model_size = 'base'

exp_dir = "exp/mask01-base-f16-t16-b24-lr0.0001-m400-pretrain_joint-custom"
pretrained_model_path = f"{exp_dir}/models/audio_model.3.pth"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on device: {device}")

# Initialize the model
ast_mdl = ASTModel(
    fshape=fshape,
    tshape=tshape,
    fstride=fstride,
    tstride=tstride,
    input_fdim=num_mel_bins,
    input_tdim=target_length,
    model_size=model_size,
    pretrain_stage=False,
    load_pretrained_mdl_path=pretrained_model_path
)

ast_mdl = ast_mdl.to(device)
ast_mdl.eval()

# Audio configuration
dataset_mean = -4.2677393
dataset_std = 4.5689974
audio_conf = {
    'num_mel_bins': num_mel_bins,
    'target_length': target_length,
    'freqm': 0,
    'timem': 0,
    'mixup': 0,
    'dataset': 'custom',
    'mode': 'evaluation',
    'mean': dataset_mean,
    'std': dataset_std,
    'noise': False
}

# Paths to files
support_json_path = "data/support_data.json"
query_json_path = "data/query_data.json"
label_csv = "data/labels.csv"
taxonomy_path = "data/taxonomy.csv"

# Load taxonomy to get all unique inat_taxon_id values
df = pd.read_csv(taxonomy_path)
df['inat_taxon_id'] = df['inat_taxon_id'].astype(str)
unique_species_ids = sorted(df['inat_taxon_id'].unique().tolist())
num_species = len(unique_species_ids)
print(f"Total unique species (inat_taxon_id) in taxonomy: {num_species}")
print(f"Unique species IDs: {unique_species_ids}")

# Create mapping from inat_taxon_id to index
species_to_index = {sid: idx for idx, sid in enumerate(unique_species_ids)}

# Create DataLoaders
support_dataset = AudioDataset(support_json_path, label_csv=label_csv, audio_conf=audio_conf)
query_dataset = AudioDataset(query_json_path, label_csv=label_csv, audio_conf=audio_conf)

support_loader = torch.utils.data.DataLoader(
    support_dataset,
    batch_size=24,
    shuffle=False,
    num_workers=0,
    pin_memory=torch.cuda.is_available()
)

query_loader = torch.utils.data.DataLoader(
    query_dataset,
    batch_size=24,
    shuffle=False,
    num_workers=0,
    pin_memory=torch.cuda.is_available()
)

# Extract segment labels from the datasets
support_clip_species = support_dataset.get_segment_labels()
query_clip_species = query_dataset.get_segment_labels()

print(f"inat_taxon_id for each support segment: {support_clip_species}")
print(f"inat_taxon_id for each query segment: {query_clip_species}")

# Encode support and query clips
def encode_clips(model, data_loader, device, return_labels=True):
    model.eval()
    embeddings = []
    labels = [] if return_labels else None
    with torch.no_grad():
        for audio_input, label in data_loader:
            audio_input = audio_input.to(device)
            output = model(audio_input, task='ft_cls')
            embeddings.append(output.cpu())
            if return_labels:
                labels.append(label)
    
    embeddings = torch.cat(embeddings, dim=0)
    if return_labels:
        labels = torch.cat(labels, dim=0)
    return embeddings, labels

support_embeddings, _ = encode_clips(ast_mdl, support_loader, device, return_labels=False)
query_embeddings, _ = encode_clips(ast_mdl, query_loader, device, return_labels=False)

# Dynamically create one-hot encoded support_labels based on inat_taxon_id
support_labels = torch.zeros(len(support_clip_species), num_species)
for i, species_id in enumerate(support_clip_species):
    idx = species_to_index[species_id]
    support_labels[i, idx] = 1.0
print(f"Dynamically generated support_labels:\n{support_labels}")

print(f"Support embeddings shape: {support_embeddings.shape}")
print(f"Query embeddings shape: {query_embeddings.shape}")

# Convert one-hot support_labels to class indices
support_label_indices = torch.argmax(support_labels, dim=1)
print(f"Support label indices: {support_label_indices}")

# Compute number of classes (excluding 'unknown', which has index -1)
unique_classes = set(support_label_indices.tolist()) - {-1}
num_classes = len(unique_classes)
print(f"Number of classes: {num_classes}")

# Map global indices to local batch indices (0, 1, ..., num_classes-1)
global_to_local = {global_idx: local_idx for local_idx, global_idx in enumerate(sorted(unique_classes))}
# Map local batch indices back to global indices
local_to_global = {local_idx: global_idx for global_idx, local_idx in global_to_local.items()}
print(f"Global to local mapping: {global_to_local}")
print(f"Local to global mapping: {local_to_global}")

# Map global indices to inat_taxon_id
index_to_species = {global_idx: unique_species_ids[global_idx] for global_idx in unique_classes}
print(f"Index to species mapping: {index_to_species}")

# Attention-based few-shot inference
def attention_based_prototypes(query_emb, support_emb, support_label_indices, num_classes):
    query_emb = F.normalize(query_emb, dim=-1)
    support_emb = F.normalize(support_emb, dim=-1)
    attention_scores = torch.mm(query_emb, support_emb.t())
    attention_weights = F.softmax(attention_scores, dim=-1)
    
    # Map global support_label_indices to local indices for prototype computation
    local_support_label_indices = torch.tensor([global_to_local[idx.item()] for idx in support_label_indices])
    
    prototypes = torch.zeros(num_classes, support_emb.size(1))
    for c in range(num_classes):
        class_mask = (local_support_label_indices == c).float().unsqueeze(0)
        class_weights = (attention_weights * class_mask).sum(dim=0, keepdim=True)
        class_weights = class_weights / (class_weights.sum(dim=1, keepdim=True) + 1e-8)
        prototype = torch.mm(class_weights, support_emb)
        prototypes[c] = prototype.squeeze(0)
    return prototypes

def few_shot_inference(query_emb, support_emb, support_label_indices, num_classes):
    query_emb = query_emb.to(device)
    support_emb = support_emb.to(device)
    support_label_indices = support_label_indices.to(device)
    prototypes = attention_based_prototypes(query_emb, support_emb, support_label_indices, num_classes).to(device)
    query_emb = F.normalize(query_emb, dim=-1)
    prototypes = F.normalize(prototypes, dim=-1)
    similarity = torch.mm(query_emb, prototypes.t())
    predictions = similarity.argmax(dim=1)
    return predictions

# Perform inference
predictions = few_shot_inference(query_embeddings, support_embeddings, support_label_indices, num_classes)

print(f"Predictions for query clips (local indices): {predictions}")

# Map local indices to global indices, then to inat_taxon_id
predicted_global_indices = [local_to_global[pred.item()] for pred in predictions]
predicted_species = [index_to_species[global_idx] for global_idx in predicted_global_indices]
print(f"Predicted species (inat_taxon_id) for query clips: {predicted_species}")

# Evaluate accuracy using query segment labels
query_labels = torch.zeros(len(query_clip_species), num_species)
for i, species_id in enumerate(query_clip_species):
    idx = species_to_index[species_id]
    query_labels[i, idx] = 1.0
query_ground_truth = torch.argmax(query_labels, dim=1)
query_ground_truth_local = torch.tensor([global_to_local[idx.item()] for idx in query_ground_truth])
accuracy = (predictions.cpu() == query_ground_truth_local).float().mean().item()
print(f"Few-shot classification accuracy: {accuracy * 100:.2f}%")

Running on device: cpu
now load a SSL pretrained models from E:/bird/exp/mask01-base-f16-t16-b24-lr0.0001-m400-pretrain_joint-custom/models/audio_model.3.pth
pretraining patch split stride: frequency=16, time=16
pretraining patch shape: frequency=16, time=16
pretraining patch array dimension: frequency=8, time=64
pretraining number of patches=512
fine-tuning patch split stride: frequncey=16, time=16
fine-tuning number of patches=512
Total unique species (inat_taxon_id) in taxonomy: 206
Unique species IDs: ['10057', '10126', '10199', '10288', '10295', '10297', '10359', '107936', '10890', '11063', '11281', '11364', '1139490', '116877', '11872', '1192948', '11937', '1194042', '11972', '126247', '1286908', '1289601', '1289646', '1300', '1346504', '134933', '135045', '14235', '1432779', '144455', '144643', '144857', '144878', '145234', '145236', '1462711', '1462737', '1468', '1538', '1564122', '15807', '15810', '15909', '1593', '16063', '16371', '16447', '16559', '16567', '16714', '16737', 

# GNN:

In [None]:
import pandas as pd
import torch
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
import torch.optim as optim
import json
from ssast.src.dataloader import AudioDataset

# ------------------ Input Setup ------------------
taxonomy_path = "data/taxonomy.csv"
support_json_path = "data/support_data.json"

# Audio configuration (same as few-shot script)
dataset_mean = -4.2677393
dataset_std = 4.5689974
audio_conf = {
    'num_mel_bins': 128,
    'target_length': 1024,
    'freqm': 0,
    'timem': 0,
    'mixup': 0,
    'dataset': 'custom',
    'mode': 'evaluation',
    'mean': dataset_mean,
    'std': dataset_std,
    'noise': False
}

# Load support dataset to get segment labels
support_dataset = AudioDataset(support_json_path, audio_conf=audio_conf, label_csv=None)
support_segment_species = support_dataset.get_segment_labels()
print(f"inat_taxon_id for each support segment: {support_segment_species}")

# Load taxonomy to get all unique inat_taxon_id values
df = pd.read_csv(taxonomy_path)
df['inat_taxon_id'] = df['inat_taxon_id'].astype(str)
unique_species_ids = sorted(df['inat_taxon_id'].unique().tolist())
num_species = len(unique_species_ids)

# Create mapping from inat_taxon_id to index
species_to_index = {sid: idx for idx, sid in enumerate(unique_species_ids)}

# Dynamically create one-hot encoded support_labels based on segment labels
num_segments = len(support_segment_species)
support_labels = torch.zeros(num_segments, num_species)
for i, species_id in enumerate(support_segment_species):
    idx = species_to_index.get(species_id, -1)
    if idx != -1:
        support_labels[i, idx] = 1.0
    else:
        print(f"Warning: {species_id} not found in taxonomy, skipping.")
print(f"Dynamically generated support_labels shape: {support_labels.shape}")

# Verify support_embeddings shape matches num_segments
assert support_embeddings.shape[0] == num_segments, f"support_embeddings shape[0] ({support_embeddings.shape[0]}) must match num_segments ({num_segments})"

# Convert one-hot support_labels to global indices
support_label_indices = torch.argmax(support_labels, dim=1)

# Map global indices to inat_taxon_id
mapped_species_ids = [unique_species_ids[idx.item()] if idx.item() != -1 else 'unknown' for idx in support_label_indices]

# Unique species in the support set
unique_support_ids = list(dict.fromkeys(mapped_species_ids))
if 'unknown' in unique_support_ids:
    unique_support_ids.remove('unknown')

# ------------------ Build Taxonomy Graph ------------------
support_df = df[df['inat_taxon_id'].isin(unique_support_ids)]
class_names = support_df['class_name'].unique().tolist()

# Build node mapping
node_id_map = {}
edges = []
current_index = 0

# Map species (inat_taxon_id)
for sid in unique_support_ids:
    node_id_map[sid] = current_index
    current_index += 1

# Map class nodes
for cname in class_names:
    node_id_map[cname] = current_index
    current_index += 1

# Edges species ↔ class
for _, row in support_df.iterrows():
    sid = row['inat_taxon_id']
    cname = row['class_name']
    src = node_id_map[sid]
    dst = node_id_map[cname]
    edges.append((src, dst))
    edges.append((dst, src))

edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
num_nodes = len(node_id_map)
embed_dim = support_embeddings.shape[1]

# Assign support embeddings to species nodes by averaging embeddings per species
x = torch.zeros((num_nodes, embed_dim))
for species_id in unique_support_ids:
    species_mask = torch.tensor([sid == species_id for sid in mapped_species_ids], dtype=torch.bool)
    species_embeddings = support_embeddings[species_mask].mean(dim=0)
    x[node_id_map[species_id]] = species_embeddings

# Initialize class nodes with average of support embeddings
avg_embed = support_embeddings.mean(dim=0)
for cname in class_names:
    x[node_id_map[cname]] = avg_embed

# Build Data object
data = Data(x=x, edge_index=edge_index)

# ------------------ GNN Model ------------------
class TaxonomyGNN(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim):
        super().__init__()
        self.conv1 = GCNConv(in_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, in_dim)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.3, training=self.training)
        x = self.conv2(x, edge_index)
        return x

# ------------------ Train ------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TaxonomyGNN(in_dim=embed_dim, hidden_dim=256).to(device)
data = data.to(device)
support_embeddings = support_embeddings.to(device)

optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

def train():
    model.train()
    optimizer.zero_grad()
    out = model(data)
    species_nodes = [node_id_map[sid] for sid in unique_support_ids if sid in node_id_map]
    averaged_embeddings = torch.zeros(len(unique_support_ids), embed_dim, device=device)
    for i, species_id in enumerate(unique_support_ids):
        species_mask = torch.tensor([sid == species_id for sid in mapped_species_ids], dtype=torch.bool, device=device)
        averaged_embeddings[i] = support_embeddings[species_mask].mean(dim=0)
    loss = F.mse_loss(out[species_nodes], averaged_embeddings)
    loss.backward()
    optimizer.step()
    return loss.item()

for epoch in range(100):
    loss = train()
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss:.4f}")

# ------------------ Output Refined Support Embeddings ------------------
model.eval()
with torch.no_grad():
    all_embeddings = model(data)

refined_support_embeddings = torch.stack([
    all_embeddings[node_id_map[sid]] for sid in unique_support_ids if sid in node_id_map
])

print("Refined support embeddings shape:", refined_support_embeddings.shape)
print("Refined support embeddings sample:\n", refined_support_embeddings[0][:5])

inat_taxon_id for each support segment: ['21116', '21116', '46010', '46010', '46010', '46010', '46010', '46010', '46010', '46010', '46010', '46010', '46010', '46010', '46010', '46010', '46010', '46010', '46010', '46010']
Dynamically generated support_labels shape: torch.Size([20, 206])
Epoch 0, Loss: 0.5284
Epoch 10, Loss: 0.0416
Epoch 20, Loss: 0.0197
Epoch 30, Loss: 0.0415
Epoch 40, Loss: 0.0103
Epoch 50, Loss: 0.0073
Epoch 60, Loss: 0.0072
Epoch 70, Loss: 0.0217
Epoch 80, Loss: 0.0698
Epoch 90, Loss: 0.0592
Refined support embeddings shape: torch.Size([2, 527])
Refined support embeddings sample:
 tensor([ 0.1119, -0.7223, -0.2006,  0.4223,  0.1636])


# Temporal Memory:

In [7]:
import torch
import torch.nn as nn

# Define the TemporalMemory module with an embedding projection
class TemporalMemory(nn.Module):
    def __init__(self, input_dim=527, output_dim=512, hidden_dim=768, num_layers=2, nhead=8):
        super(TemporalMemory, self).__init__()
        # Project input embeddings to a dimension divisible by nhead
        self.embedding_projection = nn.Linear(input_dim, output_dim)
        # Ensure output_dim is divisible by nhead
        assert output_dim % nhead == 0, f"output_dim ({output_dim}) must be divisible by nhead ({nhead})"
        encoder_layer = nn.TransformerEncoderLayer(d_model=output_dim, nhead=nhead, dim_feedforward=hidden_dim)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        # Project back to original input_dim for consistency
        self.output_projection = nn.Linear(output_dim, input_dim)

    def forward(self, x):
        """
        x: Tensor of shape (sequence_length, batch_size, input_dim)
        Returns: Context-aware embeddings with shape (sequence_length, batch_size, input_dim)
        """
        print(f"Input shape to TemporalMemory: {x.shape}")
        # Ensure x has 3 dimensions
        if x.dim() != 3:
            raise ValueError(f"Expected 3D tensor, got {x.dim()}D tensor with shape {x.shape}")
        sequence_length, batch_size, input_dim = x.shape
        # Reshape to combine sequence_length and batch_size for nn.Linear
        x_reshaped = x.view(sequence_length * batch_size, input_dim)  # Shape: [8*1, 527] = [8, 527]
        # Project embeddings to transformer-compatible dimension
        x_projected = self.embedding_projection(x_reshaped)  # Shape: [8*1, 512] = [8, 512]
        x_projected = x_projected.view(sequence_length, batch_size, -1)  # Shape: [8, 1, 512]
        print(f"Shape after embedding_projection: {x_projected.shape}")
        # Apply transformer
        x_transformed = self.transformer_encoder(x_projected)  # Shape: [8, 1, 512]
        print(f"Shape after transformer_encoder: {x_transformed.shape}")
        # Reshape again for output projection
        x_transformed = x_transformed.view(sequence_length * batch_size, -1)  # Shape: [8*1, 512] = [8, 512]
        # Project back to original dimension
        x_output = self.output_projection(x_transformed)  # Shape: [8*1, 527] = [8, 527]
        x_output = x_output.view(sequence_length, batch_size, input_dim)  # Shape: [8, 1, 527]
        print(f"Shape after output_projection: {x_output.shape}")
        return x_output

# Assume query_embeddings comes from the few-shot code (previous cell)
# It should have shape [num_query_clips, embedding_dim], e.g., [8, 527], but currently [8, 1, 1, 1, 1, 527]

# Check and adjust query_embeddings shape
print(f"Initial query_embeddings shape: {query_embeddings.shape}")
# Squeeze out all singleton dimensions except the last two
query_embeddings = query_embeddings.squeeze()  # Remove all singleton dimensions
print(f"Shape after squeezing singletons: {query_embeddings.shape}")
print(f"Initial query_embeddings sample (first 5 elements of first clip):\n{query_embeddings[0][:5]}")

# Check if the result is 2D or 3D
if query_embeddings.dim() == 2:
    query_embeddings = query_embeddings.unsqueeze(1)  # Add batch dimension to make it 3D
    print(f"Adjusted query_embeddings shape after unsqueeze: {query_embeddings.shape}")
elif query_embeddings.dim() != 3:
    raise ValueError(f"query_embeddings must be 2D or 3D after squeezing, got {query_embeddings.dim()}D tensor with shape {query_embeddings.shape}")

# Initialize temporal memory module
temporal_memory = TemporalMemory(input_dim=query_embeddings.size(2))  # input_dim should now be correct

# Ensure model and input are on the same device
device = query_embeddings.device
temporal_memory = temporal_memory.to(device)

# Get context-aware query embeddings
contextual_embeddings = temporal_memory(query_embeddings)

# Remove batch dimension to match original shape
contextual_embeddings = contextual_embeddings.squeeze(1)  # Shape: [num_query_clips, embedding_dim]

# Output shape
print(f"Contextual embeddings shape: {contextual_embeddings.shape}")

Initial query_embeddings shape: torch.Size([48, 527])
Shape after squeezing singletons: torch.Size([48, 527])
Initial query_embeddings sample (first 5 elements of first clip):
tensor([ 0.4867, -0.6439, -0.2926, -0.0734,  0.2799])
Adjusted query_embeddings shape after unsqueeze: torch.Size([48, 1, 527])
Input shape to TemporalMemory: torch.Size([48, 1, 527])
Shape after embedding_projection: torch.Size([48, 1, 512])
Shape after transformer_encoder: torch.Size([48, 1, 512])
Shape after output_projection: torch.Size([48, 1, 527])
Contextual embeddings shape: torch.Size([48, 527])




# labaling:

In [None]:
import torch
import torch.nn.functional as F
import pandas as pd
import json
from ssast.src.dataloader import AudioDataset

# Assume refined_support_embeddings comes from the GNN code
# Shape: [num_unique_species, 527], e.g., [2, 527] for unique_support_ids ['21116', '46010']
# Assume contextual_embeddings comes from the TemporalMemory code
# Shape: [num_query_segments, 527], e.g., [48, 527]

# Load taxonomy and data paths
taxonomy_path = "data/taxonomy.csv"
support_json_path = "data/support_data.json"
query_json_path = "data/query_data.json"

# Audio configuration (same as previous scripts)
dataset_mean = -4.2677393
dataset_std = 4.5689974
audio_conf = {
    'num_mel_bins': 128,
    'target_length': 1024,
    'freqm': 0,
    'timem': 0,
    'mixup': 0,
    'dataset': 'custom',
    'mode': 'evaluation',
    'mean': dataset_mean,
    'std': dataset_std,
    'noise': False
}

# Load taxonomy
df = pd.read_csv(taxonomy_path)
df['inat_taxon_id'] = df['inat_taxon_id'].astype(str)
unique_species_ids = sorted(df['inat_taxon_id'].unique().tolist())
num_species = len(unique_species_ids)

# Create mapping from inat_taxon_id to index
species_to_index = {sid: idx for idx, sid in enumerate(unique_species_ids)}

# Load support data
with open(support_json_path, 'r') as f:
    support_data = json.load(f)

if isinstance(support_data, dict) and 'data' in support_data:
    if not support_data['data']:
        raise ValueError("The 'data' key in support_data.json is empty.")
    support_clip_species = [str(item['labels']) for item in support_data['data']]
else:
    raise ValueError(f"Unexpected structure for support_data.json. Expected a dict with 'data' key, got {type(support_data)}")

# Load query data to get original clip labels
with open(query_json_path, 'r') as f:
    query_data = json.load(f)

if isinstance(query_data, dict) and 'data' in query_data:
    ground_truth_species = [str(item['labels']) for item in query_data['data']]
    print(f"Original ground truth species from query_data.json: {ground_truth_species}")
else:
    raise ValueError(f"Unexpected structure for query_data.json. Expected a dict with 'data' key, got {type(query_data)}")

# Map refined_support_embeddings back to the number of unique support species
unique_support_ids = sorted(list(set(support_clip_species)))  # e.g., ['21116', '46010']
species_to_refined_idx = {sid: idx for idx, sid in enumerate(unique_support_ids)}

# Expand refined_support_embeddings to match the number of support clips
expanded_support_embeddings = torch.zeros(len(support_clip_species), refined_support_embeddings.size(1))
for i, species_id in enumerate(support_clip_species):
    refined_idx = species_to_refined_idx[species_id]
    expanded_support_embeddings[i] = refined_support_embeddings[refined_idx]

# Set refined support and contextual query embeddings
support_embeddings = expanded_support_embeddings  # Shape: [7, 527]
query_embeddings = contextual_embeddings         # Shape: [48, 527]

# Dynamically create one-hot encoded support_labels based on inat_taxon_id
support_labels = torch.zeros(len(support_clip_species), num_species)
for i, species_id in enumerate(support_clip_species):
    idx = species_to_index[species_id]
    support_labels[i, idx] = 1.0
print(f"Dynamically generated support_labels:\n{support_labels}")

print(f"Support embeddings shape: {support_embeddings.shape}")
print(f"Query embeddings shape: {query_embeddings.shape}")

# Convert one-hot support_labels to class indices
support_label_indices = torch.argmax(support_labels, dim=1)
print(f"Support label indices: {support_label_indices}")

# Compute number of classes (excluding 'unknown', which has index -1)
unique_classes = set(support_label_indices.tolist()) - {-1}
num_classes = len(unique_classes)
print(f"Number of classes: {num_classes}")

# Map global indices to local batch indices (0, 1, ..., num_classes-1)
sorted_unique_classes = sorted(unique_classes)
global_to_local = {global_idx: local_idx for local_idx, global_idx in enumerate(sorted_unique_classes)}
local_to_global = {local_idx: global_idx for local_idx, global_idx in enumerate(sorted_unique_classes)}
print(f"Global to local mapping: {global_to_local}")
print(f"Local to global mapping: {local_to_global}")

# Map global indices to inat_taxon_id
index_to_species = {global_idx: unique_species_ids[global_idx] for global_idx in unique_classes}
print(f"Index to species mapping: {index_to_species}")

# Attention-based few-shot inference
def attention_based_prototypes(query_emb, support_emb, support_label_indices, num_classes):
    query_emb = F.normalize(query_emb, dim=-1)
    support_emb = F.normalize(support_emb, dim=-1)
    attention_scores = torch.mm(query_emb, support_emb.t())
    print(f"Attention scores shape: {attention_scores.shape}")
    attention_weights = F.softmax(attention_scores, dim=-1)
    print(f"Attention weights shape: {attention_weights.shape}")
    
    local_support_label_indices = torch.tensor([global_to_local[idx.item()] for idx in support_label_indices])
    print(f"Local support label indices: {local_support_label_indices}")
    
    prototypes = torch.zeros(num_classes, support_emb.size(1))
    for c in range(num_classes):
        class_mask = (local_support_label_indices == c).float().unsqueeze(0)
        print(f"Class {c} mask shape: {class_mask.shape}")
        class_weights = (attention_weights * class_mask).sum(dim=0, keepdim=True)
        class_weights = class_weights / (class_weights.sum(dim=1, keepdim=True) + 1e-8)
        prototype = torch.mm(class_weights, support_emb)
        prototypes[c] = prototype.squeeze(0)
    return prototypes

def few_shot_inference(query_emb, support_emb, support_label_indices, num_classes, device):
    query_emb = query_emb.to(device)
    support_emb = support_emb.to(device)
    support_label_indices = support_label_indices.to(device)
    prototypes = attention_based_prototypes(query_emb, support_emb, support_label_indices, num_classes).to(device)
    query_emb = F.normalize(query_emb, dim=-1)
    prototypes = F.normalize(prototypes, dim=-1)
    similarity = torch.mm(query_emb, prototypes.t())
    predictions = similarity.argmax(dim=1)
    print(f"Predictions before mapping: {predictions}")
    return predictions

# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Perform inference
predictions = few_shot_inference(query_embeddings, support_embeddings, support_label_indices, num_classes, device)

# Validate predictions
valid_local_indices = list(local_to_global.keys())
print(f"Valid local indices: {valid_local_indices}")
for pred in predictions:
    if pred.item() not in local_to_global:
        raise ValueError(f"Prediction {pred.item()} is not a valid local index. Valid indices are {valid_local_indices}")

print(f"Predictions for query clips (local indices): {predictions}")

# Map local indices to global indices, then to inat_taxon_id
predicted_global_indices = [local_to_global[pred.item()] for pred in predictions]
predicted_segment_species = [index_to_species[global_idx] for global_idx in predicted_global_indices]
print(f"Predicted species (inat_taxon_id) for query segments: {predicted_segment_species}")

# Load query_data.json and map segment predictions to clips
with open(query_json_path, 'r') as f:
    query_data = json.load(f)

if isinstance(query_data, dict) and 'data' in query_data:
    num_clips = len(query_data['data'])
    # Assume segments are grouped by clip (e.g., 6 segments per clip for 8 clips = 48 segments)
    segments_per_clip = len(predictions) // num_clips  # e.g., 48 // 8 = 6
    if len(predictions) % num_clips != 0:
        print(f"Warning: Number of segments ({len(predictions)}) not evenly divisible by number of clips ({num_clips}). Using floor division.")
    
    predicted_clip_species = []
    for i in range(num_clips):
        start_idx = i * segments_per_clip
        end_idx = min((i + 1) * segments_per_clip, len(predictions))
        clip_predictions = predictions[start_idx:end_idx]
        # Majority vote for clip-level prediction
        if len(clip_predictions) > 0:
            majority_pred = torch.mode(clip_predictions).values.item()
            predicted_clip_species.append(index_to_species[local_to_global[majority_pred]])
        else:
            predicted_clip_species.append('unknown')
    
    # Update each clip's 'labels' with the majority-predicted species
    for i, item in enumerate(query_data['data']):
        item['labels'] = predicted_clip_species[i]
    print(f"Original ground truth species from query_data.json: {ground_truth_species}")
    print(f"Updated predicted species in query_data.json: {predicted_clip_species}")
else:
    raise ValueError(f"Unexpected structure for query_data.json. Expected a dict with 'data' key, got {type(query_data)}")

# Save the updated query_data.json
with open(query_json_path, 'w') as f:
    json.dump(query_data, f, indent=2)
print(f"Updated query_data.json with predicted labels saved to {query_json_path}")

Original ground truth species from query_data.json: ['46010', '46010', '46010', '46010', '46010', '46010', '46010', '46010']
Dynamically generated support_labels:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
Support embeddings shape: torch.Size([7, 527])
Query embeddings shape: torch.Size([48, 527])
Support label indices: tensor([ 89,  89, 122, 122, 122, 122, 122])
Number of classes: 2
Global to local mapping: {89: 0, 122: 1}
Local to global mapping: {0: 89, 1: 122}
Index to species mapping: {89: '21116', 122: '46010'}
Attention scores shape: torch.Size([48, 7])
Attention weights shape: torch.Size([48, 7])
Local support label indices: tensor([0, 0, 1, 1, 1, 1, 1])
Class 0 mask shape: torch.Size([1, 7])
Class 1 mask shape: torch.Size([1, 7])
Predictions before mapping: tensor([0, 0

# Fine tuning:

In [None]:
import torch
import pandas as pd
import json
import os
import sys
import pickle
from ssast.src.models.ast_models import ASTModel
from ssast.src.dataloader import AudioDataset
import importlib

# Add the correct path to sys.path
sys.path.append(os.path.abspath('ssast/src'))

# Verify utilities package exists
utilities_dir = os.path.join(os.path.abspath('ssast/src'), 'utilities')
if not os.path.exists(utilities_dir) or not os.path.isdir(utilities_dir):
    raise FileNotFoundError(
        f"utilities directory not found at {utilities_dir}. "
        "Please ensure the utilities package is present with __init__.py, util.py, and stats.py."
    )
if not os.path.exists(os.path.join(utilities_dir, '__init__.py')):
    raise FileNotFoundError(
        f"__init__.py not found in {utilities_dir}. "
        "Please ensure the utilities package is correctly set up."
    )
if not os.path.exists(os.path.join(utilities_dir, 'util.py')) or not os.path.exists(os.path.join(utilities_dir, 'stats.py')):
    raise FileNotFoundError(
        f"util.py or stats.py not found in {utilities_dir}. "
        "Please download them from https://github.com/YuanGongND/ssast/tree/main/src/utilities "
        "and place them in E:/bird/ssast/src/utilities/"
    )

# Import traintest for fine-tuning
import ssast.src.traintest
importlib.reload(ssast.src.traintest)
from ssast.src.traintest import train

# Paths and data preparation
taxonomy_path = "data/taxonomy.csv"
support_json_path = "data/support_data.json"
query_json_path = "data/query_data.json"
combined_json_path = "data/combined_data.json"
labels_csv_path = "data/labels.csv"

# Load taxonomy
df = pd.read_csv(taxonomy_path)
df['inat_taxon_id'] = df['inat_taxon_id'].astype(str)
unique_species_ids = sorted(df['inat_taxon_id'].unique().tolist())
num_species = len(unique_species_ids)
species_to_index = {sid: idx for idx, sid in enumerate(unique_species_ids)}

# Create or update labels.csv
labels_df = pd.DataFrame({
    'mid': unique_species_ids,
    'index': range(num_species),
    'display_name': df.set_index('inat_taxon_id').loc[unique_species_ids, 'common_name'].fillna(df['scientific_name'])
})
labels_df.to_csv(labels_csv_path, index=False)
print(f"Created/updated labels.csv at {labels_csv_path} with {num_species} classes")

# Combine support and query data
with open(support_json_path, 'r') as f:
    support_data = json.load(f)
with open(query_json_path, 'r') as f:
    query_data = json.load(f)

combined_data = {"data": support_data['data'] + query_data['data']}
with open(combined_json_path, 'w') as f:
    json.dump(combined_data, f, indent=2)
print(f"Combined dataset saved to {combined_json_path}")

# Audio and model configurations
dataset_mean = -4.2677393
dataset_std = 4.5689974
target_length = 1024
num_mel_bins = 128
model_size = 'base'
fshape = 16
tshape = 16
fstride = 16
tstride = 16

audio_conf = {
    'num_mel_bins': num_mel_bins,
    'target_length': target_length,
    'freqm': 24,
    'timem': 96,
    'mixup': 0.5,
    'dataset': 'custom',
    'mode': 'train',
    'mean': dataset_mean,
    'std': dataset_std,
    'noise': True
}

val_audio_conf = {
    'num_mel_bins': num_mel_bins,
    'target_length': target_length,
    'freqm': 0,
    'timem': 0,
    'mixup': 0,
    'dataset': 'custom',
    'mode': 'evaluation',
    'mean': dataset_mean,
    'std': dataset_std,
    'noise': False
}

# Create DataLoaders with adjustable batch size
train_dataset = AudioDataset(combined_json_path, label_csv=labels_csv_path, audio_conf=audio_conf)
batch_size = min(8, len(train_dataset))  # Adjustable, max 16 based on your success with batch size
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,  # Parallel data loading (adjust if needed)
    pin_memory=torch.cuda.is_available(),
    drop_last=False
)

val_loader = torch.utils.data.DataLoader(
    AudioDataset(combined_json_path, label_csv=labels_csv_path, audio_conf=val_audio_conf),
    batch_size=max(1, batch_size * 2),
    shuffle=False,
    num_workers=4,
    pin_memory=torch.cuda.is_available()
)

# Initialize the pre-trained model
exp_dir = r"E:/bird/exp/mask01-base-f16-t16-b24-lr0.0001-m400-pretrain_joint-custom"
pretrained_model_path = os.path.join(exp_dir, "models/audio_model.1.pth")
if not os.path.exists(pretrained_model_path):
    raise FileNotFoundError(f"Pre-trained model not found at {pretrained_model_path}. Ensure pre-training completed.")

ast_mdl = ASTModel(
    fshape=fshape,
    tshape=tshape,
    fstride=fstride,
    tstride=tstride,
    input_fdim=num_mel_bins,
    input_tdim=target_length,
    model_size=model_size,
    pretrain_stage=False,
    load_pretrained_mdl_path=pretrained_model_path
)

# Optimize device usage
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() <= 1:
    ast_mdl = ast_mdl.to(device)
else:
    ast_mdl = torch.nn.DataParallel(ast_mdl).to(device)

# Set up arguments for fine-tuning
class Args:
    def __init__(self):
        self.data_train = combined_json_path
        self.data_val = combined_json_path
        self.data_eval = None
        self.label_csv = labels_csv_path
        self.n_class = num_species
        self.dataset = 'custom'
        self.dataset_mean = dataset_mean
        self.dataset_std = dataset_std
        self.target_length = target_length
        self.num_mel_bins = num_mel_bins
        self.exp_dir = os.path.join(exp_dir.replace("pretrain", "finetune"), f"base-f16-t16-b{batch_size}-lr0.00001-ft_cls-custom")
        self.lr = 1e-5
        self.warmup = False
        self.optim = "adam"
        self.batch_size = batch_size
        self.num_workers = 4
        self.n_epochs = 5
        self.lr_patience = 2
        self.adaptschedule = True
        self.n_print_steps = 100
        self.save_model = True
        self.freqm = 24
        self.timem = 96
        self.mixup = 0.5
        self.bal = 'none'
        self.fstride = fstride
        self.tstride = tstride
        self.fshape = fshape
        self.tshape = tshape
        self.model_size = model_size
        self.task = 'ft_cls'
        self.cluster_factor = 3
        self.epoch_iter = 4000
        self.pretrained_mdl_path = pretrained_model_path
        self.head_lr = 1
        self.noise = True
        self.metrics = "acc"
        self.lrscheduler_start = 5
        self.lrscheduler_step = 3
        self.lrscheduler_decay = 0.5
        self.wa = False
        self.wa_start = 16
        self.wa_end = 30
        self.loss = "CE"

args = Args()
os.makedirs(args.exp_dir, exist_ok=True)
os.makedirs(os.path.join(args.exp_dir, "models"), exist_ok=True)
with open(os.path.join(args.exp_dir, "args.pkl"), "wb") as f:
    pickle.dump(args, f)

# Start fine-tuning
print(f"\n✅ Now starting fine-tuning for {args.n_epochs} epochs with task: {args.task}")
train(ast_mdl, train_loader, val_loader, args)
print("Files in exp_dir:", os.listdir(args.exp_dir))
print("Files in models dir:", os.listdir(os.path.join(args.exp_dir, "models")))



Created/updated labels.csv at E:/bird/labels.csv with 206 classes
Combined dataset saved to E:/bird/combined_data.json
now load a SSL pretrained models from E:/bird/exp/mask01-base-f16-t16-b24-lr0.0001-m400-pretrain_joint-custom\models/audio_model.1.pth
pretraining patch split stride: frequency=16, time=16
pretraining patch shape: frequency=16, time=16
pretraining patch array dimension: frequency=8, time=64
pretraining number of patches=512
fine-tuning patch split stride: frequncey=16, time=16
fine-tuning number of patches=512

✅ Now starting fine-tuning for 5 epochs with task: ft_cls
running on cpu
Total parameter number is : 87.594 million
Total trainable parameter number is : 87.594 million
The mlp header uses 1 x larger lr
Total mlp parameter number is : 0.407 million
Total base parameter number is : 87.188 million
now use adaptive learning rate scheduler.
now training with custom, main metrics: acc, loss function: CrossEntropyLoss(), learning rate scheduler: <torch.optim.lr_schedu



start validation




acc: 0.970588
AUC: nan
Avg Precision: 0.004854
Avg Recall: 1.000000
d_prime: nan
train_loss: 1.794949
valid_loss: 5.773499
validation finished
adaptive learning rate scheduler step
Epoch-1 lr: 1e-05
Epoch-1 lr: 1e-05
epoch 1 training time: 191.226
---------------
2025-05-16 13:03:58.020213
current #epochs=2, #steps=9
start validation




acc: 0.970588
AUC: nan
Avg Precision: 0.004854
Avg Recall: 1.000000
d_prime: nan
train_loss: 0.280909
valid_loss: 5.767474
validation finished
adaptive learning rate scheduler step
Epoch-2 lr: 1e-05
Epoch-2 lr: 1e-05
epoch 2 training time: 190.556
---------------
2025-05-16 13:07:08.576308
current #epochs=3, #steps=18
start validation




acc: 0.970588
AUC: nan
Avg Precision: 0.004854
Avg Recall: 1.000000
d_prime: nan
train_loss: 0.187093
valid_loss: 5.766757
validation finished
adaptive learning rate scheduler step
Epoch-3 lr: 1e-05
Epoch-3 lr: 1e-05
epoch 3 training time: 189.190
---------------
2025-05-16 13:10:17.766475
current #epochs=4, #steps=27
start validation




acc: 0.970588
AUC: nan
Avg Precision: 0.004854
Avg Recall: 1.000000
d_prime: nan
train_loss: 0.151513
valid_loss: 5.766628
validation finished
adaptive learning rate scheduler step
Epoch-4 lr: 1e-05
Epoch-4 lr: 1e-05
epoch 4 training time: 198.436
---------------
2025-05-16 13:13:36.202675
current #epochs=5, #steps=36
start validation




acc: 0.970588
AUC: nan
Avg Precision: 0.004854
Avg Recall: 1.000000
d_prime: nan
train_loss: 0.134475
valid_loss: 5.762209
validation finished
adaptive learning rate scheduler step
Epoch-5 lr: 1e-05
Epoch-5 lr: 1e-05
epoch 5 training time: 238.483
Files in exp_dir: ['args.pkl', 'models', 'predictions', 'progress.pkl', 'result.csv', 'stats_1.pickle', 'stats_2.pickle', 'stats_3.pickle', 'stats_4.pickle', 'stats_5.pickle']
Files in models dir: ['audio_model.1.pth', 'audio_model.2.pth', 'audio_model.3.pth', 'audio_model.4.pth', 'audio_model.5.pth', 'best_audio_model.pth', 'best_optim_state.pth']


In [3]:
import os
import json
import torchaudio

def create_json_with_duration(audio_dir, output_json_path, label="dummy_label"):
    audio_files = [f for f in os.listdir(audio_dir) if f.lower().endswith(('.wav', '.ogg', '.mp3'))]

    data = []
    for fname in audio_files:
        full_path = os.path.join(audio_dir, fname).replace("\\", "/")
        try:
            waveform, sample_rate = torchaudio.load(full_path)
            duration = waveform.shape[1] / sample_rate  # samples / rate = seconds
        except Exception as e:
            print(f"⚠️ Could not load {fname}: {e}")
            duration = 0.0
        
        data.append({
            "wav": full_path,
            "labels": label,
            "duration": round(duration, 2)
        })

    with open(output_json_path, 'w') as f:
        json.dump({"data": data}, f, indent=2)

    print(f"✅ Created JSON with {len(data)} entries at: {output_json_path}")

# Example usage
audio_directory = "E:/bird/data/train_soundscapes"
output_json = "E:/bird/data/unlabeled_data.json"
create_json_with_duration(audio_directory, output_json)


✅ Created JSON with 9726 entries at: E:/bird/data/unlabeled_data.json
