Skip to content

Commit

Permalink
formating
Browse files Browse the repository at this point in the history
  • Loading branch information
MoSafi2 committed May 16, 2024
1 parent 0ea0b2e commit d78d56c
Showing 1 changed file with 55 additions and 50 deletions.
105 changes: 55 additions & 50 deletions blazeseq/stats.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@ from utils.static_tuple import StaticTuple

alias py_lib: String = "./.pixi/envs/default/lib/python3.12/site-packages/"


fn hash_list() -> List[UInt64]:
var li: List[UInt64] = List[UInt64](
_seq_to_hash("AGATCGGAAGAG"),
_seq_to_hash("TGGAATTCTCGG"),
_seq_to_hash("GATCGTCGGACT"),
_seq_to_hash("CTGTCTCTTATA"),
_seq_to_hash("AAAAAAAAAAAA"),
_seq_to_hash("GGGGGGGGGGGG")
)
var li: List[UInt64] = List[UInt64](
_seq_to_hash("AGATCGGAAGAG"),
_seq_to_hash("TGGAATTCTCGG"),
_seq_to_hash("GATCGTCGGACT"),
_seq_to_hash("CTGTCTCTTATA"),
_seq_to_hash("AAAAAAAAAAAA"),
_seq_to_hash("GGGGGGGGGGGG"),
)
return li


Expand All @@ -35,9 +36,11 @@ trait Analyser(CollectionElement, Stringable):

fn report(self) -> Tensor[DType.int64]:
...

fn __str__(self) -> String:
...


@value
struct FullStats(Stringable, CollectionElement):
var num_reads: Int64
Expand Down Expand Up @@ -68,10 +71,11 @@ struct FullStats(Stringable, CollectionElement):
self.cg_content.tally_read(record) # Almost Free
self.dup_reads.tally_read(record)
self.kmer_content.tally_read(record)

# BUG: There is a bug here which causes core dumped
self.qu_dist.tally_read(record) #Expensive operation, a lot of memory access

# BUG: There is a bug here which causes core dumped
self.qu_dist.tally_read(
record
) # Expensive operation, a lot of memory access

@always_inline
fn plot(inout self) raises:
Expand Down Expand Up @@ -154,9 +158,7 @@ struct CGContent(Analyser):
):
cg_num += 1

var read_cg_content = int(
round(cg_num * 100 / record.len_record())
)
var read_cg_content = int(round(cg_num * 100 / record.len_record()))
self.cg_content[read_cg_content] += 1

fn report(self) -> Tensor[DType.int64]:
Expand All @@ -176,7 +178,7 @@ struct CGContent(Analyser):
return String("\nThe CpG content tensor is: ") + self.cg_content


#TODO: You should extraplolate from the number of reads in the unique reads to how it would look like for everything.
# TODO: You should extraplolate from the number of reads in the unique reads to how it would look like for everything.
@value
struct DupReads(Analyser):
var unique_dict: Dict[FastqRecord, Int64]
Expand All @@ -193,7 +195,6 @@ struct DupReads(Analyser):
self.corrected_counts = Dict[Int, Float64]()

fn tally_read(inout self, record: FastqRecord):

self.n += 1

if record in self.unique_dict:
Expand All @@ -211,11 +212,9 @@ struct DupReads(Analyser):
self.count_at_max = self.n
else:
pass



fn predict_reads(inout self):
#Construct Duplication levels dict
# Construct Duplication levels dict
var dup_dict = Dict[Int, Int]()
for entry in self.unique_dict.values():
if int(entry[]) in dup_dict:
Expand All @@ -226,45 +225,51 @@ struct DupReads(Analyser):
else:
dup_dict[int(entry[])] = 0


# Correct reads levels
var corrected_reads = Dict[Int, Float64]()
for entry in dup_dict:
try:
var count = dup_dict[entry[]]
var level = entry[]
var corrected_count = self.correct_values(level, count, self.count_at_max, self.n)
var corrected_count = self.correct_values(
level, count, self.count_at_max, self.n
)
corrected_reads[level] = corrected_count
except:
print("Error")

self.corrected_counts = corrected_reads

#Check how it is done in Falco.
# Check how it is done in Falco.
@staticmethod
fn correct_values(dup_level: Int, count_at_level: Int, count_at_max: Int, total_count: Int ) -> Float64:
fn correct_values(
dup_level: Int, count_at_level: Int, count_at_max: Int, total_count: Int
) -> Float64:
if count_at_max == total_count:
return count_at_level

if total_count - count_at_level < count_at_max:
return count_at_level

var pNotSeeingAtLimit: Float64 = 1
var limitOfCaring = Float64(1) - (count_at_level / (count_at_level + 0.01))
var limitOfCaring = Float64(1) - (
count_at_level / (count_at_level + 0.01)
)

for i in range(count_at_max):
pNotSeeingAtLimit *= ((total_count -i ) - dup_level) / (total_count - i)
pNotSeeingAtLimit *= ((total_count - i) - dup_level) / (
total_count - i
)

if pNotSeeingAtLimit < limitOfCaring:
pNotSeeingAtLimit = 0
break

var pSeeingAtLimit:Float64 = 1 - pNotSeeingAtLimit
var pSeeingAtLimit: Float64 = 1 - pNotSeeingAtLimit
var trueCount = count_at_level / pSeeingAtLimit

return trueCount


fn report(self) -> Tensor[DType.int64]:
var report = Tensor[DType.int64](1)
report[0] = len(self.unique_dict)
Expand All @@ -273,24 +278,24 @@ struct DupReads(Analyser):
fn __str__(self) -> String:
return String("\nNumber of duplicated reads is") + self.report()


