In [None]:
import torch
import torch.nn as nn
from collections import defaultdict
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import matplotlib.animation
from scipy.spatial import Voronoi, voronoi_plot_2d
from IPython.display import HTML

In [None]:
class MCL_MSE(nn.Module):

    def __init__(self, return_assigmnent=True, epsilon=1e-9):
        super(MCL_MSE, self).__init__()
        self.return_assignment = return_assigmnent
        self.epsilon = self.EPS = epsilon

    def forward(self, prediction_list, score_list, target_list, epoch=None):

        # extract shape
        n_prediction, n_target = prediction_list.shape[1], target_list.shape[1]

        # compute pairwise distance
        pairwise_distance = -torch.square(prediction_list.unsqueeze(2)-target_list.unsqueeze(1)).sum(dim=-1)

        # assign source to closest prediction    
        mcl_sisdr, target_assignment = pairwise_distance.max(dim=1) # [batch_size, n_target], [batch_size, n_target]
        mcl_sisdr = -mcl_sisdr.mean() # []

        # compute prediction -> target assignment
        if self.return_assignment:
            prediction_assignment = torch.stack([torch.nn.functional.one_hot(target_assignment[:,target_index], num_classes=n_prediction).float() for target_index in range(n_target)], dim=-1) # [batch_size, n_prediction, n_target]
            score_loss = self.score_metric(score_list, prediction_assignment.any(dim=-1).float()) # []

        return (mcl_sisdr, score_loss, prediction_assignment) if (self.return_score_loss and self.return_assignment) else  ((mcl_sisdr, score_loss) if self.return_score_loss else mcl_sisdr)

class Annealed_MCL_MSE(nn.Module):
    def __init__(self, temperature_schedule, return_score_loss=True, return_assignment=False, sample_normalization=True, epsilon=1e-9, min_temperature=1e-4):
        super(Annealed_MCL_MSE, self).__init__()
        self.score_metric = torch.nn.BCELoss()
        self.return_score_loss = return_score_loss
        self.return_assignment = return_assignment
        self.sample_normalization = sample_normalization
        self.epsilon = epsilon
        self.temperature_schedule = temperature_schedule
        self.min_temperature = min_temperature

    def forward(self, prediction_list, score_list, target_list, epoch=None):
        # extract shape
        n_prediction, n_target = prediction_list.shape[1], target_list.shape[1]

        # compute pairwise distance
        pairwise_distance = -torch.square(prediction_list.unsqueeze(1)-target_list.unsqueeze(2)).sum(dim=-1) # [batch_size, n_target, n_prediction]

        # soft assignation of source to closest prediction (& hard assignment for scoring purposes)
        temperature = self.temperature_schedule(epoch)
        amcl_sisdr = (torch.softmax(pairwise_distance / temperature, dim=2).detach() * pairwise_distance).sum(dim=2)  if temperature > self.min_temperature else pairwise_distance.max(dim=2)[0] # [batch_size, n_target]  
        target_assignment = pairwise_distance.max(dim=2)[1] # [batch_size, n_target]

        # mask inactive target (normalize per sample)
        target_mask = (target_list.abs().sum(dim=-1) > 0.).squeeze(-1) # [batch_size, n_target]
        amcl_sisdr = ((amcl_sisdr * target_mask).sum(dim=-1) / target_mask.sum(dim=-1)).mean() if self.sample_normalization else amcl_sisdr[target_mask].mean() # []
        amcl_sisdr = - amcl_sisdr

        # compute prediction -> target assignment
        if True:
            prediction_assignment = torch.stack([torch.nn.functional.one_hot(target_assignment[:,target_index], num_classes=n_prediction).float() for target_index in range(n_target)], dim=-1) # [batch_size, n_prediction, n_target]
            score_loss = self.score_metric(score_list, prediction_assignment.any(dim=-1).float()) # []
        return (amcl_sisdr, score_loss, prediction_assignment) if (self.return_score_loss and self.return_assignment) else  ((amcl_sisdr, score_loss) if self.return_score_loss else amcl_sisdr)


### Vanilla MCL (annealed MCL with no temperature scheduler)
We fit the hypotheses to a synthetic dataset with 3 gaussian modes. This first example presents the collapse issue with vanilla MCL: most of the hypotheses converge to the barycenter of the data.

In [None]:
# experiment parameters

# data parameters
input_dim = 2
# training set generated at he beginning 
len_training = 1
n_epoch = 20
device = "cpu"

# training set
loc = torch.Tensor([-2,2])
scale = torch.ones(2)*0.1
mvn1 = torch.distributions.MultivariateNormal(loc=loc,scale_tril=torch.diag(scale))

