In [None]:
import csv
import re

input_file = "E:\Spotify\loss_summary.csv"
output_file = "E:\Spotify\loss_summary_parsed.csv"

def parse_config(config):
    pattern = r"posenc_(.*?)_len_(\d+)_mask_([0-9.]+)"
    m = re.match(pattern, config)

    if not m:
        return None, None, None

    strategy, length, mask = m.groups()
    return strategy, int(length), float(mask)


with open(input_file, "r", newline="") as f_in, \
     open(output_file, "w", newline="") as f_out:

    reader = csv.DictReader(f_in)
    fieldnames = (
        ["strategy", "length", "mask_ratio"]
        + reader.fieldnames
    )
    writer = csv.DictWriter(f_out, fieldnames=fieldnames)
    writer.writeheader()

    for row in reader:
        strategy, length, mask_ratio = parse_config(row["config"])

        row_out = {
            "strategy": strategy,
            "length": length,
            "mask_ratio": mask_ratio
        }
        row_out.update(row)

        writer.writerow(row_out)

print("Finished! Saved:", output_file)

In [None]:
from models.PositionalEmbeddings import AttentionClamping
from utils.Config import Config
from data.data_utils import *
from torch.utils.data import DataLoader
from utils import misc
from models.Myna import Myna
from training.contrastive_training import train_contrastive

model_name = "Myna-CLS-TEST"
config = Config(
        save_path=f"trained_models\\{model_name}\\",
        num_epochs=16,
        learning_rate=3e-4,
        weight_decay=1e-4,
        num_workers=4,
        batch_size= 8,
        dtype=torch.float32
    )

chunk_size = 256

train_dataset = StreamViewDataset(f"D:\\SongsDataset\\melspec-mtg-jamendo\\train_set\\", chunk_size=chunk_size)
test_dataset  = StreamViewDataset(f"D:\\SongsDataset\\melspec-mtg-jamendo\\test_set\\", chunk_size=chunk_size)

prefetch_factor = 1
train_dataloader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    shuffle=True,
    # num_workers=config.num_workers,
    # prefetch_factor=prefetch_factor,
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=config.batch_size,
    shuffle=True,
    # num_workers=config.num_workers,
    # prefetch_factor=prefetch_factor,
)

model = Myna(
    image_size=(128, chunk_size),
    channels=1,
    patch_size=(16, 16),
    latent_space=128,
    d_model=384,
    depth=12,
    heads=6,
    mlp_dim=1536,
    mask_ratio=0.9,
    use_cls=True,
    predict_tempo="CNN",
    use_sinusoidal=True,
    use_y_emb=True,
    use_rope_x=True,
    use_rope_y=True,
    rope_base=512,
    use_alibi_x=True,
    use_alibi_y=True
    #clamping=AttentionClamping(method="cap", cap_type="tanh", cap_value=16, learnable=False)
)

print(f"{misc.model_size(model)} Parameters")
train_contrastive(model, test_dataloader, train_dataloader, config, start_epoch=0, views=2)

In [None]:
model_name = "Myna-CLS-Sinusoid-Stochastic-Length"
config = Config(
        save_path=f"trained_models\\{model_name}\\",
        num_epochs=16,
        learning_rate=3e-4,
        weight_decay=1e-4,
        num_workers=2,
        batch_size= 256,
        dtype=torch.float32
    )

chunk_size = 256

train_dataset = StreamViewDataset(f"D:\\SongsDataset\\melspec-mtg-jamendo\\train_set\\", chunk_size=chunk_size, views=2)
test_dataset  = StreamViewDataset(f"D:\\SongsDataset\\melspec-mtg-jamendo\\test_set\\", chunk_size=chunk_size, views=2)

prefetch_factor = 1
train_dataloader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=config.num_workers,
    prefetch_factor=prefetch_factor,
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=config.num_workers,
    prefetch_factor=prefetch_factor,
)

model = Myna(
    image_size=(128, chunk_size),
    channels=1,
    patch_size=(16, 16),
    latent_space=128,
    d_model=384,
    depth=12,
    heads=6,
    mlp_dim=1536,
    mask_ratio=0.9,
    use_cls=True,
    positional_encoding="sinusoidal",
    #rope_base=8192
    #clamping=AttentionClamping(method="cap", cap_type="tanh", cap_value=16, learnable=False)
)

print(f"{misc.model_size(model)} Parameters")
train_contrastive(model, test_dataloader, train_dataloader, config, start_epoch=0, views=2)