Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multibunch wakefieldpass #493

Merged
merged 23 commits into from
Sep 29, 2022
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 26 additions & 103 deletions atintegrators/WakeFieldPass.c
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ struct elem
int nslice;
int nelem;
int nturns;
double circumference;
double *normfact;
double *waketableT;
double *waketableDX;
Expand All @@ -25,7 +24,8 @@ struct elem
};


void WakeFieldPass(double *r_in,int num_particles,double circumference,struct elem *Elem) {
void WakeFieldPass(double *r_in,int num_particles,double circumference,int nbunch,
double *bunch_spos,double *bunch_currents,struct elem *Elem) {
/*
* r_in - 6-by-N matrix of initial conditions reshaped into
* 1-d array of 6*N elements
Expand All @@ -43,7 +43,7 @@ void WakeFieldPass(double *r_in,int num_particles,double circumference,struct el
double *turnhistory = Elem->turnhistory;
double *z_cuts = Elem->z_cuts;

size_t sz = 5*nslice*sizeof(double) + num_particles*sizeof(int);
size_t sz = 5*nslice*nbunch*sizeof(double) + num_particles*sizeof(int);
int c;

int *pslice;
Expand All @@ -57,19 +57,20 @@ void WakeFieldPass(double *r_in,int num_particles,double circumference,struct el
double *dptr = (double *) buffer;
int *iptr;

kx = dptr; dptr += nslice;
ky = dptr; dptr += nslice;
kx2 = dptr; dptr += nslice;
ky2 = dptr; dptr += nslice;
kz = dptr; dptr += nslice;
kx = dptr; dptr += nslice*nbunch;
ky = dptr; dptr += nslice*nbunch;
kx2 = dptr; dptr += nslice*nbunch;
ky2 = dptr; dptr += nslice*nbunch;
kz = dptr; dptr += nslice*nbunch;

iptr = (int *) dptr;
pslice = iptr; iptr += num_particles;

/*slices beam and compute kick*/
rotate_table_history(nturns,nslice,turnhistory,circumference);
slice_bunch(r_in,num_particles,nslice,nturns,turnhistory,pslice,z_cuts);
compute_kicks(nslice,nturns,nelem,turnhistory,waketableT,waketableDX,
rotate_table_history(nturns,nslice*nbunch,turnhistory,circumference);
slice_bunch(r_in,num_particles,nslice,nturns,nbunch,bunch_spos,bunch_currents,
turnhistory,pslice,z_cuts);
compute_kicks(nslice*nbunch,nturns,nelem,turnhistory,waketableT,waketableDX,
waketableDY,waketableQX,waketableQY,waketableZ,
normfact,kx,ky,kx2,ky2,kz);

Expand Down Expand Up @@ -97,8 +98,7 @@ ExportMode struct elem *trackFunction(const atElem *ElemData,struct elem *Elem,
{
if (!Elem) {
long nslice,nelem,nturns;
int i;
double num_charges, wakefact;
double wakefact;
static double lnf[3];
double *normfact;
double *waketableT;
Expand All @@ -109,11 +109,11 @@ ExportMode struct elem *trackFunction(const atElem *ElemData,struct elem *Elem,
double *waketableZ;
double *turnhistory;
double *z_cuts;
int i;

nslice=atGetLong(ElemData,"_nslice"); check_error();
nelem=atGetLong(ElemData,"_nelem"); check_error();
nturns=atGetLong(ElemData,"_nturns"); check_error();
num_charges=atGetDouble(ElemData,"NumParticles"); check_error();
wakefact=atGetDouble(ElemData,"_wakefact"); check_error();
waketableT=atGetDoubleArray(ElemData,"_wakeT"); check_error();
turnhistory=atGetDoubleArray(ElemData,"_turnhistory"); check_error();
Expand All @@ -131,8 +131,8 @@ ExportMode struct elem *trackFunction(const atElem *ElemData,struct elem *Elem,
Elem->nslice=nslice;
Elem->nelem=nelem;
Elem->nturns=nturns;
for (i=0;i<3;i++){
lnf[i]=normfact[i]*num_charges*wakefact;
for(i=0;i<3;i++){
lnf[i]=normfact[i]*wakefact;
}
Elem->normfact=lnf;
Elem->waketableT=waketableT;
Expand All @@ -144,100 +144,23 @@ ExportMode struct elem *trackFunction(const atElem *ElemData,struct elem *Elem,
Elem->turnhistory=turnhistory;
Elem->z_cuts=z_cuts;
}
WakeFieldPass(r_in,num_particles,Param->RingLength,Elem);
if(num_particles<Param->nbunch){
atError("Number of particles has to be greater or equal to the number of bunches.");
}else if (num_particles%Param->nbunch!=0){
atWarning("Number of particles not a multiple of the number of bunches: uneven bunch load.");
}
WakeFieldPass(r_in,num_particles,Param->RingLength,Param->nbunch,Param->bunch_spos,
Param->bunch_currents,Elem);
return Elem;
}

MODULE_DEF(WakeFieldPass) /* Dummy module initialisation */

#endif /*defined(MATLAB_MEX_FILE) || defined(PYAT)*/

#ifdef MATLAB_MEX_FILE

#if defined(MATLAB_MEX_FILE)
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
if (nrhs == 2) {
double *r_in;
const mxArray *ElemData = prhs[0];
int num_particles = mxGetN(prhs[1]);
int i;
struct elem El, *Elem=&El;

long nslice,nelem,nturns;
double num_charges, wakefact;
double *normfact;
double *waketableT;
double *waketableDX;
double *waketableDY;
double *waketableQX;
double *waketableQY;
double *waketableZ;
double *turnhistory;
double *z_cuts;

nslice=atGetLong(ElemData,"_nslice"); check_error();
nelem=atGetLong(ElemData,"_nelem"); check_error();
nturns=atGetLong(ElemData,"_nturns"); check_error();
num_charges=atGetDouble(ElemData,"NumParticles"); check_error();
wakefact=atGetDouble(ElemData,"_wakefact"); check_error();
waketableT=atGetDoubleArray(ElemData,"_wakeT"); check_error();
turnhistory=atGetDoubleArray(ElemData,"_turnhistory"); check_error();
normfact=atGetDoubleArray(ElemData,"NormFact"); check_error();
/*optional attributes*/
waketableDX=atGetOptionalDoubleArray(ElemData,"_wakeDX"); check_error();
waketableDY=atGetOptionalDoubleArray(ElemData,"_wakeDY"); check_error();
waketableQX=atGetOptionalDoubleArray(ElemData,"_wakeQX"); check_error();
waketableQY=atGetOptionalDoubleArray(ElemData,"_wakeQY"); check_error();
waketableZ=atGetOptionalDoubleArray(ElemData,"_wakeZ"); check_error();
z_cuts=atGetOptionalDoubleArray(ElemData,"ZCuts"); check_error();

Elem->nslice=nslice;
Elem->nelem=nelem;
Elem->nturns=nturns;
for (i=0;i<3;i++){
normfact[i]*=num_charges*wakefact;
}
Elem->normfact=normfact;
Elem->waketableT=waketableT;
Elem->waketableDX=waketableDX;
Elem->waketableDY=waketableDY;
Elem->waketableQX=waketableQX;
Elem->waketableQY=waketableQY;
Elem->waketableZ=waketableZ;
Elem->turnhistory=turnhistory;
Elem->z_cuts=z_cuts;

if (mxGetM(prhs[1]) != 6) mexErrMsgIdAndTxt("AT:WrongArg","Second argument must be a 6 x N matrix: particle array");
/* ALLOCATE memory for the output array of the same size as the input */
plhs[0] = mxDuplicateArray(prhs[1]);
r_in = mxGetDoubles(plhs[0]);
WakeFieldPass(r_in, num_particles, 0.0, Elem);
}
else if (nrhs == 0) {
/* list of required fields */
plhs[0] = mxCreateCellMatrix(8,1);
mxSetCell(plhs[0],0,mxCreateString("_nelem"));
mxSetCell(plhs[0],1,mxCreateString("_nslice"));
mxSetCell(plhs[0],2,mxCreateString("_nturns"));
mxSetCell(plhs[0],3,mxCreateString("NumParticles"));
mxSetCell(plhs[0],4,mxCreateString("_wakefact"));
mxSetCell(plhs[0],5,mxCreateString("_wakeT"));
mxSetCell(plhs[0],6,mxCreateString("_turnhistory"));
mxSetCell(plhs[0],7,mxCreateString("Normfact"));

if (nlhs>1) {
/* list of optional fields */
plhs[1] = mxCreateCellMatrix(6,1); /* No optional fields */
mxSetCell(plhs[0],0,mxCreateString("_wakeDX"));
mxSetCell(plhs[0],1,mxCreateString("_wakeDY"));
mxSetCell(plhs[0],2,mxCreateString("_wakeQX"));
mxSetCell(plhs[0],3,mxCreateString("_wakeQY"));
mxSetCell(plhs[0],4,mxCreateString("_wakeZ"));
mxSetCell(plhs[0],5,mxCreateString("ZCuts"));
}
}
else {
mexErrMsgIdAndTxt("AT:WrongArg","Needs 2 or 0 arguments");
}
atError("WakeFieldPass: mex function undefined");
}
#endif
#endif /*defined(MATLAB_MEX_FILE)*/
4 changes: 4 additions & 0 deletions atintegrators/atelem.c
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ static void *atCalloc(size_t count, size_t size)

typedef mxArray atElem;
#define check_error()
#define atError(...) mexErrMsgIdAndTxt("AT:PassError", __VA_ARGS__)
#define atWarning(...) mexWarnMsgIdAndTxt("AT:PassWarning", __VA_ARGS__)

static mxArray *get_field(const mxArray *pm, const char *fieldname)
{
Expand Down Expand Up @@ -141,6 +143,8 @@ static double* atGetOptionalDoubleArray(const mxArray *ElemData, const char *fie

typedef PyObject atElem;
#define check_error() if (PyErr_Occurred()) return NULL
#define atError(...) return (struct elem *) PyErr_Format(PyExc_ValueError, __VA_ARGS__)
#define atWarning(...) if (PyErr_WarnFormat(PyExc_RuntimeWarning, 0, __VA_ARGS__) != 0) return NULL

static int array_imported = 0;

Expand Down
Loading