loc = torch.Tensor([0,-2])
scale = torch.ones(2)*0.1
mvn2 = torch.distributions.MultivariateNormal(loc=loc,scale_tril=torch.diag(scale))

loc = torch.Tensor([2,2])
scale = torch.ones(2)*0.1
mvn3 = torch.distributions.MultivariateNormal(loc=loc,scale_tril=torch.diag(scale))

mvn_set = [mvn1,mvn2,mvn3]

batch_size = 1000
# batch with 3 2-D gaussians with different mean and covariance
target_list = torch.cat([mvn_set[i].sample((batch_size,)) for i in range(3)],dim=0).unsqueeze(1).to(device)


# model parameters
n_model = 10 #number of hypotheses
hidden_dim=128

# temperature schedule set to 0
temperature_schedule = lambda epoch : 0
training_metric = Annealed_MCL_MSE(temperature_schedule=temperature_schedule)
# training_metric = MCL_MSE(return_assigmnent=False)

# model
predictor_list = [torch.nn.Sequential(
    torch.nn.Linear(input_dim, hidden_dim),
    torch.nn.Sigmoid(),
    torch.nn.Linear(hidden_dim, input_dim),
).to(device) for _ in range(n_model)]
score_model_list = [torch.nn.Sequential(
    torch.nn.Linear(input_dim, 1),
    torch.nn.Sigmoid(),
    torch.nn.Flatten(start_dim=0),
).to(device) for _ in range(n_model)]

# optimizer
parameter_list = []
for predictor, score_model in zip(predictor_list, score_model_list):
    parameter_list += list(predictor.parameters())
    parameter_list += list(score_model.parameters())
optimizer = torch.optim.Adam(parameter_list, 1e-3, betas=(0.9, 0.999))



log_dict = defaultdict(list)
for epoch in tqdm(range(n_epoch)):
    for _ in range(len_training):
        #target_list = torch.normal(mean, variance, (batch_size, n_src, input_dim)).to(device)
        x = torch.zeros((target_list.shape[0], input_dim)).to(device)
        prediction_list = torch.stack([predictor(x) for predictor in predictor_list], dim=1)
        score_list = torch.stack([score_model(x) for score_model in score_model_list], dim=1)
        prediction_loss, score_loss = training_metric(prediction_list, score_list, target_list, epoch=epoch)
        # prediction_loss, score_loss = training_metric(prediction_list, target_list, epoch=epoch), torch.tensor(0)
        loss = prediction_loss + 0 * score_loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        log_dict[f"loss/{epoch}"].append(loss.item())
        log_dict[f"prediction_loss/{epoch}"].append(prediction_loss.item())
        log_dict[f"score_loss/{epoch}"].append(score_loss.item())
    log_dict[f"prediction_list/{epoch}"].append(prediction_list[0, :, :].detach().cpu())
    log_dict[f"target_list/{epoch}"].append(target_list[:, 0, :].detach().cpu())
    log_dict[f"temperature/{epoch}"].append(temperature_schedule(epoch))

In [None]:
# plot Voronoi
output_data = np.stack([log_dict[f"target_list/{epoch}"][0] for epoch in range(n_epoch)], axis=0)
prediction_list = np.stack([log_dict[f"prediction_list/{epoch}"][0] for epoch in range(n_epoch)], axis=0)
x_min = y_min = np.min(output_data)
x_max = y_max = np.max(output_data)

fig, ax = plt.subplots()
def animate(i):
    ax.cla()
    ax.scatter(output_data[i, :, 0], output_data[i, :, 1], label="data", s=50, c="red")
    voronoi = Voronoi(prediction_list[i, :, :])
    voronoi_plot_2d(voronoi, show_vertices=False, ax=ax, point_size=10)
    ax.set_xlim([x_min-0.1, x_max+0.1])
    ax.set_ylim([y_min-0.1, y_max+0.1])
animation = matplotlib.animation.FuncAnimation(fig, animate, frames=n_epoch)
animation.save(f'./mcl.gif', writer='imagemagick', fps=10)
display(HTML(animation.to_jshtml()));

### Annealed MCL 
In this example, a temperature scheduler is added (Annealed MCL). The example shows that the hypotheses converge to each cluster with no collapse issue.

In [None]:
# experiment parameters

# data parameters
input_dim = 2
len_training = 1
n_src = 1
device = "cpu"
# training set
loc = torch.Tensor([-2,2])
scale = torch.ones(2)*0.1
mvn1 = torch.distributions.MultivariateNormal(loc=loc,scale_tril=torch.diag(scale))

