Skip to content

Commit

Permalink
Multibunch wakefieldpass (#493)
Browse files Browse the repository at this point in the history
* initial commit

* move clear history to wake element

* append wake element

* import collective

* error handling

* update

* help

* c windows error

* c windows error

* zero weight slices

* utils

* bugfix: proper bunch weight

* bug

* update

* fix mpi

* bugfix: proper for weights, initialisation

* Added example files for LCBI simulation and analysis

* Added atError and atWarning C functions. WakefieldPass modified as example

* add mexfunction

* restore mexfunction

* Added TRW example files

Co-authored-by: Lee Carver <carver@slurm-nice-devel1904.esrf.fr>
Co-authored-by: Laurent Farvacque <laurent.farvacque@esrf.fr>
  • Loading branch information
3 people authored Sep 29, 2022
1 parent 555c2af commit d6b2ccc
Show file tree
Hide file tree
Showing 13 changed files with 833 additions and 112 deletions.
70 changes: 41 additions & 29 deletions atintegrators/WakeFieldPass.c
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ struct elem
int nslice;
int nelem;
int nturns;
double circumference;
double *normfact;
double *waketableT;
double *waketableDX;
Expand All @@ -25,7 +24,8 @@ struct elem
};


void WakeFieldPass(double *r_in,int num_particles,double circumference,struct elem *Elem) {
void WakeFieldPass(double *r_in,int num_particles,double circumference,int nbunch,
double *bunch_spos,double *bunch_currents,struct elem *Elem) {
/*
* r_in - 6-by-N matrix of initial conditions reshaped into
* 1-d array of 6*N elements
Expand All @@ -43,7 +43,7 @@ void WakeFieldPass(double *r_in,int num_particles,double circumference,struct el
double *turnhistory = Elem->turnhistory;
double *z_cuts = Elem->z_cuts;

size_t sz = 5*nslice*sizeof(double) + num_particles*sizeof(int);
size_t sz = 5*nslice*nbunch*sizeof(double) + num_particles*sizeof(int);
int c;

int *pslice;
Expand All @@ -57,19 +57,20 @@ void WakeFieldPass(double *r_in,int num_particles,double circumference,struct el
double *dptr = (double *) buffer;
int *iptr;

kx = dptr; dptr += nslice;
ky = dptr; dptr += nslice;
kx2 = dptr; dptr += nslice;
ky2 = dptr; dptr += nslice;
kz = dptr; dptr += nslice;
kx = dptr; dptr += nslice*nbunch;
ky = dptr; dptr += nslice*nbunch;
kx2 = dptr; dptr += nslice*nbunch;
ky2 = dptr; dptr += nslice*nbunch;
kz = dptr; dptr += nslice*nbunch;

iptr = (int *) dptr;
pslice = iptr; iptr += num_particles;

/*slices beam and compute kick*/
rotate_table_history(nturns,nslice,turnhistory,circumference);
slice_bunch(r_in,num_particles,nslice,nturns,turnhistory,pslice,z_cuts);
compute_kicks(nslice,nturns,nelem,turnhistory,waketableT,waketableDX,
rotate_table_history(nturns,nslice*nbunch,turnhistory,circumference);
slice_bunch(r_in,num_particles,nslice,nturns,nbunch,bunch_spos,bunch_currents,
turnhistory,pslice,z_cuts);
compute_kicks(nslice*nbunch,nturns,nelem,turnhistory,waketableT,waketableDX,
waketableDY,waketableQX,waketableQY,waketableZ,
normfact,kx,ky,kx2,ky2,kz);

Expand Down Expand Up @@ -97,8 +98,7 @@ ExportMode struct elem *trackFunction(const atElem *ElemData,struct elem *Elem,
{
if (!Elem) {
long nslice,nelem,nturns;
int i;
double num_charges, wakefact;
double wakefact;
static double lnf[3];
double *normfact;
double *waketableT;
Expand All @@ -109,11 +109,11 @@ ExportMode struct elem *trackFunction(const atElem *ElemData,struct elem *Elem,
double *waketableZ;
double *turnhistory;
double *z_cuts;
int i;

nslice=atGetLong(ElemData,"_nslice"); check_error();
nelem=atGetLong(ElemData,"_nelem"); check_error();
nturns=atGetLong(ElemData,"_nturns"); check_error();
num_charges=atGetDouble(ElemData,"NumParticles"); check_error();
wakefact=atGetDouble(ElemData,"_wakefact"); check_error();
waketableT=atGetDoubleArray(ElemData,"_wakeT"); check_error();
turnhistory=atGetDoubleArray(ElemData,"_turnhistory"); check_error();
Expand All @@ -131,8 +131,8 @@ ExportMode struct elem *trackFunction(const atElem *ElemData,struct elem *Elem,
Elem->nslice=nslice;
Elem->nelem=nelem;
Elem->nturns=nturns;
for (i=0;i<3;i++){
lnf[i]=normfact[i]*num_charges*wakefact;
for(i=0;i<3;i++){
lnf[i]=normfact[i]*wakefact;
}
Elem->normfact=lnf;
Elem->waketableT=waketableT;
Expand All @@ -144,14 +144,21 @@ ExportMode struct elem *trackFunction(const atElem *ElemData,struct elem *Elem,
Elem->turnhistory=turnhistory;
Elem->z_cuts=z_cuts;
}
WakeFieldPass(r_in,num_particles,Param->RingLength,Elem);
if(num_particles<Param->nbunch){
atError("Number of particles has to be greater or equal to the number of bunches.");
}else if (num_particles%Param->nbunch!=0){
atWarning("Number of particles not a multiple of the number of bunches: uneven bunch load.");
}
WakeFieldPass(r_in,num_particles,Param->RingLength,Param->nbunch,Param->bunch_spos,
Param->bunch_currents,Elem);
return Elem;
}

MODULE_DEF(WakeFieldPass) /* Dummy module initialisation */

#endif /*defined(MATLAB_MEX_FILE) || defined(PYAT)*/


#ifdef MATLAB_MEX_FILE

void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
Expand All @@ -164,7 +171,8 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
struct elem El, *Elem=&El;

long nslice,nelem,nturns;
double num_charges, wakefact;
double wakefact;
static double lnf[3];
double *normfact;
double *waketableT;
double *waketableDX;
Expand All @@ -178,7 +186,6 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
nslice=atGetLong(ElemData,"_nslice"); check_error();
nelem=atGetLong(ElemData,"_nelem"); check_error();
nturns=atGetLong(ElemData,"_nturns"); check_error();
num_charges=atGetDouble(ElemData,"NumParticles"); check_error();
wakefact=atGetDouble(ElemData,"_wakefact"); check_error();
waketableT=atGetDoubleArray(ElemData,"_wakeT"); check_error();
turnhistory=atGetDoubleArray(ElemData,"_turnhistory"); check_error();
Expand All @@ -194,10 +201,10 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
Elem->nslice=nslice;
Elem->nelem=nelem;
Elem->nturns=nturns;
for (i=0;i<3;i++){
normfact[i]*=num_charges*wakefact;
for(i=0;i<3;i++){
lnf[i]=normfact[i]*wakefact;
}
Elem->normfact=normfact;
Elem->normfact=lnf;
Elem->waketableT=waketableT;
Elem->waketableDX=waketableDX;
Elem->waketableDY=waketableDY;
Expand All @@ -211,19 +218,24 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
/* ALLOCATE memory for the output array of the same size as the input */
plhs[0] = mxDuplicateArray(prhs[1]);
r_in = mxGetDoubles(plhs[0]);
WakeFieldPass(r_in, num_particles, 0.0, Elem);
double *bspos = malloc(sizeof(double));
double *bcurr = malloc(sizeof(double));
bspos[0] = 0.0;
bcurr[0] = 0.0;
WakeFieldPass(r_in,num_particles, 1, 1, bspos, bcurr, Elem);
free(bspos);
free(bcurr);
}
else if (nrhs == 0) {
/* list of required fields */
plhs[0] = mxCreateCellMatrix(8,1);
plhs[0] = mxCreateCellMatrix(7,1);
mxSetCell(plhs[0],0,mxCreateString("_nelem"));
mxSetCell(plhs[0],1,mxCreateString("_nslice"));
mxSetCell(plhs[0],2,mxCreateString("_nturns"));
mxSetCell(plhs[0],3,mxCreateString("NumParticles"));
mxSetCell(plhs[0],4,mxCreateString("_wakefact"));
mxSetCell(plhs[0],5,mxCreateString("_wakeT"));
mxSetCell(plhs[0],6,mxCreateString("_turnhistory"));
mxSetCell(plhs[0],7,mxCreateString("Normfact"));
mxSetCell(plhs[0],3,mxCreateString("_wakefact"));
mxSetCell(plhs[0],4,mxCreateString("_wakeT"));
mxSetCell(plhs[0],5,mxCreateString("_turnhistory"));
mxSetCell(plhs[0],6,mxCreateString("Normfact"));

if (nlhs>1) {
/* list of optional fields */
Expand Down
4 changes: 4 additions & 0 deletions atintegrators/atelem.c
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ static void *atCalloc(size_t count, size_t size)

typedef mxArray atElem;
#define check_error()
#define atError(...) mexErrMsgIdAndTxt("AT:PassError", __VA_ARGS__)
#define atWarning(...) mexWarnMsgIdAndTxt("AT:PassWarning", __VA_ARGS__)

static mxArray *get_field(const mxArray *pm, const char *fieldname)
{
Expand Down Expand Up @@ -141,6 +143,8 @@ static double* atGetOptionalDoubleArray(const mxArray *ElemData, const char *fie

typedef PyObject atElem;
#define check_error() if (PyErr_Occurred()) return NULL
#define atError(...) return (struct elem *) PyErr_Format(PyExc_ValueError, __VA_ARGS__)
#define atWarning(...) if (PyErr_WarnFormat(PyExc_RuntimeWarning, 0, __VA_ARGS__) != 0) return NULL

static int array_imported = 0;

Expand Down
Loading

0 comments on commit d6b2ccc

Please sign in to comment.