# Assignment 3: MuseGAN

In [1]:
import random

import musegan
import numpy as np
import torch as t
from torch.utils.data import DataLoader
from torchinfo import summary

N_BARS = 4
N_TRACKS = 5
N_STEPS_PER_BAR = 48
BATCH_SIZE = 16
Z_DIM = 32
DEVICE = "cuda:0"

t.random.manual_seed(0x0D000721)
random.seed(0x0D000721)
np.random.seed(0x0D000721)

## 3.1. Model structure

### 3.1.1. Temporal network

In [2]:
temp_net = musegan.temporal.TemporalNetwork(
    z_dimension=Z_DIM, hid_channels=1024, n_bars=N_BARS
)
summary(temp_net, input_size=(BATCH_SIZE, Z_DIM))

Layer (type:depth-idx)                   Output Shape              Param #
TemporalNetwork                          [16, 32, 4]               --
├─Sequential: 1-1                        [16, 32, 4]               --
│    └─Reshape: 2-1                      [16, 32, 1, 1]            --
│    └─ConvTranspose2d: 2-2              [16, 1024, 2, 1]          66,560
│    └─BatchNorm2d: 2-3                  [16, 1024, 2, 1]          2,048
│    └─ReLU: 2-4                         [16, 1024, 2, 1]          --
│    └─ConvTranspose2d: 2-5              [16, 32, 4, 1]            98,336
│    └─BatchNorm2d: 2-6                  [16, 32, 4, 1]            64
│    └─ReLU: 2-7                         [16, 32, 4, 1]            --
│    └─Reshape: 2-8                      [16, 32, 4]               --
Total params: 167,008
Trainable params: 167,008
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 8.46
Input size (MB): 0.00
Forward/backward pass size (MB): 0.56
Params size (MB): 0.67
Estimated Total Siz

### 3.1.2. Bar generator

In [3]:
HID_FEATURES = 1152
HID_CHANNELS = 192
N_PITCHES = 84

bar_gen = musegan.bar_generator.BarGenerator(
    z_dimension=Z_DIM,
    hid_features=HID_FEATURES,
    hid_channels=HID_CHANNELS,
    n_steps_per_bar=N_STEPS_PER_BAR,
    n_pitches=N_PITCHES,
)
summary(bar_gen, input_size=(BATCH_SIZE, 4 * Z_DIM))

Layer (type:depth-idx)                   Output Shape              Param #
BarGenerator                             [16, 1, 1, 48, 84]        --
├─Sequential: 1-1                        [16, 1, 1, 48, 84]        --
│    └─Linear: 2-1                       [16, 1152]                148,608
│    └─BatchNorm1d: 2-2                  [16, 1152]                2,304
│    └─ReLU: 2-3                         [16, 1152]                --
│    └─Reshape: 2-4                      [16, 192, 6, 1]           --
│    └─ConvTranspose2d: 2-5              [16, 192, 12, 1]          73,920
│    └─BatchNorm2d: 2-6                  [16, 192, 12, 1]          384
│    └─ReLU: 2-7                         [16, 192, 12, 1]          --
│    └─ConvTranspose2d: 2-8              [16, 96, 24, 1]           36,960
│    └─BatchNorm2d: 2-9                  [16, 96, 24, 1]           192
│    └─ReLU: 2-10                        [16, 96, 24, 1]           --
│    └─ConvTranspose2d: 2-11             [16, 96, 48, 1]           

### 3.1.3. Generator

In [4]:
muse_gen = musegan.generator.MuseGenerator(
    z_dimension=Z_DIM,
    hid_channels=HID_CHANNELS * 2,
    hid_features=HID_FEATURES,
    n_tracks=N_TRACKS,
    n_bars=N_BARS,
    n_steps_per_bar=N_STEPS_PER_BAR,
    n_pitches=N_PITCHES,
)
summary(
    muse_gen,
    input_size=(
        (BATCH_SIZE, Z_DIM),
        (BATCH_SIZE, Z_DIM),
        (BATCH_SIZE, N_TRACKS, Z_DIM),
        (BATCH_SIZE, N_TRACKS, Z_DIM),
    ),
)

Layer (type:depth-idx)                        Output Shape              Param #
MuseGenerator                                 [16, 5, 4, 48, 84]        --
├─TemporalNetwork: 1-1                        [16, 32, 4]               --
│    └─Sequential: 2-1                        [16, 32, 4]               --
│    │    └─Reshape: 3-1                      [16, 32, 1, 1]            --
│    │    └─ConvTranspose2d: 3-2              [16, 384, 2, 1]           24,960
│    │    └─BatchNorm2d: 3-3                  [16, 384, 2, 1]           768
│    │    └─ReLU: 3-4                         [16, 384, 2, 1]           --
│    │    └─ConvTranspose2d: 3-5              [16, 32, 4, 1]            36,896
│    │    └─BatchNorm2d: 3-6                  [16, 32, 4, 1]            64
│    │    └─ReLU: 3-7                         [16, 32, 4, 1]            --
│    │    └─Reshape: 3-8                      [16, 32, 4]               --
├─ModuleDict: 1-40                            --                        (recursive)
│ 

### 3.1.4. Discriminator

In [5]:
critic = musegan.critic.MuseCritic(
    hid_channels=128,
    n_tracks=N_TRACKS,
    n_bars=N_BARS,
    n_steps_per_bar=N_STEPS_PER_BAR,
    n_pitches=N_PITCHES,
)
summary(critic, input_size=(BATCH_SIZE, N_TRACKS, N_BARS, N_STEPS_PER_BAR, N_PITCHES))

Layer (type:depth-idx)                   Output Shape              Param #
MuseCritic                               [16, 1]                   --
├─Sequential: 1-1                        [16, 1]                   --
│    └─Conv3d: 2-1                       [16, 128, 3, 48, 84]      1,408
│    └─LeakyReLU: 2-2                    [16, 128, 3, 48, 84]      --
│    └─Conv3d: 2-3                       [16, 128, 1, 48, 84]      49,280
│    └─LeakyReLU: 2-4                    [16, 128, 1, 48, 84]      --
│    └─Conv3d: 2-5                       [16, 128, 1, 48, 7]       196,736
│    └─LeakyReLU: 2-6                    [16, 128, 1, 48, 7]       --
│    └─Conv3d: 2-7                       [16, 128, 1, 48, 1]       114,816
│    └─LeakyReLU: 2-8                    [16, 128, 1, 48, 1]       --
│    └─Conv3d: 2-9                       [16, 128, 1, 24, 1]       32,896
│    └─LeakyReLU: 2-10                   [16, 128, 1, 24, 1]       --
│    └─Conv3d: 2-11                      [16, 128, 1, 12, 1]    

## 3.2. Dataset

In [6]:
DATASET_PATH = "prepared/train_x_lpd_5_phr.npz"

In [7]:
def seed_worker(_):
    worker_seed = t.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)


g = t.Generator()
g.manual_seed(0x0D000721)

dataset = musegan.dataset.LPDDataset(DATASET_PATH)
loader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True,
    worker_init_fn=seed_worker,
    generator=g,
)
print(dataset[0].shape)
print(len(dataset))

torch.Size([5, 4, 48, 84])
102378
