In [None]:
import torch
from torch import nn

import numpy as np

import matplotlib.pyplot as plt
import torch.nn.functional as F
import FrEIA.framework as Ff
import FrEIA.modules as Fm
import torch.distributions

from data.circleData import make_circles_ssl
from data.moonData import make_moons_ssl
from data.multiGaussian import make_multigaussian_ssl


from torch.utils.data import DataLoader
import matplotlib.pyplot as plt


# Data

In [None]:
n_samples = 2000
# x, y, true_labels = make_circles_ssl(n_samples=n_samples, label_ratio=0.05)
data1, _, label1 = make_moons_ssl(n_samples=n_samples, label_ratio=0.05, seed=0)
# x, y, true_labels = make_squaredata_ssl(n_samples=n_samples, label_ratio=0.05)
print(len(data1))
plt.scatter(data1[label1 == 0][:,0], data1[label1 == 0][:,1], c='r', s=3)
plt.scatter(data1[label1 == 1][:,0], data1[label1 == 1][:,1], c='b', s=3)
plt.show()

In [None]:
n_samples = 2000
data2, _, label2 = make_circles_ssl(n_samples=n_samples, label_ratio=0.05)
# data2 = data2[label2 == 0]
# x2, _, y2 = make_moons_ssl(n_samples=n_samples, label_ratio=0.05, seed=10)
# x, y, true_labels = make_squaredata_ssl(n_samples=n_samples, label_ratio=0.05)
print(len(data2))
plt.scatter(data2[label2 == 0][:,0], data2[label2 == 0][:,1], c='r', s=3)
plt.scatter(data2[label2 == 1][:,0], data2[label2 == 1][:,1], c='b', s=3)
plt.show()

In [None]:
from scipy.linalg import sqrtm
np.random.seed(2023)
mean1 = np.array([-8., 1.])
Sigma1 = np.mat(np.array([[0.5, -0.25], [-0.25, 0.5]]))
data7 = np.random.multivariate_normal(mean=mean1,cov=Sigma1,size=2000).astype(np.float32)

mean2 = np.array([0., 0.])
Sigma2 = np.mat(np.array([[1, 0], [0, 1]]))
data8 = np.random.multivariate_normal(mean=mean2,cov=Sigma2,size=2000).astype(np.float32)

plt.scatter(data7[:,0], data7[:,1], s=3, c='red')
plt.scatter(data8[:,0], data8[:,1], s=3, c='blue')
plt.show()

In [None]:
n_gauss = 8
n_samples = 2000 // n_gauss
data9 = make_multigaussian_ssl(n_samples=n_samples, n_gauss=n_gauss, radius=8, var=.5)
# x2, _, y2 = make_moons_ssl(n_samples=n_samples, label_ratio=0.05, seed=10)
# x, y, true_labels = make_squaredata_ssl(n_samples=n_samples, label_ratio=0.05)

plt.scatter(data9[:,0], data9[:,1], c='red', s=3)
# plt.scatter(data1[:,0], data1[:,1], c='blue', s=3)
plt.show()

In [None]:
n_gauss = 8
n_samples = 2000 // n_gauss
data9 = make_multigaussian_ssl(n_samples=n_samples, n_gauss=n_gauss, radius=8, var=.5)
# x2, _, y2 = make_moons_ssl(n_samples=n_samples, label_ratio=0.05, seed=10)
# x, y, true_labels = make_squaredata_ssl(n_samples=n_samples, label_ratio=0.05)

plt.scatter(data9[:,0], data9[:,1], c='red', s=3)
# plt.scatter(data1[:,0], data1[:,1], c='blue', s=3)
plt.show()

In [None]:
n_gauss = 8
n_samples = 2000 // n_gauss
data10 = make_multigaussian_ssl(n_samples=n_samples, n_gauss=n_gauss, start_angle=0, radius=4, var=.5) # radius=2
# x2, _, y2 = make_moons_ssl(n_samples=n_samples, label_ratio=0.05, seed=10)
# x, y, true_labels = make_squaredata_ssl(n_samples=n_samples, label_ratio=0.05)

plt.scatter(data10[:,0], data10[:,1], c='red', s=3)
# plt.scatter(data1[:,0], data1[:,1], c='blue', s=3)
plt.show()

In [None]:
mean = [(t - 10, 2 * t + 1) for t in range(1,11,1)]
data13 = np.zeros((2000, 2))
for i in range(10):
    data13[i * 200: (i + 1) * 200,:] = np.random.randn(200, 2) * 0.65 + mean[i]
data13 = data13.astype(np.float32)
plt.scatter(data13[:,0], data13[:,1], s=3)
plt.show()

In [None]:
mean = [(t, 0) for t in range(1,11,1)]
data14 = np.zeros((2000, 2))
for i in range(10):
    data14[i * 200: (i + 1) * 200,:] = np.random.randn(200, 2) * 0.2 + mean[i]
data14 = data14.astype(np.float32)
plt.scatter(data14[:,0], data14[:,1], s=3)
plt.show()

In [None]:
class ToyDateset(torch.utils.data.Dataset):
    def __init__(self, source_data, target_data):
        self.x1 = source_data
        self.x2 = target_data
    def __getitem__(self, index):
        x1_sample = self.x1[index]
        x2_sample = self.x2[index]
        return torch.from_numpy(x1_sample), torch.from_numpy(x2_sample)
    def __len__(self):
        return len(self.x1)

