# Setup


In [28]:
import os
import torch
ROOT_DIR=""
DATA_DIR=os.path.join(ROOT_DIR,"data") ##Directory of dataset
DATASET_FILE=os.path.join(DATA_DIR,"train_x_lpd_5_phr.npz")
# DATASET_FILE=os.path.join(DATA_DIR,"Jsb16thSeparated.npz")
EXPERIMENTS_DIR=os.path.join(ROOT_DIR, "logs/experiments")
use_cuda = torch .cuda.is_available()
DEVICE = torch.device("cuda" if use_cuda else "cpu")


False

## Utils

In [15]:
import logging
import os
import sys
from time import strftime
def setup_logger(args):
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    a_logger = logging.getLogger()
    a_logger.setLevel(args.log_level)
    log_dir=os.path.join(ROOT_DIR,"logs","output_logs")
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    output_file_handler = logging.FileHandler(os.path.join(log_dir,strftime("log_%d_%m_%Y_%H_%M.log")))
    stdout_handler = logging.StreamHandler(sys.stdout)
    stdout_handler.setFormatter(formatter)
    a_logger.propagate=False
    a_logger.addHandler(output_file_handler)
    a_logger.addHandler(stdout_handler)

import json
import os
from enum import Enum
from itertools import islice
import numpy as np

def read_json(path_json):
    with open(path_json, encoding='utf8') as json_file:
        return json.load(json_file)
def softmax(x):
    return np.exp(x) / np.sum(np.exp(x))
def chunks(data, SIZE):
    """Split a dictionnary into parts of max_size =SIZE"""
    it = iter(data)
    for _ in range(0, len(data), SIZE):
        yield {k: data[k] for k in islice(it, SIZE)}

def sorted_dict(x, ascending=True):
    """
    Sort dict according to value.
    x must be a primitive type: int,float, str...
    @param x:
    @return:
    """
    return dict(sorted(x.items(), key=lambda item: (1 if ascending else -1) * item[1]))
def reverse_dict(input_dict):
    """
    Reverse a dictonary
    Args:
        input_dict:

    Returns:

    """
    inv_dict = {}
    for k, v in input_dict.items():
        inv_dict[v] = inv_dict.get(v, []) + [k]

    return inv_dict

def save_matrix(matrix,filename):
    with open(filename,'wb') as output:
        np.save(output,matrix)
def load_matrix(filename,auto_delete=False):
    with open(filename,'rb') as input:
        matrix=np.load(input)

    if auto_delete:
        os.remove(filename)
    return matrix



class Averager:
    def __init__(self):
        self.current_total = 0.0
        self.iterations = 0.0

    def send(self, value):
        self.current_total += value
        self.iterations += 1

    @property
    def value(self):
        if self.iterations == 0:
            return 0
        else:
            return 1.0 * self.current_total / self.iterations

    def reset(self):
        self.current_total = 0.0
        self.iterations = 0.0

# Network

In [16]:
"""Utils."""
from typing import List
from torch import Tensor
from torch import nn
def initialize_weights(layer: nn.Module, mean: float = 0.0, std: float = 0.02):
    """Initialize module with normal distribution.

    Parameters
    ----------
    layer: nn.Module
        Layer.
    mean: float, (default=0.0)
        Mean value.
    std: float, (default=0.02)
        Standard deviation value.

    """
    if isinstance(layer, (nn.Conv3d, nn.ConvTranspose2d)):
        torch.nn.init.normal_(layer.weight, mean, std)
    elif isinstance(layer, (nn.Linear, nn.BatchNorm2d)):
        torch.nn.init.normal_(layer.weight, mean, std)
        torch.nn.init.constant_(layer.bias, 0)


class Reshape(nn.Module):
    """Reshape layer.

    Parameters
    ----------
    shape: List[int]
        Dimensions after number of batches.

    """

    def __init__(self, shape: List[int]) -> None:
        """Initialize."""
        super().__init__()
        self.shape = shape

    def forward(self, x: Tensor) -> Tensor:
        """Perform forward.

        Parameters
        ----------
        x: Tensor
            Input batch.

        Returns
        -------
        Tensor:
            Preprocessed input batch.

        """
        return x.view(x.size(0), *self.shape)


In [17]:
"""Bar Generator."""

