# INSTRUCTIONS

### **Overview**

This notebook aims to help you to verify the sequence of the plasmid you've cloned with low cost.\
By sequencing a mixture of multiple plasmids at once by Nanopore and disassembling the obtained data, plasmid sequencing can be performed at a fraction to a tenth of the cost.

The notebook consists of 3 sections. You can start the analysis at the beginning of any section below.

- Pre-survey (Step 0)
- Alignment & Calculation of consensus (Step 1-5)
- Visualization of results (Step 6)

### **Quick start**

**Before you submit your samples**

<details><summary>Pre-survey (Step 0)</summary><div>

1. Upload your plasmid sequences (`*.fasta` or `*.dna`) under the `sample_data` directory.
2. Select the cell "**0. Pre-survey**", then press "`Runtime` -> `Run the focused cell`".
3. Recommended combiination of plasmids that is safe to mix will be displayed.

</div></details>

**Submit your samples**

<details><summary>Details</summary><div>

1. Verify that your plasmid is built properly by using diagnostic restriction digestion or a method with an equivalent degree of confidence.
2. Mix plasmids according to the results of the grouping of the pre-survey. The concentration (not the molar ratios) of each plasmid should be the same.
   
   <table>
      <caption>e.g. Making mixture of 3 plasmids (final 20 uL, 30 ng/uL)</caption>
      <tr>
        <th>
        </th>
        <th>
          conc. [ng/uL]
        </th>
        <th>
          volume [uL]
        </th>
      </tr>
      <tr align="center">
        <td>
          Plasmid 1
        </td>
        <td>
          A
        </td>
        <td>
          200 (= 20 * 30 / #_of_plasmids) / A
        </td>
      </tr>
      <tr align="center">
        <td>
          Plasmid 2
        </td>
        <td>
          B
        </td>
        <td>
          200 / B
        </td>
      </tr>
      <tr align="center">
        <td>
          Plasmid 3
        </td>
        <td>
          C
        </td>
        <td>
          200 / C
        </td>
      </tr>
      <tr align="center">
        <td>
          H2O
        </td>
        <td>
          –
        </td>
        <td>
          20 – (200 / A – 200 / B – 200 / C)
        </td>
      </tr>
   </table>

3. Send your samples to Nanopore sequencing. Be sure to receive `*.fastq` files as results. Consensus sequence results that company usually returns do not make any sense because you are sending a mixture of plasmids.

</div></details>

**After you get Nanopore results**

<details><summary>Alignment & Calculation of consensus (Step 1-5)
</summary><div>

1. Upload your plasmid sequences (`*.fasta` or `*.dna`) and Nanopore sequencing results (`*.fastq`) under the `sample_data` directory.
2. Select the cell "**1. Upload and select files**", then press "`Runtime` -> `Run after`".
3. The currently running cell is indicated by a circle with a stop sign next to it.
4. After the completion of all processes, a `result.zip` file will appear. Right-click and select "Download".

</div></details>

<details><summary>Visualization of results (Step 6)</summary><div>

1. Upload a result file (`*.alignment_with_prior.txt` or `*.alignment_without_prior.txt`) under the `sample_data` directory.
2. Enter parameters in the cell "**6. Visualize results**" (`target_file_path`, `target_position`, and `display_range`).
3. Press "`Runtime` -> `Run the focused cell`".

</div></details>

### **Details**
<details><summary>0. Pre-survey</summary><div>

---

# 0-0. Overview

Before you submit your samples, run this cell to know what combination of plasmids is safe to mix.\
If you just want to analyze the nanopore-sequencing results, start from `1. Uploadand select files` cell.

# 0-1. Upload reference (plasmid) sequence files.

Click on the little folder icon to the left, then drag and drop files under the `sample_data` directory.

Currently supported files are: 
- `*.dna` (SnapGene file)
- `*.fasta` (FASTA file)

# 0-2. Advanced settings

## Parameters used for alignment

- `gap_open_penalty`
- `gap_extend_penalty`
- `match_score`
- `mismatch_score`

Ref 1: [parasail-python](https://github.com/jeffdaily/parasail-python)\
Ref 2: [Daily, Jeff. (2016). Parasail: SIMD C library for global, semi-global, and local pairwise sequence alignments. BMC Bioinformatics, 17(1), 1-11.](http://dx.doi.org/10.1186/s12859-016-0930-z)\
Ref 2: [Smith-Waterman algorithm](https://doi.org/10.1016/0022-2836(81)90087-5)

## Parameters used for classification
- `score_threshold`

The "score" in `score_threshold` refers to the alignment score calculated using the above four parameters, but with normalization.
Normalization is performed by dividing the score by the length of the reference (plasmid) sequence.

# 0-3. Hit `Runtime` -> `Run the focused cell`
- The normalized alignment score and the recommended combination of plasmids will be displayed.
- The results are also exported as `recommended_group_(# of group).svg` under the `sample_data` directory.
- Lower alignment score between two sequence indicate that they can be mixed more safely.

---

</div></details>

<details><summary>1. Upload and select files</summary><div>

---

# 1-1. Upload files

Click on the little folder icon to the left, then drag and drop files under the `sample_data` directory to upload  your reference (plasmid) sequences and Nanopore sequencing results.

- Reference sequence files can be the same files as you uploaded in the **0. Pre-survey** step.
- Multiple `*.fastq` files can be uploaded. They will be combined inside the program.
- Reads inside `*.fastq` files will be aligned to each reference sequence at **2. Execute alignment** step, and will be assigned to one of them at **4. Calculate consensus** step based on the threshold set at **3. Set threshold for assignment** step.

Currently supported reference sequence files:
- `*.dna` (SnapGene file)
- `*.fasta` (FASTA file)

Nanopore sequencing results:
- `*.fastq` files (nanopore sequence results)

# 1-2. Select this cell and hit `Runtime` -> `Run after`

If you do not want to use all files you uploaded, press `Runtime` -> `Run the focused cell` to select files. Follow these steps:

1. After you run this cell, checkboxes for selection will appear below the cell.
2. Select files you want to include.
3. At least 1 `*.fastq` file and 1 reference sequence files (`*.dna` or `*.fasta`) are required for the analysis.
4. Select "**2. Execute alignment**" cell and hit `Runtime` -> `Run after` to process all cells after.

---

</div></details>

<details><summary>2. Execute alignment</summary><div>

---

- If the `save_to_google_drive` option was selected, the result zip file will be uploaded to your Google Drive.
- The meaning of the parameters are the same as those used in **0. Pre-survey** cell.
- If the parameters set in **0. Pre-survey** cell and in this cell are different, the former will be ignored.

---

</div></details>

<details><summary>3. Set threshold for assignment</summary><div>

---

- The meaning of the parameter is the same as that used in **0. Pre-survey** cell.
- If the parameter set in **0. Pre-survey** cell and in this cell are different, the former will be ignored.

---

</div></details>

<details><summary>4. Calculate consensus</summary><div>

---

Prior information on how often errors occur during the construction of plasmid, such as PCR, ligation, and assembly processes etc.

- `error_rate`\
   e.g. A -> T, C, G, - (They are treated equally)

- `ins_rate`\
   e.g. AA -> ANA (N represents one of the ATCG)

Regardless of the settings of the above parameters, the results without taking prior information into account are also generated (i.e., `error_rate`=0.8 and `ins_rate`=0.2).

---

</div></details>

<details><summary>5. Export results</summary><div>

---

- After the completion of all processes, a result.zip file will appear. Right-click and select "Download".
- If the `save_to_google_drive` option was selected, the result zip file will be uploaded to your Google Drive.

---

</div></details>

<details><summary>6. Visualize results</summary><div>

---

To execute this cell, follow what is described in **Quick start**

- `target_file_path`\
   exported alignment result file\
   e.g. `result/your_plasmid_file_name.alignment_without_prior.txt` ()
   e.g. `your_plasmid_file_name.alignment_without_prior.txt` (if it is directly uploaded under the `sample_data` folder)

- `target_position`\
   position from 5' end on the reference sequence

---

</div></details>

<details><summary>Known issues</summary><div>

---

# `results.zip` file cannot be uploaded on GoogleDrive

## Feb. 01, 2023

### symptom(s)

Attempting to upload large files from Colaboratory to Google Drive using PyDrive fails with a `RedirectMissingLocation` exception.

### cause(s)

The cause appears to be a bug in httplib2, on which PyDrive depends.
See the following issue for details.\
[httplib2 v0.16.0 breaks the library · Issue #803 · googleapis/google-api-python-client · GitHub](https://github.com/googleapis/google-api-python-client/issues/803)

### workaround

1. Execute this notebook normally from **1. Upload and select files**.
2. When **2. Execute alignment** is executed and the access to google drive is permitted, the following warning will appear:
   ```
   WARNING: The following packages were previously imported in this runtime:
     [httplib2]
   You must restart the runtime in order to use newly installed versions.
   ```
3. Click `RESTART RUNTIME` button just below the warning or hit `Runtime` -> `Restart runtime`
4. Execute this notebook normally from **1. Upload and select files** again.

---

</div></details>

In [None]:
#@title # 0. Pre-survey
#@markdown ## 0-1. Upload reference (plasmid) sequence files.
#@markdown - `*.dna` (SnapGene file) or `*.fasta` (FASTA file)

#@markdown ## 0-2. Advanced settings
gap_open_penalty = 3   #@param {type:"integer"}
gap_extend_penalty = 1 #@param {type:"integer"}
match_score = 1        #@param {type:"integer"}
mismatch_score = -2    #@param {type:"integer"}
score_threshold = 0.96  #@param {type:"number"}

#@markdown ## 0-3. Hit `Runtime` -> `Run the focused cell`

# install dependencies
print("installing dependencies...")
import sys, os
# save_stdout =  sys.stdout
# sys.stdout = open(os.devnull, 'w')

!apt-get install libcairo2-dev libjpeg-dev libgif-dev
!pip install snapgene-reader
!pip install parasail
!pip install pycairo

import re
import matplotlib.pyplot as plt
import scipy.spatial.distance as distance
import numpy as np
import itertools
import gc
import parasail
import xml.etree.ElementTree as ET
import cairo
from pathlib import Path
from snapgene_reader import snapgene_file_to_dict, snapgene_file_to_seqrecord
from scipy.cluster.hierarchy import dendrogram, linkage
from mpl_toolkits.axes_grid1 import make_axes_locatable

# sys.stdout.close()
# sys.stdout = save_stdout
print("installation: DONE")

%matplotlib inline

my_custom_matrix =parasail.matrix_create("ACGT", match_score, mismatch_score)

pwd = Path('/content/sample_data/')
uploaded_refseq_file_paths = [path for path in pwd.glob("*.*") if path.suffix in (".dna", ".fa", ".fasta")]

if not len(uploaded_refseq_file_paths) > 1:
    raise Exception("Please upload at least 2 reference files under the 'sample_data' directory!")

# functions
def pre_survery(refseq_list):
    N = len(refseq_list)
    total_N = N ** 2
    i = 0
    score_matrix = np.empty((N, N), dtype=float)
    for r, my_refseq in enumerate(refseq_list):
        for c, query in enumerate(refseq_list):
            print(f"\rProcessing... {i+1}/{total_N}", end="")
            i += 1
            if r != c:
                duplicated_refseq_seq = my_refseq.seq + my_refseq.seq
                score_matrix[r, c] = calc_corrected_alignment_score(duplicated_refseq_seq, query.seq)
            else:
                score_matrix[r, c] = 1
    print()
    return score_matrix

def calc_corrected_alignment_score(duplicated_refseq_seq, query_seq):
    result = parasail.sw_trace(query_seq, duplicated_refseq_seq, gap_open_penalty, gap_extend_penalty, my_custom_matrix)
    result = MyResult_Minimum(result)
    gc.collect()
    return result.score / (len(duplicated_refseq_seq) / 2)

def recommended_combination(score_matrix):

    distance_matrix = 1 - score_matrix
    N = len(distance_matrix)
    for i in range(N - 1):
        for j in range(i + 1, N):
            v = min(distance_matrix[i, j], distance_matrix[j, i])
            distance_matrix[i, j] = v
            distance_matrix[j, i] = v

    # draw_heatmap_core(a, x_labels=np.arange(score_matrix.shape[1]), y_labels=np.arange(score_matrix.shape[1]))
    # plt.show()

    # clustering by sequence similarity (similar sequences will be grouped)
    dArray = distance.squareform(distance_matrix)
    result = linkage(dArray, method='complete')
    # np.set_printoptions(suppress=True)
    # print(result)
    # distance_matrix_sorted, dendro_levels = draw_dendrogram_core(distance_matrix, result)
    # plt.show()
    # quit()

    # organize result
    threshold = 1 - score_threshold
    grouping = [[i] for i in range(N)]
    for idx1, idx2, d, number_of_sub_cluster in result:
        if d > threshold:
            break
        grouping.append(grouping[int(idx1)] + grouping[int(idx2)])
        grouping[int(idx1)] = None
        grouping[int(idx2)] = None
    grouping = [g for g in grouping if g is not None]

    print(grouping)

    # distance_matrix_sorted, dendro_levels = draw_dendrogram_core(distance_matrix, result)
    # plt.show()

    score_matrix_tmp = np.copy(score_matrix)
    def calc_group_score(index_list):
        combination_of_index = list(itertools.combinations(index_list, 2))
        scores = score_matrix_tmp[tuple(zip(*combination_of_index))]
        return scores.max() * len(scores)

    # 似た配列がコンビにならないように、組み合わせを選出（グループ：似た者同士の集合、コンビ：違う者同士の集合）
    N_combination = max(len(g) for g in grouping)
    combination_list = [[] for i in range(N_combination)]
    # 大きいグループから処理していく
    for c in range(N_combination, 0, -1):
        # 指定の長さのグループを抽出、None を追加して長さを len(combination_list) に合わせる
        selected_groups = [g + [None for i in range(N_combination - c)] for g in grouping if len(g) == c]
        # どのような組み合わせでコンビに追加するかを全通り書き出す
        selected_group_permuation = [list(set(itertools.permutations(g, N_combination))) for g in selected_groups]
        product_of_selected_group_permutation = list(itertools.product(*selected_group_permuation))
        # スコアの平均を計算
        average_score_list = []
        for prod in product_of_selected_group_permutation:
            scores = [calc_group_score(combination_list[i] + [p_sub for p_sub in p if p_sub is not None]) for i, p in enumerate(zip(*prod))]
            if len(scores):
                average_score_list.append(np.average(scores))
            else:
                average_score_list.append(np.nan)
        selected_prod = product_of_selected_group_permutation[np.argmin(average_score_list)]
        for i, p in enumerate(zip(*selected_prod)):
            combination_list[i].extend([p_sub for p_sub in p if p_sub is not None])
    print(combination_list)
    return combination_list

class MyRefSeq_Minimum():
    def __init__(self, path: Path):
        self.path = path
        if self.path.suffix == ".dna":
            snapgene_dict = snapgene_file_to_dict(self.path.as_posix())
            # seqrecord = snapgene_file_to_seqrecord(self.path.as_posix())
            assert snapgene_dict["isDNA"]
            self.topology = snapgene_dict["dna"]["topology"]
            self.strandedness = snapgene_dict["dna"]["strandedness"]
            self.length = snapgene_dict["dna"]["length"]
            self.seq = snapgene_dict["seq"]
            if self.topology != "circular":
                print(f"WARNING: {self.path.name} is not circular!")
            assert self.strandedness == "double"
            assert self.length == len(self.seq)
        elif self.path.suffix in (".fasta", ".fa"):
            with open(self.path.as_posix(), 'r') as f:
                self.seq=''
                for line in f.readlines():
                    if line[0] != '>':
                        self.seq += line.strip()
            self.topology = "circular"
            self.strandedness = "double"
            self.length = len(self.seq)
        else:
            raise Exception(f"Unsupported type of sequence file: {self.path}")
    def reverse_complement(self):
        return str(Seq(self.seq).reverse_complement())

class MyResult_Minimum():
    def __init__(self, parasail_result) -> None:
        self.cigar = parasail_result.cigar.decode.decode("ascii")
        self.score = parasail_result.score
        self.beg_ref = parasail_result.cigar.beg_ref
        self.beg_query = parasail_result.cigar.beg_query
        self.end_ref = parasail_result.end_ref
        self.end_query = parasail_result.end_query

class StringSizeWithSuffix():
    def __init__(self, size_with_suffix):
        if isinstance(size_with_suffix, str):
            m = re.match(r"([0-9.]+)([a-z]+)", size_with_suffix)
            self.size = float(m.group(1))
            self.suffix = m.group(2)
        elif isinstance(size_with_suffix, list):
            self.size = size_with_suffix[0]
            self.suffix = size_with_suffix[1]
        else:
            raise Exception("error!")
    def __add__(self, v):
        return StringSizeWithSuffix([self.size + v, self.suffix])
    def __sub__(self, v):
        return StringSizeWithSuffix([self.size - v, self.suffix])
    def __mul__(self, v):
        return StringSizeWithSuffix([self.size * v, self.suffix])
    def __truediv__(self, v):
        return StringSizeWithSuffix([self.size / v, self.suffix])
    def __floordiv__(self, v):
        return StringSizeWithSuffix([self.size // v, self.suffix])
    def __iadd__(self, v):
        self.size += v
        return self
    def __isub__(self, v):
        self.size -= v
        return self
    def __imul__(self, v):
        self.size *= v
        return self
    def __itruediv__(self, v):
        self.size /= v
        return self
    def __ifloordiv__(self, v):
        self.size //= v
        return self
    def __str__(self):
        return f"{self.size}{self.suffix}"

class Svg():
    def __init__(self, path):
        ET.register_namespace("","http://www.w3.org/2000/svg")
        tree = ET.parse(path)
        self.svg = tree.getroot()   # svg
    def adjust_margin(self, path, l, r, t, b):  # l, r, t, b represents margins to add to the viewbox of svg file.
        self.svg.attrib["width"] = str(StringSizeWithSuffix(self.svg.attrib["width"]) + l + r)
        self.svg.attrib["height"] = str(StringSizeWithSuffix(self.svg.attrib["height"]) + t + b)
        view_box = list(map(float, self.svg.attrib["viewBox"].split(" ")))
        view_box[0] -= l    # x0
        view_box[1] -= t    # y0
        view_box[2] += l + r    # x1
        view_box[3] += t + b    # y1
        self.svg.attrib["viewBox"] = " ".join(map(str, view_box))
    def draw_path(self, d, stroke="#000000", stroke_width=2, stroke_linecap="butt", stroke_linejoin="miter", stroke_opacity=1, stroke_miterlimit=4, stroke_dasharray="none"):
        self.svg.append(ET.Element(
            'path', 
            attrib={
                "style" :f"fill:none;stroke:{stroke};stroke-width:{stroke_width};stroke-linecap:{stroke_linecap};stroke-linejoin:{stroke_linejoin};stroke-opacity:{stroke_opacity};stroke-miterlimit:{stroke_miterlimit};stroke-dasharray:{stroke_dasharray}", 
                "d"     :d
            }
        ))
    def draw_text(self, string, x, y, font_style="normal", font_weight="normal", font_size="30px", line_height=1.25, font_family="sans-serif", fill="#000000", fill_opacity=1, stroke="none", stroke_width=0.75, text_anchor="middle", text_align="center"):
        text = ET.Element(
            "text", 
            attrib={
                "xml:space" :"preserve", 
                "style"     :f"font-style:{font_style};font-weight:{font_weight};font-size:{font_size};line-height:{line_height};font-family:{font_family};fill:{fill};fill-opacity:{fill_opacity};stroke:{stroke};stroke-width:{stroke_width};text-anchor:{text_anchor};text-align:{text_align}", 
                "x"         :str(x), 
                "y"         :str(y), 
            }
        )
        text.text = string
        self.svg.append(text)
    @staticmethod
    def textsize(text, fontsize, font_style):
        tmp_svg_path = "undefined.svg"
        surface = cairo.SVGSurface(tmp_svg_path, 1280, 200)
        cr = cairo.Context(surface)
        cr.select_font_face(font_style, cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_BOLD)
        cr.set_font_size(fontsize)
        xbearing, ybearing, width, height, xadvance, yadvance = cr.text_extents(text)
        os.remove(tmp_svg_path)
        return {
            "xbearing":xbearing, 
            "ybearing":ybearing, 
            "width":width, 
            "height":height, 
            "xadvance":xadvance, 
            "yadvance":yadvance
        }
    def save(self, path):   # 上書き保存されます
        tree = ET.ElementTree(element=self.svg)
        tree.write(path, encoding='utf-8', xml_declaration=True)

class D(str):
    def __new__(cls, path_command_list=[]):
        initial_string = " ".join([f"{path_command} {','.join(map(str, values))}" for path_command, values in path_command_list])
        self = super().__new__(cls, initial_string)
        return self
    def append(self, path_command, values):
        # example of horizontal line: "M 6.7760638,-8.370432 H 169.80019"
        self += f"{path_command} {','.join(map(str, values))}"

def draw_heatmap(score_matrix, refseq_names, comb_idx_list, threshold_used, save_path):
    # label
    comb_txt_list = [f"threshold={threshold_used}"] + [', '.join(['P' + str(i+1) for i in comb_idx_list])]
    font_style = "Helvetica"
    tmp_names = [f"P{i+1}" for i in range(len(refseq_names))]
    details_list = [f"{i: <4}: {j}" for i, j in zip(tmp_names, refseq_names)] + [""] + comb_txt_list

    # size
    dpi = 72
    value_font_size=10
    tick_font_size=12
    label_font_size = 14
    details_font_size = 8
    left_top_margin = (tick_font_size + label_font_size) * 2
    bottom_margin = details_font_size * (len(details_list) + 1)
    details_width = max([Svg.textsize(details, details_font_size, font_style)["xadvance"] for details in details_list])

    # make matplotlib fig
    fig, ax = draw_heatmap_core(score_matrix, x_labels=tmp_names, y_labels=tmp_names, value_font_size=value_font_size, tick_font_size=tick_font_size)

    # highlight combination
    for i in comb_idx_list:
        for j in comb_idx_list:
            if i == j:
                continue
            else:
                highlight_cell(i, j, ax=ax, color="r")

    # add titles etc.
    ax.set_xlabel("query", fontsize=label_font_size^)
    ax.set_ylabel("reference", fontsize=label_font_size)
    plt.savefig(save_path, dpi=dpi)

    # adjust saved svg
    svg = Svg(save_path)
    x0, y0, x1, y1 = svg.svg.attrib["viewBox"].split(" ")
    assert float(x0) == float(y0) == 0
    for i, details in enumerate(details_list):
        svg.draw_text(details, x=label_font_size - left_top_margin, y=float(y1) + float(y0) + details_font_size * (i + 1), font_size=details_font_size, text_anchor="left", font_style=font_style)
    right_margin = max(details_width + label_font_size - left_top_margin - float(x1), (tick_font_size + label_font_size) * 2)
    svg.adjust_margin(save_path, l=left_top_margin, r=right_margin, t=left_top_margin, b=bottom_margin)
    svg.save(save_path)

def draw_heatmap_core(score_matrix, x_labels, y_labels, value_font_size=10, tick_font_size=14, subplot=[1,1,1]):
    assert score_matrix.shape == (len(y_labels), len(x_labels))
    figsize_unit = 0.5

    fig =plt.figure(figsize=(len(x_labels) * figsize_unit, len(y_labels) * figsize_unit))
    ax = plt.subplot(*subplot)
    im = plt.imshow(score_matrix, cmap="YlGn", vmin=0, vmax=1)
    bar = plt.colorbar(im, fraction=0.046, pad=0.04)
    # Loop over data dimensions and create text annotations.
    for i in range(score_matrix.shape[0]):
        for j in range(score_matrix.shape[1]):
            if np.isnan(score_matrix[i, j]):
                continue
            value = f"{np.round(score_matrix[i, j], 3):0<5}"
            if np.absolute(score_matrix[i, j]) < score_matrix.max()/2:
                text = ax.text(j, i, value, ha="center", va="center", color="k", fontsize=value_font_size)
            else:
                text = ax.text(j, i, value, ha="center", va="center", color="w", fontsize=value_font_size)

    # Show all ticks and label them with the respective list entries
    # x
    ax.xaxis.tick_top()
    ax.xaxis.set_label_position('top')
    ax.set_xticks(np.arange(len(x_labels)))
    ax.set_xticklabels(labels=x_labels, fontsize=tick_font_size)
    
    # y
    ax.set_yticks(np.arange(len(y_labels)))
    ax.set_yticklabels(labels=y_labels, fontsize=tick_font_size)
    plt.subplots_adjust(bottom=0.0, left=0.0, right=1, top=1)
    return fig, ax

def highlight_cell(x, y, ax=None, **kwargs):
    rect = plt.Rectangle((x-0.45, y-0.45), 0.9, 0.9, fill=False, **kwargs)
    ax = ax or plt.gca()
    ax.add_patch(rect)
    return rect

# open files
my_refseq_list = [
    MyRefSeq_Minimum(refseq_file_path) for refseq_file_path in uploaded_refseq_file_paths
]

# calc distance & propose optimized combination
score_matrix = pre_survery(my_refseq_list)
comb = recommended_combination(score_matrix)
refseq_names = [refseq.path.name for refseq in my_refseq_list]

# remove before make
for i in pwd.glob(f"recommended_group_*.svg"):
    i.unlink()

# draw histogram(s)
for group_idx, comb_idx_list in enumerate(comb):
    save_path = pwd / (f"recommended_group_{group_idx}.svg")
    draw_heatmap(score_matrix, refseq_names, comb_idx_list, score_threshold, save_path)

print()
print("#########################")
print("# Recommended groupings #")
print("#########################")
for group_idx, comb_idx_list in enumerate(comb):
    print(f"Group{group_idx}")
    for comb_idx in comb_idx_list:
        print(f"P{comb_idx+1: <4}: " + refseq_names[comb_idx])
    print()



In [None]:
#@title # 1. Upload and select files

app_name = "MyApp"
version = "0.1.1"
description = "written by MU"

from pathlib import Path
pwd = Path('/content/sample_data/')
uploaded_fastq_files = [path for path in pwd.glob("*.fastq")]
uploaded_refseq_files = [path for path in pwd.glob("*.*") if path.suffix in (".dna", ".fa", ".fasta")]

if not len(uploaded_fastq_files) > 0:
    raise Exception("Please upload fastq files under the 'sample_data' directory!")
if not len(uploaded_refseq_files) > 0:
    raise Exception("Please upload reference files under the 'sample_data' directory!")

#@markdown ## 1-1. Upload files
#@markdown - `*.dna` (SnapGene file) or `*.fasta` (FASTA file)
#@markdown - `*.fastq` (Nanopore sequence results)

#@markdown ## 1-2. Select this cell and hit `Runtime` -> `Run after`

from IPython.display import display
from ipywidgets import Checkbox, VBox, Layout, interactive_output, Label
# widgets
child_widget_list = [Label("# fastq files")]
arg_dict = {}
for uploaded_fastq in uploaded_fastq_files:
    ckbx = Checkbox(value=True, description=uploaded_fastq.name, indent=False, layout=Layout(width='80%'))
    child_widget_list.append(ckbx)
    arg_dict[uploaded_fastq.name] = ckbx
child_widget_list.extend([Label(" "), Label("# dna files")])
for uploaded_refseq in uploaded_refseq_files:
    ckbx = Checkbox(value=True, description=uploaded_refseq.name, indent=False, layout=Layout(width='80%'))
    child_widget_list.append(ckbx)
    arg_dict[uploaded_refseq.name] = ckbx
ui = VBox(children=child_widget_list)

# observation function
def select_data(**kwargs):
    N_fastq = sum([1 for k, ckbx in kwargs.items() if (ckbx & k.endswith(".fastq"))])
    N_refseq = sum([1 for k, ckbx in kwargs.items() if ckbx & (not k.endswith(".fastq"))])
    if (N_fastq > 0) & (N_refseq > 0):
        print(f"\n{N_fastq} fastq files selected\n{N_refseq} reference sequence files selected")
    else:
        error_text = ""
        if N_fastq == 0:
            error_text += f"\nPlease select at least 1 fastq file!"
        if N_refseq == 0:
            error_text += f"\nPlease select at least 1 reference sequence file!"
        print(error_text)

# display
out = interactive_output(select_data, arg_dict)
display(ui, out)



VBox(children=(Label(value='# fastq files'), Checkbox(value=True, description='Uematsu_n7x_1_MU-test1.fastq', …

Output()

In [None]:
#@title # 2. Execute alignment
save_to_google_drive = True #@param {type:"boolean"}
if save_to_google_drive:
  from pydrive.drive import GoogleDrive
  from pydrive.auth import GoogleAuth
  from google.colab import auth
  from oauth2client.client import GoogleCredentials
  auth.authenticate_user()
  gauth = GoogleAuth()
  gauth.credentials = GoogleCredentials.get_application_default()
  drive = GoogleDrive(gauth)
  print("You are logged into Google Drive and are good to go!")

gap_open_penalty = 3   #@param {type:"integer"}
gap_extend_penalty = 1 #@param {type:"integer"}
match_score = 1        #@param {type:"integer"}
mismatch_score = -2    #@param {type:"integer"}

print("installing dependencies...")
import sys, os
# save_stdout =  sys.stdout
# sys.stdout = open(os.devnull, 'w')

!pip install -U httplib2==0.15.0
!pip install snapgene-reader
!pip install parasail

# sys.stdout.close()
# sys.stdout = save_stdout

import io
import numpy as np
import pandas as pd
import re
import copy
import zipfile
import parasail
import gc
import datetime
import textwrap
import contextlib
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from itertools import product
from collections import OrderedDict, namedtuple, defaultdict
from snapgene_reader import snapgene_file_to_dict, snapgene_file_to_seqrecord
from Bio.Seq import Seq
from numpy.core.memmap import uint8
from PIL import Image

class MyFastQ(OrderedDict):
    def __init__(self, path=None):
        super().__init__()
        self.path = path
        if self.path is not None: # for deep copy
            with open(self.path.as_posix(), "r") as f:
                fastq_txt = f.readlines()
            # check
            self.N_seq, mod = divmod(len(fastq_txt), 4)
            assert mod == 0
            # register
            for i in range(self.N_seq):
                seq_id = fastq_txt[4 * i].strip()
                seq = fastq_txt[4 * i + 1].strip()
                p = fastq_txt[4 * i + 2].strip()
                q_scores = [ord(q) - 33 for q in fastq_txt[4 * i + 3].strip()]
                assert p == "+"
                assert len(seq) == len(q_scores)
                self[seq_id] = [seq, q_scores]
        else:
            pass
    def get_read_lengths(self):
        return np.array([len(v[0]) for v in self.values()])
    def get_q_scores(self):
        q_scores = []
        for v in self.values():
            q_scores.extend(v[1])
        return np.array(q_scores)
    def get_new_seq_id(self, k):
        if k not in self.keys():
            return k
        else:
            n = 1
            new_k = f"{k} {n}"
            while new_k in self.keys():
                n += 1
                new_k = f"{k} {n}"
            return new_k
    def append(self, fastq):
        for k, v in fastq.items():
            new_k = self.get_new_seq_id(k)
            self[new_k] = v
    @staticmethod
    def combine(fastq_list):
        assert len(fastq_list) > 1
        combined_fastq = copy.deepcopy(fastq_list[0])
        for fastq in fastq_list[1:]:
            combined_fastq.append(fastq)
        return combined_fastq
    def __getitem__(self, k):
        if not isinstance(k, slice):
            return OrderedDict.__getitem__(self, k)
        x = self.__class__()
        if k.start is None: start = 0
        else:               start = k.start
        if k.stop is None: stop = len(self) - k.stop
        else:              stop = k.stop
        assert (0 <= start <= stop)
        for idx, key in enumerate(self.keys()):
            if start <= idx < stop:
                x[key] = self[key]
        return x

class MyRefSeq():
    def __init__(self, path: Path):
        self.path = path
        if self.path.suffix == ".dna":
            snapgene_dict = snapgene_file_to_dict(self.path.as_posix())
            # seqrecord = snapgene_file_to_seqrecord(self.path.as_posix())
            assert snapgene_dict["isDNA"]
            self.topology = snapgene_dict["dna"]["topology"]
            self.strandedness = snapgene_dict["dna"]["strandedness"]
            self.length = snapgene_dict["dna"]["length"]
            self.seq = snapgene_dict["seq"]
            if self.topology != "circular":
                print(f"WARNING: {self.path.name} is not circular!")
            assert self.strandedness == "double"
            assert self.length == len(self.seq)
        elif self.path.suffix in (".fasta", ".fa"):
            with open(self.path.as_posix(), 'r') as f:
                self.seq=''
                for line in f.readlines():
                    if line[0] != '>':
                        self.seq += line.strip()
            self.topology = "circular"
            self.strandedness = "double"
            self.length = len(self.seq)
        else:
            raise Exception(f"Unsupported type of sequence file: {self.path}")
    def reverse_complement(self):
        return str(Seq(self.seq).reverse_complement())

# get files selected above and generate class objects
fastq_list = []
refseq_list = []
for child_widget in child_widget_list:
    if not isinstance(child_widget, Checkbox):
        continue
    if child_widget.value & child_widget.description.endswith(".fastq"):
        fastq = MyFastQ(pwd / child_widget.description)
        fastq_list.append(fastq)
    elif child_widget.value & (not child_widget.description.endswith(".fastq")):
        refseq = MyRefSeq(pwd / child_widget.description)
        refseq_list.append(refseq)

# assert refseq
if len(refseq_list) == 0:
    raise Exception("Please select at least 1 reference sequence file!")
refseq_stem_list = [refseq.path.stem for refseq in refseq_list]
if len(refseq_stem_list) != len(set(refseq_stem_list)):
    raise Exception("The file name must not be the same even if the extension is different.")

# assert fastq
if len(fastq_list) > 1:
    combined_fastq = MyFastQ.combine(fastq_list)
elif len(fastq_list) == 1:
    combined_fastq = copy.deepcopy(fastq_list[0])
else:
    raise Exception("Please select at least 1 fastq file!")
combined_fastq.path = [fastq.path for fastq in fastq_list]
print("installation: DONE")

# Definition of main classes
class MyResult():
    def __init__(self, parasail_result) -> None:
        self.cigar = parasail_result.cigar.decode.decode("ascii")
        self.score = parasail_result.score
        self.beg_ref = parasail_result.cigar.beg_ref
        self.beg_query = parasail_result.cigar.beg_query
        self.end_ref = parasail_result.end_ref
        self.end_query = parasail_result.end_query

class MyAligner():
    gap_open_penalty = gap_open_penalty
    gap_extend_penalty = gap_extend_penalty
    match_score = match_score
    mismatch_score = mismatch_score
    my_custom_matrix = my_custom_matrix =parasail.matrix_create("ACGT", match_score, mismatch_score)
    def __init__(self, refseq_list, combined_fastq):
        self.refseq_list = refseq_list
        self.combined_fastq = combined_fastq
        self.duplicated_refseq_seq_list = []
        self.is_refseq_seq_all_ATGC_list = []
        for refseq in refseq_list:
            self.duplicated_refseq_seq_list.append(refseq.seq + refseq.seq)
            is_all_ATCG = all([b.upper() in "ATCG" for b in refseq.seq])
            if not is_all_ATCG:
                print(f"\033[38;2;255;0;0mWARNING: Non-ATCG letter(s) were found in '{refseq.path.name}'.\nWhen calculating the alignment score, they are treated as 'mismatched', no matter what characters they are.\033[0m")
            self.is_refseq_seq_all_ATGC_list.append(is_all_ATCG)
    # refが環状プラスミドであるために、それを元に戻すのに使う（プラスミド上のどこがシーケンスの始まりと終わりなのか）を決めるのに使うカスタムのスコア
    def get_custom_cigar_score_dict(self):
        return {
            "=":self.match_score, 
            "X":self.mismatch_score, 
            "D":self.gap_open_penalty * -1, 
            "H":self.gap_open_penalty * -1, 
            "S":0, 
            "N":0, 
            "I":0
        }
    # calcualted based on gap_open_penalty, gap_extend_penalty, match_score, mismatch_score
    def clac_cigar_score(self, cigar_str):
        # なぜか result の cigar に、左端にたくさん D もしくは I が連なることがあるので、それを除く
        cigar_str_NL_list = re.findall('(\d+)(\D)', cigar_str)
        if cigar_str_NL_list[0][1] == "D":
            cigar_str_NL_list = cigar_str_NL_list[1:]
        elif cigar_str_NL_list[0][1] == "I":
            cigar_str_NL_list = cigar_str_NL_list[1:]
        score = 0
        for N, L in cigar_str_NL_list:
            N = int(N)
            if L == "=":
                score += self.match_score * N
            elif L == "X":
                score += self.mismatch_score * N
            elif L in "DI":
                score -= self.gap_open_penalty + self.gap_extend_penalty * (N - 1)
            else:
                raise Exception(f"unknown letter code: {L}")
        return score
    def align_all(self):
        fastq_len = len(self.combined_fastq)
        result_dict = OrderedDict()
        for query_idx, (seq_id, (query_seq, q_scores)) in enumerate(list(self.combined_fastq.items())):
            print(f"\rExecuting alignment: {query_idx + 1} out of {fastq_len} ({seq_id})", end="")
            # calc scores for each refseq
            result_list = []
            for duplicated_refseq_seq, is_refseq_seq_all_ATGC in zip(self.duplicated_refseq_seq_list, self.is_refseq_seq_all_ATGC_list):
                # なぜか result の cigar に、左端にたくさん D もしくは I が連なることがあるが、多分スコアはちゃんと計算されてる
                result = parasail.sw_trace(query_seq, duplicated_refseq_seq, self.gap_open_penalty, self.gap_extend_penalty, self.my_custom_matrix)
                result = MyResult(result)
                result_rc = parasail.sw_trace(str(Seq(query_seq).reverse_complement()), duplicated_refseq_seq, self.gap_open_penalty, self.gap_extend_penalty, self.my_custom_matrix)
                result_rc = MyResult(result_rc)
                # 一応スコアを確認する
                if is_refseq_seq_all_ATGC:
                    assert result.score == self.clac_cigar_score(result.cigar)
                    assert result_rc.score == self.clac_cigar_score(result_rc.cigar)
                # レジスター
                result_list.append(result)
                result_list.append(result_rc)
                gc.collect()
            result_dict[seq_id] = result_list
        return result_dict

class MyCigarStr(str):
    def __new__(cls, cigar_str):
        # when common cigar strings are passed
        if cigar_str[0].isdecimal():
            val = "".join([
                L for N, L in re.findall('(\d+)(\D)', cigar_str)
                    for i in range(int(N))
            ])
            self = super().__new__(cls, val)
            return self
        # when "MyCigarStr" strings are passed
        else:
            self = super().__new__(cls, cigar_str)
            return self
    def __iadd__(self, other):
        return self.__class__(self + other)
    def invert(self):
        return self.__class__(self[::-1])
    def number_of_letters_on_5prime(self, letters):
        for i, l in enumerate(self):
            if l not in letters:
                return i
    def number_of_letters_on_3prime(self, letters):
        for k, l in enumerate(self[::-1]):
            if l not in letters:
                return k
    def clip_from_both_ends(self, letters):
        i = self.number_of_letters_on_5prime(letters)
        k = self.number_of_letters_on_3prime(letters)
        return self.__class__(self[i:len(self) - k])
    def clipped_len(self):
        return len(self.clip())

# Execute
my_aligner = MyAligner(refseq_list, combined_fastq)
result_dict = my_aligner.align_all()
print()
print("alignment: DONE")


In [None]:
#@title # 3. Set threshold for assignment

score_threshold = 0.5  #@param {type:"number"}

class AlignmentResult():
    score_threshold = score_threshold
    def __init__(self, result_dict, my_aligner):
        self.result_dict = result_dict
        self.my_aligner = my_aligner
        # attributs to register results
        self.score_list_ALL = None
        self.result_info_assigned = None
        self.aligned_result_list = None
    def get_score_summary_df(self):
        records = []
        for info in self.score_list_ALL:
            d = OrderedDict()
            d["query_idx"] = info["query_idx"]
            d["seq_id"] = info["seq_id"]
            for i, refseq in enumerate(self.my_aligner.refseq_list):
                d[f"{refseq.path.name} (idx={i})"]= info["score_list"][2 * i]
                d[f"{refseq.path.name} (idx={i},rc)"] = info["score_list"][2 * i + 1]
            for i, refseq in enumerate(self.my_aligner.refseq_list):
                d[f"{refseq.path.name} (idx={i}, normalized)"]= info["normalized_score_list"][2 * i]
                d[f"{refseq.path.name} (idx={i},rc, normalized)"] = info["normalized_score_list"][2 * i + 1]
            d["assigned_refseq_idx"] = info["assigned_refseq_idx"]
            d["is_reverse_compliment"] = info["is_reverse_compliment"]
            d["assigned"] = info["assigned"]
            records.append(d)
        return pd.DataFrame.from_records(records)
    def save_score_summary(self, save_path):
        score_summary = (
            "query_idx" 
            + "\tseq_id" 
            + "\t" 
            + "\t".join([f"{refseq.path.name} (idx={i})\t{refseq.path.name} (idx={i},rc)" for i, refseq in enumerate(self.my_aligner.refseq_list)])
            + "\t"
            + "\t".join([f"{refseq.path.name} (idx={i}, normalized)\t{refseq.path.name} (idx={i},rc, normalized)" for i, refseq in enumerate(self.my_aligner.refseq_list)])
            + "\tassigned_refseq_idx"
            + "\tis_reverse_compliment"
            + "\tassigned\n"
            + "\n".join(
                [(
                    str(info["query_idx"])
                    + "\t" + info["seq_id"]
                    + "\t" + "\t".join(map(str, info["score_list"]))
                    + "\t" + "\t".join(map(str, info["normalized_score_list"]))
                    + "\t" + str(info["assigned_refseq_idx"])
                    + "\t" + str(info["is_reverse_compliment"])
                    + "\t" + str(info["assigned"])
                ) for info in self.score_list_ALL]
            )
        )
        with open(save_path, "w") as f:
            f.write(score_summary)
        return score_summary
    def normalize_scores_and_apply_threshold(self):
        self.score_list_ALL = []
        self.result_info_assigned = [[] for i in self.my_aligner.refseq_list] # [[[seq_id, is_reverse_compliment, result, query_idx], ...], ...]
        assert len(self.result_dict) == len(self.my_aligner.combined_fastq)
        for query_idx, (seq_id, result_list) in enumerate(self.result_dict.items()):
            assert len(result_list) == len(self.my_aligner.duplicated_refseq_seq_list) * 2
            # normalize scores for each refseq
            score_list = []
            normalized_score_list = []
            for result_idx, result in enumerate(result_list):
                score_list.append(result.score)
                duplicated_refseq_seq = self.my_aligner.duplicated_refseq_seq_list[result_idx // 2]
                normalized_score = result.score / len(duplicated_refseq_seq) * 2
                if normalized_score > 1:
                    normalized_score = 1
                normalized_score_list.append(normalized_score)
            # choose sequence with maximum score
            idx = np.argmax(normalized_score_list)
            refseq_idx, is_reverse_compliment = divmod(idx, 2)
            # quality check
            assigned = (normalized_score_list[idx] >= self.score_threshold) & (len(self.my_aligner.combined_fastq[seq_id][0]) <= len(self.my_aligner.duplicated_refseq_seq_list[refseq_idx]))    # refseq の長さの二倍以上ある query_seq は omit する

            # register
            self.score_list_ALL.append({
                "query_idx":query_idx, 
                "seq_id":seq_id, 
                "score_list":score_list, 
                "normalized_score_list":normalized_score_list, 
                "assigned_refseq_idx":refseq_idx, 
                "is_reverse_compliment":is_reverse_compliment, 
                "assigned":int(assigned)
            })
            if assigned:
                self.result_info_assigned[refseq_idx].append([
                    seq_id, 
                    is_reverse_compliment, 
                    result_list[idx], 
                    query_idx
                ])
    def integrate_assigned_result_info(self):
        self.aligned_result_list = []
        assert len(self.my_aligner.refseq_list) == len(self.result_info_assigned)
        total_N = len(self.my_aligner.refseq_list)
        for cur_idx, (refseq, result_info_list) in enumerate(zip(self.my_aligner.refseq_list, self.result_info_assigned)):
            print(f"\rIntegrating alignment results: {cur_idx + 1} out of {total_N}", end="")
            if len(result_info_list) > 0:
                my_cigar_str_list = []
                new_q_scores_list = []
                new_seq_list = []
                seq_id_list, is_reverse_compliment_list, result_list, query_idx_list = list(zip(*result_info_list))
                for result, is_reverse_compliment, seq_id in zip(result_list, is_reverse_compliment_list, seq_id_list):
                    # query info
                    seq = self.my_aligner.combined_fastq[seq_id][0]
                    q_scores = self.my_aligner.combined_fastq[seq_id][1]
                    if is_reverse_compliment:
                        seq = str(Seq(seq).reverse_complement())
                        q_scores = q_scores[::-1]
                    # results
                    my_cigar_str = MyCigarStr(result.cigar)
                    # organize alignment based on refseq
                    number_of_ref_bases_before_query = result.beg_ref - result.beg_query
                    number_of_ref_bases_after_query = (refseq.length * 2 - result.end_ref - 1) - (len(seq) - result.end_query - 1)
                    """
                                beg_ref(9)      end_ref(24)
                                     |                |
                    pos     0         10         20         30
                    ref     atcgatcggGGCTATG-CTTGCAT-GCatcgatcg
                    align   HHHHHHHSS====X==I===D===N==SSSHHHHH
                    query          caGGCTGTGACTT-CAT-GCtga
                    pos            0         10          20
                                     |                |         
                                beg_query(2)    end_query(17)
                    """
                    # truncate
                    assert (number_of_ref_bases_before_query >= 0) or (number_of_ref_bases_after_query >= 0)
                    if (number_of_ref_bases_before_query < 0):
                        q_scores = q_scores[-number_of_ref_bases_before_query:]
                        seq      = seq[-number_of_ref_bases_before_query:]
                        result.beg_query -= -number_of_ref_bases_before_query
                        result.end_query -= -number_of_ref_bases_before_query                        
                        number_of_ref_bases_before_query = 0
                    if (number_of_ref_bases_after_query < 0):
                        q_scores = q_scores[:number_of_ref_bases_after_query]
                        seq      = seq[:number_of_ref_bases_after_query]
                        # result.beg_query # do nothing
                        # result.end_query # do nothing
                        number_of_ref_bases_after_query = 0
                    assert (number_of_ref_bases_before_query >= 0) & (number_of_ref_bases_after_query >= 0)
                    # organize my_cigar_str
                    my_cigar_str = MyCigarStr(
                        "H" * number_of_ref_bases_before_query      # add deletion of ref
                        + "S" * result.beg_query                    # soft clip of query
                        + my_cigar_str                              # aligned region
                        + "S" * (len(seq) - result.end_query - 1)   # soft clip of query
                        + "H" * number_of_ref_bases_after_query     # add deletion of ref
                    )
                    my_cigar_str = MyCigarStr(
                        "H" * my_cigar_str.number_of_letters_on_5prime("HD")    # なぜか parasail の結果で 5'側に D が連なっている場合があるので、それを除く（本来 beg_ref で調節されるべき？）
                        + my_cigar_str.clip_from_both_ends("HD")
                        + "H" * my_cigar_str.number_of_letters_on_3prime("HD")
                    )
                    my_cigar_str_H_clip = my_cigar_str.clip_from_both_ends("H")
                    assert len(q_scores) == len(seq) == len(my_cigar_str_H_clip) - my_cigar_str_H_clip.count("D")
                    assert refseq.length * 2 == len(my_cigar_str) - my_cigar_str_H_clip.count("I")

                    # なぜか parasail の結果で 5'側に I が連なっている場合があるので、それを除く（本来 beg_query で調節されるべき？）
                    number_of_I_on_5prime = my_cigar_str.number_of_letters_on_5prime("I")
                    if number_of_I_on_5prime > 0:
                        my_cigar_str = MyCigarStr(my_cigar_str[number_of_I_on_5prime:])
                        q_scores = q_scores[number_of_I_on_5prime:]
                        seq = seq[number_of_I_on_5prime:]

                    my_cigar_str_list.append(my_cigar_str)
                    new_q_scores_list.append(q_scores)
                    new_seq_list.append(seq)
                # further organize to match refseq, new_seq (query), and new_qscores.
                my_cigar_str_net_length_list = [len(my_cigar_str) - my_cigar_str.count("I") for my_cigar_str in my_cigar_str_list]
                assert all(my_cigar_str_net_length_list[0] == x for x in my_cigar_str_net_length_list)

                duplicated_refseq = refseq.seq + refseq.seq
                duplicated_refseq_with_insertion = ""
                my_cigar_str_list_with_insertion = ["" for i in my_cigar_str_list]
                new_q_scores_list_with_insertion = [[] for i in new_q_scores_list]
                new_seq_list_with_insertion = ["" for i in new_seq_list]

                # print(duplicated_refseq)
                # for i in new_seq_list:
                #     print(i)

                # current idx (positions of new_q_scores and new_seq are the same)
                cur_refseq_idx = 0
                cur_my_cigar_str_idx_list = [0 for i in my_cigar_str_list]
                cur_q_scores_idx_list = [0 for i in new_q_scores_list]
                max_refseq_idx = refseq.length * 2 - 1
                max_my_cigar_str_idx_list = [len(my_cigar_str) - 1 for my_cigar_str in my_cigar_str_list]
                max_q_scores_idx_list = [len(new_q_scores) - 1 for new_q_scores in new_q_scores_list]

                # print(max_refseq_idx)
                # print(max_my_cigar_str_idx_list)
                # print(max_q_scores_idx_list)

                all_done = False
                cur_idx = -1
                while not all_done:
                    cur_idx += 1
                    cur_my_cigar_letter_list = [my_cigar_str[cur_my_cigar_str_idx] for my_cigar_str, cur_my_cigar_str_idx in zip(my_cigar_str_list, cur_my_cigar_str_idx_list)]
                    if "I" not in cur_my_cigar_letter_list:
                        for i, L in enumerate(cur_my_cigar_letter_list):
                            if L in "DH":
                                my_cigar_str_list_with_insertion[i] += L
                                new_q_scores_list_with_insertion[i] += [-1]
                                new_seq_list_with_insertion[i]      += "-"
                                cur_my_cigar_str_idx_list[i]        += 1
                            elif L in "SX=":
                                my_cigar_str_list_with_insertion[i] += L
                                new_q_scores_list_with_insertion[i] += [ new_q_scores_list[i][cur_q_scores_idx_list[i]] ]
                                new_seq_list_with_insertion[i]      += new_seq_list[i][cur_q_scores_idx_list[i]]
                                cur_my_cigar_str_idx_list[i]        += 1
                                cur_q_scores_idx_list[i]            += 1
                            else:
                                print(L)
                                raise Exception("error!")
                        else:
                            duplicated_refseq_with_insertion += duplicated_refseq[cur_refseq_idx]
                            if cur_refseq_idx == refseq.length:
                                turning_idx = cur_idx   # 後半開始
                            cur_refseq_idx += 1
                    else:
                        # TODO: insertion 同士に関してはアラインメントしてないよ！
                        for i, L in enumerate(cur_my_cigar_letter_list):
                            if L == "I":
                                my_cigar_str_list_with_insertion[i] += "I"
                                new_q_scores_list_with_insertion[i] += [ new_q_scores_list[i][cur_q_scores_idx_list[i]] ]
                                new_seq_list_with_insertion[i]      += new_seq_list[i][cur_q_scores_idx_list[i]]
                                cur_my_cigar_str_idx_list[i]        += 1
                                cur_q_scores_idx_list[i]            += 1
                            else:
                                my_cigar_str_list_with_insertion[i] += "N"
                                new_q_scores_list_with_insertion[i] += [-1]
                                new_seq_list_with_insertion[i]      += "-"
                        else:
                            duplicated_refseq_with_insertion += "-"

                    # インデックスの最大値を参照して、すべて終わったら終える！
                    all_done = bool(
                        (cur_refseq_idx > max_refseq_idx)
                        * all([cur_my_cigar_str_idx > max_my_cigar_str_idx for cur_my_cigar_str_idx, max_my_cigar_str_idx in zip(cur_my_cigar_str_idx_list, max_my_cigar_str_idx_list)])
                        * all([cur_q_scores_idx > max_q_scores_idx for cur_q_scores_idx, max_q_scores_idx in zip(cur_q_scores_idx_list, max_q_scores_idx_list)])
                    )
                # print(duplicated_refseq_with_insertion)
                # print(duplicated_refseq_with_insertion[turning_idx:])
                for i in my_cigar_str_list_with_insertion:
                    assert len(duplicated_refseq_with_insertion) == len(i)
                for i in new_seq_list_with_insertion:
                    assert len(duplicated_refseq_with_insertion) == len(i)
                for i in new_q_scores_list_with_insertion:
                    assert len(duplicated_refseq_with_insertion) == len(i)

                # linearlize
                refseq_with_insertion_1            = duplicated_refseq_with_insertion[:turning_idx]
                refseq_with_insertion_2            = duplicated_refseq_with_insertion[turning_idx:]
                assert (len(refseq_with_insertion_1) - refseq_with_insertion_1.count("-")) == (len(refseq_with_insertion_2) - refseq_with_insertion_2.count("-")) == refseq.length
                my_cigar_str_list_with_insertion_1 = [i[:turning_idx] for i in my_cigar_str_list_with_insertion]
                my_cigar_str_list_with_insertion_2 = [i[turning_idx:] for i in my_cigar_str_list_with_insertion]
                new_seq_list_with_insertion_1      = [i[:turning_idx] for i in new_seq_list_with_insertion]
                new_seq_list_with_insertion_2      = [i[turning_idx:] for i in new_seq_list_with_insertion]
                new_q_scores_list_with_insertion_1 = [i[:turning_idx] for i in new_q_scores_list_with_insertion]
                new_q_scores_list_with_insertion_2 = [i[turning_idx:] for i in new_q_scores_list_with_insertion]

                # 前半と後半をアラインメント（insertionを考慮するだけで良い）
                # 前半の末端処理
                assert refseq_with_insertion_2[-1] != "-"
                N_gap_refseq_with_insertion_1_end = 1
                while True:
                    if refseq_with_insertion_1[-N_gap_refseq_with_insertion_1_end] != "-":
                        break
                    N_gap_refseq_with_insertion_1_end += 1
                if N_gap_refseq_with_insertion_1_end > 1:
                    refseq_with_insertion_1 = refseq_with_insertion_1[:1 - N_gap_refseq_with_insertion_1_end]
                    my_cigar_str_list_with_insertion_1 = [i[:1 - N_gap_refseq_with_insertion_1_end] for i in my_cigar_str_list_with_insertion_1]
                    new_seq_list_with_insertion_1 = [i[:1 - N_gap_refseq_with_insertion_1_end] for i in new_seq_list_with_insertion_1]
                    new_q_scores_list_with_insertion_1 = [i[:1 - N_gap_refseq_with_insertion_1_end] for i in new_q_scores_list_with_insertion_1]
                # 前半後半アラインメント開始
                idx = -1
                refseq_idx1 = 0
                refseq_idx2 = 0
                refseq_max_idx1 = len(refseq_with_insertion_1) - 1
                refseq_max_idx2 = len(refseq_with_insertion_2) - 1
                while True:
                    idx += 1
                    if refseq_with_insertion_1[idx] == refseq_with_insertion_2[idx]:
                        refseq_idx1 += 1
                        refseq_idx2 += 1
                    elif refseq_with_insertion_1[idx] == "-":
                        refseq_idx1 += 1
                        refseq_with_insertion_2 = refseq_with_insertion_2[:idx] + "-" + refseq_with_insertion_2[idx:]
                        for i, j in enumerate(my_cigar_str_list_with_insertion_2):
                            my_cigar_str_list_with_insertion_2[i] = j[:idx] + "N" + j[idx:]
                        for i, j in enumerate(new_seq_list_with_insertion_2):
                            new_seq_list_with_insertion_2[i]      = j[:idx] + "-" + j[idx:]
                        for i in new_q_scores_list_with_insertion_2:
                            i.insert(idx, -1)
                    elif refseq_with_insertion_2[idx] == "-":
                        refseq_idx2 += 1
                        refseq_with_insertion_1 = refseq_with_insertion_1[:idx] + "-" + refseq_with_insertion_1[idx:]
                        for i, j in enumerate(my_cigar_str_list_with_insertion_1):
                            my_cigar_str_list_with_insertion_1[i] = j[:idx] + "N" + j[idx:]
                        for i, j in enumerate(new_seq_list_with_insertion_1):
                            new_seq_list_with_insertion_1[i]      = j[:idx] + "-" + j[idx:]
                        for i in new_q_scores_list_with_insertion_1:
                            i.insert(idx, -1)
                    else:
                        raise Exception("error!")
                    # end
                    if (refseq_idx1 == refseq_max_idx1) & (refseq_idx2 == refseq_max_idx2):
                        break
                # check
                assert refseq_with_insertion_1 == refseq_with_insertion_2
                for i in my_cigar_str_list_with_insertion_1:
                    assert len(refseq_with_insertion_1) == len(i)
                for i in my_cigar_str_list_with_insertion_2:
                    assert len(refseq_with_insertion_2) == len(i)
                for i in new_seq_list_with_insertion_1:
                    assert len(refseq_with_insertion_1) == len(i)
                for i in new_seq_list_with_insertion_2:
                    assert len(refseq_with_insertion_2) == len(i)
                for i in new_q_scores_list_with_insertion_1:
                    assert len(refseq_with_insertion_1) == len(i)
                for i in new_q_scores_list_with_insertion_2:
                    assert len(refseq_with_insertion_2) == len(i)

                # print("Alighment of top half and bottom half: DONE")

                # スコアマキシマムになるような前半後半の境界を探す
                custom_cigar_score_dict = self.my_aligner.get_custom_cigar_score_dict()
                refseq_with_insertion = refseq_with_insertion_1
                my_cigar_str_list_with_insertion = [None for i in my_cigar_str_list_with_insertion_1]
                new_seq_list_with_insertion = [None for i in new_seq_list_with_insertion_1]
                new_q_scores_list_with_insertion = [None for i in new_q_scores_list_with_insertion_1]
                for i, (j1, j2, k1, k2, l1, l2) in enumerate(zip(
                        my_cigar_str_list_with_insertion_1, 
                        my_cigar_str_list_with_insertion_2, 
                        new_seq_list_with_insertion_1, 
                        new_seq_list_with_insertion_2, 
                        new_q_scores_list_with_insertion_1, 
                        new_q_scores_list_with_insertion_2
                )):
                    my_cigar_scores_with_insertion_1 = np.array([custom_cigar_score_dict[j] for j in j1])
                    my_cigar_scores_with_insertion_2 = np.array([custom_cigar_score_dict[j] for j in j2])
                    switching_idx = np.argmin(np.cumsum(my_cigar_scores_with_insertion_1 - my_cigar_scores_with_insertion_2))
                    # register
                    my_cigar_str_list_with_insertion[i] = j2[:switching_idx + 1] + j1[switching_idx + 1:]
                    new_seq_list_with_insertion[i]      = k2[:switching_idx + 1] + k1[switching_idx + 1:]
                    new_q_scores_list_with_insertion[i] = l2[:switching_idx + 1] + l1[switching_idx + 1:]

                # print(seq_id_list)
                # print(refseq_with_insertion)
                # for i, seq_id in enumerate(seq_id_list):
                #     # print(seq_id)
                #     print(my_cigar_str_list_with_insertion[i])
                #     print(new_seq_list_with_insertion[i])
                #     print(new_q_scores_list_with_insertion[i])
                self.aligned_result_list.append({
                    "refseq_with_insertion": refseq_with_insertion, 
                    "query_idx_list": query_idx_list, 
                    "seq_id_list": seq_id_list, 
                    "my_cigar_str_list_with_insertion": my_cigar_str_list_with_insertion, 
                    "new_seq_list_with_insertion": new_seq_list_with_insertion, 
                    "new_q_scores_list_with_insertion": new_q_scores_list_with_insertion
                })
            else:
                self.aligned_result_list.append({
                    "refseq_with_insertion": refseq.seq, 

                    "query_idx_list": (), 
                    "seq_id_list": (), 
                    "my_cigar_str_list_with_insertion": [], 
                    "new_seq_list_with_insertion": [], 
                    "new_q_scores_list_with_insertion": []

                    # "query_idx_list": (-1, ), 
                    # "seq_id_list": ("@None", ), 
                    # "my_cigar_str_list_with_insertion": ["X" * len(refseq.seq)], 
                    # "new_seq_list_with_insertion": ["-" * len(refseq.seq)], 
                    # "new_q_scores_list_with_insertion": [[-1] * len(refseq.seq)]
                })
        assert len(self.my_aligner.refseq_list) == len(self.aligned_result_list)
    def export_as_text(self, save_dir):
        text_list = []
        save_path_list = []
        for refseq, aligned_result in zip(self.my_aligner.refseq_list, self.aligned_result_list):
            text = ""
            idx_label_minimum = "consensus"
            query_idx_list = aligned_result["query_idx_list"]
            query_idx_len_max = max([len(str(query_idx)) for query_idx in query_idx_list] + [0])
            label_N0 = max(query_idx_len_max, len(idx_label_minimum)) + 1
            seq_id_list = aligned_result["seq_id_list"]
            seq_id_len_max = max([len(seq_id) for seq_id in seq_id_list] + [0])
            label_N1 = max(seq_id_len_max, len(refseq.path.name)) + 1
            text += (
                "ref"
                + " " * (label_N0 - 3)
                + refseq.path.name
                + " " * (label_N1 - len(refseq.path.name))
                + aligned_result["refseq_with_insertion"]
            )
            consensus_seq, consensus_q_scores, consensus_seq_all, consensus_q_scores_all = self.consensus_dict[refseq.path.name]
            text += (
                "\n"
                + "consensus"
                + " " * (label_N0 - 9 + label_N1)
                + consensus_seq_all
            )
            text += (
                "\n"
                + "consensus"
                + " " * (label_N0 - 9 + label_N1)
                + "".join([chr(q) for q in (np.array(consensus_q_scores_all) + 33)])
            )
            for query_idx, seq_id, my_cigar_str_with_insertion, new_seq_with_insertion, new_q_scores_with_insertion in \
                zip(
                    aligned_result["query_idx_list"], 
                    aligned_result["seq_id_list"], 
                    aligned_result["my_cigar_str_list_with_insertion"], 
                    aligned_result["new_seq_list_with_insertion"], 
                    aligned_result["new_q_scores_list_with_insertion"]
                ):
                label = (
                    "\n"
                    + str(query_idx)
                    + " " * (label_N0 - len(str(query_idx)))
                    + seq_id
                    + " " * (label_N1 - len(seq_id))
                )
                text += (
                    label
                    + new_seq_with_insertion
                    + label
                    + my_cigar_str_with_insertion
                    + label
                    + "".join([chr(q) for q in (np.array(new_q_scores_with_insertion) + 33)])
                )
            save_path = (save_dir / refseq.path.name).with_suffix(".txt")
            with open(save_path, "w") as f:
                f.write(text)
            text_list.append(text)
            save_path_list.append(save_path)
        return text_list, save_path_list
    def alignment_reuslt_list_2_text_list(self, linewidth=""):
        text_list = []
        highlight_pos_list = []
        refseq_name_list = []
        for refseq, aligned_result in zip(self.my_aligner.refseq_list, self.aligned_result_list):
            ref_label = "REF"
            label_N = max(len(str(max(aligned_result["query_idx_list"]))), len(ref_label)) + 1
            refseq_name_list.append(refseq.path.name)
            # "linewidth" 行ごとにまとめて改行
            child_lines = re.findall(fr".{{1,{linewidth}}}", aligned_result["refseq_with_insertion"].upper())
            master_lines = [list(map(lambda l: ref_label + " " * (label_N - len(ref_label)) + l, child_lines))]
            for idx, (new_seq_with_insertion, query_idx) in enumerate(zip(aligned_result["new_seq_list_with_insertion"], aligned_result["query_idx_list"])):
                child_lines = re.findall(fr".{{1,{linewidth}}}", new_seq_with_insertion.upper())
                master_lines.append(list(map(lambda l: f"{query_idx}" + " " * (label_N - len(str(query_idx))) + l, child_lines)))
            # 改行したものを zip でくっつけていく
            text = ""
            for i, lines in enumerate(zip(*master_lines)):
                text += (
                    f"{i * linewidth + 1}-{(i + 1) * linewidth}\n"
                    + "\n".join(lines)
                    + "\n\n"
                )
            text_list.append(text.strip())
            # ハイライト部分
            highlight_pos_in_text = []
            for i, my_cigar_str_with_insertion in enumerate(aligned_result["my_cigar_str_list_with_insertion"]):
                # print(len(my_cigar_str_with_insertion))
                true_highlight_pos = [m.start() for m in re.finditer('[IXDS]', my_cigar_str_with_insertion)]
                for p in true_highlight_pos:
                    r, c = divmod(p, linewidth) 
                    r = r * (len(master_lines) + 2) + i + 2 # ポジション行、master_lines、改行空白
                    c += label_N
                    highlight_pos_in_text.append((r, c))
            highlight_pos_list.append(highlight_pos_in_text)
        return refseq_name_list, text_list, highlight_pos_list
    def export_log(self, save_path, header):
        text = (
            header
            + f"\n{datetime.datetime.now()}"
            + "\n\ninput reference files"
            + "\n"
            + "\n".join([refseq.path.as_posix() for refseq in self.my_aligner.refseq_list])
            + "\n\ninput fastq files"
            + "\n"
            + "\n".join([fastq_path.as_posix() for fastq_path in self.my_aligner.combined_fastq.path])
            + "\n\nalignment params"
            + f"\ngap_open_penalty\t{self.my_aligner.gap_open_penalty}"
            + f"\ngap_extend_penalty\t{self.my_aligner.gap_extend_penalty}"
            + f"\nmatch_score\t{self.my_aligner.match_score}"
            + f"\nmismatch_score\t{self.my_aligner.mismatch_score}"
            + f"\nscore_threshold\t{self.score_threshold}"
            + "\n"
            + "\n".join(f"custom_cigar_score '{k}'\t{v}" for k, v in self.my_aligner.get_custom_cigar_score_dict().items())
            + "\n\nscore_matrix (basically the same as match_score and mismatch_score)"
            + "\n"
            + self.matrix2string(self.my_aligner.my_custom_matrix.matrix, digit=3, round=True)
            + "\n"
            + "\nconsensus_settings"
            + "\nsbq_pdf_version\t" + self.consensus_settings["sbq_pdf_version"]
            + "\n"
            + "\nerror_matrix (row:true_base, col:base_calling , val:P(true_base|base_calling))"
            + "\n"
            + self.matrix2string(
                self.consensus_settings["P_N_dict_matrix"], 
                bases=self.consensus_settings["bases"], 
                digit=None, 
                round=False
            )
        )
        with open(save_path, "w") as f:
            f.write(text)
    @staticmethod
    def matrix2string(matrix, bases="ATCG", digit=3, round=True):
        bio = io.BytesIO()
        if round:
            np.savetxt(bio, matrix, fmt=f"%{digit}d")
        else:
            np.savetxt(bio, matrix, fmt=f"%.5e")
            digit = 11
        matrix_str = bio.getvalue().decode('latin1')
        bases = bases + "*"
        output = (
            " " * (digit + 1)
            + (" " * digit).join(b for b in bases)
        )
        for b, m in zip(bases, matrix_str.split("\n")):
            output += f"\n{b} {m}"
        return output
    def save_consensus(self, save_dir):
        save_path_list = []
        for key, val in self.consensus_dict.items():
            save_path1 = save_dir / Path(key).with_suffix(".fastq")
            consensus_seq, consensus_q_scores, *all_results = val
            consensus_q_scores = "".join([chr(q + 33) for q in consensus_q_scores])
            consensus_fastq_txt = f"@{key}:\n{consensus_seq.upper()}\n+\n{consensus_q_scores}"
            with open(save_path1, "w") as f:
                f.write(consensus_fastq_txt)
            save_path_list.append(save_path1)
        return save_path_list

def draw_distributions(score_summary_df, combined_fastq):
    refseq_idx_dict = OrderedDict()
    for c in score_summary_df.columns:
        m = re.match(r"(.+) \(idx=([0-9]+)\)", c)
        if m is not None:
            refseq_idx_dict[int(m.group(2))] = m.group(1)

    # データ収集
    assignment_set_4_read_length = [[] for i in range(len(refseq_idx_dict) + 1)]   # last one is for idx=-1 (not assigned)
    assignment_set_4_q_scores = [[] for i in range(len(refseq_idx_dict) + 1)]   # last one is for idx=-1 (not assigned)
    for i, s in score_summary_df.iterrows():
        seq_id = s["seq_id"]
        assigned_refseq_idx = s["assigned_refseq_idx"]
        assigned = s["assigned"]
        if assigned == 0:
            assigned_refseq_idx = -1
        assignment_set_4_read_length[assigned_refseq_idx].append(len(combined_fastq[seq_id][0]))
        assignment_set_4_q_scores[assigned_refseq_idx].extend(combined_fastq[seq_id][1])

    # 描画パラメータ
    rows = len(refseq_idx_dict)
    columns = 3 # 4
    fig = plt.figure(figsize=(4 * columns, 2 * rows), clear=True)
    fig.subplots_adjust(hspace=0.05, wspace=0.05)
    widths = [2] + [2 for i in range(columns - 1)]
    # heights = [2] + [3 for i in range(rows - 1)]
    spec = fig.add_gridspec(ncols=columns, nrows=rows, width_ratios=widths)#, height_ratios=heights)

    ###########
    # labeles #
    ###########
    column_idx = 0
    text_wrap = 15
    for refseq_idx, refseq_name in refseq_idx_dict.items():
        ax = fig.add_subplot(spec[refseq_idx, column_idx])
        refseq_name_wrapped = "\n".join([refseq_name[i:i+text_wrap] for i in range(0, len(refseq_name), text_wrap)])
        ax.text(0.5, 0.6, refseq_name_wrapped, ha='center', va='center', wrap=True, family="monospace")
        ax.set_axis_off()
    color_cycle = plt.rcParams['axes.prop_cycle'].by_key()['color']
    legend_elements = [
        Patch(facecolor=color_cycle[0], label='Focused'), 
        Patch(facecolor=color_cycle[1], label='Others'), 
        Patch(facecolor="grey", label='not assigned')
    ]
    fig.legend(handles=legend_elements, loc="lower left", borderaxespad=0)

    ############################
    # read length distribution #
    ############################
    column_idx = 1
    # assignment ごとにヒートマップを描画
    bin_unit = 100
    bins = range(0, int(np.ceil(max(max(v) if len(v) > 0 else bin_unit for v in assignment_set_4_read_length) / bin_unit) * bin_unit), bin_unit)
    for refseq_idx, refseq_name in refseq_idx_dict.items():
        hist_params = dict(
            x=assignment_set_4_read_length[-2::-1] + assignment_set_4_read_length[-1:], 
            color=[color_cycle[0] if i == refseq_idx else color_cycle[1] for i in range(len(refseq_idx_dict))][::-1] + ["grey"], 
            bins=bins, 
            histtype='bar', 
            stacked=True
        )
        # 描画
        ax0 = fig.add_subplot(spec[refseq_idx, column_idx])
        ax0.hist(**hist_params)
        ax0.set_ylabel("count")
        # # log scale
        # ax1 = fig.add_subplot(spec[refseq_idx, column_idx + 1])
        # ax1.hist(**hist_params)
        # ax1.set_yscale("log")
        # ax1.set_ylabel("count")
        if refseq_idx == 0:
            ax0.set_title("read length distribution")
            # ax1.set_title("read length distribution (log)")
        if refseq_idx == len(refseq_idx_dict) - 1:
            ax0.set_xlabel("bp")
            # ax1.set_xlabel("bp")
        else:
            ax0.set_xticklabels([])
            # ax1.set_xticklabels([])

    ########################
    # q_score distribution #
    ########################
    column_idx = 2
    for refseq_idx, refseq_name in refseq_idx_dict.items():
        hist_params = dict(
            x=assignment_set_4_q_scores[-2::-1] + assignment_set_4_q_scores[-1:], 
            color=[color_cycle[0] if i == refseq_idx else color_cycle[1] for i in range(len(refseq_idx_dict))][::-1] + ["grey"], 
            bins=np.arange(42), 
            histtype='bar', 
            stacked=True, 
            density=True
        )
        # 描画
        ax0 = fig.add_subplot(spec[refseq_idx, column_idx])
        ax0.hist(**hist_params)

        # labels
        ax0.set_ylabel("density")
        if refseq_idx == 0:
            ax0.set_title("Q-score distribution")
        if refseq_idx == len(refseq_idx_dict) - 1:
            ax0.set_xlabel("Q-score")
        else:
            ax0.set_xticklabels([])

    plt.tight_layout()

def draw_alignment_score_scatter(score_summary_df, score_threshold):
    refseq_idx_dict = OrderedDict()
    for c in score_summary_df.columns:
        m = re.match(r"(.+) \(idx=([0-9]+)\)", c)
        if m is not None:
            refseq_idx_dict[int(m.group(2))] = m.group(1)

    # アサインされたスコアまとめを追加
    for refseq_idx, refseq_name in refseq_idx_dict.items():
        col_name1 = refseq_name + f" (idx={refseq_idx}, normalized)"
        col_name2 = refseq_name + f" (idx={refseq_idx},rc, normalized)"
        score_summary_df[refseq_name] = score_summary_df.apply(lambda row: max(row[col_name1], row[col_name2]), axis=1)

    # 描画パラメータ
    rows = columns = len(refseq_idx_dict) + 1
    color_cycle = plt.rcParams['axes.prop_cycle'].by_key()['color']
    focused_color1 = color_cycle[0]
    focused_color2 = color_cycle[1]
    not_assigned_color = "grey"
    fig = plt.figure(figsize=(2.5 * columns, 2.5 * rows), clear=True)
    widths = [2] + [3 for i in range(columns - 1)]
    heights = [2] + [3 for i in range(columns - 1)]
    spec = fig.add_gridspec(ncols=columns, nrows=rows, width_ratios=widths, height_ratios=heights)

    # label
    legend_elements = [
        Patch(facecolor=color_cycle[0], label='Focused'), 
        Patch(facecolor=color_cycle[1], label='Others'), 
        Patch(facecolor="grey", label='not assigned')
    ]
    fig.legend(handles=legend_elements, loc="upper left", borderaxespad=0.2)
    # fig.suptitle("alignment score scatter")

    ######################
    # score distribution #
    ######################
    diagonal_axes = []
    other_axes = []
    for (refseq_idx1, refseq_name1), (refseq_idx2, refseq_name2) in product(refseq_idx_dict.items(), refseq_idx_dict.items()):
        ax = fig.add_subplot(spec[refseq_idx1 + 1, refseq_idx2 + 1]) # 原点を左上にに取った！
        if refseq_idx1 == refseq_idx2:
            diagonal_axes.append(ax)
            hist_params = dict(
                x=[
                    score_summary_df.query("(assigned_refseq_idx == @refseq_idx1)&(assigned == 1)")[refseq_name1], 
                    score_summary_df.query("(assigned_refseq_idx != @refseq_idx1)&(assigned == 1)")[refseq_name1], 
                    score_summary_df.query("(assigned == 0)")[refseq_name1]
                ], 
                color=[focused_color1, focused_color2, not_assigned_color], 
                bins=np.linspace(0, 1, 100), 
                histtype='bar', 
                stacked=True, 
                density=True
            )
            ax.hist(**hist_params)
        else:
            other_axes.append(ax)
            scatter_params = dict(
                x=refseq_name2, 
                y=refseq_name1, 
                ax=ax, 
                s=5, 
                alpha=0.3
            )
            plot_params = dict(
                c="k", 
                linestyle="--", 
                linewidth=1
            )
            score_summary_df.query("(assigned_refseq_idx == @refseq_idx2)&(assigned == 1)").plot.scatter(color=focused_color1, **scatter_params)
            score_summary_df.query("(assigned_refseq_idx != @refseq_idx2)&(assigned == 1)").plot.scatter(color=focused_color2, **scatter_params)
            score_summary_df.query("assigned == 0").plot.scatter(color=not_assigned_color, **scatter_params)
            ax.plot((score_threshold, score_threshold), (0, score_threshold), **plot_params)
            ax.plot((0, score_threshold), (score_threshold, score_threshold), **plot_params)
            ax.plot((score_threshold, 1), (score_threshold, 1), **plot_params)
            ax.set_ylim(-0.05, 1.05)
        ax.set_xlim(-0.05, 1.05)
        ax.set_xticks(np.linspace(0, 1, 6))
        ax.set_xticklabels(["0.0", "0.2", "0.4", "0.6", "0.8", "1.0"])
        if refseq_idx2 != 0:
            # ax.yaxis.set_ticks_position('none')
            ax.set(ylabel=None)
            plt.setp(ax.get_yticklabels(), visible=False)
        else:
            if refseq_idx1 == refseq_idx2:
                ax.set_ylabel("density")
            else:
                ax.set_ylabel("normalized alignment score")
                ax.set_yticks(np.linspace(0, 1, 6))
                ax.set_yticklabels(["0.0", "0.2", "0.4", "0.6", "0.8", "1.0"])
        if refseq_idx1 != rows - 2:
            # ax.xaxis.set_ticks_position('none')
            ax.set(xlabel=None)
            plt.setp(ax.get_xticklabels(), visible=False)
        else:
            ax.set_xlabel("normalized alignment score")

    range_max = max(ax.get_ylim()[1] for ax in diagonal_axes)
    for ax in diagonal_axes:
        ax.set_ylim(0, range_max)

    text_wrap = 15
    for refseq_idx, refseq_name in refseq_idx_dict.items():
        ax = fig.add_subplot(spec[0, refseq_idx + 1])
        refseq_name_wrapped = "\n".join([refseq_name[i:i+text_wrap] for i in range(0, len(refseq_name), text_wrap)])
        ax.text(0.5, 0.0, refseq_name_wrapped, ha='center', va='bottom', wrap=True, family="monospace")
        ax.set_axis_off()

        ax = fig.add_subplot(spec[refseq_idx + 1, 0])
        refseq_name_wrapped = "\n".join([refseq_name[i:i+text_wrap] for i in range(0, len(refseq_name), text_wrap)])
        ax.text(0.5, 0.4, refseq_name_wrapped, ha='right', va='center', wrap=True, family="monospace")
        ax.set_axis_off()

    # set aspect after setting the ylim
    # ax = other_axes[0]
    # aspect = (ax.get_xlim()[1] - ax.get_xlim()[0]) / (ax.get_ylim()[1] - ax.get_ylim()[0])
    # for ax in other_axes:
    #     ax.set_aspect(aspect, adjustable='box')
    ax = diagonal_axes[0]
    aspect_diagonal = (ax.get_xlim()[1] - ax.get_xlim()[0]) / (ax.get_ylim()[1] - ax.get_ylim()[0])
    for ax in diagonal_axes:
        ax.set_aspect(aspect_diagonal, adjustable='box')
    fig.subplots_adjust(hspace=0.05, wspace=0.05, left=0.0, right=0.8, bottom=0.2, top=1.0)

alignment_result = AlignmentResult(result_dict, my_aligner)
print("normalizing scores...")
alignment_result.normalize_scores_and_apply_threshold()
print("normalization: DONE")
score_summary_df = alignment_result.get_score_summary_df()

# draw graphical summary
print("drawing figures...")
draw_distributions(score_summary_df, my_aligner.combined_fastq)
draw_alignment_score_scatter(score_summary_df, alignment_result.score_threshold)



In [None]:
#@title # 4. Calculate consensus

error_rate = 0.0001   #@param {type:"number"}
del_mut_rate = error_rate / 4     # e.g. "A -> T, C, G, del"
ins_rate   = 0.0001 #@param {type:"number"}   # 挿入は独立に考える？

##
bases = "ATCG-"
assert bases[-1] == "-"

from collections import defaultdict
default_value = {b_key2:ins_rate / 4 if b_key2 != "-" else 1 - ins_rate for b_key2 in bases}

P_N_dict_dict = defaultdict(
    lambda: default_value, 
    {   # 真のベースが b_key1 である場合に、b_key2 への mutation/deletion などが起こる確率
        b_key1:{b_key2:1 - error_rate if b_key2 == b_key1 else del_mut_rate for b_key2 in bases} for b_key1 in bases[:-1]  # remove "-" from b_key1
    }
)
P_N_dict_dict["-"] = default_value

default_value_2 = {b_key2:0.2 / 4 if b_key2 != "-" else 0.8 for b_key2 in bases}
P_N_dict_dict_2 = defaultdict(
    lambda: default_value_2, 
    {
        b_key1:{b_key2: 0.2 for b_key2 in bases} for b_key1 in bases[::-1]
    }
)
P_N_dict_dict_2["-"] = default_value_2
letter_code_dict = {
    "ATCG":"N", # Any base
    "TCG":"B",  # Not A
    "ACG":"V",  # Not T
    "ATG":"D",  # Not C
    "ATC":"H",  # Not G
    "TG":"K",   # Keto
    "AC":"M",   # Amino
    "AG":"R",   # Purine
    "CG":"S",   # Strong
    "AT":"W",   # Weak
    "TC":"Y",   # Pyrimidine
    "A":"A", 
    "T":"T", 
    "C":"C", 
    "G":"G", 
}
def mixed_bases(base_list):
    if len(base_list) == 1:
        return base_list[0]
    elif "-" not in base_list:
        pass
    else:
        base_list.remove("-")
    letters = ""
    for b in bases[:-1]:
        if b in base_list:
            letters += b
    return letter_code_dict[letters]

def P_N_dict_dict_2_matrix(P_N_dict_dict, bases=bases):
    r_matrix = np.empty((len(bases), len(bases)), dtype=float)
    for r, b_key1 in enumerate(bases):
        for c, b_key2 in enumerate(bases):
            r_matrix[r, c] = P_N_dict_dict[b_key1][b_key2]
    return r_matrix

@contextlib.contextmanager
def fopen(filein, *args, **kwargs):
    if isinstance(filein, str) or isinstance(filein, Path):  # filename/Path
        with open(filein, *args, **kwargs) as f:
            yield f
    else:  # file-like object
        yield filein

class MyTextFormat():
    def to_text(self):
        text = ""
        for k, data_type in self.keys:
            text += f"# {k}({data_type})\n"
            v = getattr(self, k)
            if data_type in ("str", "int"):
                string = str(v)
            elif data_type == "ndarray":
                with io.StringIO() as s:
                    np.savetxt(s, v)
                    string = s.getvalue().strip()
            elif data_type == "list":
                string = "\n".join(v)
            elif data_type in ("dict", "OrderedDict"):
                string = "\n".join(f"{k}\t{v}" for k, v in v.items())
            elif data_type == "eval":
                string = v.__str__()
            elif data_type == "df":
                string_io = io.StringIO()
                v.to_csv(string_io, sep="\t")
                string = string_io.getvalue().strip("\n")
            else:
                raise Exception(f"unsupported data type: {type(v)}")
            text += f"{string}\n\n"
        return text
    def save(self, save_path):
        text = self.to_text()
        with open(save_path, "w") as f:
            f.write(text)
    def load(self, load_path):
        added_keys = []
        with fopen(load_path, "r") as f:
            lines = f.readlines()
        cur_k = None
        cur_v = None
        cur_data_type = None
        for l in lines:
            if l.startswith("# "):
                if cur_k is None:   pass
                else:   self.set_attribute(cur_k, cur_v[:-2], cur_data_type)    # 改行コードが２つ入るので除く
                m = re.match(r"^(.+)\((.+)\)$", l[2:].strip("\n"))
                cur_k = m.group(1)
                cur_data_type = m.group(2)
                cur_v = ""
                added_keys.append((cur_k, cur_data_type))
            else:
                cur_v += l
        else:
            self.set_attribute(cur_k, cur_v[:-2], cur_data_type)
        return added_keys
    def set_attribute(self, cur_k: str, cur_v: str, cur_data_type: str):
        if isinstance(getattr(type(self), cur_k, None), property):
            if getattr(type(self), cur_k).fset is None:
                return
        if cur_data_type == "str":
            v = cur_v
        elif cur_data_type == "ndarray":
            v = np.array([list(map(float, line.split())) for line in cur_v.split("\n")])
        elif cur_data_type == "list":
            v = self.convert_to_number_if_possible(cur_v.split("\n"), method="all")
        elif cur_data_type == "dict":
            v = {l.split("\t")[0]:l.split("\t")[1] for l in cur_v.split("\n")}
        elif cur_data_type == "OrderedDict":
            v = OrderedDict([l.split("\t") for l in cur_v.split("\n")])
        elif cur_data_type == "eval":
            v = eval(cur_v)
        elif cur_data_type == "df":
            from ast import literal_eval
            string_io = io.StringIO(cur_v)
            v = pd.read_csv(string_io, sep="\t", index_col=0, dtype=str)
        else:
            raise Exception(f"unsupported data type\n{cur_data_type}")
        setattr(self, cur_k, v)
    def convert_to_number_if_possible(self, values, method):
        new_values = []
        for v in values:
            try:
                new_values.append(float(v))
            except:
                new_values.append(v)
        if (method == "all") and any(map(lambda x: not isinstance(x, float), new_values)):
            return values
        else:
            return new_values

class SequenceBasecallQscoreLibrary(MyTextFormat):
    def __init__(self, path=None) -> None:
        self.file_version = version
        self.path = path
        self.master_params_dict = None
        self.meta_info = pd.DataFrame()
        self.alignment_summary = pd.DataFrame(columns=["=", "I", "D", "X", "H", "S", "aligned_query_len", "refseq_len", "score"], dtype=object)
        # info for saving
        self.keys = [
            ("file_version", "str"),
            ("master_params_dict", "dict"),
            ("meta_info", "df"), 
            ("alignment_summary", "df"), 
        ]
        self.variable_key_start_idx = 4
        # load
        if self.path is not None:
            self.load(load_path=self.path)
        # when pdf
        if (len(self.keys) > 4) and (self.keys[4][0] == "sum"):
            self.initialize_pdf()
    def copy_key_data(self, lib):
        for k, d_type in self.keys:
            setattr(self, k, copy.deepcopy(getattr(lib, k)))
    def get_sum(self, pdf_params_dict):
        lib_sum = self.__class__()
        lib_sum.copy_key_data(self)
        lib_sum.path = self.path.parent / (self.path.stem + f"_sum{self.path.suffix}")
        # combine all data
        combined_df, loc = self.combine_data(**pdf_params_dict)
        combined_df = combined_df.astype(object).astype(int)
        lib_sum.add_df_by_dict(OrderedDict(
            [("sum", combined_df)]
        ))
        # add params and records
        for k, v in pdf_params_dict.items():
            lib_sum.master_params_dict[k] = v
        lib_sum.alignment_summary["used_for_pdf"] = loc
        return lib_sum
    def add_df_by_dict(self, ordered_dict: OrderedDict):
        for k, v in ordered_dict.items():
            setattr(self, k, v)
            self.keys.append((k, "df"))
    def combine_data(self, threshold, thredhold_type, **kwargs):
        if thredhold_type == "score_over_aligned_query_len":
            extracted_summary = self.alignment_summary.astype(float).query("(score / aligned_query_len) > @threshold")
            loc = self.alignment_summary.astype(float).apply(lambda x: x["score"] / x["aligned_query_len"], axis=1) > threshold
        else:
            raise Exception("error!")
        assert len(extracted_summary.index) > 0
        df = pd.DataFrame(0, index=getattr(self, extracted_summary.index[0]).index, columns=getattr(self, extracted_summary.index[0]).columns, dtype=float)
        for key in  extracted_summary.index:
            new_df = getattr(self, key)
            assert all(df.index == new_df.index) and all(df.columns == new_df.columns)
            df += new_df.astype(float)
        return df, loc
    def save(self, save_path=None):
        if save_path is None:
            save_path = self.path
        else:
            self.path = save_path
        super().save(save_path)
    def load(self, load_path):
        added_keys = super().load(load_path)
        for i in range(self.variable_key_start_idx):
            assert added_keys[i] == self.keys[i]
        for k in added_keys[self.variable_key_start_idx:]:
            self.keys.append(k)
        # post-processing
        self.meta_info["refseq_info"] = self.meta_info["refseq_info"].apply(lambda x: eval(x))
        for k, v in self.master_params_dict.items():
            try:    self.master_params_dict[k] = int(v) # int
            except: self.master_params_dict[k] = v      # string
        self.path = load_path
    def register_meta_info(self, fastq, refseq_list, **kwargs):
        meta_string = self.generate_meta_info_key(fastq)
        assert meta_string not in self.meta_info.index.values
        self.meta_info.loc[meta_string, "refseq_info"] = [[f"{refseq.my_hash}:{refseq.path.name}"] for refseq in refseq_list]
        for k, v in kwargs.items():
            self.meta_info.loc[meta_string, k] = v
    def register_alignment_summary(self, summary_info_dict, **kwargs):
        for key, d in summary_info_dict.items():
            for k, v in d.items():
                self.alignment_summary.loc[key, k] = v
    @staticmethod
    def generate_meta_info_key(fastq):
        return f"{fastq.my_hash}:{fastq.path.name}"
    @staticmethod
    def summary_df_2_matrix(summary_df:pd.DataFrame, base_order=None, **kwargs):
        crushed = summary_df.sum(axis=0)
        summary_matrix = np.zeros(shape=(len(base_order), len(base_order)), dtype=float)
        for k, v in crushed.items():
            m = re.match(r"(.+)_(.+)", k)
            ref_base = m.group(1)
            query_base = m.group(2)
            summary_matrix[base_order.index(ref_base), base_order.index(query_base)] = v
        return summary_matrix

    #########################
    # PDF related functions #
    #########################
    def initialize_pdf(self):
        assert self.keys[4][0] == "sum"
        self.sum = self.sum.astype(int)
        # 確率 0 となるのを避ける
        for c in self.sum.columns:
            if c.endswith("-"):
                continue
            for i in self.sum.index:
                if i < 2:
                    continue
                if self.sum.at[i, c] == 0:
                    self.sum.at[i, c] += 1
        # bunbo
        total_events_when_true_base = {}
        for column_names, values in self.sum.items():
            true_base = column_names.split("_")[0]
            if true_base not in total_events_when_true_base.keys():
                total_events_when_true_base[true_base] = values.sum()
            else:
                total_events_when_true_base[true_base] += values.sum()
        # calc probability
        self.P_base_calling_given_true_refseq_dict = {}
        for column_names, values in self.sum.items():
            true_base = column_names.split("_")[0]
            self.P_base_calling_given_true_refseq_dict[column_names] = values.sum() / total_events_when_true_base[true_base]
        # others
        self.pdf_core = {}
        for column_names, values in self.sum.items():
            assert all(values.index == np.arange(-1, 42))
            values /= values.sum()
            # マイナス1で最後のやつにアクセスできるようにする（さすがに50も間を開けてれば、q-scoreがかぶってくることは無いでしょう…）
            values_list = list(values)[1:] + [0.0 for i in range(50)] + list(values)[:1]
            self.pdf_core[column_names] = values
    # example:  a = self.calc_P_event_given_true_refseq(event=("T", 30), true_refseq="T")
    def calc_P_event_given_true_refseq(self, event, true_refseq):
        readseq, q_score = event
        key = f"{true_refseq}_{readseq}"
        return (
            self.P_base_calling_given_true_refseq_dict[key]
            * self.pdf_core[key][q_score]
        )
    def calc_consensus_error_rate(self, event_list, true_refseq, P_N_dict, bases):
        bunbo_bunshi_sum = 0
        bunshi_list = [self.calc_P_event_given_true_refseq(event, true_refseq) for event in event_list]
        bunshi_P_N = P_N_dict[true_refseq]
        # inside sum
        for base in bases:
            val = P_N_dict[base] / bunshi_P_N
            for event, bunshi in zip(event_list, bunshi_list):
                val *= self.calc_P_event_given_true_refseq(event, base) / bunshi
            bunbo_bunshi_sum += val
        return 1 - 1 / bunbo_bunshi_sum

NanoporeStats_PDF_txt = textwrap.dedent("""
    # file_version(str)
    0.2.0

    # master_params_dict(dict)
    gap_open_penalty	3
    gap_extend_penalty	1
    match_score	1
    mismatch_score	-2
    base_length_2_observe	1
    threshold	0.6
    thredhold_type	score_over_aligned_query_len

    # meta_info(df)
    	refseq_info
    omitted.fastq	['omitted.fasta']

    # alignment_summary(df)
    	=	I	D	X	H	S	aligned_query_len	refseq_len	score	used_for_pdf
    omitted_id	-1	-1	-1	-1	-1	-1	-1	-1	-1	True

    # sum(df)
    	A_A	A_T	A_C	A_G	A_-	T_A	T_T	T_C	T_G	T_-	C_A	C_T	C_C	C_G	C_-	G_A	G_T	G_C	G_G	G_-	-_A	-_T	-_C	-_G	-_-
    -1	0	0	0	0	8999	0	0	0	0	7112	0	0	0	0	11317	0	0	0	0	11289	0	0	0	0	6107251
    0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0
    1	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0
    2	426	10	18	42	0	23	340	26	11	0	62	34	372	40	0	81	22	28	380	0	50	21	30	26	0
    3	1409	36	85	163	0	35	1263	90	49	0	117	87	1477	127	0	224	48	118	1691	0	164	95	115	145	0
    4	3406	72	184	338	0	87	3048	144	73	0	213	181	3645	209	0	484	118	182	4025	0	349	231	260	341	0
    5	6053	75	227	496	0	105	5323	209	104	0	297	234	6258	281	0	709	115	261	7043	0	500	288	367	505	0
    6	8591	85	261	611	0	98	7093	215	106	0	303	252	8927	356	0	870	96	275	9851	0	656	346	455	589	0
    7	11716	87	297	714	0	103	9555	227	94	0	316	263	12075	322	0	992	105	373	13710	0	653	365	451	643	0
    8	12193	76	253	748	0	84	9965	206	86	0	276	255	12926	287	0	1002	95	316	14210	0	609	340	431	571	0
    9	11699	56	223	657	0	56	9736	188	62	0	230	201	12742	258	0	894	78	276	13925	0	571	307	394	518	0
    10	11142	68	200	532	0	52	9166	163	54	0	183	166	11856	220	0	771	56	239	13184	0	494	305	328	442	0
    11	10884	39	152	418	0	34	9128	113	37	0	139	158	11887	169	0	601	57	200	12733	0	454	264	300	400	0
    12	10689	45	113	385	0	43	9039	119	23	0	132	126	11849	143	0	541	34	172	12313	0	398	236	306	360	0
    13	10550	30	96	291	0	33	9135	87	22	0	107	127	11604	128	0	459	34	130	12436	0	323	218	239	324	0
    14	10593	39	63	264	0	26	9015	73	13	0	92	127	11872	103	0	418	35	135	12676	0	310	190	237	299	0
    15	10512	20	84	219	0	17	9192	58	16	0	71	99	12312	71	0	359	31	107	12539	0	257	219	201	271	0
    16	10780	34	65	221	0	21	9392	54	19	0	59	88	12563	66	0	310	29	86	13057	0	232	180	211	202	0
    17	10843	26	59	169	0	17	9289	63	16	0	41	61	12766	57	0	281	17	79	13102	0	227	153	181	195	0
    18	11027	13	45	148	0	17	9761	61	16	0	54	65	12987	63	0	220	18	78	13440	0	210	132	164	199	0
    19	11470	15	45	130	0	11	10142	54	8	0	47	45	13721	46	0	215	13	54	14029	0	177	142	137	192	0
    20	11936	14	36	118	0	5	10610	39	20	0	31	47	14134	38	0	195	18	44	14509	0	154	141	126	167	0
    21	12248	10	30	114	0	7	10970	27	10	0	33	38	14727	37	0	176	16	46	15021	0	133	123	157	158	0
    22	12759	13	25	88	0	9	11309	26	6	0	28	45	15618	36	0	152	16	36	15891	0	131	145	131	142	0
    23	13413	7	23	88	0	12	11814	25	5	0	27	25	16359	35	0	155	9	33	16541	0	128	143	149	136	0
    24	14338	7	19	75	0	8	13022	19	4	0	29	26	17435	23	0	127	4	27	17893	0	104	125	130	116	0
    25	15154	7	23	66	0	4	13434	16	5	0	16	26	18684	21	0	121	8	30	19223	0	98	117	156	105	0
    26	16154	4	13	52	0	5	14733	12	2	0	13	28	20212	14	0	92	7	25	20598	0	101	140	137	101	0
    27	17541	6	8	48	0	1	15844	9	2	0	14	22	21975	15	0	108	4	23	22447	0	112	140	106	105	0
    28	19297	2	11	41	0	3	17711	5	5	0	14	15	24315	17	0	75	2	21	24780	0	119	113	124	98	0
    29	21224	1	15	38	0	3	19865	16	0	0	13	23	27098	10	0	77	4	16	27354	0	99	155	132	83	0
    30	24032	4	9	21	0	3	22599	11	5	0	15	15	30276	13	0	65	1	10	30460	0	101	123	142	116	0
    31	27448	1	5	32	0	0	26257	11	0	0	5	15	33926	11	0	47	2	7	34298	0	94	136	169	109	0
    32	31166	3	3	27	0	1	30015	5	0	0	5	10	38340	7	0	44	7	5	38796	0	108	145	167	109	0
    33	35026	2	4	17	0	2	34652	8	1	0	8	9	42458	6	0	46	2	9	42678	0	115	128	127	111	0
    34	39313	2	6	13	0	0	39155	7	0	0	8	7	47036	7	0	53	0	4	47352	0	117	160	206	116	0
    35	43784	4	5	15	0	0	43501	4	2	0	5	6	51506	8	0	26	3	7	52394	0	113	167	155	111	0
    36	48503	1	3	7	0	2	48846	5	1	0	9	8	56565	3	0	28	1	6	56832	0	125	145	155	106	0
    37	53503	1	3	10	0	1	55312	5	0	0	5	3	62160	7	0	37	1	2	62695	0	121	189	121	102	0
    38	58885	0	4	9	0	0	61504	3	1	0	7	3	67806	2	0	22	4	6	67856	0	126	173	144	124	0
    39	64020	0	4	7	0	3	68218	4	1	0	3	5	73282	4	0	22	1	1	72869	0	135	170	142	105	0
    40	69864	0	1	7	0	1	74745	2	0	0	1	3	79278	0	0	17	1	3	77279	0	124	183	93	96	0
    41	612803	4	3	17	0	0	649004	6	3	0	9	9	673360	3	0	64	3	7	650025	0	684	891	599	538	0
""").strip() + "\n\n"

sbq_pdf = SequenceBasecallQscoreLibrary(io.StringIO(NanoporeStats_PDF_txt))

def calc_consensus(self, sbq_pdf, P_N_dict_dict):
    self.consensus_dict = {}
    for refseq_idx, aligned_result in enumerate(self.aligned_result_list):
        print(f"\nrefseq No. {refseq_idx}")
        consensus_seq = ""
        consensus_q_scores = []
        consensus_seq_all = ""
        consensus_q_scores_all = []
        N_bases = len(aligned_result["refseq_with_insertion"])
        for refbase_idx, refbase in enumerate(aligned_result["refseq_with_insertion"]):
            print(f"\r{refbase_idx + 1} out of {N_bases}", end="")
            seq_base_list = [i[refbase_idx] for i in aligned_result["new_seq_list_with_insertion"]]
            q_score_list = [i[refbase_idx] for i in aligned_result["new_q_scores_list_with_insertion"]]
            event_list = [(i.upper(), j) for i, j in zip(seq_base_list, q_score_list)]

            P_N_dict = P_N_dict_dict[refbase.upper()]
            p_list = [
                sbq_pdf.calc_consensus_error_rate(event_list, true_refseq=B, P_N_dict=P_N_dict, bases=bases)
                for B in bases
            ]
            p = min(p_list)
            # p_idx_list = [i for i, v in enumerate(p_list) if v == p]
            consensus_base_call = mixed_bases([b for b, tmp_p in zip(bases, p_list) if tmp_p == p])

            # register
            if p >= 10 ** (-5):
                q_score = np.round(-10 * np.log10(p)).astype(int)
            elif p < 0:
                raise Exception("unknown error")
            else:
                q_score = 50
            if  consensus_base_call != "-":
                consensus_seq += consensus_base_call
                consensus_q_scores.append(q_score)

            # registre "all" results
            consensus_seq_all += consensus_base_call
            consensus_q_scores_all.append(q_score)

        # 登録
        self.consensus_dict[self.my_aligner.refseq_list[refseq_idx].path.name] = [
            consensus_seq, 
            consensus_q_scores, 
            consensus_seq_all, 
            consensus_q_scores_all
        ]
    # register settings
    self.consensus_settings = {
        "sbq_pdf_version":sbq_pdf.file_version, 
        "P_N_dict_matrix":P_N_dict_dict_2_matrix(P_N_dict_dict), 
        "bases": bases
    }

alignment_result.integrate_assigned_result_info()
print()
print("integration: DONE")

print("Calculating consensus with prior information...")
calc_consensus(alignment_result, sbq_pdf, P_N_dict_dict)
print("\n\nCalculating consensus without prior information...")
alignment_result_2 = copy.deepcopy(alignment_result)
calc_consensus(alignment_result_2, sbq_pdf, P_N_dict_dict_2)

In [None]:
#@title # 5. Export results
export_alignment_image = False # @ param {type:"boolean"}
Size = namedtuple('Size', ("ax0", "ax1"))
class ATCG_5x5_img():
    dtype = uint8
    # bases
    A = np.array([
        [0,1,1,0], 
        [1,0,0,1], 
        [1,1,1,1], 
        [1,0,0,1], 
        [1,0,0,1]
    ], dtype=dtype)
    C = np.array([
        [0,1,1,0], 
        [1,0,0,1], 
        [1,0,0,0], 
        [1,0,0,1], 
        [0,1,1,0]
    ], dtype=dtype)
    G = np.array([
        [0,1,1,1], 
        [1,0,0,0], 
        [1,0,1,1], 
        [1,0,0,1], 
        [0,1,1,0]
    ], dtype=dtype)
    T = np.array([
        [1,1,1,0], 
        [0,1,0,0], 
        [0,1,0,0], 
        [0,1,0,0], 
        [0,1,0,0]
    ], dtype=dtype)
    # special letters
    R = np.array([
        [1,1,1,0], 
        [1,0,0,1], 
        [1,1,1,0], 
        [1,0,1,0], 
        [1,0,0,1]
    ], dtype=dtype)
    E = np.array([
        [1,1,1,1], 
        [1,0,0,0], 
        [1,1,1,0], 
        [1,0,0,0], 
        [1,1,1,1]
    ], dtype=dtype)
    F = np.array([
        [1,1,1,1], 
        [1,0,0,0], 
        [1,1,1,0], 
        [1,0,0,0], 
        [1,0,0,0]
    ], dtype=dtype)
    # numbers
    zero = np.array([
        [0,1,1,0], 
        [1,0,0,1], 
        [1,0,0,1], 
        [1,0,0,1], 
        [0,1,1,0]
    ], dtype=dtype)
    one = np.array([
        [0,1,0,0], 
        [1,1,0,0], 
        [0,1,0,0], 
        [0,1,0,0], 
        [1,1,1,0]
    ], dtype=dtype)
    two = np.array([
        [0,1,1,0], 
        [1,0,0,1], 
        [0,0,1,0], 
        [0,1,0,0], 
        [1,1,1,1]
    ], dtype=dtype)
    three = np.array([
        [0,1,1,0], 
        [1,0,0,1], 
        [0,0,1,0], 
        [1,0,0,1], 
        [0,1,1,0]
    ], dtype=dtype)
    four = np.array([
        [0,0,1,0], 
        [0,1,1,0], 
        [1,0,1,0], 
        [1,1,1,1], 
        [0,0,1,0]
    ], dtype=dtype)
    five = np.array([
        [1,1,1,0], 
        [1,0,0,0], 
        [1,1,1,0], 
        [0,0,0,1], 
        [1,1,1,0]
    ], dtype=dtype)
    six = np.array([
        [0,1,1,0], 
        [1,0,0,0], 
        [1,1,1,0], 
        [1,0,0,1], 
        [0,1,1,0]
    ], dtype=dtype)
    seven = np.array([
        [1,1,1,1], 
        [0,0,0,1], 
        [0,0,1,0], 
        [0,1,0,0], 
        [0,1,0,0]
    ], dtype=dtype)
    eight = np.array([
        [0,1,1,0], 
        [1,0,0,1], 
        [0,1,1,0], 
        [1,0,0,1], 
        [0,1,1,0]
    ], dtype=dtype)
    nine = np.array([
        [0,1,1,0], 
        [1,0,0,1], 
        [0,1,1,1], 
        [0,0,0,1], 
        [0,1,1,0]
    ], dtype=dtype)
    hyphen = np.array([
        [0,0,0,0], 
        [0,0,0,0], 
        [1,1,1,1], 
        [0,0,0,0], 
        [0,0,0,0]
    ])
    blank = np.array([
        [0,0,0,0], 
        [0,0,0,0], 
        [0,0,0,0], 
        [0,0,0,0], 
        [0,0,0,0]
    ])
    w2n = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
    s2n = {"-":"hyphen", " ":"blank"}
    def __init__(self, ax0, ax1, ax0_space=2, ax1_space=1):
        # prefixed
        self.letter_size = Size(ax0=5, ax1=4)
        self.bit = int(re.match(fr"^uint([0-9]+)$", self.dtype.__name__).group(1))
        self.max_intensity = 2 ** self.bit - 1
        # not prefixed
        self.size = Size(ax0=ax0, ax1=ax1)
        self.space = Size(ax0=ax0_space, ax1=ax1_space)
        self.actual_size = Size(
            ax0=self.size.ax0 * (self.letter_size.ax0 + self.space.ax0) + self.space.ax0, 
            ax1=self.size.ax1 * (self.letter_size.ax1 + self.space.ax1) + self.space.ax1
        )
        self.img_R = np.ones(self.actual_size, dtype=self.dtype) * self.max_intensity
        self.img_G = np.ones(self.actual_size, dtype=self.dtype) * self.max_intensity
        self.img_B = np.ones(self.actual_size, dtype=self.dtype) * self.max_intensity
        self.letter_matrix = np.empty(self.size, dtype=str)
    def write_letters(self, ax0, ax1, letters):
        for letter in letters:
            self.write_letter(ax0, ax1, letter)
            ax1 += 1
    def write_letter(self, ax0, ax1, letter):
        self.letter_matrix[ax0, ax1] = letter
        if letter in "0123456789":
            letter = self.w2n[int(letter)]
        elif letter in "- ":
            letter = self.s2n[letter]
        letter_array = getattr(self, letter)
        actual_ax0 = ax0 * (self.letter_size.ax0 + self.space.ax0) + self.space.ax0
        actual_ax1 = ax1 * (self.letter_size.ax1 + self.space.ax1) + self.space.ax1
        self.img_R[actual_ax0:actual_ax0 + self.letter_size.ax0, actual_ax1:actual_ax1 + self.letter_size.ax1] = (1 - letter_array) * self.max_intensity
        self.img_G[actual_ax0:actual_ax0 + self.letter_size.ax0, actual_ax1:actual_ax1 + self.letter_size.ax1] = (1 - letter_array) * self.max_intensity
        self.img_B[actual_ax0:actual_ax0 + self.letter_size.ax0, actual_ax1:actual_ax1 + self.letter_size.ax1] = (1 - letter_array) * self.max_intensity
    def highlight_letters(self, ax0, ax1, color, length):
        for i in range(length):
            self.highlight_letter(ax0, ax1 + i, color)
    def highlight_letter(self, ax0, ax1, color):
        letter = self.letter_matrix[ax0, ax1]
        if letter == "-":
            letter = self.s2n[letter]
        letter_array = getattr(self, letter)
        actual_ax0 = ax0 * (self.letter_size.ax0 + self.space.ax0) + self.space.ax0
        actual_ax1 = ax1 * (self.letter_size.ax1 + self.space.ax1) + self.space.ax1
        for img, color_value in zip([self.img_R, self.img_G, self.img_B], color):
            highlighted_letter = (1 - letter_array) * color_value
            img[actual_ax0:actual_ax0 + self.letter_size.ax0, actual_ax1:actual_ax1 + self.letter_size.ax1] = highlighted_letter
    def export_as_img(self, save_path, max_row_per_img):
        rgb_stack = np.stack((self.img_R, self.img_G, self.img_B), axis=2)
        N = np.ceil(self.size[0] / max_row_per_img).astype(int)
        save_paths = []
        for i in range(N):
            save_path_tmp = save_path.parent / f"{save_path.stem}_{i}{save_path.suffix}"
            save_paths.append(save_path_tmp)
            s = i * max_row_per_img       * (self.letter_size.ax0 + self.space.ax0) + self.space.ax0
            e = (i + 1) * max_row_per_img * (self.letter_size.ax0 + self.space.ax0) + self.space.ax0
            Image.fromarray(rgb_stack[s:e,:,:]).save(save_path_tmp)
        return save_paths

all_file_paths = []
# export settings
print("Exporting logs...")
header = f"{app_name} ver{version}\n{description}"
save_path_log_1 = pwd / "log_with_prior.txt"

alignment_result.export_log(save_path_log_1, header=header)
all_file_paths.append(save_path_log_1)

save_path_log_2 = pwd / "log_without_prior.txt"
alignment_result_2.export_log(save_path_log_2, header=header)
all_file_paths.append(save_path_log_2)

# export text
print("Exporting alignment results...")
save_path_list_text = []
text_list, save_path_list_text_1 = alignment_result.export_as_text(save_dir=pwd)
for file_path in save_path_list_text_1:
    path = file_path.parent / (file_path.stem + ".alignment_with_prior.txt")
    os.replace(src=file_path, dst=path.as_posix())
    save_path_list_text.append(path)
text_list, save_path_list_text_2 = alignment_result_2.export_as_text(save_dir=pwd)
for file_path in save_path_list_text_2:
    path = file_path.parent / (file_path.stem + ".alignment_without_prior.txt")
    os.replace(src=file_path, dst=path.as_posix())
    save_path_list_text.append(path)
all_file_paths += save_path_list_text

# print("Aligned Sequences")
# for text in text_list:
#     print(text)

# export gif
if export_alignment_image:
    print("Exporting alignment gifs...")
    refseq_name_list, text_list, highlight_pos_list = alignment_result.alignment_reuslt_list_2_text_list(linewidth=250)
    save_path_list_gif = []
    for refseq_name, text, highlight_pos in zip(refseq_name_list, text_list, highlight_pos_list):
        # テキスト記入
        splitted_text = text.split("\n")
        ax0 = len(splitted_text)
        ax1 = max([len(i) for i in splitted_text])
        atcg_img = ATCG_5x5_img(ax0=ax0, ax1=ax1)
        for ax0, t in enumerate(splitted_text):
            atcg_img.write_letters(ax0, 0, t)
        # ハイライト
        for r, c in highlight_pos:
            atcg_img.highlight_letter(r, c, (255, 100, 100))
        # 保存(サイズでかいので分割)
        save_path = (pwd / refseq_name).with_suffix(".gif")
        save_paths = atcg_img.export_as_img(save_path=save_path, max_row_per_img=1000)
        save_path_list_gif.extend(save_paths)
    all_file_paths += save_path_list_gif

# export score_summary
print("Exporting summary...")
save_path_summary_score = pwd / "summary_scores.txt"
score_summary = alignment_result.save_score_summary(save_path=save_path_summary_score)
all_file_paths.append(save_path_summary_score)

# print("Score Summary")
# print(score_summary)

# export summary image
print("Exporting summary svg images...")
save_path_summary_dictribution = pwd /"summary_distribution.svg"
save_path_summary_scatter = pwd / "summary_scatter.svg"
score_summary_df = pd.read_csv(save_path_summary_score, sep="\t")
draw_distributions(score_summary_df, combined_fastq)
plt.savefig(save_path_summary_dictribution)
plt.close()
draw_alignment_score_scatter(score_summary_df, score_threshold)
plt.savefig(save_path_summary_scatter)
plt.close()
all_file_paths.extend([save_path_summary_dictribution, save_path_summary_scatter])

# export consensus
print("Exporting consensus fastq files...")
consensus_path_list = []
consensus_path_list_1 = alignment_result.save_consensus(save_dir=pwd)
for file_path in consensus_path_list_1:
    path = file_path.parent / (file_path.stem + ".consensus_with_prior.fastq")
    os.replace(src=file_path, dst=(path).as_posix())
    consensus_path_list.append(path)
consensus_path_list_2 = alignment_result_2.save_consensus(save_dir=pwd)
for file_path in consensus_path_list_2:
    path = file_path.parent / (file_path.stem + ".consensus_without_prior.fastq")
    os.replace(src=file_path, dst=(path).as_posix())
    consensus_path_list.append(path)

all_file_paths += consensus_path_list

# make new folder
idx = 0
results_dir = pwd / "results"
while os.path.exists(results_dir.as_posix()):
    idx += 1
    results_dir = pwd / f"results {idx}"
os.makedirs(results_dir)

# move files
for file_path in all_file_paths:
    os.replace(src=file_path, dst=(results_dir / file_path.name).as_posix())

# compress as zip
os.chdir(results_dir)
zip_path = results_dir.with_suffix(".zip")
with zipfile.ZipFile(zip_path.as_posix(), 'w') as f:
    for file_path in all_file_paths:
        f.write(file_path.name)

print("export: DONE")

if save_to_google_drive == True and drive:
  uploaded = drive.CreateFile({'title': zip_path.name})
  uploaded.SetContentFile(zip_path)
  uploaded.Upload()
  print(f"Uploaded {zip_path} to Google Drive with ID {uploaded.get('id')}")


In [None]:
#@title # 6. Visualize results

target_file_path = "_m243mod_Lyn11-FRB-dGFP-PLDs48(280-506)-P2A-PLDs48(1-279)-mCherry-iFKBP_pCAGGS.alignment_without_prior.txt" #@param {type:"string"}
target_position = 3164 #@param {type:"number"}
display_range = 300 #@param {type:"number"}


#@markdown ## 1-1. Upload files
#@markdown `*.alignment_with_prior.txt` or `*.alignment_without_prior.txt`

#@markdown ## 1-2. Select this cell and hit `Runtime` -> `Run after`

import numpy as np
import re
from pathlib import Path
pwd = Path('/content/sample_data/')

class AlignmentViewer():
    def __init__(self, file_path):
        with open(file_path, "r") as f:
            # ref
            ref_line = f.readline()
            m = re.match(r"^(ref {1,}.+ {1,})([ATCGatcg-]+)$", ref_line)
            self.ref_label = m.group(1)
            self.ref_seq = m.group(2)
            len_label = len(self.ref_label)

            # consensus
            consensus_line_1 = f.readline()
            assert consensus_line_1[:len_label] == "consensus" + " " * (len_label - 9)
            self.consensus_seq = consensus_line_1[len_label:]
            consensus_line_2 = f.readline()
            assert consensus_line_2[:len_label] == "consensus" + " " * (len_label - 9)
            self.consensus_q_scores = [ord(q) - 33 for q in consensus_line_2[len_label:]]

            # query
            self.query_list = []
            i = 0
            for line_0 in f:
                line_0 =line_0.strip("\n")
                line_1 = f.readline().strip("\n")
                line_2 = f.readline().strip("\n")
                self.query_list.append(Query(line_0, line_1, line_2))

        # idx etc.
        self.refseq_len = len(self.ref_seq) - self.ref_seq.count("-")
        i = 1
        self.refseq_idx_list = []
        for b in self.ref_seq:
            if b != "-":
                self.refseq_idx_list.append(i)
                i += 1
            else:
                self.refseq_idx_list.append(-1)
        assert max(self.refseq_idx_list) == self.refseq_len
    def print(self, target_position, display_range):    # starts with 1
        assert 0 < target_position <= self.refseq_len
        beg = max(target_position - display_range, 1)
        end = min(target_position + display_range, self.refseq_len)
        b = self.refseq_idx_list.index(beg)
        if end < self.refseq_len:  e = self.refseq_idx_list.index(end + 1)
        else:                      e = len(self.ref_seq)
        t = self.refseq_idx_list.index(target_position)
        # print ref
        printed_refseq = self.print_core(b, e, t, target_position)
        # print query
        for query in self.query_list:
            query.print_core(b, e, t, printed_refseq)
    def print_core(self, b, e, t, target_position):
        # consensus q-scores
        print(' ' * (len(self.ref_label) + 0 - len("consensus q-scores ")) + 
            "consensus q-scores " + "".join(map(lambda x: f"{x:<3}", self.consensus_q_scores[b:e:3]))
        )
        print(' ' * (len(self.ref_label) + 1) + "".join(map(lambda x: f"{x:<3}", self.consensus_q_scores[b+1:e:3])))
        print(' ' * (len(self.ref_label) + 2) + "".join(map(lambda x: f"{x:<3}", self.consensus_q_scores[b+2:e:3])))
        # pre
        print(f"\033[1m{' ' * (len(self.ref_label) - len(str(target_position)) - 6 - 1)}(pos:{target_position}) {' ' * (t - b)}*{' ' * (e - t - 1)} Q-score\033[0m")
        # ref
        printed_refseq = self.ref_seq[b:e]
        print(f"\033[1m{self.ref_label + self.ref_seq[b:t]}\033[48;2;200;200;200m{self.ref_seq[t]}\033[0m\033[1m{self.ref_seq[t+1:e]}\033[0m")
        # consensus
        printed_consensus = ""
        for i, j, idx in zip(self.consensus_seq[b:e], self.ref_seq[b:e], range(b, e)):
            if i.upper() == j.upper():
                printed_consensus += f"\033[1m{i}\033[0m"
            else:
                printed_consensus += f"\033[1m\033[48;2;255;176;176m{i}\033[0m"
            if idx == t:
                if i.upper() == j.upper():
                    bar_color = "0;0;0"
                    bar_letter_color = "255;255;255"
                else:
                    bar_color = "255;176;176"
                    bar_letter_color = "0;0;0"
        # q_scores
        printed_q_scores = f"\033[1m{self.consensus_q_scores[t]: >2}"
        # bar
        printed_bar = f"\033[1m\033[48;2;{bar_color}m\033[38;2;{bar_letter_color}m{'*' * self.consensus_q_scores[t]}\033[0m  "
        # print
        print(f"\033[1mconsensus\033[0m{' ' * (len(self.ref_label) - 9) + printed_consensus} {printed_q_scores} {printed_bar}")
        return printed_refseq

class Query():
    def __init__(self, line_0, line_1, line_2):
        # sequence with insertion
        m = re.match(r"^([0-9]+ {1,}.+ {1,})([ATCGatcg-]+)$", line_0)
        self.query_label = m.group(1)
        self.query_seq = m.group(2)
        len_label = len(self.query_label)
        # my_cigar_string
        assert self.query_label == line_1[:len_label]
        self.my_cigar_string = line_1[len_label:]
        # q_score
        assert self.query_label == line_2[:len_label]
        self.q_scores = [ord(q) - 33 for q in line_2[len_label:]]
    def print_core(self, b, e, t, printed_refseq):
        # label, seq
        printed_query_seq = ""
        previous_L = ""
        bar_color = None
        for i, j, k, idx in zip(self.query_seq[b:e], self.my_cigar_string[b:e], printed_refseq, range(b, e)):
            if (previous_L == " ") & (i == "-"):
                L = " "
            elif j in "HS":
                L = " "
            elif i.upper() == k.upper():
                L = i
            else:
                L = f"\033[48;2;255;176;176m{i}\033[0m"
            if (idx == t) and (L != " "):
                L = f"\033[1m{L}\033[0m"
                if i.upper() == k.upper():
                    bar_color = "0;0;0"
                else:
                    bar_color = "255;176;176"
            printed_query_seq += L
            previous_L = L
        # q-scores
        if (self.q_scores[t] == -1) or (bar_color is None):
            q_score = "  "
            bar = "  "
        else:
            q_score = f"{self.q_scores[t]: >2}"
            bar = f"\033[48;2;{bar_color}m\033[38;2;{bar_color}m{'*' * self.q_scores[t]}\033[0m  "
        printed_q_scores = f"{q_score}"
        # print
        print(f"{self.query_label}{printed_query_seq} {printed_q_scores} {bar}")
        return printed_query_seq

file_path = pwd / target_file_path
if file_path.exists():
    alignment_viewer = AlignmentViewer(file_path)
    alignment_viewer.print(target_position, display_range)
else:
    print(f"file does not exist:\n{file_path}")

