In [1]:
import tarfile

file = tarfile.open('./jsb_chorales.tgz', 'r')
file.extractall()

In [34]:
import os
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader

DATASET_DIR = './jsb_chorales'

class BachChorales(Dataset):
    def __init__(self, path, window_size=32, transforms=None):
        self.path = path
        self.window_size = window_size
        self.transforms = transforms
        self.chorales = [pd.read_csv(os.path.join(path, p)) for p in os.listdir(os.path.join(path))]
    
    def __len__(self):
        return len(self.chorales)
    
    def __getitem__(self, idx):
        chorale = self.chorales[idx]
        chorale = np.array(chorale)
        chorale = chorale.reshape(-1,1)
        chorale = chorale[:self.window_size*4+1] # Add one for target
        chorale -= 36 # Subtract lowest node so all nodes are in range 0,46
        chorale, target = chorale[:-1], chorale[-1]
        sample = (chorale, target)
        if self.transforms:
            sample = self.transforms(sample)
        return sample

In [41]:
import torch
import torch.nn as nn
from torchvision.transforms import Compose

model = nn.Sequential(*[
    nn.RNN(1,1,20,batch_first=True)
])

In [42]:
class ToTensor(object):
    def __call__(self, sample):
        chorale, target = sample
        chorale = torch.tensor(chorale, dtype=torch.float)
        target = torch.tensor(target, dtype=torch.float)
        return (chorale, target)

transforms = Compose([
    ToTensor()
])

dataset = BachChorales(os.path.join(DATASET_DIR,'train'), transforms=transforms)
loader = DataLoader(dataset, batch_size=32)
chorales, targets = next(iter(loader))
print(chorales.shape, targets.shape)
out, hc = model(chorales)
out.shape

torch.Size([32, 128, 1]) torch.Size([32, 1])


torch.Size([32, 128, 1])

In [43]:
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

for epoch in range(50):
    for chorales, target in loader:
        out, _ = model(chorales)
        out = out[:,-1]
        loss = criterion(out, target)
        loss.backward()
        optimizer.step()
        print(loss.item())
        optimizer.zero_grad()

1208.2098388671875
1274.236083984375
1236.970703125
1238.7669677734375
1276.3419189453125
1245.673095703125
1216.258544921875
1271.2977294921875
1196.7625732421875
1262.7218017578125
1225.8726806640625
1227.8597412109375
1265.5177001953125
1235.133056640625
1206.039306640625
1260.99560546875
1186.9840087890625
1252.85400390625
1216.3544921875
1218.508056640625
1256.249267578125
1226.125732421875
1197.326171875
1252.227294921875
1178.6734619140625
1244.472900390625
1208.2734375
1210.56884765625
1248.3807373046875
1218.48095703125
1189.9359130859375
1244.7991943359375
1171.6485595703125
1237.4097900390625
1201.4910888671875
1203.9403076171875
1241.8529052734375
1212.187255859375
1183.905517578125
1238.797607421875
1166.0355224609375
1231.8323974609375
1196.203857421875
1198.8424072265625
1236.90234375
1207.4832763671875
1179.4644775390625
1234.442626953125
1162.0230712890625
1227.9036865234375
1192.5335693359375
1195.3529052734375
1233.559814453125
1204.348388671875
1176.5419921875
1231.

KeyboardInterrupt: 