In [5]:
import torch
import torch.nn as nn
import random
import json
from datasets import load_dataset
import os

# === Parameters ===
SEQ_LENGTH = 75
EMBED_DIM = 64
HIDDEN_DIM = 128
NUM_LAYERS = 1
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
GENERATE_LENGTH = 800
TEMPERATURE = 1.0  # Set between 0.7 and 1.2 for different output styles

# === Load or rebuild vocab ===
vocab_path = "vocab.json"
if os.path.exists(vocab_path):
    print("Loading saved vocabulary...")
    with open(vocab_path) as f:
        vocab_data = json.load(f)
    char2idx = {k: int(v) for k, v in vocab_data['char2idx'].items()}
    idx2char = {int(k): v for k, v in vocab_data['idx2char'].items()}
    VocabSize = len(char2idx)
else:
    print("Building vocabulary from dataset...")
    dataset = load_dataset('sander-wood/irishman', split='train')
    texts = dataset['abc notation']
    vocab = sorted(set(''.join(texts)))
    char2idx = {ch: i for i, ch in enumerate(vocab)}
    idx2char = {i: ch for ch, i in char2idx.items()}
    VocabSize = len(vocab)

    with open(vocab_path, "w") as f:
        json.dump({
            "char2idx": char2idx,
            "idx2char": {str(k): v for k, v in idx2char.items()}
        }, f)
    print("Vocabulary saved to vocab.json.")

# === Define model ===
class SimpleRNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.embed = nn.Embedding(VocabSize, EMBED_DIM)
        self.lstm = nn.LSTM(EMBED_DIM, HIDDEN_DIM, NUM_LAYERS, batch_first=True)
        self.fc = nn.Linear(HIDDEN_DIM, VocabSize)

    def forward(self, x, hidden=None):
        x = self.embed(x)
        out, hidden = self.lstm(x, hidden)
        out = out.reshape(-1, HIDDEN_DIM)
        logits = self.fc(out)
        return logits, hidden

# === Load model ===
model = SimpleRNN().to(DEVICE)

if not os.path.exists("best_model.pth"):
    raise FileNotFoundError("Checkpoint 'best_model.pth' not found!")

checkpoint = torch.load("best_model.pth", map_location=DEVICE)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print("Model loaded successfully.")

# === Generate ABC header ===
M_options = ["M:6/8", "M:4/4", "M:3/4"]
L_options = ["L:1/8", "L:1/16", "L:1/4"]
K_options = ["K:D", "K:G", "K:C"]

start = f"""X:1
T:Generated Tune
{random.choice(M_options)}
{random.choice(L_options)}
{random.choice(K_options)}
"""

seq = [char2idx.get(c, 0) for c in start]
hidden = None

# === Generate sequence ===
print("Generating ABC tune...")
for i in range(GENERATE_LENGTH):
    inp_seq = seq[-SEQ_LENGTH:]
    inp = torch.tensor(inp_seq, dtype=torch.long).unsqueeze(0).to(DEVICE)
    logits, hidden = model(inp, hidden)

    # Reset hidden periodically to prevent drift
    if i % 100 == 0:
        hidden = None

    probs = torch.softmax(logits[-1] / TEMPERATURE, dim=0)
    idx = torch.multinomial(probs, 1).item()
    seq.append(idx)

# === Decode and save ===
generated = ''.join(idx2char[i] for i in seq)
print("\n=== Generated ABC Notation ===\n")
print(generated)

with open("generated_tune.abc", "w") as f:
    f.write(generated)



Loading saved vocabulary...
Model loaded successfully.
Generating ABC tune...


  checkpoint = torch.load("best_model.pth", map_location=DEVICE)



=== Generated ABC Notation ===

X:1
T:Generated Tune
M:4/4
L:1/8
K:D
 fAG FED BAB | ded dAF DEF | G3 AFA F2 d | AFD D2 C D2 :| Bd | e3 efe e3 cde | 
"^13" f2 f fdf b3 f2 f | a2 f afd d^cd fed | e2 B dfe dcB ABd | f2 e fef A3 faf | B2 B AGF E3 e2 :|2 
 aba afb a2 f aba | d'2 b f3 a6 |: g2 f ede b3 b2 b | f2 e f2 a f2 e d2 A | d2 A d3 f2 e d2 f a2 g f2 e e3 e2 f | f2 c f2 a a2 f f2 a g3 f3 f2 e e2 f e2 f e2 d c2 d e3 c3 A2 e c d c B3 c2 G E2 F F4 E3 C a e g8 c3 z2 d d2 e c A6 z2 A | 
 e2 d e2 a c2 a e8 d e4 d e4 | g f3 B4 A4 E4 D4 E8 | D6 F2 | E3 E TF3 E G3 E A6 (E2 G2) F4 ^G4 E4 D2 D4 || B>B | A3 d d2 d6 | d4 dc/B/ c3 A A2 D4 E2 E6 E2 || e3 ^f !>!a3 A .^G3 .A .B4 G4 .A6 B2 | (A6 A2) A6 A6 z2 | 
 d2 B4 (A3 B) B8 B4 A3 G F6 (FG) A6 z4 (b4 B4) | a2 g4 ^f6 g4 (e6 e2) e6 e6 e6 e2 z4 z2 z4 z2 z2 | z6 z6 d6 _e6 e4 f2 z4 z6 z2 fe d6 e4 e6 b4 f2 b2 e4 fe ^d2 e6 f4 e8
