In [None]:
import numpy as np 
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
import music21
import matplotlib as plt
import torchvision.utils as vutils

import torch.backends.cudnn as cudnn
torch.cuda.empty_cache()
cudnn.benchmark = True  # Optimise for hardware

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
BATCH_SIZE = 8
EPOCHS = 400
NOISE_DIM= 100
NUM_CLASSES = 18
BETA1 = 0.5 # Hyperparamter for adam optimizer
LR = 0.004 # Might need to adjust
EMBEDDING_DIM = 50

In [None]:
import torch
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(NOISE_DIM, 128 * 8 * 16),
            nn.ReLU(True),
            nn.BatchNorm1d(128 * 8 * 16)  # Batch Normalization
        )
        
        self.conv1 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)  # Upsample to 16x32
        self.conv1_bn = nn.BatchNorm2d(64)
        
        self.conv2 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1)   # Upsample to 32x64
        self.conv2_bn = nn.BatchNorm2d(32)
        
        self.conv3 = nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1)     # Upsample to 64x128
        self.conv4 = nn.ConvTranspose2d(1, 1, kernel_size=4, stride=2, padding=1)      # Upsample to 128x256

        self.tanh_layer = nn.Tanh()

    def forward(self, noise):
        x = self.fc(noise)
        x = x.view(-1, 128, 8, 16)  # Reshape for convolution
        x = self.conv1(x)
        x = self.conv1_bn(x)
        x = nn.ReLU()(x)
        
        x = self.conv2(x)
        x = self.conv2_bn(x)
        x = nn.ReLU()(x)
        
        x = self.conv3(x)
        x = nn.ReLU()(x)
        
        out = self.conv4(x)
        output = torch.sigmoid(out)
        
        return output

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1)  # (N, 64, 64, 128)
        self.conv1_bn = nn.BatchNorm2d(64)
        
        self.conv2 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1) # (N, 128, 32, 64)
        self.conv2_bn = nn.BatchNorm2d(128)
        
        self.fc = nn.Linear(128 * 32 * 64, 1)  # Output layer

    def forward(self, img):
        x = self.conv1(img)
        
        x = self.conv1_bn(x)
        x = nn.LeakyReLU(0.2)(x)  # (N, 64, 64, 128)
       
        x = self.conv2(x)
        
        x = self.conv2_bn(x)
        x = nn.LeakyReLU(0.2)(x)  # (N, 128, 32, 64)
        
        
        x = x.view(x.size(0), -1)  # Flatten
        x = torch.sigmoid(self.fc(x))
        
        
        return x  # Sigmoid for binary classification


In [None]:
from torch.utils.data import DataLoader, TensorDataset

inputs_seq = torch.load("Input_tensors.pt")
labels_seq = torch.load("Labels_tensors.pt")
dataset = TensorDataset(inputs_seq, labels_seq)

#Split into batches
batch_size = 16
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# Check that the data is loaded correctly
print("Number of input pianorolls: ", len(dataset))

In [None]:
import torch.optim as optim

# Initialize models and optimizers
netG = Generator().to(device)
netD = Discriminator().to(device)
gen = []
criterion = nn.BCELoss()
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))

G_losses = []
D_losses = []

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

netG.apply(weights_init)
netD.apply(weights_init)
print("Starting Training Loop...")
# For each epoch
for epoch in range(10):
    for i, data in enumerate(dataloader):
        # Prepare real data
        real_cpu = data[0].to(device).float().unsqueeze(1)
        b_size = real_cpu.size(0)
        
        #if b_size != BATCH_SIZE:
           # print(f"Skipping batch {i} due to insufficient size: {b_size}")
            #continue
        
        # Labels for real and fake
        real_label = torch.full((b_size,), 0.9, dtype=torch.float, device=device)
        fake_label = torch.full((b_size,), 0.15, dtype=torch.float, device=device)

        # (1) Update D network
        netD.zero_grad()
        
        # Train with real data
        output = netD(real_cpu).view(-1)
        errD_real = criterion(output, real_label)
        errD_real.backward()

        # Train with fake data
        noise = torch.randn(b_size, NOISE_DIM, device=device)
        fake = netG(noise)
        output = netD(fake.detach()).view(-1)
        errD_fake = criterion(output, fake_label)
        errD_fake.backward()

        # Update D
        errD = errD_real + errD_fake
        optimizerD.step()

        # (2) Update G network
        netG.zero_grad()
        output = netD(fake).view(-1)
        errG = criterion(output, real_label)
        errG.backward()
        optimizerG.step()

        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # Output training stats
        if i % 2 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f'
                  % (epoch, 100, i, len(dataloader), errD.item(), errG.item()))
            
        
        # Generate samples after a certain number of epochs
        with torch.no_grad():
            fixed_noise = torch.randn(BATCH_SIZE, NOISE_DIM, device=device)
            generated_samples = netG(fixed_noise)
            gen.append(generated_samples)
            # Here you can visualize `generated_samples` using matplotlib or similar

