Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/BenTenmann/setriq
Browse files Browse the repository at this point in the history
  • Loading branch information
BenTenmann committed Dec 5, 2021
2 parents 87ac901 + 228c910 commit 250df98
Show file tree
Hide file tree
Showing 7 changed files with 239 additions and 4 deletions.
24 changes: 24 additions & 0 deletions include/setriq/metrics/Levenshtein.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
//
// Created by Benjamin Tenmann on 05/12/2021.
//

#ifndef SETRIQ_LEVENSHTEIN_H
#define SETRIQ_LEVENSHTEIN_H

#include <string>
#include "Metric.h"

namespace metric {
class Levenshtein : public Metric {
double extraCost = 0;

public:
Levenshtein() : extraCost() {};
explicit Levenshtein(double xCost) : extraCost {xCost} {};

double forward(const std::string &, const std::string &);

};
}

#endif //SETRIQ_LEVENSHTEIN_H
162 changes: 162 additions & 0 deletions src/setriq/_C/metrics/Levenshtein.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
//
// Created by Benjamin Tenmann on 05/12/2021.
//

#include <cstring>
#include <numeric>
#include <string>
#include <vector>
#include "metrics/Levenshtein.h"

double metric::Levenshtein::forward(const std::string &a, const std::string &b) {
size_t lengthOfA {a.size()};
size_t lengthOfB {b.size()};
size_t halfOfLengthA;

const char* ptrToA = &a.front();
const char* ptrToB = &b.front();

// catch the trivial cases
if (a.empty()) return lengthOfB;
if (b.empty()) return lengthOfA;

// grind down common prefix
while (lengthOfA > 0 && lengthOfB > 0 && (*ptrToA) == (*ptrToB)) {
lengthOfA--;
lengthOfB--;
ptrToA++;
ptrToB++;
}

// grind down common suffix
while (lengthOfA > 0 && lengthOfB > 0 && ptrToA[lengthOfA - 1] == ptrToB[lengthOfB - 1]) {
lengthOfA--;
lengthOfB--;
}

// again, catch trivial cases
if (lengthOfA == 0) return lengthOfB;
if (lengthOfB == 0) return lengthOfA;

if (lengthOfA > lengthOfB) { // enforce that b is the longer string
size_t temporaryLengthStore = lengthOfA;
const char *temporaryPtrStore = ptrToA;
lengthOfA = lengthOfB;
lengthOfB = temporaryLengthStore;
ptrToA = ptrToB;
ptrToB = temporaryPtrStore;
}

if (lengthOfA == 1) {
if (this->extraCost > 0)
return (double) lengthOfB + 1 - this->extraCost * (memchr(ptrToB, *ptrToA, lengthOfB) != nullptr);
else
return (double) lengthOfB - (memchr(ptrToB, *ptrToA, lengthOfB) != nullptr);
}
lengthOfA++;
lengthOfB++;
halfOfLengthA = lengthOfA >> 1;

// first row initialization
std::vector<size_t> row (lengthOfB);
std::iota(row.begin(), row.end() - (this->extraCost > 0 ? 0 : halfOfLengthA), 0);

size_t rowIndex;
size_t *end = &row.back() - lengthOfB - 1;

if (this->extraCost > 0) {
for (rowIndex = 1; rowIndex < lengthOfA; rowIndex++) {
size_t *ptrToRowElement = &row[1];
const char currentCharFromA = ptrToA[rowIndex - 1];
const char *ptrToCurrentCharFromB = ptrToB;

size_t rowIndexCopy1 = rowIndex;
size_t rowIndexCopy2 = rowIndex;
while (ptrToRowElement <= end) {
if (currentCharFromA == *(ptrToCurrentCharFromB++))
rowIndexCopy2 = --rowIndexCopy1;
else
rowIndexCopy2++;

rowIndexCopy1 = *ptrToRowElement;
rowIndexCopy1++;

if (rowIndexCopy2 > rowIndexCopy1)
rowIndexCopy2 = rowIndexCopy1;

*(ptrToRowElement++) = rowIndexCopy2;
}
}
}

else {
/*
*
* in this case we don't have to scan two corner triangles (of size len1/2)
* in the matrix because no best path can go throught them. note this
* breaks when len1 == len2 == 2 so the memchr() special case above is
* necessary
*
*/
row[0] = lengthOfA - halfOfLengthA - 1;
for (rowIndex = 1; rowIndex < lengthOfA; rowIndex++) {
size_t *ptrToRowElement;
const char currentCharFromA = ptrToA[rowIndex - 1];
const char *ptrToCurrentCharFromB;

size_t rowIndexCopy1, rowIndexCopy2;
/* skip the upper triangle */
if (rowIndex >= lengthOfA - halfOfLengthA) {
size_t offset = rowIndex - (lengthOfA - halfOfLengthA);
size_t rowIndexCopy3;

ptrToCurrentCharFromB = ptrToB + offset;
ptrToRowElement = &row[offset];

rowIndexCopy3 = *(ptrToRowElement++) + (currentCharFromA != *(ptrToCurrentCharFromB++));
rowIndexCopy2 = *ptrToRowElement;
rowIndexCopy2++;
rowIndexCopy1 = rowIndexCopy2;
if (rowIndexCopy2 > rowIndexCopy3)
rowIndexCopy2 = rowIndexCopy3;
*(ptrToRowElement++) = rowIndexCopy2;
}
else {
ptrToRowElement = &row[1];
ptrToCurrentCharFromB = ptrToB;
rowIndexCopy1 = rowIndexCopy2 = rowIndex;
}

/* skip the lower triangle */
if (rowIndex <= halfOfLengthA + 1)
end = &row[lengthOfB + rowIndex - halfOfLengthA - 2];

/* main */
while (ptrToRowElement <= end) {
size_t rowIndexCopy3 = --rowIndexCopy2 + (currentCharFromA != *(ptrToCurrentCharFromB++));
rowIndexCopy1++;

if (rowIndexCopy1 > rowIndexCopy3)
rowIndexCopy1 = rowIndexCopy3;

rowIndexCopy2 = *ptrToRowElement;
rowIndexCopy2++;

if (rowIndexCopy1 > rowIndexCopy2)
rowIndexCopy1 = rowIndexCopy2;
*(ptrToRowElement++) = rowIndexCopy1;
}

/* lower triangle sentinel */
if (rowIndex <= halfOfLengthA) {
size_t rowIndexCopy3 = --rowIndexCopy1 + (currentCharFromA != (*ptrToCurrentCharFromB));
rowIndexCopy2++;
if (rowIndexCopy2 > rowIndexCopy3)
rowIndexCopy2 = rowIndexCopy3;
*ptrToRowElement = rowIndexCopy2;
}
}
}

return (double) *end;
}
12 changes: 12 additions & 0 deletions src/setriq/_C/pythonWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <pybind11/stl.h>
#include "PairwiseDistanceComputer.h"
#include "metrics/CdrDist.h"
#include "metrics/Levenshtein.h"
#include "metrics/TcrDist.h"
#include "utils/typeDefs.h"