In [None]:
source_data = data1
target_data = data2
plt.scatter(source_data[:,0], source_data[:,1], s=3, c='blue')
plt.scatter(target_data[:,0], target_data[:,1], s=3, c='red')
plt.show()

# MMD

In [None]:
def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
    n_samples = int(source.size()[0])+int(target.size()[0])
    total = torch.cat([source, target], dim=0)
    total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
    total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
    L2_distance = ((total0-total1)**2).sum(2)
    if fix_sigma:
        bandwidth = fix_sigma
    else:
        bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)
    bandwidth /= kernel_mul ** (kernel_num // 2)
    bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]
    kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
    return sum(kernel_val)

def mmd(source, target, kernel_type='gaussian_kernel', kernel_mul=2.0, kernel_num=5, fix_sigma=None):
    assert kernel_type in ['gaussian_kernel', 'ed_kernel']

    batch_size = int(source.size()[0])
    kernels = guassian_kernel(source, target,
                            kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)
    XX = kernels[:batch_size, :batch_size]
    YY = kernels[batch_size:, batch_size:]
    XY = kernels[:batch_size, batch_size:]
    YX = kernels[batch_size:, :batch_size]
    loss = torch.mean(XX) + torch.mean(YY) - torch.mean(XY) -torch.mean(YX)
    return loss


# Training

In [None]:
device = 'cuda:5'
dataset = ToyDateset(source_data, target_data)
train_dataloader = DataLoader(dataset, batch_size=200, shuffle=True, drop_last=True)

n_dim = (2, )

def subnet_fc(dims_in, dims_out):
    return nn.Sequential(nn.Linear(dims_in, 256), nn.ReLU(),
                         nn.Linear(256,  dims_out))

n_epoch = 2000

flow = Ff.SequenceINN(*n_dim)
for _ in range(8):
    flow.append(
        Fm.AllInOneBlock, 
        subnet_constructor=subnet_fc, 
        affine_clamping=2., 
        permute_soft=True
    )


flow = flow.to(device)
optimizer = torch.optim.Adam(flow.parameters(), lr=3e-3, weight_decay=1e-5) # 3e-3 1e-4


for t in range(1, n_epoch + 1):
    loss_record = 0. 
    pot_record = 0.

    cnt = 0
    for x1, x2 in train_dataloader:
        x1, x2 = x1.to(device), x2.to(device)
        x2_fake, jac = flow(x1)
        x1_fake, _ = flow(x2, rev=True)

        loss = mmd(x2_fake, x2, kernel_mul=2, kernel_num=10) + mmd(x1_fake, x1, kernel_mul=2, kernel_num=10) 
        penalty_ot = 0.5 * torch.mean((x1 - x2_fake) ** 2) + 0.5 * torch.mean((x2 - x1_fake) ** 2)

        loss += 0.15 * penalty_ot 

        optimizer.zero_grad()
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(parameters=flow.parameters(), max_norm=10, norm_type=2)
        optimizer.step()
        
        loss_record += loss.item()
        pot_record += penalty_ot.item()

        cnt += 1

        loss_avg = loss_record / cnt
        pot_avg = pot_record / cnt

    if t % 100 == 0:
        print('iter: {}, loss: {:.5f}, p_ot: {:.5f}'.format(t, loss_avg, pot_avg))



In [None]:
# flow.load_state_dict(torch.load('mmd_results_m2c/20230502-122106/net_params.pth'))
# flow.load_state_dict(torch.load('params_new/mmd/gauss2gauss_o_2.pth'))
# flow.load_state_dict(torch.load('params_new/mmd/8gauss28gauss_o.pth'))
# flow.load_state_dict(torch.load('params_new/mmd/gauss_classfication_o_1.pth'))

source_data = data13
target_data = data14
with torch.no_grad():
    x1_samples_fake, _ = flow(torch.from_numpy(source_data).to(device))
    x2_samples_fake, _ = flow(torch.from_numpy(target_data).to(device), rev=True)


x1_samples_fake = x1_samples_fake.cpu()
x2_samples_fake = x2_samples_fake.cpu()


In [None]:
plt.figure(figsize=(20,10))
plt.subplot(1,2,1)
plt.scatter(source_data[:,0], source_data[:,1], s=3, label='original')
plt.scatter(x1_samples_fake.numpy()[:,0], x1_samples_fake.numpy()[:,1], s=3, label='generated')

plt.legend()
for i in range(0,2000,5):
    plt.plot([source_data[i,0], x1_samples_fake.numpy()[i,0]], [source_data[i,1], x1_samples_fake.numpy()[i,1]], c='green', linewidth=0.3)
# plt.xlim([-2,2])
# plt.ylim([-2,2])
plt.subplot(1,2,2)
plt.scatter(target_data[:,0], target_data[:,1], s=3, label='original')
plt.scatter(x2_samples_fake.numpy()[:,0], x2_samples_fake.numpy()[:,1], s=3, label='generated')
plt.legend()
# plt.xlim([-2,2])
# plt.ylim([-2,2])
for i in range(0,2000,5):
    plt.plot([target_data[i,0], x2_samples_fake.numpy()[i,0]], [target_data[i,1], x2_samples_fake.numpy()[i,1]], c='green', linewidth=0.3)
plt.show()