In [None]:
import torch
from spherebsde import nn, utils
device = torch.device('cuda:0')

In [2]:
rmin = 10
rmax = 100
R = 0.5
N = 100
rtip = 0.05
a = 0.05
c1 = 5
c2 = 0.1
kbar = 4
x0 = torch.tensor([0,0,R]).float().to(device)

def u0(r,x,y):
    x_polar = utils.polar_corr(x)
    thetab = torch.arccos(r.unsqueeze(1)*8/1000/2/R)
    return (x_polar[:,:1]<thetab)/(1-r.unsqueeze(1)*8/1000/2/R)*(r.unsqueeze(1)<=50)*(y[:,-1:] > R)

def xb(r,x,y):
    x_polar = utils.polar_corr(x)
    in_ornotx = x_polar[:,0] < torch.arccos(r*8/1000/2/R)
    xb_ = x_polar.clone()
    xb_[:,0] = torch.arccos(r*8/1000/2/R)
    xb_ = utils.transform_x(xb_)
    nb_ = torch.cat([xb_[:,:1]*xb_[:,2:],xb_[:,1:2]*xb_[:,2:],-xb_[:,:2].norm(dim=1,keepdim=True)**2],dim=1)
    nb_ = nb_ / nb_.norm(dim=1,keepdim=True)
    in_ornoty = (y-x0).norm(dim=1) < R
    yb = (y-x0) / (y-x0).norm(dim=1,keepdim=True) * R + x0
    nby = yb - x0
    nby = nby / nby.norm(dim=1,keepdim=True)
    return xb_, nb_, in_ornotx, yb, nby, in_ornoty

def capture(r,x,y):
    ynorm = y.norm(dim=1)
    gamm = torch.arccos((x*y).sum(dim=1)/ynorm)
    r = r*8/1000
    tip = r-y.norm(dim=1)*torch.cos(gamm)
    return (tip >= 0)*(tip <= rtip)*((ynorm*torch.sin(gamm)) <= a)

def data_gen(batch):
    r = torch.randint(40,41,[batch]).to(device)
    x = torch.tensor([0,0,1.],device=device).expand([batch,3]).clone()
    y = torch.tensor([0,0,0.6],device=device).expand([batch,3]).clone()
    return r, x, y

def lamb(r,k,c1,kbar):
    return c1*torch.exp(-(k-kbar).abs())*(r>=rmin)*(r<=rmax)*(k>0)*(k<=(rmax-r))

def mu(r,k,c2,kbar):
    return c2*r*torch.exp(-(k-kbar).abs())*(r>=rmin)*(r<=rmax)*(k>0)*(k<=(r-rmin))

spherebsde = nn.MKCapture(
    rmin = rmin,
    rmax = rmax,
    lamb = lambda r,k:lamb(r,k,c1,kbar),
    mu = lambda r,k:mu(r,k,c2,kbar),
    u0 = u0,
    t = torch.tensor(0.1).to(device),
    Dx = torch.tensor(50.).to(device),
    Dy = torch.tensor(1.).to(device),
    data_gen = data_gen,
    N = 100,
    xb = xb,
    capture = capture,
    mc_size = 10**4,
    P = 128
).to(device)

In [3]:
train_params = {
    'epoch': 10**4,
    'batch': 512,
    'lr': 1e-3
}

loss_values, res_values = nn.train(
    spherebsde,
    train_params,
    True
)

10000/10000|##################################################|17319.06s  [Loss: 9.111197e-02, Result: 0.811610] 
Training has been completed.