fn plot(inout self) raises:
print("Enterd Plotting")
self.predict_reads()
print("Predicted Reads")
var temp_tensor = Tensor[DType.int64](len(self.corrected_counts) * 2 + 1)
# Make this a matrix
var temp_tensor = Tensor[DType.int64](
len(self.corrected_counts) * 2 + 1
)
var i = 0
for index in self.corrected_counts:
print(index[])
print(self.corrected_counts[index[]])
temp_tensor[i * 2] = index[]
temp_tensor[i * 2 + 1] = self.corrected_counts[index[]]
i += 1

var np = Python.import_module("numpy")
var arr = tensor_to_numpy_1d(temp_tensor)
np.save("arr_DupReads.npy", arr)


@value
struct LengthDistribution(Analyser):
Expand Down Expand Up @@ -341,7 +346,7 @@ struct LengthDistribution(Analyser):
return String("\nLength Distribution: ") + self.length_vector


#TODO: FIX this struct to reflect FastQC
# TODO: FIX this struct to reflect FastQC
@value
struct QualityDistribution(Analyser):
var qu_dist: Tensor[DType.int64]
Expand Down Expand Up @@ -369,19 +374,18 @@ struct QualityDistribution(Analyser):
self.max_qu = base_qu

# Use this answer for plotting: https://stackoverflow.com/questions/58053594/how-to-create-a-boxplot-from-data-with-weights
#TODO: Make an abbreviator of the plot to get always between 50-60 bars per plot
#TODO: Stylize the plot
# TODO: Make an abbreviator of the plot to get always between 50-60 bars per plot
# TODO: Stylize the plot
fn plot(self) raises:
var arr = matrix_to_numpy(self.qu_dist)

Python.add_to_path(py_lib)
var np = Python.import_module("numpy")
var plt = Python.import_module("matplotlib.pyplot")
var sns = Python.import_module("seaborn")
var py_builtin = Python.import_module("builtins")
np.save("arr_qu.npy", arr)


################# Quality Histogram ##################

var mean_line = np.sum(arr * np.arange(1, 41), axis=1) / np.sum(
Expand Down Expand Up @@ -414,12 +418,14 @@ struct QualityDistribution(Analyser):
ax.plot(mean_line)
fig.savefig("QualityDistribution.png")

#################### Quality Heatmap #########################
###############################################################
### Quality Heatmap ###
###############################################################

var y = plt.subplots()
var fig2 = y[0]
var ax2 = y[1]
sns.heatmap(np.flipud(arr).T, cmap="Blues", robust= True, ax = ax2)
sns.heatmap(np.flipud(arr).T, cmap="Blues", robust=True, ax=ax2)
fig2.savefig("QualityDistributionHeatMap.png")

fn report(self) -> Tensor[DType.int64]:
Expand All @@ -439,7 +445,7 @@ struct QualityDistribution(Analyser):
@value
struct KmerContent[bits: Int = 3](Analyser):
var kmer_len: Int
var hash_counts:Tensor[DType.int64]
var hash_counts: Tensor[DType.int64]
var hash_list: List[UInt64]

fn __init__(inout self, hashes: List[UInt64], kmer_len: Int = 0):
Expand All @@ -453,13 +459,12 @@ struct KmerContent[bits: Int = 3](Analyser):
# TODO: Check if it will be easier to use the bool_tuple and hashes as a list instead
@always_inline
fn tally_read(inout self, record: FastqRecord):

var hash: UInt64 = 0
var end = 0
# Make a custom bit mask of 1s by certain length
var mask: UInt64 = (0b1 << self.kmer_len * bits) - 1
var neg_mask = mask >> bits
var bit_shift = (0b1 << bits) -1
var bit_shift = (0b1 << bits) - 1

# Check initial Kmer
if len(self.hash_list) > 0:
Expand All @@ -470,7 +475,7 @@ struct KmerContent[bits: Int = 3](Analyser):
hash = hash & neg_mask

# Mask for the least sig. three bits, add to hash
var rem = record.SeqStr[i] & bit_shift
var rem = record.SeqStr[i] & bit_shift
hash = (hash << bits) + int(rem)
if len(self.hash_list) > 0:
self._check_hashes(hash)
Expand All @@ -492,11 +497,12 @@ fn _seq_to_hash(seq: String) -> UInt64:
# Remove the most signifcant 3 bits
hash = hash & 0x1FFFFFFFFFFFFFFF
# Mask for the least sig. three bits, add to hash
var rem = ord(seq[i]) & 0b111
var rem = ord(seq[i]) & 0b111
hash = (hash << 3) + int(rem)
return hash

#TODO: Make this also parametrized on the number of bits per bp, this now works only for 3bits

# TODO: Make this also parametrized on the number of bits per bp, this now works only for 3bits
fn _hash_to_seq(hash: UInt64) -> String:
var inner = hash
var out: String = ""
Expand Down Expand Up @@ -525,6 +531,7 @@ def tensor_to_numpy_1d[T: DType](tensor: Tensor[T]) -> PythonObject:
ar.itemset(i, tensor[i])
return ar


def matrix_to_numpy[T: DType](tensor: Tensor[T]) -> PythonObject:
np = Python.import_module("numpy")
ar = np.zeros([tensor.shape()[0], tensor.shape()[1]])
Expand Down Expand Up @@ -552,11 +559,9 @@ fn grow_matrix[
return new_tensor



# TODO: Add module for adapter content
@value
struct AdapterContent(Analyser):

fn tally_read(inout self, read: FastqRecord):
pass

Expand Down

0 comments on commit d78d56c

Please sign in to comment.