In [1]:
import sys
from tqdm import tqdm

import torch
from torch.utils.data import DataLoader

import numpy as np
import pandas as pd 

import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px

from src.model.flow_matching import FlowMatching
from src.model.base_models.mlp import MLP

from examples.utils import *

In [None]:
import torch
import torch.nn as nn
from torchdiffeq import odeint_adjoint as odeint


class FlowMatching(nn.Module):
    def __init__(self, model, obs_dim=(2,), sigma_min=1e-6, n_samples=10):
        super().__init__()
        self.model = model
        self.sigma_min = sigma_min
        self.obs_dim = obs_dim
        self.n_samples = n_samples
    
    def process_timesteps(self, t, x):
        if len(t.shape) == 0:
            t = t.repeat(x.shape[0])
        if t.shape[0]!=x.shape[0] or len(t.shape)!=1:
            raise ValueError("Timesteps shape should (batch_size, )")
        return t

    def forward(self, t, x0, x1):
        """ 
        Computes velocity v from the equation dphi(t, x) = v(t, phi(t, x))dt. 
        """
        t = self.process_timesteps(t, x0)
        x = self.conditional_flow(t, x0, x1)
        return self.model(t, x)

    def velocity(self, t, x0):
        t = self.process_timesteps(t, x0)
        return self.model(t, x0)
    
    def reversed_velocity_with_div(self, t, state):
        s = 1-t
        x, logp = state
        x_ = x.detach().clone().requires_grad_(True)
        div_estimates = []
        with torch.set_grad_enabled(True):
            for i in range(self.n_samples):
                v = self.model(s, x_)
                is_last = (i == self.n_samples - 1)
                div_estimates.append(
                    self.approx_div(v, x_, retain_graph=not is_last)
                )
        
        mean_div = torch.stack(div_estimates).mean(dim=0)
        return ((-v).detach(), mean_div)
    
    def sigma(self, t, x1=None):
        return (1 - (1-self.sigma_min)*t)
    
    def dsigma_dt(self, t, x1):
        return - (1-self.sigma_min)
    
    def mu(self, t, x1):
        return t*x1
    
    def dmu_dt(self, t, x1):
        return x1

    def conditional_flow(self, t, x, x1):
        """
        Computes \phi(t,x) = \sigma(t, x1)x + \mu(t, x1), where \phi(0,x) = x = x0
        
        :param t: timestep. Float in [0,1].
        :param x0: starting point sampled from N(0, I).
        :param x1: observation
        """
        dims = [1]*(len(x.shape)-1)
        t = t.view(-1, *dims)
        return self.sigma(t, x1)*x + self.mu(t, x1)
    
    def conditional_velocity(self, t, x, x1, eps=1e-7):
        dims = [1]*(len(x.shape)-1)
        t = t.view(-1, *dims)
        return self.sigma(t, x1)/(self.dsigma_dt(t, x1) + eps)*(x-self.mu(t, x1)) + self.dmu_dt(t, x1)
    
    def target_velocity(self, t, x, x1):
        return self.conditional_velocity(t, self.conditional_flow(t, x, x1), x1)
    
    def criterion(self, t, x0, x1):
       v = self.forward(t, x0, x1)
       target = self.target_velocity(t, x0, x1)
       dim = tuple(torch.arange(1, len(x0.shape)))
       return torch.mean((v - target).pow(2).sum(dim=dim))
    
    def sample(self, n_samples, method='dopri5', rtol=1e-5, atol=1e-5):
        self.device = next(self.parameters()).device
        x0 = torch.randn([n_samples]+list(self.obs_dim), device=self.device)
        t = torch.linspace(0,1,2, device=self.device)
        with torch.no_grad():
            return odeint(self.model, x0, t, rtol=rtol, atol=atol, method=method)[-1,:,:]
        

    def approx_div(self, f_x, x, retain_graph=True):
        z = torch.randint(low=0, high=2, size=x.shape).to(x) * 2 - 1
        e_dzdx = torch.autograd.grad(f_x, x, z, create_graph=True, retain_graph=retain_graph)[0]
        return (e_dzdx*z).view(z.shape[0], -1).sum(dim=1)

        
    def logp(self, x1, n_samples=50, rtol=1e-05, atol=1e-05):
        self.device = next(self.parameters()).device
        self.n_samples = n_samples
        t = torch.linspace(0, 1, 2, device=self.device )
        phi, f = odeint(
            self.reversed_velocity_with_div, 
            (x1, torch.zeros((x1.shape[0], 1))), 
            t, 
            rtol=rtol,
            atol=atol,
            adjoint_params = self.model.parameters(),
            )
        phi, f = phi[-1].detach().cpu(), f[-1].detach().cpu().flatten()
        logp_noise = -0.5 * (phi.pow(2).sum(1) + phi.shape[1] * torch.log(torch.tensor(2 * torch.pi)))
        return logp_noise - f
    
    def __str__(self):
        """
        Model prints with the number of parameters.
        """
        all_parameters = sum([p.numel() for p in self.model.parameters()])
        trainable_parameters = sum(
            [p.numel() for p in self.model.parameters() if p.requires_grad]
        )

        result_info = super().__str__()
        result_info = result_info + f"\nAll parameters: {all_parameters}"
        result_info = result_info + f"\nTrainable parameters: {trainable_parameters}"

        return result_info
        

