In [2]:
import torch
import torch.nn as nn
from datasets import load_dataset

# === Parameters ===
SEQ_LENGTH = 75
EMBED_DIM = 64
HIDDEN_DIM = 128
NUM_LAYERS = 1
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
MAX_CONTINUE_LEN = 1000  # Maximum characters to generate
TEMPERATURE = 0.8  # <1 for less randomness, >1 for more randomness

# === Load dataset for vocab ===
print("Loading dataset and vocab...")
dataset = load_dataset('sander-wood/irishman', split='train')
texts = dataset['abc notation']
vocab = sorted(set(''.join(texts)))
char2idx = {ch: i for i, ch in enumerate(vocab)}
idx2char = {i: ch for ch, i in char2idx.items()}
VocabSize = len(vocab)

# === Define model ===
class SimpleRNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.embed = nn.Embedding(VocabSize, EMBED_DIM)
        self.lstm = nn.LSTM(EMBED_DIM, HIDDEN_DIM, NUM_LAYERS, batch_first=True)
        self.fc = nn.Linear(HIDDEN_DIM, VocabSize)

    def forward(self, x, hidden=None):
        x = self.embed(x)
        out, hidden = self.lstm(x, hidden)
        out = out.reshape(-1, HIDDEN_DIM)
        logits = self.fc(out)
        return logits, hidden

# === Load model ===
model = SimpleRNN().to(DEVICE)
checkpoint = torch.load("best_model.pth", map_location=DEVICE)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print("Model loaded successfully!")

# === Initial tunes to test ===
initial_tunes = [
    """X:1\nT:The Blarney Pilgrim\nM:6/8\nL:1/8\nK:D\n|:DFA dfa|dfd ABc|""",
    """X:2\nT:The Kesh Jig\nM:6/8\nL:1/8\nK:G\n|:GFG AGA|BAB d2B|""",
    """X:3\nT:Morrison's Jig\nM:6/8\nL:1/8\nK:Em\n|:EFE B2B|EFE B2B|""",
    """X:4\nT:Cliffs of Moher\nM:6/8\nL:1/8\nK:D\n|:D2F AFD|G2B dBG|""",
    """X:5\nT:Connaughtman's Rambles\nM:6/8\nL:1/8\nK:D\n|:FGA B2d|gfg afd|"""
]

# === Generate continuations ===
for i, initial_tune in enumerate(initial_tunes):
    print(f"\n=== Generating Tune {i + 1}: {initial_tune.splitlines()[1][2:]} ===\n")
    seq = [char2idx.get(c, 0) for c in initial_tune]
    hidden = None

    for _ in range(MAX_CONTINUE_LEN):
        inp_seq = seq[-SEQ_LENGTH:]
        inp = torch.tensor(inp_seq).unsqueeze(0).to(DEVICE)
        logits, hidden = model(inp, hidden)
        
        # Apply temperature
        logits = logits[-1] / TEMPERATURE
        probs = torch.softmax(logits, dim=0)
        idx = torch.multinomial(probs, 1).item()
        seq.append(idx)
        
        # Stop if musical ending detected
        if len(seq) > len(initial_tune) + 3:
            last_chars = ''.join(idx2char[i] for i in seq[-3:])
            if '|]' in last_chars:
                break

    # === Decode and print ===
    continued = ''.join(idx2char[i] for i in seq)
    print(continued)


Loading dataset and vocab...


  checkpoint = torch.load("best_model.pth", map_location=DEVICE)


Model loaded successfully!

=== Generating Tune 1: The Blarney Pilgrim ===

X:1
T:The Blarney Pilgrim
M:6/8
L:1/8
K:D
|:DFA dfa|dfd ABc|| d/f/ | gece | agbg | fd D2 | B2 BB | dc BA | eA G2 | 
 Afaf | df f2 :: cfdf | cece | dgfg | 
 Bdfg | acac | 
 B2 A2 | 
 gf e2 | fedc | Bc _B2 | ABgf | 
 
 efed | 
 
"G7" ed f2 :| fa |: 
"F" abag | abag |"C" ee/g/ge' |"G7" fg/a/fd' | c'b/d'3 | b>d d'3 c'/b/ | ad'/d'"_g" d'c' | 
"G7" !>!b3 a/g/ |"C" .g3/2 .G/.G/.E/.G/ | .c .c"_h" b |"F7" a3 |"C7" (e !>!c2)) | 
"F" (c d e) |"C7" (e c3/2 d/) |"F" c3 |"C7" (c2 B) | (e f g) |"F" f2 f | e f g |"F" f3- | 
 (c2 A) |"C7" (B3 | B3- | B2 c) |"F" c2 c | c A c |"C" e3 |"G7" g3 | (g2 f) | 
"C" (e2 c)) | d2 c | (g2 a |"C" g3 | g3) |"G7" (g2 b) | (g2 e) |"C" (f2 e) |"C" (g2 e/g/) | 
"G7" (g2 B) | 
"C" (c2 d) |"F" (c3 | c) z A | (c2 c) |"C" (c2 A) | 
"G7" (B2 c) | (d2 c) | (B3 | B2) A |"C" (c2 d) |"F" (e2 d) | c3 | (c2 c) |"F" (c2 A) | 
 (c2 A) |"C" (c3 | c3) |"C" c2 c |"C7" (c2 B) |"F" A3 |"C" G3 | z C2 | (CE) z | 
"