# Testing Gemma-3N Model for Parkinson's Disease Detection

This notebook loads and tests the fine-tuned Gemma-3N model on audio data for Parkinson's disease detection.

## Imports

In [1]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1,<4.0.0" "huggingface_hub>=0.34.0" hf_transfer
    !pip install --no-deps unsloth

In [2]:
%%capture
# Install latest transformers for Gemma 3N
!pip install --no-deps --upgrade transformers # Only for Gemma 3N
!pip install --no-deps --upgrade timm # Only for Gemma 3N

In [3]:
!pip install opensmile

Collecting opensmile
  Downloading opensmile-2.6.0-py3-none-manylinux_2_17_x86_64.whl.metadata (15 kB)
Collecting audobject>=0.6.1 (from opensmile)
  Downloading audobject-0.7.12-py3-none-any.whl.metadata (2.7 kB)
Collecting audinterface>=0.7.0 (from opensmile)
  Downloading audinterface-1.3.1-py3-none-any.whl.metadata (4.3 kB)
Collecting audeer>=2.1.1 (from audinterface>=0.7.0->opensmile)
  Downloading audeer-2.2.2-py3-none-any.whl.metadata (4.1 kB)
Collecting audformat<2.0.0,>=1.0.1 (from audinterface>=0.7.0->opensmile)
  Downloading audformat-1.3.2-py3-none-any.whl.metadata (4.7 kB)
Collecting audiofile>=1.3.0 (from audinterface>=0.7.0->opensmile)
  Downloading audiofile-1.5.1-py3-none-any.whl.metadata (4.9 kB)
Collecting audmath>=1.4.1 (from audinterface>=0.7.0->opensmile)
  Downloading audmath-1.4.2-py3-none-any.whl.metadata (3.7 kB)
Collecting audresample<2.0.0,>=1.1.0 (from audinterface>=0.7.0->opensmile)
  Downloading audresample-1.3.4-py3-none-manylinux_2_17_x86_64.whl.metadat

In [4]:
import os
import gc
import json
import torch
import librosa
import numpy as np
from pathlib import Path
from unsloth import FastModel
from transformers import AutoTokenizer
from tqdm import tqdm
import soundfile as sf
import opensmile
import tempfile
from datasets import Dataset
from trl import SFTTrainer, SFTConfig

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!


## ParkinsonTestInterface Class

In [5]:
class ParkinsonTestInterface:
    def __init__(self, model_path, sample_rate=16000):
        self.sample_rate = sample_rate
        self.smile = opensmile.Smile(
            feature_set=opensmile.FeatureSet.ComParE_2016,
            feature_level=opensmile.FeatureLevel.Functionals
        )

        print("Loading model and tokenizer...")
        self.model, self.tokenizer = self.load_model(model_path)
        print("Model loaded successfully!")

    def load_model(self, model_path):
        base_model, tokenizer = FastModel.from_pretrained(
            model_name = "unsloth/gemma-3n-E2B-it",
            dtype = torch.float16,
            load_in_4bit = True,
            device_map = "auto"
        )

        model = FastModel.get_peft_model(
            base_model,
            peft_model_id = model_path
        )
        return model, tokenizer

    def extract_features(self, y, sr):
        f0 = librosa.yin(y, fmin=80, fmax=450, sr=sr)
        pitch_stats = {
            "pitch_mean": float(np.mean(f0)),
            "pitch_std": float(np.std(f0)),
            "pitch_min": float(np.min(f0)),
            "pitch_max": float(np.max(f0)),
            "pitch_var": float(np.var(f0))
        }

        zcr = float(librosa.feature.zero_crossing_rate(y)[0].mean())
        centroid = float(librosa.feature.spectral_centroid(y=y, sr=sr)[0].mean())

        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
            sf.write(temp_audio.name, y, sr)
            smile_feats = self.smile.process_file(temp_audio.name)
        os.unlink(temp_audio.name)

        smile_dict = smile_feats.iloc[0].to_dict()

        return {
            "zcr": zcr,
            "centroid": centroid,
            "hnr": float(smile_dict.get("HNRdBACF_sma[0]_amean", 0.0)),
            **pitch_stats,
            **{"jitter_local": float(smile_dict.get("jitterLocal_sma[0]_amean", 0.0))},
            **{"shimmer_db": float(smile_dict.get("shimmerLocaldB_sma[0]_amean", 0.0))}
        }

    def build_instruction(self, features):
        return (
            f"Analyze this audio clip for signs of Parkinson's. Acoustic features: "
            f"pitch={features['pitch_mean']:.1f}Hz, zcr={features['zcr']:.3f}, "
            f"centroid={features['centroid']:.1f}Hz, jitter={features['jitter_local']:.3f}, "
            f"shimmer={features['shimmer_db']:.2f}dB, hnr={features['hnr']:.2f}."
        )

    def predict(self, audio_path):
        try:
            y, sr = librosa.load(audio_path, sr=self.sample_rate)
            y = librosa.util.normalize(y)
            features = self.extract_features(y, sr)

            messages = [
                {"role": "system", "content": [{"type": "text", "text": "You are an assistant that detects Parkinson's disease from audio."}]},
                {"role": "user", "content": [
                    {"type": "audio", "audio": str(audio_path)},
                    {"type": "text", "text": self.build_instruction(features)}
                ]}
            ]

            # Generate inputs with attention mask
            inputs = self.tokenizer.apply_chat_template(
                messages,
                add_generation_prompt=True,
                tokenize=True,
                return_tensors="pt",
                return_attention_mask=True  # Add this line
            )

            # Move inputs to device
            inputs = {k: v.to(self.model.device) for k, v in inputs.items()}

            with torch.no_grad():
                outputs = self.model.generate(
                    input_ids=inputs['input_ids'],
                    attention_mask=inputs['attention_mask'],
                    max_new_tokens=128,
                    temperature=0.7,
                    top_p=0.95,
                    do_sample=True,
                    pad_token_id=self.tokenizer.pad_token_id,
                    eos_token_id=self.tokenizer.eos_token_id
                )

            prediction = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            if "model\n" in prediction:
                prediction = prediction.split("model\n")[1].strip()

            return prediction

        except Exception as e:
            print(f"Error processing {audio_path}: {str(e)}")
            return None

        finally:
            torch.cuda.empty_cache()
            gc.collect()

