In [None]:
# !pip install ipywidgets # huggin face widgets
# !pip install --upgrade timm # timm error gpu gemma 3n
# !pip install torchcodec
# !pip install librosa soundfile

## audio errors
# !sudo apt update
# !sudo apt install -y ffmpeg

In [None]:
# ##############################
# #Memory cleaning

# import torch
# import gc

# torch.cuda.empty_cache()
# gc.collect() # python garbage collector
# ##############################

In [None]:
from huggingface_hub import login

login()

In [None]:
from datasets import load_dataset
cypriot_audio = load_dataset("Elormiden/MilaMou_Cypriot_Dataset")

In [None]:
cypriot_audio

In [None]:
import librosa
import numpy as np

def fix_audio_dataset(dataset):
    def process_audio(example):
        try:
            audio_decoder = example['audio']
            if hasattr(audio_decoder, 'path'):
                audio_array, sr = librosa.load(audio_decoder.path, sr=16000)
            else:
                audio_array = np.zeros(16000)
                sr = 16000
            example['audio'] = {
                'array': audio_array.astype(np.float32),
                'sampling_rate': sr
            }
        except Exception as e:
            print(f"Ошибка: {e}")
            example['audio'] = {
                'array': np.zeros(16000, dtype=np.float32),
                'sampling_rate': 16000
            }
        return example
    fixed = {}
    for split_name, split_data in dataset.items():
        print(f"Обрабатываю {split_name}...")
        fixed[split_name] = split_data.map(process_audio, num_proc=1)

    return datasets.DatasetDict(fixed)

fixed_dataset = fix_audio_dataset(cypriot_audio)
sample = fixed_dataset['train'][0]
print(f"Audio type: {type(sample['audio'])}")
print(f"Audio shape: {sample['audio']['array'].shape}")

In [None]:
from transformers import (
    AutoTokenizer,
    AutoProcessor,
    AutoModelForCausalLM,
    Trainer,
    Gemma3nForConditionalGeneration,
    Gemma3nProcessor,
    TrainingArguments,
    AutoModelForImageTextToText,
    DataCollatorForLanguageModeling,
    BitsAndBytesConfig,
    )

from sklearn.model_selection import train_test_split
import torch
import safetensors.torch
from datasets import Dataset
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model

import requests
from PIL import Image
import librosa
from io import BytesIO
import logging
from typing import Union, Tuple
from dataclasses import dataclass
import os

# Setup logging
logging.basicConfig(level=logging.INFO)
warnings.filterwarnings('ignore')
logging.getLogger("pyngrok").setLevel(logging.ERROR)
logging.getLogger("transformers").setLevel(logging.ERROR)
logging.getLogger("torch").setLevel(logging.ERROR)
logger = logging.getLogger(__name__)

import warnings
warnings.filterwarnings('ignore')