from torch import Tensor
from torch import nn
class BarGenerator(nn.Module):
    """Bar generator.

    Parameters
    ----------
    z_dimension: int, (default=32)
        Noise space dimension.
    hid_channels: int, (default=1024)
        Number of hidden channels.
    hid_features: int, (default=1024)
        Number of hidden features.
    out_channels: int, (default=1)
        Number of output channels.

    """

    def __init__(
        self,
        z_dimension: int = 32,
        hid_features: int = 1024,
        hid_channels: int = 512,
        out_channels: int = 1,
        n_steps_per_bar = 16,
        n_pitches = 84,
    ) -> None:
        """Initialize."""
        super().__init__()
        self.n_steps_per_bar = n_steps_per_bar
        self.n_pitches = n_pitches
        self.net = nn.Sequential(
            # input shape: (batch_size, 4*z_dimension)
            nn.Linear(4 * z_dimension, hid_features),
            nn.BatchNorm1d(hid_features),
            nn.ReLU(inplace=True),
            # output shape: (batch_size, hid_features)
            Reshape(shape=[hid_channels, hid_features // hid_channels, 1]),
            # output shape: (batch_size, hid_channels, hid_features//hid_channels, 1)
            nn.ConvTranspose2d(
                hid_channels,
                hid_channels,
                kernel_size=(2, 1),
                stride=(2, 1),
                padding=0,
            ),
            nn.BatchNorm2d(hid_channels),
            nn.ReLU(inplace=True),
            # output shape: (batch_size, hid_channels, 2*hid_features//hid_channels, 1)
            nn.ConvTranspose2d(
                hid_channels,
                hid_channels // 2,
                kernel_size=(2, 1),
                stride=(2, 1),
                padding=0,
            ),
            nn.BatchNorm2d(hid_channels // 2),
            nn.ReLU(inplace=True),
            # output shape: (batch_size, hid_channels//2, 4*hid_features//hid_channels, 1)
            nn.ConvTranspose2d(
                hid_channels // 2,
                hid_channels // 2,
                kernel_size=(2, 1),
                stride=(2, 1),
                padding=0,
            ),
            nn.BatchNorm2d(hid_channels // 2),
            nn.ReLU(inplace=True),
            # output shape: (batch_size, hid_channels//2, 8*hid_features//hid_channels, 1)
            nn.ConvTranspose2d(
                hid_channels // 2,
                hid_channels // 2,
                kernel_size=(1, 7),
                stride=(1, 7),
                padding=0,
            ),
            nn.BatchNorm2d(hid_channels // 2),
            nn.ReLU(inplace=True),
            # output shape: (batch_size, hid_channels//2, 8*hid_features//hid_channels, 7)
            nn.ConvTranspose2d(
                hid_channels // 2,
                out_channels,
                kernel_size=(1, 12),
                stride=(1, 12),
                padding=0,
            ),
            # output shape: (batch_size, out_channels, 8*hid_features//hid_channels, n_pitches)
            Reshape(shape=[1, 1, self.n_steps_per_bar, self.n_pitches])
            # output shape: (batch_size, out_channels, 1, n_steps_per_bar, n_pitches)
        )

    def forward(self, x: Tensor) -> Tensor:
        """Perform forward.

        Parameters
        ----------
        x: Tensor
            Input batch.

        Returns
        -------
        Tensor:
            Preprocessed input batch.

        """
        fx = self.net(x)
        return fx


In [18]:
"""Temporal Network."""

from torch import Tensor

import torch
from torch import nn
class TemporalNetwork(nn.Module):
    """Temporal network.

    Parameters
    ----------
    z_dimension: int, (default=32)
        Noise space dimension.
    hid_channels: int, (default=1024)
        Number of hidden channels.

    """

    def __init__(
        self,
        z_dimension: int = 32,
        hid_channels: int = 1024,
        n_bars: int = 2,
    ) -> None:
        """Initialize."""
        super().__init__()
        self.n_bars = n_bars
        self.net = nn.Sequential(
            # input shape: (batch_size, z_dimension)
            Reshape(shape=[z_dimension, 1, 1]),
            # output shape: (batch_size, z_dimension, 1, 1)
            nn.ConvTranspose2d(
                z_dimension,
                hid_channels,
                kernel_size=(2, 1),
                stride=(1, 1),
                padding=0,
            ),
            nn.BatchNorm2d(hid_channels),
            nn.ReLU(inplace=True),
            # output shape: (batch_size, hid_channels, 2, 1)
            nn.ConvTranspose2d(
                hid_channels,
                z_dimension,
                kernel_size=(self.n_bars - 1, 1),
                stride=(1, 1),
                padding=0,
            ),
            nn.BatchNorm2d(z_dimension),
            nn.ReLU(inplace=True),
            # output shape: (batch_size, z_dimension, 1, 1)
            Reshape(shape=[z_dimension, self.n_bars]),
        )

    def forward(self, x: Tensor) -> Tensor:
        """Perform forward.

        Parameters
        ----------
        x: Tensor
            Input batch.

        Returns
        -------
        Tensor:
            Preprocessed input batch.

        """
        fx = self.net(x)
        return fx


In [19]:
"""Muse Generator."""

from torch import Tensor

import torch
from torch import nn
class MuseGenerator(nn.Module):
    """Muse generator.

    Parameters
    ----------
    z_dimension: int, (default=32)
        Noise space dimension.
    hid_channels: int, (default=1024)
        Number of hidden channels.
    hid_features: int, (default=1024)
        Number of hidden features.
    out_channels: int, (default=1)
        Number of output channels.

    """

    def __init__(
        self,
        z_dimension: int = 32,
        hid_channels: int = 1024,
        hid_features: int = 1024,
        out_channels: int = 1,
        n_tracks: int = 4,
        n_bars: int = 2,
        n_steps_per_bar: int = 16,
        n_pitches: int = 84,
    ) -> None:
        """Initialize."""
        super().__init__()
        self.n_tracks = n_tracks
        self.n_bars = n_bars
        self.n_steps_per_bar = n_steps_per_bar
        self.n_pitches = n_pitches
        # chords generator
        self.chords_network = TemporalNetwork(z_dimension, hid_channels, n_bars=n_bars)
        # melody generators
        self.melody_networks = nn.ModuleDict({})
        for n in range(self.n_tracks):
            self.melody_networks.add_module(
                "melodygen_" + str(n),
                TemporalNetwork(z_dimension, hid_channels, n_bars=n_bars),
            )
        # bar generators
        self.bar_generators = nn.ModuleDict({})
        for n in range(self.n_tracks):
            self.bar_generators.add_module(
                "bargen_" + str(n),
                BarGenerator(
                    z_dimension,
                    hid_features,
                    hid_channels // 2,
                    out_channels,
                    n_steps_per_bar=n_steps_per_bar,
                    n_pitches=n_pitches,
                )
            )
        # musegan generator compiled

    def forward(self, chords: Tensor, style: Tensor, melody: Tensor, groove: Tensor) -> Tensor:
        """Perform forward.

        Parameters
        ----------
        chords: Tensor
            Chords.
        style: Tensor
            Style.
        melody: Tensor
            Melody.
        groove: Tensor
            Groove.

        Returns
        -------
        Tensor:
            Preprocessed input batch.

        """
        # chords shape: (batch_size, z_dimension)
        # style shape: (batch_size, z_dimension)
        # melody shape: (batch_size, n_tracks, z_dimension)
        # groove shape: (batch_size, n_tracks, z_dimension)
        chord_outs = self.chords_network(chords)
        bar_outs = []
        for bar in range(self.n_bars):
            track_outs = []
            chord_out = chord_outs[:, :, bar]
            style_out = style
            for track in range(self.n_tracks):
                melody_in = melody[:, track, :]
                melody_out = self.melody_networks["melodygen_" + str(track)](melody_in)[:, :, bar]
                groove_out = groove[:, track, :]
                z = torch.cat([chord_out, style_out, melody_out, groove_out], dim=1)
                track_outs.append(self.bar_generators["bargen_" + str(track)](z))
            track_out = torch.cat(track_outs, dim=1)
            bar_outs.append(track_out)
        out = torch.cat(bar_outs, dim=2)
        # out shape: (batch_size, n_tracks, n_bars, n_steps_per_bar, n_pitches)
        return out


In [20]:
"""Muse critic."""

from torch import Tensor
from torch import nn
class MuseCritic(nn.Module):
    """Muse critic.

    Parameters
    ----------
    hid_channels: int, (default=128)
        Number of hidden channels.
    hid_features: int, (default=1024)
        Number of hidden features.
    out_channels: int, (default=1)
        Number of output channels.

    """

    def __init__(
        self,
        hid_channels: int = 128,
        hid_features: int = 1024,
        out_features: int = 1,
        n_tracks: int = 4,
        n_bars: int = 2,
        n_steps_per_bar: int = 16,
        n_pitches: int = 84,
    ) -> None:
        """Initialize."""
        super().__init__()
        self.n_tracks = n_tracks
        self.n_bars = n_bars
        self.n_steps_per_bar = n_steps_per_bar
        self.n_pitches = n_pitches
        in_features = 4 * hid_channels if n_bars == 2 else 12 * hid_channels
        self.net = nn.Sequential(
            # input shape: (batch_size, n_tracks, n_bars, n_steps_per_bar, n_pitches)
            nn.Conv3d(self.n_tracks, hid_channels, (2, 1, 1), (1, 1, 1), padding=0),
            nn.LeakyReLU(0.3, inplace=True),
            # output shape: (batch_size, hid_channels, n_bars//2, n_steps_per_bar, n_pitches)
            nn.Conv3d(hid_channels, hid_channels, (self.n_bars - 1, 1, 1), (1, 1, 1), padding=0),
            nn.LeakyReLU(0.3, inplace=True),
            # output shape: (batch_size, hid_channels, n_bars//2, n_steps_per_bar, n_pitches)
            nn.Conv3d(hid_channels, hid_channels, (1, 1, 12), (1, 1, 12), padding=0),
            nn.LeakyReLU(0.3, inplace=True),
            # output shape: (batch_size, hid_channels, n_bars//2, n_steps_per_bar, n_pitches//12)
            nn.Conv3d(hid_channels, hid_channels, (1, 1, 7), (1, 1, 7), padding=0),
            nn.LeakyReLU(0.3, inplace=True),
            # output shape: (batch_size, hid_channels, n_bars//2, n_steps_per_bar//2, n_pitches//12)
            nn.Conv3d(hid_channels, hid_channels, (1, 2, 1), (1, 2, 1), padding=0),
            nn.LeakyReLU(0.3, inplace=True),
            # output shape: (batch_size, hid_channels, n_bars//2, n_steps_per_bar//4, n_pitches//12)
            nn.Conv3d(hid_channels, hid_channels, (1, 2, 1), (1, 2, 1), padding=0),
            nn.LeakyReLU(0.3, inplace=True),
            # output shape: (batch_size, hid_channels, n_bars//2, n_steps_per_bar//4, n_pitches//12)
            nn.Conv3d(hid_channels, 2 * hid_channels, (1, 4, 1), (1, 2, 1), padding=(0, 1, 0)),
            nn.LeakyReLU(0.3, inplace=True),
            # output shape: (batch_size, hid_channels, n_bars//2, n_steps_per_bar//8, n_pitches//12)
            nn.Conv3d(2 * hid_channels, 4 * hid_channels, (1, 3, 1), (1, 2, 1), padding=(0, 1, 0)),
            nn.LeakyReLU(0.3, inplace=True),
            # output shape: (batch_size, hid_channels, n_bars//2, n_steps_per_bar//16, n_pitches//12)
            nn.Flatten(),
            nn.Linear(in_features, hid_features),
            nn.LeakyReLU(0.3, inplace=True),
            # output shape: (batch_size, hid_features)
            nn.Linear(hid_features, out_features),
            # output shape: (batch_size, out_features)
        )

    def forward(self, x: Tensor) -> Tensor:
        """Perform forward.

        Parameters
        ----------
        x: Tensor
            Input batch.

        Returns
        -------
        Tensor:
            Preprocessed input batch.

        """
        fx = self.net(x)
        return fx


In [21]:
import logging
import os
import torch
import torchvision.models
from torch import nn
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
class MuseGan(nn.Module):
    def __init__(self, experiment_dir="base_muse_gan",
                 reset=False, load_best=True):
        super(MuseGan, self).__init__()
        self.experiment_dir = experiment_dir
        self.model_name = os.path.basename(self.experiment_dir)
        self.reset = reset
        self.load_best = load_best
        self.setup_dirs()
        self.setup_network()

        if not reset: self.load_state()

    ##1. Defining network architecture
    def setup_network(self):
        """
        Initialize the network  architecture here
        @return:
        """

        #1: load model config
        config_file=os.path.join(self.experiment_dir,"config.json")
        assert os.path.exists(config_file),f"No config.json found in {self.experiment_dir}"
        model_config=read_json(config_file)
        self.z_dimension =model_config["z_dimension"]
        self.g_channels =model_config["g_channels"]
        self.g_features =model_config["g_features"]
        self.c_channels =model_config["c_channels"]
        self.c_features = model_config["c_features"]




        self.generator = MuseGenerator(
            z_dimension=self.z_dimension,
            hid_channels=self.g_channels,
            hid_features=self.g_features,
            out_channels=1,
        ).to(device)
        self.generator=MuseGenerator().apply(initialize_weights)


        self.critic=MuseCritic(
            hid_channels=self.c_channels,
            hid_features=self.c_features,
            out_features=1,

        )
        self.critic.apply(initialize_weights)


    ##2. Model Saving/Loading
    def load_state(self, best=False):
        """
        Load model
        :param self:
        :return:
        """
        if best and os.path.exists(self.save_best_file):
            logging.info(f"Loading best model state : {self.save_file}")
            self.load_state_dict(torch.load(self.save_file, map_location=DEVICE))
            return

        if os.path.exists(self.save_file):
            logging.info(f"Loading model state : {self.save_file}")
            self.load_state_dict(torch.load(self.save_file, map_location=DEVICE))

    def save_state(self, best=False):
        if best:
            logging.info("Saving best model")
            torch.save(self.state_dict(), self.save_best_file)
        torch.save(self.state_dict(), self.save_file)

    ##3. Setupping directories for weights /logs ... etc
    def setup_dirs(self):
        """
        Checking and creating directories for weights storage
        @return:
        """
        self.save_file = os.path.join(self.experiment_dir, f"{self.model_name}.pt")
        self.save_best_file = os.path.join(self.experiment_dir, f"{self.model_name}_best.pt")
        if not os.path.exists(self.experiment_dir):
            os.makedirs(self.experiment_dir)

  device: torch.device = torch.device("cpu"),


# Dataset

In [22]:
"""Midi dataset."""
from typing import Tuple
from torch import Tensor
import torch
from torch import nn
from torch.utils.data import Dataset
import numpy as np
from music21 import midi
from music21 import converter
from music21 import note, stream, duration, tempo
class LPDDataset(Dataset):
    """LPDDataset.

    Parameters
    ----------
    path: str
        Path to dataset.
    """

    def __init__(
        self,
        path: str,
    ) -> None:
        """Initialize."""
        dataset = np.load(path, allow_pickle=True, encoding="bytes")
        self.data_binary = dataset["arr_0"]

    def __len__(self) -> int:
        """Return the number of samples in dataset."""
        return len(self.data_binary)

    def __getitem__(self, index: int) -> Tensor:
        """Return one samples from dataset.

        Parameters
        ----------
        index: int
            Index of sample.

        Returns
        -------
        Tensor:
            Sample.

        """
        return torch.from_numpy(self.data_binary[index]).float()

class MidiDataset(Dataset):
    """MidiDataset.

    Parameters
    ----------
    path: str
        Path to dataset.
    split: str, optional (default="train")
        Split of dataset.
    n_bars: int, optional (default=2)
        Number of bars.
    n_steps_per_bar: int, optional (default=16)
        Number of steps per bar.

    """

    def __init__(
        self,
        path: str,
        # split: str = "train",
        split: str = "train",
        n_bars: int = 2,
        n_steps_per_bar: int = 16,
    ) -> None:
        """Initialize."""
        self.n_bars = n_bars
        self.n_steps_per_bar = n_steps_per_bar
        dataset = np.load(path, allow_pickle=True, encoding="bytes")[split]
        self.data_binary, self.data_ints, self.data = self.__preprocess__(dataset)

    def __len__(self) -> int:
        """Return the number of samples in dataset."""
        return len(self.data_binary)

    def __getitem__(self, index: int) -> Tensor:
        """Return one samples from dataset.

        Parameters
        ----------
        index: int
            Index of sample.

        Returns
        -------
        Tensor:
            Sample.

        """
        return torch.from_numpy(self.data_binary[index]).float()

    def __preprocess__(self, data: np.ndarray) -> Tuple[np.ndarray]:
        """Preprocess data.

        Parameters
        ----------
        data: np.ndarray
            Data.

        Returns
        -------
        Tuple[np.ndarray]:
            Data binary, data ints, preprocessed data.

        """
        data_ints = []
        for x in data:
            skip = True
            skip_rows = 0
            while skip:
                if not np.any(np.isnan(x[skip_rows: skip_rows + 4])):
                    skip = False
                else:
                    skip_rows += 4
            if self.n_bars * self.n_steps_per_bar < x.shape[0]:
                data_ints.append(x[skip_rows: self.n_bars * self.n_steps_per_bar + skip_rows, :])
        data_ints = np.array(data_ints)
        self.n_songs = data_ints.shape[0]
        self.n_tracks = data_ints.shape[2]
        data_ints = data_ints.reshape([self.n_songs, self.n_bars, self.n_steps_per_bar, self.n_tracks])
        max_note = 83
        mask = np.isnan(data_ints)
        data_ints[mask] = max_note + 1
        max_note = max_note + 1
        data_ints = data_ints.astype(int)
        num_classes = max_note + 1
        data_binary = np.eye(num_classes)[data_ints]
        data_binary[data_binary == 0] = -1
        data_binary = np.delete(data_binary, max_note, -1)
        data_binary = data_binary.transpose([0, 3, 1, 2, 4])
        return data_binary, data_ints, data


def binarise_output(output: np.ndarray) -> np.ndarray:
    """Binarize output.

    Parameters
    ----------
    output: np.ndarray
        Output array.

    """
    max_pitches = np.argmax(output, axis=-1)
    return max_pitches


def postprocess(
    output: np.ndarray,
    n_tracks: int = 4,
    n_bars: int = 2,
    n_steps_per_bar: int = 16,
) -> stream.Score:
    """Postprocess output.

    Parameters
    ----------
    output: np.ndarray
        Output array.
    n_tracks: int, (default=4)
        Number of tracks.
    n_bars: int, (default=2)
        Number of bars.
    n_steps_per_bar: int, (default=16)
        Number of steps per bar.

    """
    parts = stream.Score()
    parts.append(tempo.MetronomeMark(number=66))
    max_pitches = binarise_output(output)
    midi_note_score = np.vstack([
        max_pitches[i].reshape([n_bars * n_steps_per_bar, n_tracks]) for i in range(len(output))
    ])
    for i in range(n_tracks):
        last_x = int(midi_note_score[:, i][0])
        s = stream.Part()
        dur = 0
        for idx, x in enumerate(midi_note_score[:, i]):
            x = int(x)
            if (x != last_x or idx % 4 == 0) and idx > 0:
                n = note.Note(last_x)
                n.duration = duration.Duration(dur)
                s.append(n)
                dur = 0
            last_x = x
            dur = dur + 0.25
        n = note.Note(last_x)
        n.duration = duration.Duration(dur)
        s.append(n)
        parts.append(s)
    return parts


# Loss and metrics

In [23]:

from torch import Tensor

import torch
from torch import nn


class WassersteinLoss(nn.Module):
    """WassersteinLoss."""

    def __init__(self) -> None:
        """Initialize."""
        super().__init__()

    def forward(self, y_pred: Tensor, y_target: Tensor) -> Tensor:
        """Calculate Wasserstein loss.

        Parameters
        ----------
        y_pred: Tensor
            Prediction.
        y_target: Tensor
            Target.

        Returns
        -------
        Tensor:
            Loss value.

        """
        loss = - torch.mean(y_pred * y_target)
        return loss

class GradientPenalty(nn.Module):
    """Gradient penalty."""

    def __init__(self) -> None:
        """Initialize."""
        super().__init__()

    def forward(self, inputs: Tensor, outputs: Tensor) -> Tensor:
        """Calculate gradient penalty.

        Parameters
        ----------
        inputs: Tensor
            Input from which to track gradient.
        outputs: Tensor
            Output to which to track gradient.

        Returns
        -------
        Tensor:
            Penalty value.

        """
        grad = torch.autograd.grad(
            inputs=inputs,
            outputs=outputs,
            grad_outputs=torch.ones_like(outputs),
            create_graph=True,
            retain_graph=True,
        )[0]
        grad_ = torch.norm(grad.view(grad.size(0), -1), p=2, dim=1)
        penalty = torch.mean((1. - grad_) ** 2)
        return penalty


# Trainer

In [24]:
import csv
import json
import logging
import os
import shutil

import numpy as np
import pandas as pd
import torch
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
class Trainer:
    """
    Class to manage the full training pipeline
    """
    def __init__(self, network:MuseGan,
                 g_optimizer,
                 c_optimizer,
                 nb_epochs=10,
                 repeat=5,
                 batch_size=128,
                 reset=False):
        """
        @param network:
        @param dataset_name:
        @param images_dirs:
        @param loss:
        @param optimizer:
        @param nb_epochs:
        @param nb_workers: Number of worker for the dataloader
        """
        self.network = network
        self.batch_size = batch_size
        self.repeat=repeat
        self.g_optimizer=g_optimizer
        self.c_optimizer = c_optimizer

        self.g_criterion = WassersteinLoss().to(DEVICE)
        self.c_criterion = WassersteinLoss().to(DEVICE)


        self.c_penalty=GradientPenalty().to(DEVICE)

        self.nb_epochs = nb_epochs
        self.experiment_dir = self.network.experiment_dir
        self.model_info_file = os.path.join(self.experiment_dir, "model.json")
        self.model_info_best_file = os.path.join(self.experiment_dir, "model_best.json")

        if reset:
            if os.path.exists(self.experiment_dir):
                shutil.rmtree(self.experiment_dir)
        if not os.path.exists(self.experiment_dir):
            os.makedirs(self.experiment_dir)

        self.start_epoch = 0
        if not reset and os.path.exists(self.model_info_file):
            with open(self.model_info_file, "r") as f:
                self.start_epoch = json.load(f)["epoch"] + 1
                self.nb_epochs += self.start_epoch
                logging.info("Resuming from epoch {}".format(self.start_epoch))


    def save_model_info(self, infos, best=False):
        json.dump(infos, open(self.model_info_file, 'w'),indent=4)
        if best: json.dump(infos, open(self.model_info_best_file, 'w'),indent=4)

    def fit(self,train_dataloader):
        logging.info("Launch training on {}".format(DEVICE))
        self.network.train()
        self.network.to(DEVICE)
        self.summary_writer = SummaryWriter(log_dir=self.experiment_dir)
        itr = self.start_epoch * len(train_dataloader) * self.batch_size  ##Global counter for steps
        if os.path.exists(self.model_info_file):
            with open(self.model_info_file, "r") as f:
                model_info = json.load(f)
                lr=model_info["lr"]
                logging.info(f"Setting lr to {lr}")
                for g in self.optimizer.param_groups:
                    g['lr'] = lr
        if os.path.exists(self.model_info_best_file):
            with open(self.model_info_best_file, "r") as f:
                best_model_info = json.load(f)
                best_loss = best_model_info["val_loss"]

        self.alpha = torch.rand((self.batch_size, 1, 1, 1, 1)).requires_grad_().to(DEVICE)
        for epoch in range(self.start_epoch, self.nb_epochs):  # Training loop
            epoch_gloss = Averager()
            epoch_cfloss = Averager()
            epoch_crloss = Averager()
            epoch_cploss = Averager()
            epoch_closs = Averager()
            pbar = tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{self.nb_epochs}")
            for _, real in enumerate(pbar):
                """
                Training lopp
                """

                #Train the critic
                real=real.to(DEVICE)
                batch_closs = Averager()
                batch_cfloss = Averager()
                batch_crloss = Averager()
                batch_cploss = Averager()
                for _ in range(self.repeat):
                    cords = torch.randn(self.batch_size, 32).to(DEVICE)
                    style = torch.randn(self.batch_size, 32).to(DEVICE)
                    melody = torch.randn(self.batch_size, 4, 32).to(DEVICE)
                    groove = torch.randn(self.batch_size, 4, 32).to(DEVICE)

                    self.c_optimizer.zero_grad()
                    with torch.no_grad():
                        fake = self.network.generator(cords, style, melody, groove).detach()
                    realfake = self.alpha * real + (1. - self.alpha) * fake

                    fake_pred = self.network.critic(fake)
                    real_pred = self.network.critic(real)
                    realfake_pred = self.network.critic(realfake)
                    fake_loss = self.c_criterion(fake_pred, - torch.ones_like(fake_pred))
                    real_loss = self.c_criterion(real_pred, torch.ones_like(real_pred))
                    penalty = self.c_penalty(realfake, realfake_pred)
                    closs = fake_loss + real_loss + 10 * penalty
                    closs.backward(retain_graph=True)
                    self.c_optimizer.step()
                    batch_cfloss.send(fake_loss.item())
                    batch_crloss.send(real_loss.item())
                    batch_cploss.send(10 * penalty.item())
                    batch_closs.send(closs.item() / self.repeat)


                # Train Generator
                self.g_optimizer.zero_grad()
                cords = torch.randn(self.batch_size, 32).to(DEVICE)
                style = torch.randn(self.batch_size, 32).to(DEVICE)
                melody = torch.randn(self.batch_size, 4, 32).to(DEVICE)
                groove = torch.randn(self.batch_size, 4, 32).to(DEVICE)

                fake = self.network.generator(cords, style, melody, groove)
                fake_pred = self.network.critic(fake)
                b_gloss = self.g_criterion(fake_pred, torch.ones_like(fake_pred))
                b_gloss.backward()
                self.g_optimizer.step()

                """
                4.Writing logs and tensorboard data, loss and other metrics
                """
                batch_data={
                "generator_loss":b_gloss.item(),
                "critic_loss":batch_closs.value,
                "critic_fake_loss":batch_cfloss.value,
                "critic_real_loss":batch_crloss.value,
                "critic_penalized_loss":batch_cploss.value
                    }

                for k,v in batch_data.items():
                    self.summary_writer.add_scalar(f"Train steps/{k}", v, itr)

                epoch_gloss.send(b_gloss.item())
                epoch_cfloss.send(batch_cfloss.value)
                epoch_crloss.send(batch_crloss.value)
                epoch_cploss.send(batch_cploss.value)
                epoch_closs.send(batch_closs.value)


            epoch_data={
                "generator_loss":epoch_gloss.value,
                "critic_loss":epoch_closs.value,
                "critic_fake_loss":epoch_cfloss.value,
                "critic_real_loss":epoch_crloss.value,
                "critic_penalized_loss":epoch_cploss.value
            }
            for k,v in epoch_data.items():
                self.summary_writer.add_scalar(f"Train epochs/{k}",v,epoch)
            logging.info(f"Epoch {epoch}/{self.nb_epochs} | Generator loss: {epoch_gloss.value:.3f} | Critic loss: {epoch_closs.value:.3f}")
            # logging.info(f"(fake: {epoch_cfloss.value:.3f}, real: {epoch_crloss.value:.3f}, penalty: {epoch_cploss.value:.3f})")

            #TODO write epoch metrics results
            infos = epoch_data
            infos["epoch"]=epoch
            self.network.save_state()
            self.save_model_info(infos)







# Runner

In [25]:
from collections import namedtuple
import argparse
import json
import logging
import os
import torch.utils.data
from torch.optim import Adam

def main(args):
    model_name = "base_model" if args.model_name is None else args.model_name
    experiment_dir = os.path.join(EXPERIMENTS_DIR, model_name)
    config_file=os.path.join(experiment_dir,"config.json")

    if not os.path.exists(experiment_dir):os.makedirs(experiment_dir)
    if  not os.path.exists(config_file) or args.reset:
        model_config={
        "z_dimension":args.z_dimension,
        "g_channels": args.g_channels,
        "g_features" : args.g_features,
        "c_channels" : args.c_channels,
        "c_features" : args.c_features,
        }
        with open(config_file,"w") as f:json.dump(model_config,f,indent=4)



    network=MuseGan(experiment_dir=experiment_dir,
                    reset=args.reset,

                    )


    g_optimizer = Adam(network.generator.parameters(),lr=args.g_lr,betas=(0.5, 0.9))
    c_optimizer = Adam(network.critic.parameters(),lr=args.c_lr,betas=(0.5, 0.9))

    logging.info("Training : "+model_name)
    trainer = Trainer(network,
                      g_optimizer,
                      c_optimizer,
                      nb_epochs= args.nb_epochs,
                      batch_size=args.batch_size,
                      reset=args.reset,
                      )

    train_dataset=MidiDataset(DATASET_FILE, split="nonzero")
    train_dataloader=torch.utils.data.DataLoader(train_dataset,batch_size=args.batch_size,num_workers=args.num_workers,shuffle=True,drop_last=True)
    trainer.fit(train_dataloader)


In [27]:
args={
        "reset":False,
        "learning_rate":0.001,
        "nb_epochs":20,
        "model_name":"base_muse_gan",
        "num_workers":4,
        "batch_size":128,
        "z_dimension":32,
        "g_channels":1024,
        "g_features":1024,
        "g_lr":0.001,
        "c_channels":128,
        "c_features":1024,
        "c_lr":0.001,
        "log_level":"INFO"
        }
args=namedtuple("args",args.keys())(*args.values())

main(args)


NameError: name 'EXPERIMENTS_DIR' is not defined