In [4]:
import pretty_midi
import numpy as np

W = 16      # timesteps per bar (16 = 16th notes)
H = 128     # midi pitches

def midi_to_bars(p: pretty_midi.PrettyMIDI, tempo=120, time_sig=(4,4)):
    # Simplified: assume known bar boundaries by time signature & tempo
    # Convert to piano-roll per bar at 16 steps per bar
    # Return list of arrays shape (H, W) for melody channel
    pass


In [5]:
# models.py
import torch
import torch.nn as nn
import torch.nn.functional as F

# helper: conv block
def conv_block(in_ch, out_ch, kernel, stride, padding, bn=True):
    layers = [nn.Conv2d(in_ch, out_ch, kernel, stride, padding)]
    if bn:
        layers.append(nn.BatchNorm2d(out_ch))
    layers.append(nn.LeakyReLU(0.2, inplace=True))
    return nn.Sequential(*layers)

# helper: deconv block
def deconv_block(in_ch, out_ch, kernel, stride, padding, bn=True):
    layers = [nn.ConvTranspose2d(in_ch, out_ch, kernel, stride, padding)]
    if bn:
        layers.append(nn.BatchNorm2d(out_ch))
    layers.append(nn.ReLU(inplace=True))
    return nn.Sequential(*layers)

class ConditionerCNN(nn.Module):
    """
    Takes a 2D piano-roll (H x W) and produces intermediate feature maps.
    We'll build it so it outputs feature maps matching generator layers.
    """
    def __init__(self, h=128, w=16, base_filters=64):
        super().__init__()
        # Input shape: (batch, 1, H, W)
        # We'll compress vertical dimension progressively: convs chosen to match G transpose shapes.
        self.conv1 = nn.Conv2d(1, base_filters, kernel_size=(128,1), stride=(1,1))  # -> (B, base, 1, W)
        self.conv2 = nn.Conv2d(base_filters, base_filters, kernel_size=(1,2), stride=(1,2), padding=0)
        self.conv3 = nn.Conv2d(base_filters, base_filters, kernel_size=(1,2), stride=(1,2))
        self.conv4 = nn.Conv2d(base_filters, base_filters, kernel_size=(1,2), stride=(1,2))
        self.act = nn.ReLU()

    def forward(self, x):
        # x: (B,1,H,W)
        f1 = self.act(self.conv1(x))   # (B,base,1,W)
        f2 = self.act(self.conv2(f1))  # (B,base,1,W/2)
        f3 = self.act(self.conv3(f2))  # (B,base,1,W/4)
        f4 = self.act(self.conv4(f3))  # (B,base,1,W/8) -> shapes to concat with G
        return [f1, f2, f3, f4]

class Generator(nn.Module):
    def __init__(self, z_dim=100, h=128, w=16, base_filters=128, cond1d_dim=None):
        super().__init__()
        self.z_dim = z_dim
        self.h = h
        self.w = w

        # fc layers
        self.fc1 = nn.Linear(z_dim + (cond1d_dim or 0), 1024)
        self.fc2 = nn.Linear(1024, 512)
        # reshape to (B, C, 1, 2) as in paper; choose C=base_filters
        self.initial_C = base_filters
        self.fc3 = nn.Linear(512, self.initial_C * 1 * 2)

        # transposed conv stack (mirror conditioner convs)
        # input maps shape: (B, initial_C, 1, 2)
        self.deconv1 = deconv_block(self.initial_C, base_filters, kernel_size=(1,2), stride=(1,2), padding=0)
        self.deconv2 = deconv_block(base_filters*2, base_filters, kernel_size=(1,2), stride=(1,2), padding=0) # after concat
        self.deconv3 = deconv_block(base_filters*2, base_filters, kernel_size=(1,2), stride=(1,2), padding=0)
        # last layer to expand to H x 16: kernel (H,1) with stride 1
        self.deconv4 = nn.ConvTranspose2d(base_filters*2, 1, kernel_size=(h,1), stride=(1,1))

        self.tanh = nn.Sigmoid()  # output in [0,1] for binary piano roll

    def forward(self, z, cond2d_features=None, cond1d=None):
        # z: (B, z_dim), cond2d_features: list of features from Conditioner CNN
        if cond1d is not None:
            z = torch.cat([z, cond1d], dim=1)
        x = F.relu(self.fc1(z))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        B = x.size(0)
        x = x.view(B, self.initial_C, 1, 2)  # (B, C, 1, 2)

        # deconv1: upsamples width -> size (B,base,1,4)
        x = self.deconv1(x)
        # concatenate cond feature at matching scale (if provided)
        if cond2d_features is not None:
            # cond2d_features[3] should match this spatial sizing; adjust ordering if necessary
            c = cond2d_features[3]
            x = torch.cat([x, c], dim=1)

        x = self.deconv2(x)
        if cond2d_features is not None:
            c = cond2d_features[2]
            x = torch.cat([x, c], dim=1)

        x = self.deconv3(x)
        if cond2d_features is not None:
            c = cond2d_features[1]
            x = torch.cat([x, c], dim=1)

        # final projection to H x W
        x = self.deconv4(x)  # -> (B,1,H,W)
        x = self.tanh(x)     # [0,1] activations
        # optionally force monophony per time-step by argmax along pitch dim in postprocessing
        return x.squeeze(1)  # (B, H, W)

class Discriminator(nn.Module):
    def __init__(self, h=128, w=16, base_filters=64, chord_dim=None):
        super().__init__()
        # Input: (B,1,H,W). If chord_dim provided, we'll embed and tile to spatial map or concat at fc stage.
        self.conv1 = nn.Conv2d(1, 14, kernel_size=(h,2), stride=(1,2))  # as in paper first conv
        self.conv2 = nn.Conv2d(14, 77, kernel_size=(1,4), stride=(1,2))
        self.fc = nn.Linear(77 * 1 * 1 + (chord_dim or 0), 1024)  # shapes depend on conv outputs
        self.out = nn.Linear(1024, 1)
        self.act = nn.LeakyReLU(0.2)
        self.sig = nn.Sigmoid()

    def forward(self, x, chord=None):
        # x: (B,H,W) -> add channel
        x = x.unsqueeze(1)
        f = self.act(self.conv1(x))
        f = self.act(self.conv2(f))
        f = f.view(f.size(0), -1)
        if chord is not None:
            f = torch.cat([f, chord], dim=1)
        f = self.act(self.fc(f))
        return self.sig(self.out(f)).view(-1)
