In [1]:
import os
from collections import Counter, defaultdict
from io import StringIO

import pandas as pd
from Bio import SeqIO, AlignIO
from Bio.Align import MultipleSeqAlignment
from matplotlib import pyplot as plt

genomeTreeDir = "Genome_tree"
snpDir = "SNPs"
dataDir = "Data"

if not os.path.exists(snpDir):
    os.mkdir(snpDir)

exclude = "EPI_ISL_402131"
ancestral = "EPI_ISL_402123"
reference = "EPI_ISL_402125"

In [2]:
outSeqs = []

for record in SeqIO.parse(os.path.join(genomeTreeDir, "aligned.fasta"), "fasta"):
    if record.id != exclude:
        outSeqs.append(record)
        
handle = StringIO()
SeqIO.write(outSeqs, handle, "fasta")

aligned = AlignIO.read(StringIO(handle.getvalue()), "fasta")
annoRef = SeqIO.read(os.path.join(dataDir, "reference.gb"), "gb")

In [3]:
ref2aligned = {}

for record in aligned:
    if record.id == reference:
        refIndex = -1
        for alignedIndex, n in enumerate(record):
            if n != '-':
                refIndex += 1
            ref2aligned[refIndex] = alignedIndex
        break

In [4]:
cdsAlignedIndex = []
product = 0

for f in annoRef.features:
    if f.type == "CDS":
        product += 1
        for i in f.location.parts:
            cdsAlignedIndex.append({
                "start": ref2aligned[i.start],
                "end": ref2aligned[i.end],
                "product": product
            })

In [5]:
alignedCDS = None

for loc in cdsAlignedIndex:
    cdsSeq = aligned[:, loc["start"]:loc["end"]]
    if alignedCDS is None:
        alignedCDS = cdsSeq
    else:
        alignedCDS += cdsSeq

AlignIO.write(alignedCDS, os.path.join(genomeTreeDir, "aligned_cds.fasta"), "fasta")

1

In [6]:
alignedAA = None
indexedAlignedCDS = {}

for record in alignedCDS:
    indexedAlignedCDS[record.id] = record
#     aa = record.translate(gap='-', id=record.id, description="")
#     print(record.id)
#     if alignedAA is None:
#         alignedAA = MultipleSeqAlignment([aa])
#     else:
#         alignedAA.add_sequence(record.id, str(aa.seq))

In [7]:
snp = pd.DataFrame(columns=("pos", "ref", "A", "T", "G", "C", "gap", "unknown"))

alignedAnc = None
for record in alignedCDS:
    if record.id == ancestral:
        alignedAnc = record
        break

for i in range(alignedCDS.get_alignment_length()):
    aaSum = Counter(alignedCDS[:, i])
    if len(aaSum) > 1 and max(aaSum.values()) < len(alignedCDS) - 1:
        row = pd.Series(data={
            "pos": i + 1, "ref": alignedAnc[i].upper(),
            "A": aaSum["a"], "T": aaSum["t"],
            "G": aaSum["g"], "C": aaSum["c"],
            "gap": aaSum["-"], "unknown": aaSum["n"]
        })
        snp = snp.append(row, ignore_index=True)
        
snp.to_csv(os.path.join(snpDir, "all.csv"), index=False)

In [8]:
info = pd.read_csv(os.path.join(dataDir, "info.csv"))

In [9]:
groupByArea = defaultdict(list)
accessions = list(indexedAlignedCDS.keys())

for index, row in info[["Accession ID", "Area"]].iterrows():
    if row["Accession ID"] in accessions:
        groupByArea[row["Area"]].append(row["Accession ID"])

In [10]:
snpByArea = {}

for area, acList in groupByArea.items():
    localAlignedCDS = None
    for ac in acList:
        record = indexedAlignedCDS[ac]
        if localAlignedCDS is None:
            localAlignedCDS = MultipleSeqAlignment([record])
        else:
            localAlignedCDS.add_sequence(record.id, str(record.seq))
    
    localSNP = pd.DataFrame(columns=("pos", "ref", "A", "T", "G", "C", "gap", "unknown"))
    for pos in snp["pos"]:
        i = pos - 1
        aaSum = Counter(localAlignedCDS[:, i])
        if alignedAnc[i] not in aaSum or aaSum[alignedAnc[i]] < len(localAlignedCDS):
            row = pd.Series(data={
                "pos": i + 1, "ref": alignedAnc[i].upper(),
                "A": aaSum["a"], "T": aaSum["t"],
                "G": aaSum["g"], "C": aaSum["c"],
                "gap": aaSum["-"], "unknown": aaSum["n"]
            })
            localSNP = localSNP.append(row, ignore_index=True)
    if not localSNP.empty:
        snpByArea[area] = localSNP

In [30]:
def snpSummary(areaSNP):
    for index, row in areaSNP.iterrows():
        areaSNP.loc[index, row["ref"]] = 0
    area = areaSNP[["pos", "A", "T", "G", "C"]].melt(id_vars="pos")
    area = area[area["value"] != 0]
    res = defaultdict(list)
    for index, row in area.iterrows():
        res[row["pos"]].append(row["variable"])
    return(res)

In [31]:
hubeiSNP = snpSummary(snpByArea["Hubei"])
parallelSNP = defaultdict(list)

for area, localSNP in snpByArea.items():
    if area != "Hubei":
        areaSNP = snpSummary(localSNP)
        for pos in areaSNP:
            if pos in hubeiSNP:
                for n in areaSNP[pos]:
                    if n not in hubeiSNP[pos]:
                        parallelSNP[(pos, n)].append(area)
            else:
                for n in areaSNP[pos]:
                    parallelSNP[(pos, n)].append(area)

In [35]:
for (pos, n), areas in parallelSNP.items():
    if len(areas) >= 2 and pos % 3 == 2:
        print(pos, n, areas)

28337 T ['Guangdong', 'USA']
25865 T ['Australia', 'USA', 'France', 'Taiwan']


In [32]:
for (pos, n), areas in parallelSNP.items():
    if len(areas) >= 2 and pos % 3 == 1:
        print(pos, n, areas)

21436 T ['Guangdong', 'USA']
27574 C ['USA', 'Vietnam', 'Sichuan']
28786 T ['Guangdong', 'Japan']
