# Environment

## Import packages

In [29]:
import wandb
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import scanpy as sc
from src.peak2vec import PeakDataset, peak2vec_collate, prepare_adata, get_sampling_distributions

## Load data

In [30]:
adata = sc.read_h5ad("data/pbmc10k_atac_filtered.h5ad")
adata

KeyboardInterrupt: 

In [None]:
prepare_adata(adata)
adata

AnnData object with n_obs × n_vars = 10006 × 260822
    obs: 'cisTopic_nr_frag', 'cisTopic_log_nr_frag', 'cisTopic_nr_acc', 'cisTopic_log_nr_acc', 'sample_id', 'barcode_rank', 'total_fragments_count', 'log10_total_fragments_count', 'unique_fragments_count', 'log10_unique_fragments_count', 'total_fragments_in_peaks_count', 'log10_total_fragments_in_peaks_count', 'unique_fragments_in_peaks_count', 'log10_unique_fragments_in_peaks_count', 'fraction_of_fragments_in_peaks', 'duplication_count', 'duplication_ratio', 'nucleosome_signal', 'tss_enrichment', 'pdf_values_for_tss_enrichment', 'pdf_values_for_fraction_of_fragments_in_peaks', 'pdf_values_for_duplication_ratio', 'barcode', 'celltype', 'n_features_per_cell', 'total_fragment_counts'
    var: 'Chromosome', 'Start', 'End', 'Width', 'cisTopic_nr_frag', 'cisTopic_log_nr_frag', 'cisTopic_nr_acc', 'cisTopic_log_nr_acc', 'n_cells_per_feature', 'mean_counts', 'pct_dropout_by_counts', 'total_counts', 'Center'
    layers: 'binary'

# Train

## Model

In [None]:
class Peak2Vec(nn.Module):
    """Skip-gram model over peaks with negative sampling."""
    def __init__(self, n_peaks, embedding_dim=64, pos_weight=1.0, sparse=True):
        super(Peak2Vec, self).__init__()
        self.dim = embedding_dim
        self.pos_weight = pos_weight
        self.embedding = nn.Embedding(n_peaks, embedding_dim, sparse=sparse)
        self.reset_params()
    
    def reset_params(self):
        nn.init.uniform_(self.embedding.weight, -0.5 / self.dim, 0.5 / self.dim)

    def forward(self, peaks, peak_pairs, negatives):
        """
        Compute the Skip-gram with Negative Sampling loss.
        peaks: (B,) LongTensor of peak indices
        peak_pairs: (B,) LongTensor of positive context peak indices
        negatives: (B, K) LongTensor of negative sample peak indices
        """
        # Embeddings
        peak_emb = self.embedding(peaks)               # (B, D)
        pair_emb = self.embedding(peak_pairs)          # (B, D)
        neg_emb = self.embedding(negatives)            # (B, K, D)

        # Compute similarity scores
        pos_score = torch.sum(peak_emb * pair_emb, dim=1)                   # (B)
        neg_score = torch.bmm(neg_emb, peak_emb.unsqueeze(2)).squeeze(2)    # (B, K)

        # Loss
        poss_loss = F.softplus(-pos_score)         # -log(sigmoid(x)) = softplus(-x)
        neg_loss = F.softplus(neg_score).sum(1)      # -log(1 - sigmoid(x)) = softplus(x)
        loss = (self.pos_weight * poss_loss + neg_loss).mean()

        with torch.no_grad():
            stats = {
                "pos_score_mean": pos_score.mean().detach(),
                "neg_score_mean": neg_score.mean().detach(),
                "pos_loss_mean":  poss_loss.mean().detach(),
                "neg_loss_mean":  neg_loss.mean().detach(),
            }
        return loss, stats
    
    @torch.no_grad()
    def get_peak_embeddings(self, normalize=True):
        embeddings = self.embedding.weight.detach().cpu()
        if normalize:
            embeddings = F.normalize(embeddings, p=2, dim=1)
        return embeddings
    
    @torch.no_grad()
    def most_similar(self, peak_idx, topk=10):
        embeddings = self.get_peak_embeddings(normalize=True)  # (N, D)
        peak_emb = embeddings[peak_idx:peak_idx+1]             # (1, D)
        similarities = (embeddings @ peak_emb.t()).squeeze(1)  # (N)
        values, indices = torch.topk(similarities, topk + 1, embeddings.size(0), largest=True)
        # drop self if in topk
        if indices and indices[0] == peak_idx:
            indices, values = indices[1:], values[1:]
        return indices[:topk], values[:topk]

## Training

In [None]:
seed = 4
n_pairs = 20
n_negative = 20
samples_per_epoch = 20000
batch_size = 512
embedding_dim = 128
trans_fraction=0.2
cis_window=500000
same_chr_negative_prob=0.5
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

Using device: cuda


In [None]:
neg, keep = get_sampling_distributions(adata, t=5e-7, power=0.75)

In [None]:
dataset = PeakDataset(
    X=adata.X,
    chr=adata.var["Chromosome"].values,
    centers=adata.var["Center"].values,
    neg_distribution=neg,
    keep_distribution=keep,
    samples_per_epoch=samples_per_epoch,
    n_pairs=n_pairs,
    n_negative=n_negative,
    seed=seed,
    trans_fraction=trans_fraction,
    cis_window=cis_window,
    same_chr_negative_prob=same_chr_negative_prob,
)

