In [None]:
from models.PositionalEmbeddings import AttentionClamping
from utils.Config import Config
from data.data_utils import *
from torch.utils.data import DataLoader
from utils import misc
from models.Myna import Myna
from training.contrastive_training import train_contrastive

model_name = "Myna-CLS-TEST"
config = Config(
        save_path=f"trained_models\\{model_name}\\",
        num_epochs=16,
        learning_rate=3e-4,
        weight_decay=1e-4,
        num_workers=4,
        batch_size= 8,
        dtype=torch.float32
    )

chunk_size = 256

train_dataset = StreamViewDataset(f"D:\\SongsDataset\\melspec-mtg-jamendo\\train_set\\", chunk_size=chunk_size)
test_dataset  = StreamViewDataset(f"D:\\SongsDataset\\melspec-mtg-jamendo\\test_set\\", chunk_size=chunk_size)

prefetch_factor = 1
train_dataloader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    shuffle=True,
    # num_workers=config.num_workers,
    # prefetch_factor=prefetch_factor,
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=config.batch_size,
    shuffle=True,
    # num_workers=config.num_workers,
    # prefetch_factor=prefetch_factor,
)

model = Myna(
    image_size=(128, chunk_size),
    channels=1,
    patch_size=(16, 16),
    latent_space=128,
    d_model=384,
    depth=12,
    heads=6,
    mlp_dim=1536,
    mask_ratio=0.9,
    use_cls=True,
    predict_tempo="CNN",
    use_sinusoidal=True,
    use_y_emb=True,
    use_rope_x=True,
    use_rope_y=True,
    rope_base=512,
    use_alibi_x=True,
    use_alibi_y=True
    #clamping=AttentionClamping(method="cap", cap_type="tanh", cap_value=16, learnable=False)
)

print(f"{misc.model_size(model)} Parameters")
train_contrastive(model, test_dataloader, train_dataloader, config, start_epoch=0, views=2)

23277964 Parameters


  0%|          | 0/1242 [00:00<?, ?it/s]

In [None]:
model_name = "Myna-CLS-Sinusoid-Stochastic-Length"
config = Config(
        save_path=f"trained_models\\{model_name}\\",
        num_epochs=16,
        learning_rate=3e-4,
        weight_decay=1e-4,
        num_workers=2,
        batch_size= 256,
        dtype=torch.float32
    )

chunk_size = 256

train_dataset = StreamViewDataset(f"D:\\SongsDataset\\melspec-mtg-jamendo\\train_set\\", chunk_size=chunk_size, views=2)
test_dataset  = StreamViewDataset(f"D:\\SongsDataset\\melspec-mtg-jamendo\\test_set\\", chunk_size=chunk_size, views=2)

prefetch_factor = 1
train_dataloader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=config.num_workers,
    prefetch_factor=prefetch_factor,
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=config.num_workers,
    prefetch_factor=prefetch_factor,
)

model = Myna(
    image_size=(128, chunk_size),
    channels=1,
    patch_size=(16, 16),
    latent_space=128,
    d_model=384,
    depth=12,
    heads=6,
    mlp_dim=1536,
    mask_ratio=0.9,
    use_cls=True,
    positional_encoding="sinusoidal",
    #rope_base=8192
    #clamping=AttentionClamping(method="cap", cap_type="tanh", cap_value=16, learnable=False)
)

print(f"{misc.model_size(model)} Parameters")
train_contrastive(model, test_dataloader, train_dataloader, config, start_epoch=0, views=2)

In [None]:
model = Myna(
    image_size=(128, chunk_size),
    channels=1,
    patch_size=(16, 16),
    latent_space=128,
    d_model=384,
    depth=12,
    heads=6,
    mlp_dim=1536,
    mask_ratio=0.9,
    use_cls=True,
    positional_encoding="1D-ALIBI",
    #rope_base=8192
    #clamping=AttentionClamping(method="cap", cap_type="tanh", cap_value=16, learnable=False)
)

print(f"{misc.model_size(model)} Parameters")
train_contrastive(model, test_dataloader, train_dataloader, config, start_epoch=0, views=2)

21428608 Parameters


  0%|          | 0/39 [00:10<?, ?it/s]

  0%|          | 0/5 [00:35<?, ?it/s]

[Epoch 0] Train: Same Song Contrastive Loss = 5.4035
Test: Same Song Contrastive Loss = 5.1766



  0%|          | 0/39 [00:48<?, ?it/s]

  0%|          | 0/5 [00:35<?, ?it/s]

[Epoch 1] Train: Same Song Contrastive Loss = 5.2613
Test: Same Song Contrastive Loss = 5.0854



  0%|          | 0/39 [00:35<?, ?it/s]

  0%|          | 0/5 [00:35<?, ?it/s]

[Epoch 2] Train: Same Song Contrastive Loss = 5.1631
Test: Same Song Contrastive Loss = 4.9340



  0%|          | 0/39 [00:37<?, ?it/s]

  0%|          | 0/5 [00:33<?, ?it/s]

[Epoch 3] Train: Same Song Contrastive Loss = 5.0303
Test: Same Song Contrastive Loss = 4.8351



  0%|          | 0/39 [00:35<?, ?it/s]

  0%|          | 0/5 [00:34<?, ?it/s]

[Epoch 4] Train: Same Song Contrastive Loss = 4.9564
Test: Same Song Contrastive Loss = 4.8074



  0%|          | 0/39 [00:34<?, ?it/s]

  0%|          | 0/5 [00:34<?, ?it/s]

[Epoch 5] Train: Same Song Contrastive Loss = 4.9150
Test: Same Song Contrastive Loss = 4.8076



  0%|          | 0/39 [00:35<?, ?it/s]

In [None]:
model = Myna(
    image_size=(128, chunk_size),
    channels=1,
    patch_size=(16, 16),
    latent_space=128,
    d_model=384,
    depth=12,
    heads=6,
    mlp_dim=1536,
    mask_ratio=0.9,
    use_cls=True,
    positional_encoding="1D-ALIBI",
    #rope_base=8192
    #clamping=AttentionClamping(method="cap", cap_type="tanh", cap_value=16, learnable=False)
)

print(f"{misc.model_size(model)} Parameters")
train_contrastive(model, test_dataloader, train_dataloader, config, start_epoch=0, views=2)

In [None]:
model = Myna(
    image_size=(128, chunk_size),
    channels=1,
    patch_size=(16, 16),
    latent_space=128,
    d_model=384,
    depth=12,
    heads=6,
    mlp_dim=1536,
    mask_ratio=0.9,
    use_cls=True,
    positional_encoding="2D-ALIBI",
    rope_base=8192
    #clamping=AttentionClamping(method="cap", cap_type="tanh", cap_value=16, learnable=False)
)

print(f"{misc.model_size(model)} Parameters")
train_contrastive(model, test_dataloader, train_dataloader, config, start_epoch=0, views=2)

In [None]:
model = Myna(
    image_size=(128, chunk_size),
    channels=1,
    patch_size=(16, 16),
    latent_space=128,
    d_model=384,
    depth=12,
    heads=6,
    mlp_dim=1536,
    mask_ratio=0.9,
    use_cls=True,
    positional_encoding="2D-ALIBI",
    #rope_base=8192
    #clamping=AttentionClamping(method="cap", cap_type="tanh", cap_value=16, learnable=False)
)

print(f"{misc.model_size(model)} Parameters")
train_contrastive(model, test_dataloader, train_dataloader, config, start_epoch=0, views=2)