## Main Testing Code

In [6]:
# Set paths
if 'COLAB_GPU' in os.environ:
    model_path = "/content/drive/MyDrive/info/parkinsons_detector_gemma3n"
    data_dir = '/content/drive/MyDrive/data'
    results_dir = '/content/results'
    os.makedirs(results_dir, exist_ok=True)
else:
    model_path = "./info/parkinsons_detector_gemma3n"
    data_dir = './data'
    results_dir = './results'
    os.makedirs(results_dir, exist_ok=True)

In [7]:
# Initialize interface
interface = ParkinsonTestInterface(model_path)

Loading model and tokenizer...
==((====))==  Unsloth 2025.8.1: Fast Gemma3N patching. Transformers: 4.55.0.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: Gemma3N does not support SDPA - switching to eager!


model.safetensors.index.json: 0.00B [00:00, ?B/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/2.65G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/469M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/210 [00:00<?, ?B/s]

processor_config.json:   0%|          | 0.00/98.0 [00:00<?, ?B/s]

chat_template.jinja: 0.00B [00:00, ?B/s]

preprocessor_config.json: 0.00B [00:00, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/4.70M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/777 [00:00<?, ?B/s]

Unsloth: Making `model.base_model.model.model.language_model` require gradients
Model loaded successfully!


## Prediction

In [51]:
def predict_with_combined_analysis(model, tokenizer, audio_path):
    """Predict using both extracted features and direct audio input"""
    try:
        # Load and process audio
        y, sr = librosa.load(audio_path, sr=16000)
        y = librosa.util.normalize(y)

        # Extract acoustic features
        f0 = librosa.yin(y, fmin=80, fmax=450, sr=sr)
        features = {
            "pitch_mean": float(np.mean(f0)),
            "zcr": float(librosa.feature.zero_crossing_rate(y)[0].mean()),
            "centroid": float(librosa.feature.spectral_centroid(y=y, sr=sr)[0].mean())
        }

        # Extract voice quality features using OpenSMILE
        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
            sf.write(temp_audio.name, y, sr)
            smile = opensmile.Smile(
                feature_set=opensmile.FeatureSet.ComParE_2016,
                feature_level=opensmile.FeatureLevel.Functionals
            )
            smile_feats = smile.process_file(temp_audio.name)
        os.unlink(temp_audio.name)

        smile_dict = smile_feats.iloc[0].to_dict()
        features.update({
            "jitter": float(smile_dict.get("jitterLocal_sma[0]_amean", 0.0)),
            "shimmer": float(smile_dict.get("shimmerLocaldB_sma[0]_amean", 0.0)),
            "hnr": float(smile_dict.get("HNRdBACF_sma[0]_amean", 0.0))
        })

        # Create detailed prompt with features
        feature_text = (
            f"Analyze this audio for Parkinson's disease signs. "
            f"Acoustic measurements show:\n"
            f"- Pitch: {features['pitch_mean']:.1f}Hz (tremor may affect pitch stability)\n"
            f"- Jitter: {features['jitter']:.3f} (vocal fold vibration irregularity)\n"
            f"- Shimmer: {features['shimmer']:.2f}dB (amplitude perturbation)\n"
            f"- HNR: {features['hnr']:.2f} (voice quality measure)\n"
            f"- ZCR: {features['zcr']:.3f} (voice turbulence)\n"
            f"- Spectral Centroid: {features['centroid']:.1f}Hz (voice brightness)\n"
            f"Consider these measurements alongside the audio characteristics."
        )

        # Prepare messages with both audio and features
        messages = [
            {
                "role": "system",
                "content": [
                    {"type": "text", "text": "You are an assistant that detects Parkinson's disease from audio analysis."}
                ]
            },
            {
                "role": "user",
                "content": [
                    {"type": "audio", "audio": str(audio_path)},
                    {"type": "text", "text": feature_text}
                ]
            }
        ]

        # Generate prediction
        inputs = tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=True,
            return_tensors="pt"
        ).to(model.device)

        with torch.no_grad():
            outputs = model.generate(
                inputs,
                max_new_tokens=256,
                temperature=0.7,
                top_p=0.95,
                do_sample=True,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id
            )

        # Clean up prediction
        prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
        if "model\n" in prediction:
            prediction = prediction.split("model\n")[1].strip()

        return prediction

    except Exception as e:
        print(f"Error processing {audio_path}: {str(e)}")
        return None

In [None]:
# Test on random 100 samples
np.random.seed(3407)
test_indices = np.random.permutation(len(test_data))[:100]
selected_test_data = [test_data[i] for i in test_indices]

print(f"Selected {len(selected_test_data)} test examples")

# Process test examples
results = []
for idx, example in enumerate(tqdm(selected_test_data, desc="Testing samples")):
    try:
        # Extract audio path
        user_message = next(msg for msg in example['messages'] if msg['role'] == 'user')
        audio_content = next(content for content in user_message['content'] if content['type'] == 'audio')

        # Process path
        audio_path = audio_content['audio'].replace('\\', '/')
        audio_path = os.path.join(data_dir, audio_path)

        if os.path.exists(audio_path):
            prediction = predict_with_combined_analysis(interface.model, interface.tokenizer, audio_path)
            if prediction:
                results.append({
                    'audio_path': audio_path,
                    'prediction': prediction,
                    'ground_truth': example['messages'][-1]['content'][0]['text']
                })
                if idx < 5:
                    print(f"\nProcessed example {idx + 1}")
                    print(f"Audio: {os.path.basename(audio_path)}")
                    print(f"Prediction: {prediction[:100]}...")

    except Exception as e:
        print(f"\nError processing example {idx}: {str(e)}")

print(f"\nSuccessfully processed {len(results)} samples")

Selected 100 test examples


Testing samples:   1%|          | 1/100 [01:43<2:51:23, 103.87s/it]


Processed example 1
Audio: ID28_hc_0_0_0_clip1.wav
Prediction: Okay, let's analyze the provided acoustic measurements for Parkinson's disease (PD).  Here's a break...


Testing samples:   2%|▏         | 2/100 [02:59<2:22:04, 86.98s/it] 


Processed example 2
Audio: B1VLIATFOO55M300320171237_clip32.wav
Prediction: Okay, I will analyze the provided acoustic measurements for potential signs of Parkinson's disease. ...


Testing samples:   3%|▎         | 3/100 [04:12<2:10:44, 80.87s/it]


Processed example 3
Audio: ID10_hc_0_0_0_clip9.wav
Prediction: Okay, let's analyze the provided acoustic measurements for Parkinson's disease (PD).  Here's a break...


Testing samples:   4%|▍         | 4/100 [05:26<2:04:49, 78.02s/it]


Processed example 4
Audio: B1GMIAUSST39F100220171156_clip10.wav
Prediction: Okay, let's analyze the provided acoustic measurements for Parkinson's disease (PD).  Here's a break...


Testing samples:   5%|▌         | 5/100 [06:39<2:00:56, 76.39s/it]


Processed example 5
Audio: VA1LPUUITGI41M230320171111_clip2.wav
Prediction: Okay, let's analyze the provided acoustic measurements for Parkinson's disease (PD). Here's a breakd...


In [None]:
results[0]

In [None]:
print(f"\nSuccessfully processed {len(results)} samples")

# Save results
results_dir = '/content/results' if 'COLAB_GPU' in os.environ else './results'
os.makedirs(results_dir, exist_ok=True)
output_path = os.path.join(results_dir, 'audio_test_results.json')

with open(output_path, 'w') as f:
    json.dump(results, f, indent=2)

print(f"Results saved to: {output_path}")