# 0) requirements

In [1]:
# Install required package (if not already installed)

!pip install miditok kagglehub mido pydub


Collecting miditok
  Downloading miditok-3.0.5.post1-py3-none-any.whl.metadata (10 kB)
Collecting mido
  Downloading mido-1.3.3-py3-none-any.whl.metadata (6.4 kB)
Collecting symusic>=0.5.0 (from miditok)
  Downloading symusic-0.5.7-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (8.7 kB)
Collecting pySmartDL (from symusic>=0.5.0->miditok)
  Downloading pySmartDL-1.3.4-py3-none-any.whl.metadata (2.8 kB)
Downloading miditok-3.0.5.post1-py3-none-any.whl (158 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m158.3/158.3 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading mido-1.3.3-py3-none-any.whl (54 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.6/54.6 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading symusic-0.5.7-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (2.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m31.1 MB/s[0m eta [36

https://www.kaggle.com/code/yashsrivastava51213/bart-pretraining-from-scratch

TRAINING FROM SCRATCH BART

## b) import dependencies

In [7]:
import os
import random
import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import mido
import matplotlib.pyplot as plt

from miditok import REMI, TokenizerConfig
from miditok.pytorch_data import DatasetMIDI
from pathlib import Path

device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')

In [3]:
from miditok import REMI, TokenizerConfig
tokenizer = REMI.from_pretrained("Richatte2000/tokenizer_midi_piano", use_auth_token="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")

tokenizer.json:   0%|          | 0.00/1.28M [00:00<?, ?B/s]

  super().__init__(tokenizer_config, params)


In [5]:
# The following code will only execute
# successfully when compression is complete
"""
import kagglehub

# Download latest version
path = kagglehub.dataset_download("pierrepauchet/midi-piano-chunks")

print("Path to dataset files:", path)
"""

'\nimport kagglehub\n\n# Download latest version\npath = kagglehub.dataset_download("pierrepauchet/midi-piano-chunks")\n\nprint("Path to dataset files:", path)\n'

In [4]:
from miditok.pytorch_data import DatasetMIDI, DataCollator
from torch.utils.data import DataLoader
from pathlib import Path 

train_dataset = DatasetMIDI(files_paths=list(Path("/kaggle/input/midi-piano-chunks/train").resolve().glob("**/*.mid")),
                            tokenizer=tokenizer,
                            max_seq_len=512,
                            bos_token_id=tokenizer.pad_token_id,
                            eos_token_id=tokenizer["BOS_None"],
)

print("Train dataset loaded")
val_dataset = DatasetMIDI(files_paths=list(Path("/kaggle/input/midi-piano-chunks/val").resolve().glob("**/*.mid")),
                            tokenizer=tokenizer,
                            max_seq_len=512,
                            bos_token_id=tokenizer.pad_token_id,
                            eos_token_id=tokenizer["BOS_None"],
)
print("Val dataset loaded")
test_dataset = DatasetMIDI(files_paths=list(Path("/kaggle/input/midi-piano-chunks/test").resolve().glob("**/*.mid")),
                            tokenizer=tokenizer,
                            max_seq_len=512,
                            bos_token_id=tokenizer.pad_token_id,
                            eos_token_id=tokenizer["BOS_None"]
)
print("Test dataset loaded")


done
done
done


In [8]:
# Récupération des tokens spéciaux
special_tokens = tokenizer.special_tokens
special_tokens_ids = tokenizer.special_tokens_ids
pad_token, bos_token, eos_token, mask_token = special_tokens
pad_token_id, bos_token_id, eos_token_id, mask_token_id = special_tokens_ids

In [50]:
import numpy as np
import torch
from miditok.pytorch_data import DataCollator

class DataCollatorForInfilling(DataCollator):
    """
    Data collator qui hérite du DataCollator de miditok et qui ajoute une corruption :
    pour chaque exemple, on masque UNE séquence contiguë de n tokens (n ~ Poisson(15)).
    La séquence est choisie aléatoirement parmi tous les tokens valides (excluant BOS et EOS).
    Les tokens masqués dans l'input sont remplacés par mask_token_id et dans les labels,
    ces positions conservent la valeur originale (les autres positions sont mises à -100).
    """
    def __init__(self, pad_token_id, mask_token_id, poisson_lambda=15, copy_inputs_as_labels=True, shift_labels=True):
        super().__init__(pad_token_id, copy_inputs_as_labels=copy_inputs_as_labels, shift_labels=shift_labels)
        # On stocke explicitement ces attributs dans l'objet
        self.pad_token_id = pad_token_id
        self.mask_token_id = mask_token_id
        self.poisson_lambda = poisson_lambda

    def __call__(self, batch):
        # On commence par appliquer le collator de base pour le padding et le shift des labels si demandé
        batch = super().__call__(batch)
        # Récupération des input_ids et création d'une copie pour les labels
        inputs = batch["input_ids"].clone()
        labels = inputs.clone()
        # On initialise toutes les positions à -100 (pour ignorer celles qui ne seront pas masquées)
        labels[:] = -100

        # Pour chaque exemple de la batch
        for i in range(inputs.size(0)):
            seq = inputs[i]
            # On récupère les positions valides (non-padding)
            valid_positions = (seq != self.pad_token_id).nonzero(as_tuple=False).view(-1)
            # Exclure le premier et le dernier token (souvent BOS/EOS)
            valid_positions = valid_positions[(valid_positions != 0) & (valid_positions != (seq.size(0) - 1))]
            if len(valid_positions) == 0:
                continue

            # Nombre de tokens à masquer selon une loi de Poisson
            n_mask = np.random.poisson(self.poisson_lambda)

            # Déterminer la longueur disponible dans la séquence continue de tokens valides
            # On suppose ici que les tokens valides sont contigus (ce qui est généralement le cas avant padding)
            available_length = valid_positions[-1].item() - valid_positions[0].item() + 1
            # On ne masque pas plus que ce qui est disponible
            span_length = min(n_mask, available_length)
            if span_length <= 0:
                continue

            # Choix aléatoire d'un indice de départ tel que le bloc contigu reste dans les positions valides
            start_idx = np.random.randint(valid_positions[0].item(), valid_positions[-1].item() - span_length + 2)
            # Masquage de la séquence contiguë
            for j in range(start_idx, start_idx + span_length):
                labels[i, j] = inputs[i, j]      # On garde la valeur originale dans les labels
                inputs[i, j] = self.mask_token_id # On remplace dans l'input par le token mask

        batch["input_ids"] = inputs
        batch["labels"] = labels
        return batch


In [51]:
collator = DataCollatorForInfilling(
    pad_token_id=tokenizer.pad_token_id,
    mask_token_id=mask_token_id,
    poisson_lambda=15,
    copy_inputs_as_labels=True,
    shift_labels=True
)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=16, collate_fn=collator)
val_loader = DataLoader(val_dataset, batch_size=16, collate_fn=collator)
test_loader = DataLoader(test_dataset, batch_size=16, collate_fn=collator)

In [58]:
sample = next(iter(train_loader))

print("Inpus ID : ",sample['input_ids'][0])
print("-------------------")

print("LABELS : ",sample['labels'][0])
print("-------------------")

print("Attention MASK : ", sample['attention_mask'][0])
print("-------------------")

L_input,L_label,L_attention = sample['input_ids'][0], sample['labels'][0], sample['attention_mask'][0]

#Boucle pour check
for i in range(0,len(L_input)):
    if L_input[i] == mask_token_id:
        print("input : ",L_input[i] ,"labels : ",L_label[i] ,"attention mask : ",L_attention[i] )

Inpus ID :  tensor([  897, 10269,  4146,  2613,  4468,   493,  2920, 16996,  3525,  3388,
         3299,   488,  1548, 15449, 19248, 13684,  2354,  6723,  2156,   615,
         1261, 14843, 18117, 15540,  2765,  2155,  2891,   487,  1169, 15655,
        13179,  1196, 15156,  2820,  3897,   548,  1242, 14015, 17381, 16161,
         5409,  3790,   526,  1067,  2289,   421,   463,  2020, 17565,  3516,
         3711,  2599,  4830,   506,  1396, 16762,   493,  1294,   463,  3057,
         4606,  4906,   586,  1080,  5078,  4468,  3033,   526,  2981,  3692,
         5769,   647,  7452,  7566,   582,  2070,   474,  1987,   474,  2077,
         2410,   560,  1949,   507,  2032,   507,  1891,   507,  2005,  2470,
        11909,  1986,   474,  1111,  4197,  2709,   599,  1998,   507,  1979,
          507,  1989,   507,  1161,   777,   455,   474,  1955,   474,  1836,
          507,  4952,   522,  1720,  6041,  7794,   474,  2210,   507,  2142,
          507,  2099,   507,  2052,   584,  2013,   

## bart training

In [59]:
#############################
# Définition du modèle BART de base (non pré-entraîné)
#############################
from transformers import BartConfig, BartForConditionalGeneration

config = BartConfig(
    vocab_size=tokenizer.vocab_size,
    max_position_embeddings=1024,
    encoder_layers=6,
    decoder_layers=6,
    encoder_attention_heads=8,
    decoder_attention_heads=8,
    d_model=512,
    bos_token_id=bos_token_id,
    eos_token_id=eos_token_id,
    pad_token_id=pad_token_id,
    mask_token_id=mask_token_id
)
model = BartForConditionalGeneration(config)
model = model.to(device)

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

In [60]:
from transformers import AdamW, get_linear_schedule_with_warmup

optimizer = AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
num_epochs = 10
total_steps = len(train_loader) * num_epochs
scheduler = get_linear_schedule_with_warmup(optimizer, 
                                            num_warmup_steps=int(0.1 * total_steps),
                                            num_training_steps=total_steps)




In [61]:
#############################
# Boucle d'entraînement avec évaluation sur la validation et push sur Hugging Face
#############################
for epoch in range(num_epochs):
    model.train()
    total_train_loss = 0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
    
    for batch in progress_bar:
        # Déplacement des données vers le device
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        
        optimizer.zero_grad()
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()
        
        # Clipping des gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        
        total_train_loss += loss.item()
        progress_bar.set_postfix({"loss": loss.item()})
    
    avg_train_loss = total_train_loss / len(train_loader)
    
    # Évaluation sur le jeu de validation
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            total_val_loss += outputs.loss.item()
    avg_val_loss = total_val_loss / len(val_loader)
    
    print(f"\nEpoch {epoch+1} terminé : Train Loss = {avg_train_loss:.4f} | Val Loss = {avg_val_loss:.4f}\n")
    
    # Push du modèle sur Hugging Face Hub
    # Remplacez "USERNAME/REPO_NAME" et "YOUR_TOKEN" par vos identifiants et token.
    model.push_to_hub("Richatte2000/", use_auth_token="aaaaaaa", commit_message=f"Epoch {epoch+1}")



Epoch 1/10:   4%|▍         | 183/4615 [05:14<2:06:53,  1.72s/it, loss=9.07]


KeyboardInterrupt: 