In [13]:
# 1) Defino la red (el Flujo)
from pathlib import Path
from torch.utils.data import Dataset
import torch
from torch import nn, Tensor
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons

class Flow(nn.Module):
  def __init__(self, dim: int = 2, h: int = 64):
    super().__init__()
    self.net = nn.Sequential(
    nn.Linear(dim + 1, h), nn.ELU(),
    nn.Linear(h, h), nn.ELU(),
    nn.Linear(h, h), nn.ELU(),
    nn.Linear(h, dim))

  def forward(self, x_t: Tensor, t: Tensor) -> Tensor:
    return self.net(torch.cat((t, x_t), -1))

  def step(self, x_t: Tensor, t_start: Tensor, t_end: Tensor) -> Tensor:
    # este metodo se usa a la hora de evaluar el metodo, no en el entrenamiento
    t_start = t_start.view(1, 1).expand(x_t.shape[0], 1)
    # Metodo de resolucion de ODE (utiliza el metodo del punto medio)
    return x_t + (t_end - t_start) * self(x_t + self(x_t, t_start) * (t_end - t_start) / 2,
    t_start + (t_end - t_start) / 2)

In [14]:

class WaveDataset(Dataset):
    def __init__(self, path: Path, num_samples: int = 1000):
        data = torch.load(path)
        self.u = data['u']              # (Nt+1, Nx+1)
        self.x = data['x']              # (Nx+1,)
        self.t = data['t']              # (Nt+1,)
        self.f_x = data['f_x']          # condición inicial f(x)
        self.k = data['k']              # parámetro k (frecuencia)

        self.Nt, self.Nx = self.u.shape
        self.num_samples = num_samples

        # Generar índices aleatorios (pueden fijarse para reproducibilidad)
        self.t_idx = torch.randint(0, self.Nt, (num_samples,))
        self.x_idx = torch.randint(0, self.Nx, (num_samples,))

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        t_i = self.t_idx[idx]
        x_i = self.x_idx[idx]

        # Obtener valores
        x_val = self.x[x_i]
        t_val = self.t[t_i]
        u_val = self.u[t_i, x_i]

        # Entrada (x, t), salida u(x, t)
        return torch.tensor([x_val, t_val], dtype=torch.float32), u_val


In [18]:
wave_path='wave_solutions/u_wave_k1.pt'
model = Flow()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
dataset = WaveDataset(wave_path)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-2)


  data = torch.load(path)


In [None]:
for epoch in range(10):
    total_loss = 0
    for x, t_, s, u_s, v_target in dataloader:
        x = x.view(-1, 1)
        t_ = t_.view(-1, 1)
        s = s.view(-1, 1)
        u_s = u_s.view(-1, 1)
        v_target = v_target.view(-1, 1)

        pred = model(x, t_, s, u_s)
        loss = F.mse_loss(pred, v_target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")