# Pocket Music Generator (For Kaggle)

# Проверка GPU

In [None]:
!nvidia-smi

# Установка окружения (один раз запускать)

In [None]:
!pip install --upgrade pip
!pip install einops
!pip install torch-summary
!pip install sklearn
!pip install tqdm
!pip install matplotlib
!pip install torch==2.0.0 torchvision==0.15.1
!pip install numpy<2.0 --quiet

#Установка библиотек

In [None]:
print('Loading modules...')

!git clone https://github.com/asigalov61/tegridy-tools

import sys
sys.path.append('/kaggle/working/tegridy-tools/tegridy-tools')
import TMIDIX

sys.path.append('/kaggle/working/tegridy-tools/tegridy-tools/X-Transformer')
from x_transformer import TransformerWrapper, Decoder, AutoregressiveWrapper

import os
import pickle
import random
import secrets
import tqdm

import torch
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

torch.set_float32_matmul_precision('high')

import matplotlib.pyplot as plt

from torchsummary import summary
from sklearn import metrics

import sys

sys.path.append('/kaggle/input/tegridy-tools-clean')
import TMIDIX

sys.path.append('/kaggle/input/tegridy-tools-clean/X-Transformer')
from x_transformer import TransformerWrapper, Decoder, AutoregressiveWrapper

if not os.path.exists('/content/INTS'):
    os.makedirs('/content/INTS')

print('Done')

# Загрузка данных для обучения

In [None]:
print('Loading training data... Please wait...')

# Путь к файлу внутри Kaggle
file_path = "/kaggle/input/weqrtwethbfvdcs/combined_file (1).pickle"

with open(file_path, "rb") as f:
    data = pickle.load(f)
    train_data = torch.Tensor(data)

print('Loaded file:', file_path)
print('train_data shape:', train_data.shape)
print('Done!')


# Установка модели

In [None]:
SEQ_LEN = 512
BATCH_SIZE = 64
NUM_EPOCHS = 5
GRADIENT_ACCUMULATE_EVERY = 1

NUM_BATCHES = (len(train_data) // SEQ_LEN // BATCH_SIZE // GRADIENT_ACCUMULATE_EVERY) * NUM_EPOCHS

LEARNING_RATE = 2e-4

VALIDATE_EVERY  = 1000
SAVE_EVERY = 1000
GENERATE_EVERY  = 1000
PRINT_STATS_EVERY = 300

GENERATE_LENGTH = 32

def cycle(loader):
    while True:
        for data in loader:
            yield data


model = TransformerWrapper(
    num_tokens = 3088,
    max_seq_len = SEQ_LEN,
    attn_layers = Decoder(dim = 512, depth = 12, heads = 10, use_flash_attn=True))


model = AutoregressiveWrapper(model)

model = torch.nn.DataParallel(model)

model.cuda()

print('Done!')

summary(model)

class MusicDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):

        idx = secrets.randbelow(self.data.size(0) - self.seq_len - 1)
        full_seq = self.data[idx: idx + self.seq_len + 1].long()

        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0)

train_dataset = MusicDataset(train_data, SEQ_LEN)
val_dataset   = MusicDataset(train_data, SEQ_LEN)
train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader    = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))

optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# Обучение

In [None]:
# Папка для сохранения модели
checkpoint_dir = "/kaggle/working/model_checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

train_losses = []
val_losses = []

train_accs = []
val_accs = []

for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='Training'):
    model.train()

    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        loss, acc = model(next(train_loader))
        loss.backward(torch.ones(loss.shape).cuda())

    if i % PRINT_STATS_EVERY == 0:
        print(f'Training loss: {loss.mean().item()}')
        print(f'Training acc: {acc.mean().item()}')

    train_losses.append(loss.mean().item())
    train_accs.append(acc.mean().item())

    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            val_loss, val_acc = model(next(val_loader))

            print(f'Validation loss: {val_loss.mean().item()}')
            print(f'Validation acc: {val_acc.mean().item()}')

            val_losses.append(val_loss.mean().item())
            val_accs.append(val_acc.mean().item())

            print('Plotting training loss graph...')
            plt.plot(train_losses, 'b')
            plt.title("Training Loss")
            plt.show(); plt.close()

            print('Plotting training acc graph...')
            plt.plot(train_accs, 'b')
            plt.title("Training Accuracy")
            plt.show(); plt.close()

            print('Plotting validation loss graph...')
            plt.plot(val_losses, 'b')
            plt.title("Validation Loss")
            plt.show(); plt.close()

            print('Plotting validation acc graph...')
            plt.plot(val_accs, 'b')
            plt.title("Validation Accuracy")
            plt.show(); plt.close()

    if i % GENERATE_EVERY == 0:
        model.eval()
#        inp = random.choice(val_dataset)[:-1]

#        print(inp)

        # DataParallel поддержка
#        sample = model.module.generate(inp[None, ...], GENERATE_LENGTH) if hasattr(model, 'module') else model.generate(inp[None, ...], GENERATE_LENGTH)
#        print(sample)

    if i % SAVE_EVERY == 0:
        print('Saving model progress. Please wait...')

        checkpoint_name = f'model_checkpoint_{i}_steps_{round(float(train_losses[-1]), 4)}_loss_{round(float(train_accs[-1]), 4)}_acc.pth'
        checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name)

        # Сохраняем state_dict
        torch.save(model.module.state_dict() if hasattr(model, 'module') else model.state_dict(), checkpoint_path)


# Final Save

In [None]:
fname = '/kaggle/working/model_checkpoints/final_checkpoint'

torch.save(model.module.state_dict() if hasattr(model, 'module') else model.state_dict(), fname)

# Save training loss graph
from matplotlib import pyplot as plt


plt.plot([i for i in range(len(train_losses))] ,train_losses, 'b')
# plt.savefig('/content/training_loss_graph.png')
plt.close()
print('Done!')

# Save training acc graph

plt.plot([i for i in range(len(train_accs))] ,train_accs, 'b')
# plt.savefig('/content/training_acc_graph.png')
plt.close()
print('Done!')

# Save validation loss graph

plt.plot([i for i in range(len(val_losses))] ,val_losses, 'b')
# plt.savefig('/content/validation_loss_graph.png')
plt.close()
print('Done!')

# Save validation acc graph

plt.plot([i for i in range(len(val_accs))] ,val_accs, 'b')
# plt.savefig('/content/validation_acc_graph.png')
plt.close()
print('Done!')

data = [train_losses, train_accs, val_losses, val_accs]

TMIDIX.Tegridy_Any_Pickle_File_Writer(data, '/content/losses_accuracies')


# Данные, которые нужно сохранить
data = [train_losses, train_accs, val_losses, val_accs]

# Путь сохранения в Kagglе
with open('/kaggle/working/losses_accuracies.pkl', 'wb') as f:
    pickle.dump(data, f)
