In [None]:
!pip install pysam

Collecting pysam
  Downloading pysam-0.23.3-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (1.7 kB)
Downloading pysam-0.23.3-cp312-cp312-manylinux_2_28_x86_64.whl (24.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.0/24.0 MB[0m [31m103.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pysam
Successfully installed pysam-0.23.3


In [None]:
import pysam
import torch
import numpy as np
import time
from multiprocessing import Pool, cpu_count

BASES = "ACGT"
BASE2IDX = {b: i for i, b in enumerate(BASES)}
N_IDX = 4  # index for 'N'

def left_align_indel(chrom, pos, ref, alt, fasta, max_shift=50):
    """Left-aligns a biallelic indel using the reference FASTA (pysam.FastaFile)."""
    if len(ref) == len(alt) or ref[0] != alt[0]:
        return pos, ref, alt
    seq, seq_alt, left = ref, alt, pos
    # Remove shared suffix
    while len(seq) > 1 and len(seq_alt) > 1 and seq[-1] == seq_alt[-1]:
        seq = seq[:-1]
        seq_alt = seq_alt[:-1]
    # Remove shared prefix
    while len(seq) > 1 and len(seq_alt) > 1 and seq[0] == seq_alt[0]:
        seq, seq_alt = seq[1:], seq_alt[1:]
        left += 1
    # Try shifting left
    for _ in range(max_shift):
        if left <= 1: break
        prev_base = fasta.fetch(chrom, left-2, left-1)
        if len(seq) > len(seq_alt):  # deletion
            if seq[-1] == prev_base:
                seq = seq[-1] + seq[:-1]
                left -= 1
            else:
                break
        else:  # insertion
            if seq_alt[-1] == prev_base:
                seq_alt = seq_alt[-1] + seq_alt[:-1]
                left -= 1
            else:
                break
    return left, seq, seq_alt

def seq_to_int(seq):
    return torch.tensor([BASE2IDX.get(c.upper(), N_IDX) for c in seq], dtype=torch.int32)

def align_smith_waterman_simd(query_sequence, ref_sequence, match_score=2, mismatch_penalty=-1, gap_open=-2, gap_extend=-1):
    query_tensor = seq_to_int(query_sequence)
    ref_tensor = seq_to_int(ref_sequence)
    M, N = len(query_tensor), len(ref_tensor)
    H = torch.zeros((M+1, N+1), dtype=torch.int32)
    E = torch.zeros((M+1, N+1), dtype=torch.int32)
    F = torch.zeros((M+1, N+1), dtype=torch.int32)
    pointer = torch.zeros((M+1, N+1), dtype=torch.int8)  # 1=diag, 2=up, 3=left
    max_score, max_pos = 0, (0, 0)
    for i in range(1, M+1):
        match_mask = (query_tensor[i-1] == ref_tensor)
        sub_scores = match_mask * match_score + (~match_mask) * mismatch_penalty
        for j in range(1, N+1):
            diag = H[i-1, j-1] + sub_scores[j-1]
            E[i, j] = max(H[i-1, j] + gap_open, E[i-1, j] + gap_extend)
            F[i, j] = max(H[i, j-1] + gap_open, F[i, j-1] + gap_extend)
            vals = torch.tensor([0, diag, E[i, j], F[i, j]])
            best, idx = torch.max(vals, 0)
            H[i, j] = best
            pointer[i, j] = idx  # 0=zero, 1=diag, 2=up, 3=left
            if best > max_score:
                max_score = best.item()
                max_pos = (i, j)
    # Traceback
    i, j = max_pos
    aligned_query = []
    while i > 0 and j > 0:
        if pointer[i, j] == 1:
            aligned_query.append(query_sequence[i-1])
            i -= 1
            j -= 1
        elif pointer[i, j] == 2:
            aligned_query.append(query_sequence[i-1])
            i -= 1
        elif pointer[i, j] == 3:
            aligned_query.append('-')
            j -= 1
        else:
            break
    return ''.join(aligned_query[::-1])

def realign_read(read, ref_seq, region_start, sw_window=100):
    query_seq = read.query_sequence
    ref_start = max(read.reference_start - region_start - sw_window//2, 0)
    ref_end = min(ref_start + len(query_seq) + sw_window, len(ref_seq))
    local_ref = ref_seq[ref_start:ref_end]
    if not query_seq or not local_ref:
        return query_seq
    return align_smith_waterman_simd(query_seq, local_ref)

def count_alleles(reads, ref_seq, region_start):
    counts = [{} for _ in range(len(ref_seq))]
    for read in reads:
        if read.is_unmapped or read.is_duplicate:
            continue

        # key = (read.query_name, read.reference_start)
        # realigned_seq = read_to_realigned.get(key, read.query_sequence)
        # seq = realigned_seq
        seq = read.query_sequence
        ref_pos, query_pos = read.reference_start, 0
        for cigartupe, length in read.cigartuples or []:
            if cigartupe == 0:  # Match/Mismatch
                for i in range(length):
                    pos = ref_pos + i - region_start
                    if 0 <= pos < len(ref_seq) and query_pos + i < len(seq):
                        base = seq[query_pos + i]
                        counts[pos][base] = counts[pos].get(base, 0) + 1
                ref_pos += length
                query_pos += length
            elif cigartupe == 1:  # Insertion
                query_pos += length
            elif cigartupe == 2:  # Deletion
                ref_pos += length
            elif cigartupe == 4:  # Soft clip
                query_pos += length
            elif cigartupe == 5:  # Hard clip
                pass
            else:
                ref_pos += length
    return counts

def detect_candidates(counts, ref_seq, region_start, min_count=1, min_frac=0.01):
    candidates = []
    for i, cnt in enumerate(counts):
        n_total = sum(cnt.values())
        if n_total == 0:
            continue
        ref_base = ref_seq[i].upper()
        for base, count in cnt.items():
            if base != ref_base and count >= min_count and count / n_total >= min_frac:
                candidates.append((i, ref_base, base, count, n_total))
    return candidates

def process_region(args):
    region_idx, total_regions, chrom, start, end, bam_path, fasta_path = args
    start_time = time.time()
    print(f"Processing region {region_idx+1}/{total_regions}: {chrom}:{start}-{end}")

    fasta = pysam.FastaFile(fasta_path)
    bam = pysam.AlignmentFile(bam_path)

    print(f"  Fetching reference sequence for {chrom}:{start}-{end}")
    ref_seq = fasta.fetch(chrom, start, end)

    print(f"  Fetching reads for {chrom}:{start}-{end}")
    reads = list(bam.fetch(chrom, start, end))
    print(f"  Found {len(reads)} reads in region")


    print(f"  Counting alleles in region")
    counts = count_alleles(reads, ref_seq, start)

    print(f"  Detecting variant candidates")
    candidates = detect_candidates(counts, ref_seq, start, min_count=2, min_frac=0.01)
    print(f"  Found {len(candidates)} candidate variants")

    output_lines = []
    print(f"  Left-aligning indels")
    for i, (pos, ref_base, alt_base, count, total) in enumerate(candidates):
        if i > 0 and i % 1000 == 0:
            print(f"    Left-aligned {i}/{len(candidates)} indels ({i/len(candidates)*100:.1f}%)")
        pos1 = start + pos
        # Indel left-alignment (1-based for VCF style)
        left, lref, lalt = left_align_indel(chrom, pos1+1, ref_base, alt_base, fasta)
        output_lines.append(f"{chrom}\t{left-1}\t{lref}\t{lalt}\t{count}/{total}")

    bam.close()
    fasta.close()

    elapsed_time = time.time() - start_time
    print(f"Completed region {region_idx+1}/{total_regions} in {elapsed_time:.1f} seconds")
    return output_lines

def get_contigs_from_fasta(fasta_path):
    print(f"Reading contigs from {fasta_path}")
    fasta = pysam.FastaFile(fasta_path)
    contigs = [(ctg, fasta.get_reference_length(ctg)) for ctg in fasta.references]
    print(f"Found {len(contigs)} contigs in reference")
    return contigs

def region_generator(contigs, window_size=100000):
    total_regions = sum([(length + window_size - 1) // window_size for ctg, length in contigs])
    print(f"Splitting genome into {total_regions} regions of size {window_size}")
    for ctg, length in contigs:
        num_windows = (length + window_size - 1) // window_size
        print(f"  Contig {ctg}: length={length}, windows={num_windows}")
        for start in range(0, length, window_size):
            end = min(start + window_size, length)
            yield (ctg, start, end)

def main(bam_path, fasta_path, output_txt, window_size=100000, threads=None):
    print(f"\n{'='*60}")
    print(f"DeepVariant-style Candidate Generator")
    print(f"{'='*60}")
    print(f"Input BAM:      {bam_path}")
    print(f"Reference:      {fasta_path}")
    print(f"Output file:    {output_txt}")
    print(f"Window size:    {window_size}")

    start_time = time.time()

    # Get contigs from reference
    contigs = get_contigs_from_fasta(fasta_path)

    # Generate regions
    regions = list(region_generator(contigs, window_size))
    total_regions = len(regions)

    # Prepare arguments for parallel processing
    num_threads = threads or cpu_count()
    print(f"Using {num_threads} threads for parallel processing")

    args_list = [(idx, total_regions, ctg, start, end, bam_path, fasta_path)
                 for idx, (ctg, start, end) in enumerate(regions)]

    total_candidates = 0
    print(f"\nStarting variant candidate detection")
    print(f"{'='*60}")

    with Pool(num_threads) as pool, open(output_txt, "w") as fout:
        for i, region_lines in enumerate(pool.imap_unordered(process_region, args_list, chunksize=1)):
            for line in region_lines:
                fout.write(line + "\n")
            total_candidates += len(region_lines)

            # Progress update
            progress = (i + 1) / total_regions
            elapsed = time.time() - start_time
            estimated_total = elapsed / progress if progress > 0 else 0
            remaining = estimated_total - elapsed

            print(f"Progress: {i+1}/{total_regions} regions processed ({progress*100:.1f}%)")
            print(f"Time elapsed: {elapsed:.1f}s, Estimated remaining: {remaining:.1f}s")
            print(f"Candidates found so far: {total_candidates}")
            print(f"{'-'*40}")

    total_time = time.time() - start_time
    print(f"\n{'='*60}")
    print(f"Completed variant candidate detection")
    print(f"Total time:           {total_time:.1f} seconds")
    print(f"Total candidates:     {total_candidates}")
    print(f"Candidates per second: {total_candidates/total_time:.1f}")
    print(f"Output written to:    {output_txt}")
    print(f"{'='*60}\n")


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
op = "osdreb2a_candidates.txt"
bam = "/content/synthetic_osdreb2a_knockout.bam"
fasta = "/content/drive/MyDrive/deepchem stuff/rice new/IRGSP-1.0_genome.fasta"

main(bam_path=bam, fasta_path= fasta, output_txt= op)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m



  Found 0 reads in region  Detecting variant candidates  Fetching reads for chr11:19400000-19500000  Counting alleles in region


  Found 0 reads in region
  Found 0 reads in region
  Counting alleles in region  Counting alleles in region


  Detecting variant candidates
  Counting alleles in region  Detecting variant candidates
  Detecting variant candidates

  Detecting variant candidates  Found 0 candidate variants  Found 0 candidate variants
  Found 0 candidate variants  Left-aligning indels



  Left-aligning indels
  Left-aligning indels  Detecting variant candidates

  Detecting variant candidatesCompleted region 3357/3740 in 0.1 secondsCompleted region 3359/3740 in 0.1 seconds
Completed region 3361/3740 in 0.1 seconds


Processing region 3369/3740: chr11:19500000-19600000Processing region 3370/3740: chr11:19600000-19700000  Detecting variant candidates


Progress: 3344/3740 regions processed (89.4%)
Time elapse

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torchvision.models import mobilenet_v2
import numpy as np
import pysam
import random
from tqdm import tqdm
from sklearn.metrics import classification_report, confusion_matrix
import multiprocessing

# ---- CONFIG ----
WINDOW = 221
HEIGHT = 100
CHANNELS = 6
BATCH_SIZE = 256
EPOCHS = 1
LR = 1e-3
NUM_CLASSES = 3

def base_to_intensity(base):
    if base == "A":
        return 0.25
    elif base == "C":
        return 0.5
    elif base == "G":
        return 0.75
    elif base == "T":
        return 1.0
    else:
        return 0.0

# ---- DATASET ----
class PileupImageDataset(Dataset):
    def __init__(self, bam_path, fasta_path, candidates):
        self.bam_path = bam_path
        self.fasta_path = fasta_path
        self.candidates = candidates

    def __len__(self):
        return len(self.candidates)

    def __getitem__(self, idx):
        chrom, pos, ref, alt, *_, label = self.candidates[idx]
        pos = int(pos)
        label = int(label)
        image = make_pileup_image(
            self.bam_path, self.fasta_path, chrom, pos, ref, alt, WINDOW, HEIGHT
        )
        image = torch.tensor(image, dtype=torch.float32)
        return image, label

class PileupImageUnlabeledDataset(Dataset):
    def __init__(self, bam_path, fasta_path, candidates):
        self.bam_path = bam_path
        self.fasta_path = fasta_path
        self.candidates = candidates

    def __len__(self):
        return len(self.candidates)

    def __getitem__(self, idx):
        chrom, pos, ref, alt, *rest = self.candidates[idx]
        pos = int(pos)
        image = make_pileup_image(
            self.bam_path, self.fasta_path, chrom, pos, ref, alt, WINDOW, HEIGHT
        )
        image = torch.tensor(image, dtype=torch.float32)
        return image, (chrom, pos, ref, alt)

# ---- PILEUP IMAGE GENERATION ----
def make_pileup_image(bam_path, fasta_path, chrom, pos, ref, alt, width=221, height=100):
    start = pos - width // 2
    end = pos + width // 2 + 1
    fasta = pysam.FastaFile(fasta_path)
    bam = pysam.AlignmentFile(bam_path)
    start = pos - width // 2
    end = pos + width // 2 + 1

    fetch_start = max(0, start)
    ref_seq = fasta.fetch(chrom, fetch_start, end)
    # Pad left if start < 0
    left_pad = "N" * (0 - start) if start < 0 else ""
    ref_seq = left_pad + ref_seq

    # Pad right if needed (as before)
    if len(ref_seq) < width:
        ref_seq += "N" * (width - len(ref_seq))

    reads = list(bam.fetch(chrom, fetch_start, end))

    # Sort for reproducibility
    reads = sorted(reads, key=lambda r: (-r.mapping_quality, r.is_reverse, r.is_secondary, r.query_name))
    pile = np.zeros((CHANNELS, height, width), dtype=np.float32)
    # Add reference row at the bottom (row=height-1)
    for col in range(width):
        base = ref_seq[col]
        pile[0, height-1, col] = base_to_intensity(base)
        pile[1, height-1, col] = 1.0  # ref base quality (max)
        pile[2, height-1, col] = 1.0  # ref MQ (max)
        pile[3, height-1, col] = 1.0  # ref strand (arbitrary)
        pile[4, height-1, col] = 1.0 if base.upper() == alt.upper() else 0.0
        pile[5, height-1, col] = 1.0
    for row, read in enumerate(reads[:height-1]):
        seq = read.query_sequence
        quals = read.query_qualities if read.query_qualities is not None else [20]*len(seq)
        is_reverse = read.is_reverse
        mq = min(read.mapping_quality, 60) / 60.0
        for qpos, rpos in read.get_aligned_pairs(matches_only=True):
            if rpos is None or not (start <= rpos < end):
                continue
            col = rpos - start
            if qpos is None or qpos >= len(seq):
                base = "N"
            else:
                base = seq[qpos]
            pile[0, row, col] = base_to_intensity(base)
            pile[1, row, col] = min(quals[qpos], 40) / 40.0 if qpos is not None and qpos < len(quals) else 0.5
            pile[2, row, col] = mq
            pile[3, row, col] = 0.0 if is_reverse else 1.0
            pile[4, row, col] = 1.0 if base.upper() == alt.upper() else 0.0
            pile[5, row, col] = 1.0 if base.upper() == ref_seq[col].upper() else 0.0
    bam.close()
    fasta.close()
    return pile

# ---- MODEL ----
class MobileNetV2_6ch(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES):
        super().__init__()
        self.model = mobilenet_v2(weights=None)
        self.model.features[0][0] = nn.Conv2d(6, 32, kernel_size=3, stride=2, padding=1, bias=False)
        self.model.classifier[1] = nn.Linear(self.model.classifier[1].in_features, num_classes)

    def forward(self, x):
        return self.model(x)

# ---- DATA SPLITTING & BALANCING ----
def read_candidates(txt_path, has_label=True):
    with open(txt_path) as f:
        lines = [l.strip().split() for l in f if l.strip() and not l.startswith("#")]
    if has_label:
        # filter out lines with insufficient columns
        lines = [l for l in lines if len(l) >= 5 and l[-1] in {"0", "1", "2"}]
    else:
        lines = [l for l in lines if len(l) >= 4]
    return lines

def balance_candidates(candidates, seed=42):
    random.seed(seed)
    # Bucket candidates by label
    bucket = {"0": [], "1": [], "2": []}
    for cand in candidates:
        label = cand[-1]
        if label in bucket:
            bucket[label].append(cand)
    # Count positives (1 or 2)
    positives = bucket["1"] + bucket["2"]
    num_positives = len(positives)
    # Downsample negatives to match positives
    num_negatives = min(len(bucket["0"]), num_positives)
    negatives = random.sample(bucket["0"], num_negatives)
    # Merge and shuffle
    balanced = negatives + positives
    random.shuffle(balanced)
    stats = {
        "balanced": {l: sum(1 for x in balanced if x[-1]==l) for l in bucket},
        "original": {l: len(bucket[l]) for l in bucket},
        "total_balanced": len(balanced)
    }
    print(f"Original: {stats['original']}")
    print(f"Balanced: {stats['balanced']}, size={len(balanced)}")
    return balanced

# ---- TRAINING ----
def train_model(model, train_loader, val_loader, device, epochs=EPOCHS, model_save_path="best_model.pth"):
    optimizer = optim.Adam(model.parameters(), lr=LR)
    criterion = nn.CrossEntropyLoss()
    model.to(device)
    metrics = {}
    best_val_acc = 0.0
    best_epoch = -1

    for epoch in range(epochs):
        model.train()
        train_loss, correct, total = 0, 0, 0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
        for x, y in pbar:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            y_pred = model(x)
            loss = criterion(y_pred, y)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * x.size(0)
            _, pred = torch.max(y_pred, 1)
            correct += (pred == y).sum().item()
            total += x.size(0)
            pbar.set_postfix(loss=loss.item(), acc=correct/total)
        train_acc = correct/total
        print(f"Train loss: {train_loss/total:.4f} | Acc: {train_acc:.4f}")

        # Evaluate on validation set
        if val_loader is not None:
            model.eval()
            y_true, y_pred = [], []
            with torch.no_grad():
                for x, y in val_loader:
                    x, y = x.to(device), y.to(device)
                    out = model(x)
                    _, pred = torch.max(out, 1)
                    y_true.extend(y.cpu().numpy())
                    y_pred.extend(pred.cpu().numpy())
            val_acc = sum(1 for t, p in zip(y_true, y_pred) if t == p) / len(y_true)
            report = classification_report(y_true, y_pred, output_dict=True, zero_division=0)
            confmat = confusion_matrix(y_true, y_pred)
            print("Validation metrics:")
            print(classification_report(y_true, y_pred, digits=4, zero_division=0))
            print("Confusion matrix:\n", confmat)
            metrics[epoch] = {"report": report, "confmat": confmat, "val_acc": val_acc}

            # Save best model based on validation accuracy
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                best_epoch = epoch
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'val_acc': val_acc,
                    'metrics': metrics[epoch]
                }, model_save_path)
                print(f"✓ Best model saved at epoch {epoch+1} with val_acc={val_acc:.4f}")
        else:
            # If no validation set, save the last model
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_acc': train_acc
            }, model_save_path)
            print(f"✓ Model saved at epoch {epoch+1}")

    if val_loader is not None:
        print(f"\n[INFO] Best model from epoch {best_epoch+1} with val_acc={best_val_acc:.4f}")

    return model, metrics

# ---- TESTING FUNCTION ----
def test_model(model_path, test_txt, test_bam, fasta_file, device=None):
    """
    Test a saved model on a labeled test set.

    Args:
        model_path: Path to saved model checkpoint
        test_txt: Path to test candidates file (with labels)
        test_bam: Path to test BAM file
        fasta_file: Path to reference FASTA file
        device: Device to run on (default: auto-detect)

    Returns:
        Dictionary with test metrics
    """
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print(f"\n[INFO] Loading model from {model_path}")
    model = MobileNetV2_6ch()
    checkpoint = torch.load(model_path, map_location=device, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()

    print(f"[INFO] Loading test data from {test_txt}")
    test_candidates = read_candidates(test_txt, has_label=True)
    test_set = PileupImageDataset(test_bam, fasta_file, test_candidates)
    num_workers = min(8, multiprocessing.cpu_count())
    test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False,
                            num_workers=num_workers, pin_memory=True)

    print(f"[INFO] Running evaluation on {len(test_candidates)} test samples...")
    y_true, y_pred = [], []
    with torch.no_grad():
        for x, y in tqdm(test_loader, desc="Testing"):
            x, y = x.to(device), y.to(device)
            out = model(x)
            _, pred = torch.max(out, 1)
            y_true.extend(y.cpu().numpy())
            y_pred.extend(pred.cpu().numpy())

    test_acc = sum(1 for t, p in zip(y_true, y_pred) if t == p) / len(y_true)
    report = classification_report(y_true, y_pred, digits=4, zero_division=0, output_dict=True)
    confmat = confusion_matrix(y_true, y_pred)

    print("\n" + "="*60)
    print("TEST SET RESULTS")
    print("="*60)
    print(f"Test Accuracy: {test_acc:.4f}")
    print("\nClassification Report:")
    print(classification_report(y_true, y_pred, digits=4, zero_division=0))
    print("\nConfusion Matrix:")
    print(confmat)
    print("="*60)

    return {
        "accuracy": test_acc,
        "report": report,
        "confusion_matrix": confmat,
        "y_true": y_true,
        "y_pred": y_pred
    }

# ---- INFERENCE ----
def run_inference(model, test_loader, device, label_map={0: "ref", 1: "het", 2: "hom"}):
    model.eval()
    results = []
    with torch.no_grad():
        for images, meta in tqdm(test_loader, desc="Inference"):
            images = images.to(device)
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            for i in range(len(meta[0])):
                # meta is a tuple of lists, e.g. (chroms, poses, refs, alts)
                chrom, pos, ref, alt = meta[0][i], meta[1][i], meta[2][i], meta[3][i]
                results.append((chrom, pos, ref, alt, int(preds[i]), label_map.get(int(preds[i]), str(int(preds[i])))))
    return results

def save_inference_results(results, outfile):
    with open(outfile, "w") as f:
        f.write("chrom\tpos\tref\talt\tpredicted_label\tlabel_name\n")
        for row in results:
            f.write("\t".join(map(str, row)) + "\n")
    print(f"[INFO] Inference results saved to {outfile}")

# ---- INFERENCE FROM SAVED MODEL ----
def inference_from_saved_model(model_path, test_txt, test_bam, fasta_file, output_file="inference_results.txt", device=None):
    """
    Run inference using a saved model on unlabeled data.

    Args:
        model_path: Path to saved model checkpoint
        test_txt: Path to candidates file (without labels)
        test_bam: Path to BAM file
        fasta_file: Path to reference FASTA file
        output_file: Path to save inference results
        device: Device to run on (default: auto-detect)
    """
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print(f"\n[INFO] Loading model from {model_path}")
    model = MobileNetV2_6ch()
    checkpoint = torch.load(model_path, map_location=device,weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()

    print(f"[INFO] Loading data from {test_txt}")
    test_candidates = read_candidates(test_txt, has_label=False)
    test_set = PileupImageUnlabeledDataset(test_bam, fasta_file, test_candidates)
    num_workers = min(8, multiprocessing.cpu_count())
    test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False,
                            num_workers=num_workers, pin_memory=True,
                            collate_fn=collate_unlabeled)

    print(f"[INFO] Running inference on {len(test_candidates)} samples...")
    results = run_inference(model, test_loader, device)
    save_inference_results(results, output_file)

# ---- MULTI-FILE MAIN ----
def main(txt_files, bam_files, fasta_file, val_ratio=0.1, model_save_path="best_model.pth"):
    assert len(txt_files) == len(bam_files), "txt_files and bam_files must have the same length"
    all_train_sets = []
    all_val_sets = []

    for txt_path, bam_path in zip(txt_files, bam_files):
        print(f"\nProcessing {txt_path} ...")
        candidates = read_candidates(txt_path, has_label=True)
        balanced = balance_candidates(candidates)
        # Split into train/val if validation desired
        split = int(len(balanced) * (1 - val_ratio))
        train, val = balanced[:split], balanced[split:] if val_ratio > 0 else (balanced, [])
        train_set = PileupImageDataset(bam_path, fasta_file, train)
        if val:
            val_set = PileupImageDataset(bam_path, fasta_file, val)
            all_val_sets.append(val_set)
        all_train_sets.append(train_set)

    combined_train = ConcatDataset(all_train_sets)
    combined_val = ConcatDataset(all_val_sets) if all_val_sets else None
    num_workers = min(8, multiprocessing.cpu_count())
    train_loader = DataLoader(combined_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=True)
    val_loader = DataLoader(combined_val, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers, pin_memory=True) if combined_val else None

    model = MobileNetV2_6ch()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"\n[INFO] Training on device: {device}")
    print(f"[INFO] Model will be saved to: {model_save_path}")

    _, metrics = train_model(model, train_loader, val_loader, device, model_save_path=model_save_path)

    print(f"\n[INFO] Training complete! Best model saved to {model_save_path}")
    print("[INFO] To test the model, use: test_model(model_path, test_txt, test_bam, fasta_file)")
    print("[INFO] To run inference, use: inference_from_saved_model(model_path, test_txt, test_bam, fasta_file, output_file)")

def collate_unlabeled(batch):
    # Default collate fails if not all elements are tensors (meta data is tuple)
    imgs, metas = zip(*batch)
    # Transpose metas to group by field
    metas = tuple([ [m[i] for m in metas] for i in range(4) ])
    return torch.stack(imgs), metas




In [None]:
# Inference on unlabeled data
inference_from_saved_model(
    model_path="/content/best_model (1).pth",
    test_txt="/content/osdreb2a_candidates.txt",
    test_bam="/content/synthetic_osdreb2a_knockout.bam",
    fasta_file="/content/drive/MyDrive/deepchem stuff/rice new/IRGSP-1.0_genome.fasta",
    output_file="predictions_osdreb2a.txt"
)


[INFO] Loading model from /content/best_model (1).pth
[INFO] Loading data from /content/osdreb2a_candidates.txt
[INFO] Running inference on 1155 samples...


Inference: 100%|██████████| 5/5 [00:09<00:00,  1.92s/it]

[INFO] Inference results saved to predictions_osdreb2a.txt





In [None]:
import pandas as pd

# Read your variant file
df = pd.read_csv('/content/predictions_osdreb2a.txt', sep='\t')

# Create VCF header
vcf_header = """##fileformat=VCFv4.2
##reference=IRGSP-1.0
##contig=<ID=chr01,length=43270923>
##contig=<ID=chr02,length=35937250>
##contig=<ID=chr03,length=36413819>
##contig=<ID=chr04,length=35502694>
##contig=<ID=chr05,length=29958434>
##contig=<ID=chr06,length=31248787>
##contig=<ID=chr07,length=29697621>
##contig=<ID=chr08,length=28443022>
##contig=<ID=chr09,length=23012720>
##contig=<ID=chr10,length=23207287>
##contig=<ID=chr11,length=29021106>
##contig=<ID=chr12,length=27531856>
##INFO=<ID=PRED,Number=1,Type=String,Description="Predicted Label Name">
##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">
#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\tSAMPLE
"""

# Write VCF file
with open('variants_output_osdreb2a.vcf', 'w') as f:
    f.write(vcf_header)

    for idx, row in df.iterrows():
        chrom = row['chrom']
        pos = row['pos']
        ref = row['ref']
        alt = row['alt']
        label = row['label_name']

        # VCF line format
        vcf_line = f"{chrom}\t{pos}\t.\t{ref}\t{alt}\t100\tPASS\tPRED={label}\tGT\t1/1\n"
        f.write(vcf_line)

print("VCF file created: variants_output.vcf")

VCF file created: variants_output.vcf


In [None]:
import pandas as pd

# Variant information
chrom = "chr01"
pos = 3357507
ref = "G"
alt = "A"

# Gene structure from RAP-DB
gene_id = "Os01g0165000"
transcript_id = "Os01t0165000-01"
gene_name = "DREB2A"
strand = "+"

# CDS coordinates
cds_exon1_start = 3356782
cds_exon1_end = 3356828
cds_exon2_start = 3357445
cds_exon2_end = 3358222

# Calculate CDS position
if pos >= cds_exon1_start and pos <= cds_exon1_end:
    cds_pos = pos - cds_exon1_start + 1
    print(f"Variant in CDS Exon 1")
elif pos >= cds_exon2_start and pos <= cds_exon2_end:
    # Account for exon 1 length
    exon1_length = cds_exon1_end - cds_exon1_start + 1
    cds_pos = exon1_length + (pos - cds_exon2_start + 1)
    print(f"Variant in CDS Exon 2")
else:
    cds_pos = None
    print("Variant not in CDS")

print(f"CDS position: {cds_pos}")

# Calculate codon number and position within codon
codon_num = ((cds_pos - 1) // 3) + 1
codon_pos = ((cds_pos - 1) % 3) + 1

print(f"Codon number: {codon_num}")
print(f"Position in codon: {codon_pos}")

Variant in CDS Exon 2
CDS position: 110
Codon number: 37
Position in codon: 2


In [None]:
# Read the CDS sequence
with open('osdreb2a_cds_clean.fasta', 'r') as f:
    lines = f.readlines()
    cds_seq = ''.join([line.strip() for line in lines if not line.startswith('>')])

print(f"CDS length: {len(cds_seq)} bp")

# Extract codon 37
codon_start = (37 - 1) * 3  # 0-indexed
codon_end = codon_start + 3
ref_codon = cds_seq[codon_start:codon_end]

print(f"Reference codon 37: {ref_codon}")

# Apply mutation at position 2 of the codon
codon_list = list(ref_codon)
codon_list[1] = 'A'  # Position 2 in codon (G->A)
alt_codon = ''.join(codon_list)

print(f"Alternate codon 37: {alt_codon}")

# Genetic code
genetic_code = {
    'TTT': 'F', 'TTC': 'F', 'TTA': 'L', 'TTG': 'L',
    'TCT': 'S', 'TCC': 'S', 'TCA': 'S', 'TCG': 'S',
    'TAT': 'Y', 'TAC': 'Y', 'TAA': '*', 'TAG': '*',
    'TGT': 'C', 'TGC': 'C', 'TGA': '*', 'TGG': 'W',
    'CTT': 'L', 'CTC': 'L', 'CTA': 'L', 'CTG': 'L',
    'CCT': 'P', 'CCC': 'P', 'CCA': 'P', 'CCG': 'P',
    'CAT': 'H', 'CAC': 'H', 'CAA': 'Q', 'CAG': 'Q',
    'CGT': 'R', 'CGC': 'R', 'CGA': 'R', 'CGG': 'R',
    'ATT': 'I', 'ATC': 'I', 'ATA': 'I', 'ATG': 'M',
    'ACT': 'T', 'ACC': 'T', 'ACA': 'T', 'ACG': 'T',
    'AAT': 'N', 'AAC': 'N', 'AAA': 'K', 'AAG': 'K',
    'AGT': 'S', 'AGC': 'S', 'AGA': 'R', 'AGG': 'R',
    'GTT': 'V', 'GTC': 'V', 'GTA': 'V', 'GTG': 'V',
    'GCT': 'A', 'GCC': 'A', 'GCA': 'A', 'GCG': 'A',
    'GAT': 'D', 'GAC': 'D', 'GAA': 'E', 'GAG': 'E',
    'GGT': 'G', 'GGC': 'G', 'GGA': 'G', 'GGG': 'G'
}

ref_aa = genetic_code.get(ref_codon, '?')
alt_aa = genetic_code.get(alt_codon, '?')

print(f"Reference amino acid: {ref_aa}")
print(f"Alternate amino acid: {alt_aa}")

# Determine variant type
if alt_aa == '*':
    variant_type = "stop_gained"
    impact = "HIGH"
elif ref_aa == alt_aa:
    variant_type = "synonymous_variant"
    impact = "LOW"
elif ref_aa != alt_aa and alt_aa != '*':
    variant_type = "missense_variant"
    impact = "MODERATE"
else:
    variant_type = "unknown"
    impact = "MODIFIER"

print(f"\nVariant type: {variant_type}")
print(f"Impact: {impact}")
print(f"Amino acid change: {ref_aa}{codon_num}{alt_aa}")

CDS length: 2790 bp
Reference codon 37: TGG
Alternate codon 37: TAG
Reference amino acid: W
Alternate amino acid: *

Variant type: stop_gained
Impact: HIGH
Amino acid change: W37*