loader = DataLoader(
    dataset,
    batch_size=batch_size,
    collate_fn=peak2vec_collate,
    num_workers=0,
    pin_memory=True,
)

In [None]:
iterator = iter(loader)

In [None]:
peaks, peak_pairs, negatives = next(iterator)
peaks.size(), peak_pairs.size(), negatives.size(), dataset.counter, peaks, peak_pairs

(torch.Size([512]),
 torch.Size([512]),
 torch.Size([512, 20]),
 defaultdict(int,
             {'chr12': 31,
              'chr2': 39,
              'chr14': 12,
              'chr5': 24,
              'chr17': 32,
              'chr1': 48,
              'chr4': 16,
              'chr3': 26,
              'chr19': 45,
              'chr22': 16,
              'chr16': 19,
              'chr6': 35,
              'chr11': 30,
              'chr10': 16,
              'chr18': 5,
              'chr8': 18,
              'chr7': 22,
              'chr20': 12,
              'chr9': 29,
              'chrX': 8,
              'chr21': 6,
              'chr15': 18,
              'chr13': 5}),
 tensor([153976,  33045, 175417, 171519,  71453, 214272,  23239,  65453,  50967,
         224073, 251459, 186153, 202136, 204147, 190148,  71024, 225572,  84389,
         225439,  75257, 141176, 203246, 139675, 128623, 154168, 227520, 156306,
         216574, 160885, 225418, 160174, 208095,  83971, 111800, 1

In [None]:
model = Peak2Vec(adata.n_vars, embedding_dim=embedding_dim, pos_weight=1.0, sparse=True).to(device)
optimizer = torch.optim.SparseAdam(model.parameters(), lr=2e-3)

In [None]:
history = list()

In [37]:
import random
import numpy as np

with wandb.init(
    entity="claptar",
    project="peak2vec",
    dir="./data/wandb",
    id="peak2vec-2",
    name="peak2vec-test-run2",
    notes="Testing Peak2Vec with sparse embeddings and modified loss",
    tags=["test", "sparse_embeddings", "first_run"],
    config={
        'seed': seed,
        'n_pairs': n_pairs,
        'n_negative': n_negative,
        'samples_per_epoch': samples_per_epoch,
        'batch_size': batch_size,
        'embedding_dim': embedding_dim,
        'trans_fraction': trans_fraction,
        'cis_window': cis_window,
        'same_chr_negative_prob': same_chr_negative_prob,
        'device': device,
        'embedding_dim': embedding_dim,
        'pos_weight': 1.0,
        'sparse': True,
        'optimizer': 'SparseAdam',
        'learning_rate': 2e-3,
    },
    mode="online",
    resume="allow",
) as run:
    scores = []
    for epoch in range(1, 4):
        score = random.random()
        scores.append(score)
        embeddings = np.random.random((100, 10)).tolist()
        run.log({"score": score})
        table = wandb.Table(data=embeddings, columns=[f"dim_{i}" for i in range(10)])
        table.add_column("color", random.choices(["red", "green", "blue"], k=100))
    run.log(
        {"table": table}
    )
    run.summary["final_score"] = np.mean(scores)
    run.finish()

0,1
score,█▁▆

0,1
final_score,0.3809
score,0.46581


In [None]:
num_epochs = 10

for epoch in range(1, num_epochs + 1):
    model.train()

    running = 0.0
    running_pos = 0.0
    running_neg = 0.0
    #dataset.set_epoch(epoch)
    for step, (peaks, peak_pairs, negatives) in enumerate(loader, 1):
        peaks = peaks.to(device, non_blocking=True)
        peak_pairs = peak_pairs.to(device, non_blocking=True)
        negatives = negatives.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        loss, stats = model(peaks, peak_pairs, negatives)
        loss.backward()
        
        optimizer.step()

        running += float(loss.detach().cpu())
        running_pos += float(stats["pos_loss_mean"].cpu())
        running_neg += float(stats["neg_loss_mean"].cpu())

        #print(f"Epoch {epoch} | Step {step:03d} | Loss: {running / step:.4f} | Pos Loss: {stats['pos_loss_mean']:.4f} | Neg Loss: {stats['neg_loss_mean']:.4f} | Pos Score: {stats['pos_score_mean']:.4f} | Neg Score: {stats['neg_score_mean']:.4f}")
    
    # Log epoch stats
    stats["running_loss"] = running / step
    stats["pos_loss_mean"] = running_pos / step
    stats["neg_loss_mean"] = running_neg / step
    history.append(stats)
    print(f"Epoch {epoch} | Step {step:03d} | Loss: {running / step:.6f} | Pos Loss: {stats['pos_loss_mean']:.6f} | Neg Loss: {stats['neg_loss_mean']:.6f} | Pos Score: {stats['pos_score_mean']:.6f} | Neg Score: {stats['neg_score_mean']:.6f}")

Epoch 1 | Step 040 | Loss: 14.556078 | Pos Loss: 0.693123 | Neg Loss: 13.862955 | Pos Score: 0.000639 | Neg Score: -0.000009
Epoch 2 | Step 040 | Loss: 14.551709 | Pos Loss: 0.688833 | Neg Loss: 13.862877 | Pos Score: 0.020975 | Neg Score: 0.000092
Epoch 3 | Step 040 | Loss: 14.540368 | Pos Loss: 0.678216 | Neg Loss: 13.862152 | Pos Score: 0.061966 | Neg Score: 0.000160


KeyboardInterrupt: 