In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader, Dataset
import numpy as np
from torch_cubic_spline_grids import CubicBSplineGrid2d
import tqdm

In [None]:
class HistDataset(Dataset):
    def __init__(self, particle_dataset, x_index, y_index, bins):
        super().__init__()
        self.hist = torch.histogramdd(torch.tensor(particle_dataset.data[:,[x_index, y_index]]), bins=bins)
        self.x_len = self.hist[0].size()[0]
        self.y_len = self.hist[0].size()[1]
    def __getitem__(self, item):
        x = np.int32(np.floor(item/self.y_len))
        y = np.int32(item - (x*self.y_len))
        x_edges = self.hist[1][0][x]
        y_edges = self.hist[1][1][y]
        return torch.tensor([x_edges.item(), y_edges.item()]), self.hist[0][x,y]
    def __len__(self):
        return self.x_len * self.y_len

In [None]:
xy_dataset = HistDataset(dataset,0,1, bins = 400)
xy_DL = DataLoader(xy_dataset, batch_size=2**6)

et_dataset = HistDataset(dataset,5,6,bins = 400)
et_DL = DataLoader(et_dataset, batch_size=2**6)

rth_dataset = HistDataset(dataset,7,8,bins = 400)
rth_DL = DataLoader(rth_dataset, batch_size=2**6)

In [None]:
net_xy = CubicBSplineGrid2d(resolution=(400,400))
net_et = CubicBSplineGrid2d(resolution=(400,400))
net_rth = CubicBSplineGrid2d(resolution=(400,400))

In [None]:
optimizer_xy = optim.Adam(net_xy.parameters(), lr = 0.01)
optimizer_et = optim.Adam(net_et.parameters(), lr = 0.01)
optimizer_rth = optim.Adam(net_rth.parameters(), lr = 0.01)

In [None]:
def splineTrain(DL,optimizer,net,num_epochs):
    losses=[]
    loss_func = nn.MSELoss()
    net.to(mps_device)
    net.train()
    print("Starting Training Loop")
    for e in tqdm.tqdm(range(num_epochs)):
        iters = 0
        avg_item = 0
        for item in DL:
            net.zero_grad()
            pred = net(item[0])
            loss = loss_func(pred[:,0],item[1])
            loss.backward()
            avg_item+=loss.item()
            optimizer.step()
            iters+=1
        losses.append(avg_item/iters)
    return losses

In [None]:
num_epochs = 100
xy_losses = splineTrain(xy_DL,optimizer_xy,net_xy,num_epochs)
et_losses =  splineTrain(et_DL,optimizer_et,net_et,num_epochs)
rth_losses = splineTrain(rth_DL,optimizer_rth,net_rth,num_epochs)

plt.plot(xy_losses)
plt.figure()
plt.plot(et_losses)
plt.figure()
plt.plot(rth_losses)

In [None]:
x = torch.linspace(0,1,200)
y = torch.linspace(0,1,200)
net_xy.eval()
net_et.eval()
net_rth.eval()
X,Y = np.meshgrid(x,y)
Rxy  = np.zeros([200,200])
Ret  = np.zeros([200,200])
Rrth = np.zeros([200,200])
for i in range(200):
    for j in range(200):
        Rxy[i,j] = net_xy(torch.tensor([[x[i],y[j]]]))
        Ret[i,j] = net_et(torch.tensor([[x[i],y[j]]]))
        Rrth[i,j]= net_rth(torch.tensor([[x[i],y[j]]]))
plt.pcolormesh(X,Y,np.log10(Rxy))
plt.colorbar()
plt.figure()
plt.pcolormesh(X,Y,np.log10(Ret))
plt.colorbar()
plt.figure()
plt.pcolormesh(X,Y,np.log10(Rrth))
plt.colorbar()
plt.show()