Expand All @@ -22,6 +23,14 @@ py::list cdr_dist(const stringVector &sequences, const doubleMatrix& substitutio
return py::cast(out);
}

py::list levenshtein(const stringVector sequences, double extra_cost) {
metric::Levenshtein metric {extra_cost};
PairwiseDistanceComputer computer { &metric };

doubleVector out = computer.computeDistance(sequences);
return py::cast(out);
}

py::list tcr_dist_component(const stringVector& sequences,
const doubleMatrix& substitutionMatrix,
const stringIndexMap& index,
Expand All @@ -41,6 +50,9 @@ PYBIND11_MODULE(_C, m) {
m.def("cdr_dist", &cdr_dist, "Compute the pairwise CDR-dist metric for a set of CDR3 sequences.",
py::arg("sequences"), py::arg("substitution_matrix"), py::arg("index"));

m.def("levenshtein", &levenshtein, "Compute the pairwise Levenshtein distances for a set of sequences.",
py::arg("sequences"), py::arg("extra_cost"));

m.def("tcr_dist_component", &tcr_dist_component, "Compute pairwise TCR-dist for a set of TCR components.",
py::arg("sequences"), py::arg("substitution_matrix"), py::arg("index"),
py::arg("gap_penalty"), py::arg("gap_symbol"), py::arg("weight"));
Expand Down
2 changes: 1 addition & 1 deletion src/setriq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@
"""

from .modules import (
CdrDist, TcrDist,
CdrDist, Levenshtein, TcrDist,
SubstitutionMatrix, BLOSUM45, BLOSUM62, BLOSUM90
)
6 changes: 4 additions & 2 deletions src/setriq/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from ._substitution import (
from .substitution import (
SubstitutionMatrix,
BLOSUM45,
BLOSUM62,
BLOSUM90
)
from ._distances import (
from .distances import (
CdrDist,
Levenshtein,
TcrDist
)

Expand All @@ -15,5 +16,6 @@
'BLOSUM62',
'BLOSUM90',
'CdrDist',
'Levenshtein',
'TcrDist',
]
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,15 @@
from glom import glom

import setriq._C as C
from ._substitution import *
from .substitution import (
BLOSUM45,
BLOSUM62,
SubstitutionMatrix
)

__all__ = [
'CdrDist',
'Levenshtein',
'TcrDist',
'TcrDistComponent',
]
Expand Down Expand Up @@ -73,6 +78,36 @@ def forward(self, sequences: List[str]) -> List[float]:
return out


class Levenshtein(Metric):
"""
The Levenshtein class. Inherits from Metric. It uses a refactor of the `python-Levenshtein` implementation in the
backend.
Examples
--------
>>> sequences = ['CASSLKPNTEAFF', 'CASSAHIANYGYTF', 'CASRGATETQYF']
>>>
>>> metric = Levenshtein()
>>> distances = metric(sequences)
References
----------
[1] Levenshtein, V.I., 1966, February. Binary codes capable of correcting deletions, insertions, and reversals. In
Soviet physics doklady (Vol. 10, No. 8, pp. 707-710). ()
[2] python-Levenshtein (https://github.com/ztane/python-Levenshtein)
"""
def __init__(self, extra_cost: float = 0.):
self.call_args = {
'extra_cost': extra_cost
}
self.fn = C.levenshtein

def forward(self, sequences: List[str]):
out = self.fn(sequences, **self.call_args)

return out


class TcrDistComponent(Metric):
"""
The TcrDistComponent class. Inherits from Metric.
Expand Down
File renamed without changes.

0 comments on commit 250df98

Please sign in to comment.