print("Libraries imported successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

In [None]:
import subprocess
import mlflow
from pyngrok import ngrok
import sys

wandb_token = ""
ngrok_token = ""

ngrok.set_auth_token(ngrok_token)
port = "5000"

mlflow_proc = subprocess.Popen([
    sys.executable, "-m", "mlflow", "ui", "--port", port
])

mlflow.autolog()

public_url = ngrok.connect(port)
print(f"MLflow UI: {public_url}")

In [None]:
# Configuration class

GEMMA_PATH = "google/gemma-3n-e2b-it"

@dataclass
class Config:
    # Gemma3n model configuration
    MODEL_NAME: str = GEMMA_PATH

    LOAD_4_BIT: bool = True

    # Generation parameters
    MAX_NEW_TOKENS: int = 512

    # Device configuration
    TORCH_DTYPE: str = torch.bfloat16
    DEVICE_MAP: str = "cuda:0" if torch.cuda.is_available() else "cpu"

    # Image preprocessing
    IMAGE_SIZE: int = 512

    # Hugging Face token (if needed)
    HF_TOKEN: str = ""

In [None]:
config = Config()
print(f"Model: {config.MODEL_NAME}")
print(f"Device: {config.DEVICE_MAP}")
print(f"4_BIT: {config.LOAD_4_BIT}")
print(f"Data type: {config.TORCH_DTYPE}")

In [None]:
model = Gemma3nForConditionalGeneration.from_pretrained(
    config.MODEL_NAME,
    torch_dtype=config.TORCH_DTYPE, # bfloat16 does not work, float16 does not work either, only float32
    device_map=config.DEVICE_MAP,
)
processor = Gemma3nProcessor.from_pretrained(config.MODEL_NAME)

In [None]:
class LoraTrainerPipeline:
    def __init__(self, model, processor, dataset, output_dir="./gemma-3n-qlora-fine-tuned-steps"):
        self.base_model = model
        self.processor = processor
        self.dataset = dataset
        self.output_dir = output_dir

        self.lora_model = None
        self.trainer = None

        self.train_dataset = dataset['train']
        self.val_dataset = dataset['validation']

    def create_audio_collator(self):
        def audio_data_collator(features):
            audio_arrays = []
            texts = []

            for feature in features:
                audio_data = feature['audio']
                if isinstance(audio_data, dict) and 'array' in audio_data:
                    audio_array = audio_data['array']
                    if audio_data.get('sampling_rate', 16000) != 16000:
                        import librosa
                        audio_array = librosa.resample(
                            audio_array,
                            orig_sr=audio_data['sampling_rate'],
                            target_sr=16000
                        )
                else:
                    audio_array = audio_data

                audio_arrays.append(audio_array)
                messages = [
                    {
                        "role": "user",
                        "content": [
                            {"type": "audio", "audio": audio_array},
                            {"type": "text", "text": "Please transcribe this audio."}
                        ]
                    },
                    {
                        "role": "assistant",
                        "content": [
                            {"type": "text", "text": feature["sentence"]}
                        ]
                    }
                ]
                text = self.processor.apply_chat_template(
                    messages,
                    tokenize=False,
                    add_generation_prompt=False
                )
                texts.append(text)
            batch = self.processor(
                text=texts,
                audio=audio_arrays,
                return_tensors="pt",
                padding=True,
                sampling_rate=16000
            )
            labels = batch["input_ids"].clone()
            labels[labels == self.processor.tokenizer.pad_token_id] = -100

            if hasattr(self.processor.tokenizer, 'boa_token_id'):
                labels[labels == self.processor.tokenizer.boa_token_id] = -100
            if hasattr(self.processor.tokenizer, 'eoa_token_id'):
                labels[labels == self.processor.tokenizer.eoa_token_id] = -100
            if hasattr(self.processor.tokenizer, 'audio_token_id'):
                labels[labels == self.processor.tokenizer.audio_token_id] = -100

            if hasattr(self.processor.tokenizer, 'boi_token_id'):
                labels[labels == self.processor.tokenizer.boi_token_id] = -100
            if hasattr(self.processor.tokenizer, 'eoi_token_id'):
                labels[labels == self.processor.tokenizer.eoi_token_id] = -100
            if hasattr(self.processor.tokenizer, 'image_token_id'):
                labels[labels == self.processor.tokenizer.image_token_id] = -100

            batch["labels"] = labels

            print(f"Batch keys: {batch.keys()}")
            print(f"input_ids shape: {batch['input_ids'].shape}")
            if 'input_features' in batch:
                print(f"input_features shape: {batch['input_features'].shape}")
            if 'input_features_mask' in batch:
                print(f"input_features_mask shape: {batch['input_features_mask'].shape}")

            return batch

        return audio_data_collator

    def lora_training(self):
        model = prepare_model_for_kbit_training(self.base_model)
        lora_config = LoraConfig(
            r=16,
            lora_alpha=32,
            target_modules=[
                # Text/Language модули (основные)
                "q_proj", "v_proj", "k_proj", "o_proj",
                "gate_proj", "up_proj", "down_proj",
                # Можно добавить модули для аудио энкодера если нужно
                # "audio_encoder.layers.*.attention.self.query",
                # "audio_encoder.layers.*.attention.self.key",
                # "audio_encoder.layers.*.attention.self.value",
            ],
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM"
        )

        self.lora_model = get_peft_model(model, lora_config)
        self.lora_model.print_trainable_parameters()

    def model_train(self, max_steps=1000, batch_size=4, learning_rate=1e-5):
        self.lora_training()
        data_collator = self.create_audio_collator()

        training_args = TrainingArguments(
            output_dir=self.output_dir,
            overwrite_output_dir=True,
            max_steps=max_steps,
            per_device_train_batch_size=batch_size,
            gradient_accumulation_steps=8,
            save_steps=100,
            save_total_limit=2,
            remove_unused_columns=False,
            dataloader_pin_memory=False,
            prediction_loss_only=True,
            #####################
            fp16=True,
            gradient_checkpointing=True,
            #####################
            learning_rate=learning_rate,
            warmup_steps=50,
            logging_steps=10,
            eval_strategy="steps",
            eval_steps=50,
            dataloader_num_workers=0,
        )

        self.trainer = Trainer(
            model=self.lora_model,
            args=training_args,
            train_dataset=self.train_dataset,
            eval_dataset=self.val_dataset,
            tokenizer=self.processor.tokenizer,
            data_collator=data_collator,
        )

        class LoggingCallback(TrainerCallback):
            def on_log(self, args, state, control, model=None, logs=None, **kwargs):
                if logs:
                    print(f"Step {state.global_step}: {logs}")

        self.trainer.add_callback(LoggingCallback())

        try:
            self.trainer.train()
        except Exception as e:
            print(f"Training error: {e}")
            self.trainer.save_model(f"{self.output_dir}/emergency_checkpoint")
            raise

    def merge_and_unload(self, checkpoint_path=None):
        print("Merging LoRA and unloading PEFT weights...")
        if checkpoint_path:
            self.lora_model = PeftModel.from_pretrained(self.lora_model, checkpoint_path)
        merged_model = self.lora_model.merge_and_unload()
        return merged_model

def debug_dataset_sample(dataset, processor, index=0):
    sample = dataset[index]
    print(f"Sample keys: {sample.keys()}")
    print(f"Audio type: {type(sample['audio'])}")
    if isinstance(sample['audio'], dict):
        print(f"Audio keys: {sample['audio'].keys()}")
        print(f"Audio array shape: {sample['audio']['array'].shape}")
        print(f"Sampling rate: {sample['audio']['sampling_rate']}")
    print(f"Sentence: {sample['sentence']}")
    try:
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "audio", "audio": sample['audio']['array']},
                    {"type": "text", "text": "Please transcribe this audio."}
                ]
            },
            {
                "role": "assistant",
                "content": [
                    {"type": "text", "text": sample["sentence"]}
                ]
            }
        ]

        text = processor.apply_chat_template(messages, tokenize=False)
        print(f"Generated text template:\n{text}")

        batch = processor(
            text=[text],
            audio=[sample['audio']['array']],
            return_tensors="pt",
            padding=True,
            sampling_rate=16000
        )
        print(f"Processed batch keys: {batch.keys()}")
        for key, value in batch.items():
            if hasattr(value, 'shape'):
                print(f"{key} shape: {value.shape}")

    except Exception as e:
        print(f"Processing error: {e}")
        import traceback
        traceback.print_exc()

In [None]:
trainer_pipeline = LoraTrainerPipeline(
    model=model,
    processor=processor,
    dataset=cypriot_audio
)

max_steps = 100
trainer_pipeline.model_train(max_steps=max_steps)