diff --git a/Detectors/TPC/reconstruction/src/GPUCATracking.cxx b/Detectors/TPC/reconstruction/src/GPUCATracking.cxx index f7eca85185a19..b5c5d956e339a 100644 --- a/Detectors/TPC/reconstruction/src/GPUCATracking.cxx +++ b/Detectors/TPC/reconstruction/src/GPUCATracking.cxx @@ -83,50 +83,63 @@ int GPUCATracking::runTracking(GPUO2InterfaceIOPtrs* data, GPUInterfaceOutputs* float vzbinInv = 1.f / vzbin; Mapper& mapper = Mapper::instance(); - const ClusterNativeAccess* clusters; std::vector gpuDigits[Sector::MAXSECTOR]; + o2::dataformats::MCTruthContainer gpuDigitsMC[Sector::MAXSECTOR]; + GPUTrackingInOutDigits gpuDigitsMap; - GPUTPCDigitsMCInput gpuDigitsMC; + GPUTPCDigitsMCInput gpuDigitsMapMC; GPUTrackingInOutPointers ptrs; - if (data->compressedClusters) { - ptrs.tpcCompressedClusters = data->compressedClusters; - } else if (data->tpcZS) { - ptrs.tpcZS = data->tpcZS; - } else if (data->o2Digits) { - ptrs.clustersNative = nullptr; + ptrs.tpcCompressedClusters = data->compressedClusters; + ptrs.tpcZS = data->tpcZS; + if (data->o2Digits) { const float zsThreshold = mTrackingCAO2Interface->getConfig().configReconstruction.tpcZSthreshold; const int maxContTimeBin = mTrackingCAO2Interface->getConfig().configEvent.continuousMaxTimeBin; for (int i = 0; i < Sector::MAXSECTOR; i++) { const auto& d = (*(data->o2Digits))[i]; - gpuDigits[i].reserve(d.size()); - gpuDigitsMap.tpcDigits[i] = gpuDigits[i].data(); + if (zsThreshold > 0) { + gpuDigits[i].reserve(d.size()); + } for (int j = 0; j < d.size(); j++) { if (maxContTimeBin && d[j].getTimeStamp() >= maxContTimeBin) { throw std::runtime_error("Digit time bin exceeds time frame length"); } - if (d[j].getChargeFloat() >= zsThreshold) { - gpuDigits[i].emplace_back(d[j]); + if (zsThreshold > 0) { + if (d[j].getChargeFloat() >= zsThreshold) { + if (data->o2DigitsMC) { + for (const auto& element : (*data->o2DigitsMC)[i]->getLabels(j)) { + gpuDigitsMC[i].addElement(gpuDigits[i].size(), element); + } + } + gpuDigits[i].emplace_back(d[j]); + } + } + } + if (zsThreshold > 0) { + gpuDigitsMap.tpcDigits[i] = gpuDigits[i].data(); + gpuDigitsMap.nTPCDigits[i] = gpuDigits[i].size(); + if (data->o2DigitsMC) { + gpuDigitsMapMC.v[i] = &gpuDigitsMC[i]; + } + } else { + gpuDigitsMap.tpcDigits[i] = (*(data->o2Digits))[i].data(); + gpuDigitsMap.nTPCDigits[i] = (*(data->o2Digits))[i].size(); + if (data->o2DigitsMC) { + gpuDigitsMapMC.v[i] = (*data->o2DigitsMC)[i].get(); } } - gpuDigitsMap.nTPCDigits[i] = gpuDigits[i].size(); } if (data->o2DigitsMC) { - for (int i = 0; i < Sector::MAXSECTOR; i++) { - gpuDigitsMC.v[i] = (*data->o2DigitsMC)[i].get(); - } - gpuDigitsMap.tpcDigitsMC = &gpuDigitsMC; + gpuDigitsMap.tpcDigitsMC = &gpuDigitsMapMC; } ptrs.tpcPackedDigits = &gpuDigitsMap; - } else { - clusters = data->clusters; - ptrs.clustersNative = clusters; - ptrs.tpcPackedDigits = nullptr; } + ptrs.clustersNative = data->clusters; int retVal = mTrackingCAO2Interface->RunTracking(&ptrs, outputs); if (data->o2Digits || data->tpcZS || data->compressedClusters) { - clusters = ptrs.clustersNative; + data->clusters = ptrs.clustersNative; } + data->compressedClusters = ptrs.tpcCompressedClusters; const GPUTPCGMMergedTrack* tracks = ptrs.mergedTracks; int nTracks = ptrs.nMergedTracks; const GPUTPCGMMergedTrackHit* trackClusters = ptrs.mergedTrackHits; @@ -177,8 +190,8 @@ int GPUCATracking::runTracking(GPUO2InterfaceIOPtrs* data, GPUInterfaceOutputs* if (lastSide ^ (trackClusters[tracks[i].FirstClusterRef() + iCl].slice < Sector::MAXSECTOR / 2)) { auto& cacl1 = trackClusters[tracks[i].FirstClusterRef() + iCl]; auto& cacl2 = trackClusters[tracks[i].FirstClusterRef() + iCl - 1]; - auto& cl1 = clusters->clustersLinear[cacl1.num]; - auto& cl2 = clusters->clustersLinear[cacl2.num]; + auto& cl1 = data->clusters->clustersLinear[cacl1.num]; + auto& cl2 = data->clusters->clustersLinear[cacl2.num]; delta = fabs(cl1.getTime() - cl2.getTime()) * 0.5f; break; } @@ -188,8 +201,8 @@ int GPUCATracking::runTracking(GPUO2InterfaceIOPtrs* data, GPUInterfaceOutputs* // estimate max/min time increments which still keep track in the physical limits of the TPC auto& c1 = trackClusters[tracks[i].FirstClusterRef()]; auto& c2 = trackClusters[tracks[i].FirstClusterRef() + tracks[i].NClusters() - 1]; - float t1 = clusters->clustersLinear[c1.num].getTime(); - float t2 = clusters->clustersLinear[c2.num].getTime(); + float t1 = data->clusters->clustersLinear[c1.num].getTime(); + float t2 = data->clusters->clustersLinear[c2.num].getTime(); auto times = std::minmax(t1, t2); tFwd = times.first - time0; tBwd = time0 - (times.second - detParam.TPClength * vzbinInv); @@ -247,7 +260,7 @@ int GPUCATracking::runTracking(GPUO2InterfaceIOPtrs* data, GPUInterfaceOutputs* int clusterIdGlobal = trackClusters[tracks[i].FirstClusterRef() + j].num; Sector sector = trackClusters[tracks[i].FirstClusterRef() + j].slice; int globalRow = trackClusters[tracks[i].FirstClusterRef() + j].row; - int clusterIdInRow = clusterIdGlobal - clusters->clusterOffset[sector][globalRow]; + int clusterIdInRow = clusterIdGlobal - data->clusters->clusterOffset[sector][globalRow]; int regionNumber = 0; while (globalRow > mapper.getGlobalRowOffsetRegion(regionNumber) + mapper.getNumberOfRowsRegion(regionNumber)) { regionNumber++; @@ -256,8 +269,8 @@ int GPUCATracking::runTracking(GPUO2InterfaceIOPtrs* data, GPUInterfaceOutputs* sectorIndexArr[nOutCl] = sector; rowIndexArr[nOutCl] = globalRow; nOutCl++; - if (outputTracksMCTruth && clusters->clustersMCTruth) { - for (const auto& element : clusters->clustersMCTruth->getLabels(clusterIdGlobal)) { + if (outputTracksMCTruth && data->clusters->clustersMCTruth) { + for (const auto& element : data->clusters->clustersMCTruth->getLabels(clusterIdGlobal)) { bool found = false; for (int l = 0; l < labels.size(); l++) { if (labels[l].first == element) { @@ -291,10 +304,7 @@ int GPUCATracking::runTracking(GPUO2InterfaceIOPtrs* data, GPUInterfaceOutputs* } } outClusRefs->resize(clusterOffsetCounter.load()); // remove overhead - if (data->o2Digits || data->tpcZS || data->compressedClusters) { - data->clusters = ptrs.clustersNative; - } - data->compressedClusters = ptrs.tpcCompressedClusters; + mTrackingCAO2Interface->Clear(false); return (retVal);