In [1]:
import csv
import json
import os
from shutil import copyfile

import torch
import torch.nn as nn
import wandb
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader

from dataset.e_piano import create_epiano_datasets
from model.music_transformer import MusicTransformer
from utilities.constants import *
from utilities.device import get_device, use_cuda
from utilities.lr_scheduling import LrStepTracker, get_lr
from utilities.run_model import eval_model, parse_json, train_epoch

# Baseline is an untrained epoch that we evaluate as a baseline loss and accuracy
BASELINE_EPOCH = -1

import pandas as pd
import matplotlib.pyplot as plt

inspecting how train.py works

In [5]:
args = json.load(open('parameters.json', 'r'))
dataset_params = json.load(open('dataset/dataset_parameters.json', 'r'))
model_params = json.load(open('model/model_params.json', 'r'))

In [7]:
# dataset.dim:    (len_dataset, n_feats, context_size)
# dataloader.dim: (len_dataset // batch_size, n_feats, batch_size, context_size)
# context_size := max_seq

train_dataset, val_dataset, test_dataset = create_epiano_datasets(
    max_seq=model_params['max_sequence'],
    **dataset_params,
)

val_loader = DataLoader(  # noqa: F841
    val_dataset,
    batch_size=args['batch_size'],
    num_workers=args['n_workers'],
)

100%|██████████| 17123/17123 [15:42<00:00, 18.16it/s]
100%|██████████| 1955/1955 [01:49<00:00, 17.89it/s]
100%|██████████| 2179/2179 [01:43<00:00, 21.09it/s]


In [None]:
len(val_dataset), len(val_loader)

(1955, 489)

In [38]:
# Create an iterator from the validation DataLoader
val_loader_iter = iter(val_loader)

# Fetch the first batch from the iterator
first_batch = next(val_loader_iter)

# Print or inspect the first batch
print(first_batch)

In [37]:
first_batch

[tensor([[364,  53, 256,  ..., 258, 221, 256],
         [386,  52, 316,  ..., 389, 389, 389],
         [377,  75, 374,  ..., 364,  59, 256],
         [372,  64, 372,  ..., 371,  93, 256]]),
 tensor([[[  0,   0],
          [  1,   0],
          [  1,   0],
          ...,
          [  3,   9],
          [  2,   9],
          [  2,   9]],
 
         [[  0,   0],
          [  1,   0],
          [  1,   0],
          ...,
          [ 10, 127],
          [ 10, 127],
          [ 10, 127]],
 
         [[  0,   0],
          [  1,   0],
          [  1,   0],
          ...,
          [  1,  10],
          [  2,  10],
          [  2,  10]],
 
         [[  0,   0],
          [  1,   0],
          [  1,   0],
          ...,
          [  5,  29],
          [  6,  29],
          [  6,  29]]]),
 tensor([[ 0, 14],
         [11,  0],
         [ 0,  0],
         [ 8,  0]]),
 tensor([[ 53, 256, 365,  ..., 221, 256, 207],
         [ 52, 316, 180,  ..., 389, 389, 389],
         [ 75, 374,  56,  ...,  59, 25

In [40]:
first_batch[0][0], first_batch[1][0], first_batch[2][0], first_batch[3][0]

(tensor([364,  53, 256,  ..., 258, 221, 256]),
 tensor([[0, 0],
         [1, 0],
         [1, 0],
         ...,
         [3, 9],
         [2, 9],
         [2, 9]]),
 tensor([ 0, 14]),
 tensor([ 53, 256, 365,  ..., 221, 256, 207]))

In [71]:
val_dataset[1]

(tensor([386,  52, 316,  ..., 389, 389, 389]),
 tensor([[  0,   0],
         [  1,   0],
         [  1,   0],
         ...,
         [ 10, 127],
         [ 10, 127],
         [ 10, 127]]),
 tensor([11,  0]),
 tensor([ 52, 316, 180,  ..., 389, 389, 389]))

In [77]:
from collections import Counter
Counter(val_dataset[3][1][:, 0].numpy()), Counter(val_dataset[5][1][:, 1].numpy())

(Counter({0: 37,
          1: 76,
          2: 102,
          3: 133,
          4: 173,
          5: 187,
          6: 217,
          7: 289,
          8: 350,
          9: 307,
          10: 177}),
 Counter({0: 16,
          1: 13,
          2: 9,
          3: 13,
          4: 9,
          5: 11,
          6: 12,
          7: 5,
          8: 4,
          9: 13,
          10: 8,
          11: 14,
          12: 18,
          13: 8,
          14: 14,
          15: 6,
          16: 17,
          17: 8,
          18: 14,
          19: 9,
          20: 14,
          21: 8,
          22: 17,
          23: 8,
          24: 14,
          25: 18,
          26: 8,
          27: 14,
          28: 8,
          29: 14,
          30: 6,
          31: 21,
          32: 8,
          33: 17,
          34: 8,
          35: 16,
          36: 8,
          37: 7,
          38: 13,
          39: 8,
          40: 14,
          41: 8,
          42: 14,
          43: 8,
          44: 18,
          45: 5,
     