loc = torch.Tensor([0,-2])
scale = torch.ones(2)*0.1
mvn2 = torch.distributions.MultivariateNormal(loc=loc,scale_tril=torch.diag(scale))

loc = torch.Tensor([2,2])
scale = torch.ones(2)*0.1
mvn3 = torch.distributions.MultivariateNormal(loc=loc,scale_tril=torch.diag(scale))

mvn_set = [mvn1,mvn2,mvn3]

batch_size = 1000
# batch with 3 2-D gaussians with different mean and covariance
target_list = torch.cat([mvn_set[i].sample((batch_size,)) for i in range(3)],dim=0).unsqueeze(1).to(device)

# model parameters
n_model = 30 #number of hypotheses
hidden_dim=128

# training parameters
batch_size = 2000
n_epoch = 200

temperature_start, max_epoch, threshold = 3, 80, 5.0 # 0.025, 2*variance-1e-2
temperature_schedule = lambda epoch : max(threshold, (temperature_start * (max_epoch - epoch) / max_epoch))
#temperature_schedule = lambda epoch : 0
training_metric = Annealed_MCL_MSE(temperature_schedule=temperature_schedule)
# training_metric = MCL_MSE(return_assigmnent=False)

# model
predictor_list = [torch.nn.Sequential(
    torch.nn.Linear(input_dim, hidden_dim),
    torch.nn.Sigmoid(),
    torch.nn.Linear(hidden_dim, input_dim),
).to(device) for _ in range(n_model)]
score_model_list = [torch.nn.Sequential(
    torch.nn.Linear(input_dim, 1),
    torch.nn.Sigmoid(),
    torch.nn.Flatten(start_dim=0),
).to(device) for _ in range(n_model)]

# optimizer
parameter_list = []
for predictor, score_model in zip(predictor_list, score_model_list):
    parameter_list += list(predictor.parameters())
    parameter_list += list(score_model.parameters())
optimizer = torch.optim.Adam(parameter_list, 1e-3, betas=(0.9, 0.999))

# predefine data
target_list = torch.cat([mvn_set[i].sample((batch_size,)) for i in range(3)],dim=0).unsqueeze(1).to(device)

log_dict = defaultdict(list)
for epoch in tqdm(range(n_epoch)):
    for _ in range(len_training):
        #target_list = torch.normal(mean, variance, (batch_size, n_src, input_dim)).to(device)
        x = torch.zeros((target_list.shape[0], input_dim)).to(device)
        prediction_list = torch.stack([predictor(x) for predictor in predictor_list], dim=1)
        score_list = torch.stack([score_model(x) for score_model in score_model_list], dim=1)
        prediction_loss, score_loss = training_metric(prediction_list, score_list, target_list, epoch=epoch)
        # prediction_loss, score_loss = training_metric(prediction_list, target_list, epoch=epoch), torch.tensor(0)
        loss = prediction_loss + 0 * score_loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        log_dict[f"loss/{epoch}"].append(loss.item())
        log_dict[f"prediction_loss/{epoch}"].append(prediction_loss.item())
        log_dict[f"score_loss/{epoch}"].append(score_loss.item())
    log_dict[f"prediction_list/{epoch}"].append(prediction_list[0, :, :].detach().cpu())
    log_dict[f"target_list/{epoch}"].append(target_list[:, 0, :].detach().cpu())
    log_dict[f"temperature/{epoch}"].append(temperature_schedule(epoch))

In [None]:
# plot Voronoi
output_data = np.stack([log_dict[f"target_list/{epoch}"][0] for epoch in range(n_epoch)], axis=0)
prediction_list = np.stack([log_dict[f"prediction_list/{epoch}"][0] for epoch in range(n_epoch)], axis=0)
x_min = y_min = np.min(output_data)
x_max = y_max = np.max(output_data)

fig, ax = plt.subplots()
def animate(i):
    ax.cla()
    ax.scatter(output_data[i, :, 0], output_data[i, :, 1], label="data", s=50, c="red")
    voronoi = Voronoi(prediction_list[i, :, :])
    voronoi_plot_2d(voronoi, show_vertices=False, ax=ax, point_size=10)
    ax.set_xlim([x_min-0.1, x_max+0.1])
    ax.set_ylim([y_min-0.1, y_max+0.1])
animation = matplotlib.animation.FuncAnimation(fig, animate, frames=n_epoch)
animation.save(f'./amcl.gif', writer='imagemagick', fps=10)
display(HTML(animation.to_jshtml()));