In [None]:
import music21 as m21
import os
OUTPUT_FILES = "training_files"
TIME_STEP = 0.25
def convert_stream(matrix, format="midi", file_name='output.mid',filepath=OUTPUT_FILES, step_duration=TIME_STEP):
    """
    Converts the piano roll matrix back into a music 21 stream. Writes this stream to a midi file.
    :params matrix: 2D piano roll matrix
    :params format: format file type to write the stream
    :params file_name: the file name of the output file
    :params filepath: the output path of the directory holding the output files
    :params step_duration: the size of the step on the x axis of the piano roll matrix
    :returns None.
    """
    
    # Find the shape of the input matrix
    
    rows, cols = matrix.shape
    matrix = (matrix > 0.51).float()
    nulls = np.zeros((rows, 1))
    matrix = np.hstack((matrix, nulls))
    # Create two dictionaries. The first holds the notes that are on. The second holds each 'finished' note and its offset
    active_notes = {}
    note_list = {}

    # Iterates through every member in the matrix
    for col in range(cols - 1):
        for row in range(rows - 1):
            # Finds the midi pitch and creates a new note to represent the pitch and duration
            midi_pitch = row
            note = m21.note.Note(midi_pitch)
            note.quarterLength = step_duration

            # If this note is 'on':
            if matrix[row, col] == 1:
                # remove midi pitches outside the range of a piano
                if midi_pitch < 21 or midi_pitch > 95:
                    continue
                # Checks if the note has already been turned on, or is active.
                if midi_pitch in active_notes:
                    
                    # If already active, updates the step duration of the note in the dictionary
                    lst = active_notes[midi_pitch]
                    lst[0] = lst[0] + step_duration
                    active_notes[midi_pitch] = lst

                # If newly activated, then adds the note duration and offset items to the midi pitch key in the dictionary
                else:
                    note.offset = col * step_duration
                    active_notes[midi_pitch] = [note.quarterLength, note.offset]

            # If the member is off but still in acitve notes, creates a new note and removes it from the dictionary
            elif midi_pitch in active_notes:
                # Grabs the duration and offset of the note and creates a new note object with duraiton, offset, midi pitch attributes
                lst = active_notes[midi_pitch]
                note = m21.note.Note( midi_pitch)
                note.quarterLength = lst[0]
                note.offset = lst[1]
                # Adds this note to the note dictionary based off of the offset
                note_list[note.offset] = note
                del active_notes[midi_pitch]
                    
    # Creates a new stream and grabs the keys (offsets) and values (note onjects) from the note list dictionary
    new_stream = m21.stream.Stream()
    keys = list(note_list.keys())
    notes = list(note_list.values())

    # Iterates through every item in the dictionary
    for i in range(len(note_list)):
        # Inserts the note based off of its offset
        new_stream.insert(keys[i], notes[i])

    # Creates the filepath for the output file
    path = os.path.join(filepath, file_name)

    # Makes the directory if it doesn't exist
    os.makedirs(filepath, exist_ok=True)

    # Writes the stream as a midi file to the path
    new_stream.write(format, fp=path)

fixed_noise = torch.randn(BATCH_SIZE, NOISE_DIM, device=device)

output = netG(fixed_noise)
print(output)
song = output[0][0]
print(song.shape)
convert_stream(song.cpu(), file_name="output8.mid")

In [None]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()