In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


TRAK (Training Data Attribution at Scale) is a tool to analyze how much influence each training sample had on a model's output. By comparing embeddings of training data and generated outputs, it assigns TRAK scores that quantify this influence. It’s highly efficient, scalable, and useful for debugging, fine-tuning analysis, and understanding model behavior in tasks like music generation, vision, and multimodal applications.

1. **Setup** ***Environment***

In [None]:
# Install TRAK (fast version if GPU is available)
!pip install traker[fast]

# Install CLAP dependencies
!pip install transformers torchaudio

# Install PyTorch with CUDA support (if not pre-installed in Colab)
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117


Collecting traker[fast]
  Downloading traker-0.3.2.tar.gz (36 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting fast_jl (from traker[fast])
  Downloading fast_jl-0.1.3.tar.gz (5.3 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: fast_jl, traker
  Building wheel for fast_jl (setup.py) ... [?25l[?25hdone
  Created wheel for fast_jl: filename=fast_jl-0.1.3-cp310-cp310-linux_x86_64.whl size=971503 sha256=668818613823481f19d33a469f18c75df58736cd1d9224f63e54c879ca6dd686
  Stored in directory: /root/.cache/pip/wheels/c0/5e/37/3a9d828d49fbbcb9ba54a11d3b48a5a5ac627bd8beb46bf05b
  Building wheel for traker (setup.py) ... [?25l[?25hdone
  Created wheel for traker: filename=traker-0.3.2-py3-none-any.whl size=28986 sha256=81b02493096cb1318f4ba4d8a7929c42376a52978ed520f318fa865180054529
  Stored in directory: /root/.cache/pip/wheels/2a/bb/be/1e35e69a11e1aba84adfedeae2691798134199591cfe6d4f4e
Successfully built fast_jl traker
Insta

**3. Encode Training Data Using CLAP**

In [None]:
from transformers import ClapProcessor, ClapModel
import torchaudio
import torchaudio.transforms as T
import os
import torch

# Load CLAP model
processor = ClapProcessor.from_pretrained("laion/clap-htsat-fused")
model = ClapModel.from_pretrained("laion/clap-htsat-fused")

# Directory paths
audio_dir = '/content/drive/MyDrive/TRAK_Data/audio'
output_dir = '/content/drive/MyDrive/TRAK_Data/embeddings'
os.makedirs(output_dir, exist_ok=True)

# Process each song folder
for song_folder in os.listdir(audio_dir):
    song_path = os.path.join(audio_dir, song_folder)
    if os.path.isdir(song_path):
        print(f"Processing song: {song_folder}")
        song_embeddings = []

        # Process each segment in the song folder
        for segment in os.listdir(song_path):
            if segment.endswith('.wav'):
                try:
                    # Load segment
                    segment_path = os.path.join(song_path, segment)
                    waveform, sample_rate = torchaudio.load(segment_path)

                    # Resample to 48,000 Hz if necessary
                    if sample_rate != 48000:
                        resampler = T.Resample(orig_freq=sample_rate, new_freq=48000)
                        waveform = resampler(waveform)

                    # Normalize waveform (convert to mono)
                    waveform = waveform.mean(dim=0, keepdim=True)

                    # Process audio inputs
                    audio_inputs = processor(audios=waveform, sampling_rate=48000, return_tensors="pt")

                    # Extract embedding
                    embedding = model.get_audio_features(**audio_inputs)

                    # Append embedding
                    song_embeddings.append(embedding)
                    print(f"Processed segment: {segment}, Embedding shape: {embedding.shape}")
                except Exception as e:
                    print(f"Error processing {segment}: {e}")

        # Save embeddings for this song
        if len(song_embeddings) > 0:
            torch.save(song_embeddings, os.path.join(output_dir, f"{song_folder}_embeddings.pt"))
            print(f"Saved embeddings for song: {song_folder}")
        else:
            print(f"No embeddings generated for song: {song_folder}")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


preprocessor_config.json:   0%|          | 0.00/537 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/384 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/280 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/5.42k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/615M [00:00<?, ?B/s]

Processing song: 1, 2 Step (Supersonic)


**Steps to Aggregate Dataset Embedding**

**Combined segment embeddings for each song into single embeddings**

In [None]:
import torch
import os

# Path to saved segment embeddings
embedding_dir = '/content/drive/MyDrive/TRAK_Data/embeddings'
# Path to save aggregated embeddings
aggregated_output_dir = '/content/drive/MyDrive/TRAK_Data/aggregated_embeddings'
os.makedirs(aggregated_output_dir, exist_ok=True)

# Function to aggregate embeddings (mean pooling)
def aggregate_embeddings(embeddings):
    return torch.mean(torch.stack(embeddings), dim=0)

# Process each song's segment embeddings
for file in os.listdir(embedding_dir):
    file_path = os.path.join(embedding_dir, file)
    try:
        # Load segment embeddings
        embeddings = torch.load(file_path)
        if len(embeddings) > 0:
            # Aggregate segment embeddings
            aggregated_embedding = aggregate_embeddings(embeddings)
            # Save aggregated embedding
            save_path = os.path.join(aggregated_output_dir, f"{file}_aggregated.pt")
            torch.save(aggregated_embedding, save_path)
            print(f"Aggregated embedding saved for: {file}")
        else:
            print(f"No embeddings found in: {file}")
    except Exception as e:
        print(f"Error processing {file}: {e}")


Why Aggregate?
Aggregating the dataset embeddings:

Simplifies Analysis:
Instead of comparing the generated audio embedding to multiple segment embeddings, you compare it to a single embedding per song.

**Steps to Generate the Generated Output Embedding**

In [None]:
from transformers import ClapProcessor, ClapModel
import torchaudio
import torchaudio.transforms as T
import torch

# Load CLAP model
processor = ClapProcessor.from_pretrained("laion/clap-htsat-fused")
model = ClapModel.from_pretrained("laion/clap-htsat-fused")

# Path to generated audio file
generated_audio_path = '/content/drive/MyDrive/TRAK_Data/generated_audio.wav'

# Load generated audio
waveform, sample_rate = torchaudio.load(generated_audio_path)

# Resample to 48,000 Hz if necessary
if sample_rate != 48000:
    resampler = T.Resample(orig_freq=sample_rate, new_freq=48000)
    waveform = resampler(waveform)

# Normalize waveform (convert to mono if multi-channel)
waveform = waveform.mean(dim=0, keepdim=True)

# Process audio inputs
audio_inputs = processor(audios=waveform, sampling_rate=48000, return_tensors="pt")

# Extract embedding
generated_audio_embedding = model.get_audio_features(**audio_inputs)

# Save the generated audio embedding
output_path = '/content/drive/MyDrive/TRAK_Data/generated_audio_embedding.pt'
torch.save(generated_audio_embedding, output_path)
print(f"Generated audio embedding saved successfully at {output_path}")


**Load the Generated Audio Embedding**:

In [None]:
generated_audio_embedding = torch.load('/content/drive/MyDrive/TRAK_Data/generated_audio_embedding.pt')


**Load Aggregated Dataset Embeddings**

In [None]:
aggregated_output_dir = '/content/drive/MyDrive/TRAK_Data/aggregated_embeddings'
training_embeddings = []
for file in os.listdir(aggregated_output_dir):
    embedding = torch.load(os.path.join(aggregated_output_dir, file))
    training_embeddings.append(embedding)

print(f"Loaded {len(training_embeddings)} training embeddings.")


**Calculate TRAK Scores:**

In [None]:
from traker import TRAKer

# Initialize TRAK
trak = TRAKer(processor=None, model=None)  # Replace None with your CLAP processor/model if needed

# Calculate TRAK scores
trak_scores = trak.calculate_attributions(generated_audio_embedding, training_embeddings)
print(f"TRAK scores: {trak_scores}")