In [10]:
def main(d=2, device='cpu'):

    device = torch.device(device)

    
    d = d

    pi1 = 0.4
    pi2 = 0.6

    # Means
    mu1 = np.array([0.0, 0.0, 0.0])[:d]
    mu2 = np.array([5.0, -10.0, 1.0])[:d]

    # Covariances
    Sigma1 = np.array([[0.5, 0.1, 0.0],
                    [0.1, 0.3, 0.0],
                    [0.0, 0.0, 0.2]])[:d, :d]

    Sigma2 = np.array([[1, 0.2, 0.1],
                    [0.2, 1.0, 0.3],
                    [0.1, 0.3, 1]])[:d, :d]


    weights = [pi1, pi2]
    means = [mu1, mu2]
    covs = [Sigma1, Sigma2]

    

    train_data_length = 1024
    train_data = torch.from_numpy(mixture_samples(train_data_length, d, weights, means, covs)).to(torch.float)

    batch_size = 128
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

    n_epochs = 500
    lr = 0.001
    obs_dim = (d,)
    fm = FlowMatching(MLP(d, 128, 5), obs_dim).to(device)
    optimizer = torch.optim.AdamW(fm.parameters(), lr=lr)

    total_loss = []
    for epoch in tqdm(range(n_epochs)):
        epoch_loss = 0
        for i, x1 in enumerate(train_loader):
            x1 = x1.to(device)
            t = torch.rand(x1.shape[0]).to(device)
            x0 = torch.randn_like(x1).to(device)
            loss = fm.criterion(t, x0, x1)
            epoch_loss += loss.detach().cpu().numpy()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        total_loss.append(epoch_loss/(i+1))
        if epoch %100 == 0:
            print(f"Total Loss Epoch {epoch+1}: ", total_loss[-1])

    samples_fm = fm.sample(train_data_length)

    df = pd.DataFrame(train_data)
    df['sample'] = 'ground truth'
    df_fm = pd.DataFrame(samples_fm,)
    df_fm['sample'] = 'flow matching'
    title = "Comparison: True Data vs Flow Matching"
    if d > 2:
        fig = px.scatter_3d(pd.concat([df, df_fm], axis=0), x=0, y=1, z=2,
                            color='sample',
                            size_max=1, opacity=1,
                            title=title)
    else:
        fig = px.scatter(pd.concat([df, df_fm], axis=0), x=0, y=1,
                            color='sample',
                            size_max=1, opacity=1,
                            title=title)
    fig.update_traces(marker_size=2)
    fig.show()



    x1 = train_data[:256].to(device)
    logp = fm.logp(x1, n_samples=50, atol=1e-4, rtol=1e-5)
    logp_true = compute_gmm_logp(x1, weights, means, covs)

    sns.histplot(abs(logp.numpy() - logp_true))
    plt.title('Absolute error in logp')
    plt.show()

    if d==2:
        plot_density_interactive(fm, weights, means, covs, device=device, range_lim=15, n_grid=32)

if __name__ == '__main__':
    d =  2
    device =  'cpu'
    main(d, device)


  2%|▏         | 8/500 [00:00<00:06, 77.28it/s]

t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Siz

  6%|▋         | 32/500 [00:00<00:04, 108.90it/s]

t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Siz

 11%|█▏        | 57/500 [00:00<00:03, 112.41it/s]

t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Siz

 16%|█▌        | 81/500 [00:00<00:03, 114.79it/s]

t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Siz

 21%|██        | 105/500 [00:00<00:03, 117.20it/s]

t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Siz

 26%|██▌       | 129/500 [00:01<00:03, 114.85it/s]

t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Siz

 31%|███       | 153/500 [00:01<00:02, 116.78it/s]

t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Siz

 36%|███▌      | 178/500 [00:01<00:02, 118.53it/s]

t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Siz

 41%|████      | 203/500 [00:01<00:02, 119.20it/s]

t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Siz

 46%|████▌     | 228/500 [00:01<00:02, 119.64it/s]

t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Siz

 50%|█████     | 252/500 [00:02<00:02, 116.07it/s]

t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Siz

 55%|█████▌    | 277/500 [00:02<00:01, 118.16it/s]

t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Siz

 60%|██████    | 301/500 [00:02<00:01, 118.84it/s]

t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Siz

 65%|██████▌   | 326/500 [00:02<00:01, 119.52it/s]

t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Siz

 70%|███████   | 351/500 [00:03<00:01, 119.92it/s]

t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Siz

 75%|███████▌  | 375/500 [00:03<00:01, 116.34it/s]

t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Siz

 77%|███████▋  | 387/500 [00:03<00:00, 117.12it/s]

t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Siz

 82%|████████▏ | 412/500 [00:03<00:00, 118.54it/s]

t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Siz

 87%|████████▋ | 437/500 [00:03<00:00, 119.25it/s]

t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Siz

 92%|█████████▏| 462/500 [00:03<00:00, 115.96it/s]

t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Siz

 97%|█████████▋| 486/500 [00:04<00:00, 117.79it/s]

t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Siz

100%|██████████| 500/500 [00:04<00:00, 116.95it/s]


t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Size([128, 2])
t torch.Size([128, 1])
x torch.Siz

KeyboardInterrupt: 