In [1]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from transformers import AutoTokenizer, AutoModel, BertConfig, BertTokenizer, BertModel

"""
Title: CLIP-like Genotype-Phenotype Pipeline
Author: [Your Name]
Date: [Date]

Description:
  This script demonstrates a production-ready pipeline that:
    1. Pulls variant data from ClinVar via E-utilities or public FTP.
    2. Extracts a 100bp flanking sequence around each variant using Ensembl REST API.
    3. Cleans disease/phenotype names using a simple dictionary approach.
    4. Creates a dataset of (DNA_sequence, phenotype_text) pairs in JSONL format.
    5. Builds a CLIP-like contrastive learning model with DNABERT (DNA encoder) and
       BioBERT (text encoder) from Hugging Face.
    6. Trains the model on the dataset with a standard InfoNCE contrastive loss.
    7. Implements a basic evaluation loop that checks retrieval accuracy.
    8. Packages the final model for inference with a simple Flask API:
       - Route: POST /predict
         Input JSON:  { "dna_sequence": "ACGT..." }
         Output JSON: { "top_phenotype": "Disease/Phenotype" }

Environment & Requirements:
  - Python 3.9+
  - pip install -r requirements.txt
    Where requirements.txt might include:
      requests
      tqdm
      flask
      torch
      transformers
      huggingface_hub
      sentencepiece
      # (Additional packages as needed)

Usage:
  1. Adjust the configuration variables (API endpoints, file paths) as necessary.
  2. Run: python main.py --download_data --prepare_dataset --train --evaluate
  3. Optional: python main.py --serve (to launch the Flask app)

Note:
  - This script uses simplified approaches and mock examples for demonstration.
  - Production usage will require robust error-checking, secure credential handling,
    and more advanced data cleaning/validation logic.
  - Actual data retrieval from ClinVar or Ensembl may require specialized query
    parameters and handling large files. This script provides a template for
    demonstration, not a fully validated clinical pipeline.
"""

import os
import sys
import json
import argparse
import logging
import requests
from tqdm import tqdm
from typing import Dict, Any, List

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, RandomSampler

from transformers import (
    AutoTokenizer,
    AutoModel,
    AutoConfig,
    BertConfig
)
from huggingface_hub import snapshot_download

from flask import Flask, request, jsonify

##############################################################################
#                                CONFIGURATION
##############################################################################

# Logging configuration
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(name)s - %(message)s"
)

# Directories (adjust paths as needed)
DATA_DIR = "./data"
DATA_JSONL = os.path.join(DATA_DIR, "clinvar_dataset.jsonl")
MODEL_DIR = "./model_output"

# Some endpoints (these may change, check official docs):
CLINVAR_FTP_URL = "https://ftp.ncbi.nlm.nih.gov/pub/clinvar/vcf_GRCh38/"
ENSEMBL_REST_URL = "https://rest.ensembl.org/sequence/region/human"

# Simplified dictionary for disease name normalization
DISEASE_NORMALIZATION_DICT = {
    "Breast cancer": "Breast Cancer",
    "Cystic fibrosis": "Cystic Fibrosis",
    "CF": "Cystic Fibrosis"
}

# For demonstration, these are smaller or mock checkpoints
DNABERT_CKPT = "zhihan1996/DNABERT-2-117M"    # Example DNABERT checkpoint on Hugging Face
BIOBERT_CKPT = "pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb"


##############################################################################
#                          STEP 1: DATA DOWNLOAD
##############################################################################

def download_clinvar_data(output_vcf_path: str = "./data/clinvar_variants.vcf.gz"):
    """
    Downloads the latest ClinVar VCF file from NCBI's FTP server.
    """
    if not os.path.exists(DATA_DIR):
        os.makedirs(DATA_DIR, exist_ok=True)

    logging.info("Downloading ClinVar data...")
    
    # URL for latest ClinVar VCF (GRCh38)
    clinvar_url = "https://ftp.ncbi.nlm.nih.gov/pub/clinvar/vcf_GRCh38/clinvar.vcf.gz"
    
    # Download with progress bar
    response = requests.get(clinvar_url, stream=True)
    total_size = int(response.headers.get('content-length', 0))
    
    with open(output_vcf_path, 'wb') as f, tqdm(
        desc="Downloading ClinVar VCF",
        total=total_size,
        unit='iB',
        unit_scale=True,
        unit_divisor=1024,
    ) as pbar:
        for data in response.iter_content(chunk_size=1024):
            size = f.write(data)
            pbar.update(size)
    
    logging.info(f"ClinVar download completed => {output_vcf_path}")


##############################################################################
#                       STEP 2: EXTRACT FLANKING SEQUENCES
##############################################################################

def get_flanking_sequence(chrom: str, start: int, ref: str, alt: str, flank_size: int = 50) -> str:
    """
    Fetches a 2 * flank_size bp sequence around the variant from Ensembl REST API.

    Args:
        chrom (str): Chromosome identifier, e.g. '1', 'X', ...
        start (int): 1-based genomic position of the variant
        ref (str): Reference allele (unused in this mock, but typically relevant)
        alt (str): Alternate allele (unused in this mock, but typically relevant)
        flank_size (int): Number of bp on each side

    Returns:
        str: The flanking DNA sequence in uppercase (or empty if error).
    """
    # Build Ensembl region string: e.g. "1:100000-100100"
    region_start = max(start - flank_size, 1)  # Ensuring not below 1
    region_end = start + flank_size
    region_str = f"{chrom}:{region_start}..{region_end}"
    
    headers = {"Content-Type": "application/json"}
    url = f"{ENSEMBL_REST_URL}/{region_str}?content-type=text/plain"
    
    try:
        response = requests.get(url, headers=headers, timeout=10)
        if response.status_code == 200:
            return response.text.strip().upper()
        else:
            logging.warning(f"Ensembl REST API error {response.status_code} for region {region_str}")
            return ""
    except requests.exceptions.RequestException as e:
        logging.error(f"Request failed for region {region_str}: {str(e)}")
        return ""


##############################################################################
#                   STEP 3: DISEASE/PHENOTYPE NORMALIZATION
##############################################################################

def normalize_phenotype_name(raw_name: str) -> str:
    """
    Cleans disease/phenotype names using a simple dictionary approach.
    In real production, might integrate with HPO/MedGen/MeSH.
    """
    raw_lower = raw_name.lower().strip()
    for k, v in DISEASE_NORMALIZATION_DICT.items():
        if k.lower() in raw_lower:
            return v
    return raw_name  # If no match, return original


##############################################################################
#              STEP 4: CREATE DATASET (DNA_sequence, phenotype_text)
##############################################################################

def create_dataset_from_clinvar_vcf(vcf_path: str, jsonl_path: str):
    """
    Creates a dataset from ClinVar VCF file by:
    1. Parsing variants and their clinical significance
    2. Getting flanking sequences from Ensembl
    3. Writing (DNA_sequence, phenotype_text) pairs to JSONL
    """
    logging.info("Creating dataset from ClinVar VCF...")
    
    if not os.path.exists(DATA_DIR):
        os.makedirs(DATA_DIR, exist_ok=True)

    # Read VCF file (gzipped)
    import gzip
    variants = []
    
    with gzip.open(vcf_path, 'rt') as f:
        # Skip headers
        for line in f:
            if not line.startswith('#'):
                # Parse VCF fields
                fields = line.strip().split('\t')
                if len(fields) < 8:  # Ensure minimum required fields
                    continue
                    
                chrom = fields[0]
                pos = int(fields[1])
                ref = fields[3]
                alt = fields[4]
                info = fields[7]
                
                # Extract CLNDN (disease names) and CLNSIG (clinical significance)
                clndn = None
                clnsig = None
                
                for item in info.split(';'):
                    if item.startswith('CLNDN='):
                        clndn = item.split('=')[1]
                    elif item.startswith('CLNSIG='):
                        clnsig = item.split('=')[1]
                
                # Only keep pathogenic/likely pathogenic variants with disease names
                if clndn and clnsig and ('pathogenic' in clnsig.lower()):
                    variants.append({
                        "chrom": chrom,
                        "pos": pos,
                        "ref": ref,
                        "alt": alt,
                        "phenotype": clndn.split('|')[0]  # Take first disease name if multiple
                    })
                    
                    # Limit dataset size for demonstration
                    if len(variants) >= 1000:
                        break

    # Get flanking sequences and write to JSONL
    with open(jsonl_path, 'w', encoding='utf-8') as fw:
        for variant in tqdm(variants, desc="Processing variants"):
            # 1) Get flanking sequence
            flanking_seq = get_flanking_sequence(
                chrom=variant["chrom"],
                start=variant["pos"],
                ref=variant["ref"],
                alt=variant["alt"],
                flank_size=50
            )
            
            if not flanking_seq:
                continue
                
            # 2) Normalize phenotype name
            normalized_pheno = normalize_phenotype_name(variant["phenotype"])
            
            # 3) Write record
            record = {
                "dna_sequence": flanking_seq,
                "phenotype_text": normalized_pheno
            }
            fw.write(json.dumps(record) + "\n")

    logging.info(f"Dataset creation completed. Output => {jsonl_path}")


##############################################################################
#                       STEP 5: MODEL DEFINITIONS
##############################################################################

def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

class ContrastiveModel(torch.nn.Module):
    def __init__(self, dna_model_name, text_model_name, proj_dim=128, freeze_encoders=True):
        super(ContrastiveModel, self).__init__()
        
        # Load DNABERT with explicit configuration
        dna_config = BertConfig.from_pretrained(dna_model_name)
        dna_config.alibi_starting_size = 512
        self.dna_encoder = AutoModel.from_pretrained(
            dna_model_name,
            config=dna_config,
            trust_remote_code=True
        )
        
        # Load BioBERT directly as a BertModel
        self.text_encoder = BertModel.from_pretrained(text_model_name)

        # Projection heads
        self.dna_proj = torch.nn.Linear(dna_config.hidden_size, proj_dim)
        self.text_proj = torch.nn.Linear(768, proj_dim) # BioBERT output dim is 768
        
        # Freeze encoders if needed
        if freeze_encoders:
            for param in self.dna_encoder.parameters():
                param.requires_grad = False
            for param in self.text_encoder.parameters():
                param.requires_grad = False

    def forward(self, dna_input, text_input):
        # Get DNA embeddings if dna_input is provided
        if dna_input is not None:
            dna_outputs = self.dna_encoder(**dna_input)
            dna_out = dna_outputs[0][:, 0, :]  # Use CLS token embedding
            dna_proj_out = self.dna_proj(dna_out)
        else:
            dna_proj_out = None
        
        # Get text embeddings using mean pooling if text_input is provided
        if text_input is not None:
            text_outputs = self.text_encoder(**text_input)
            text_out = mean_pooling(text_outputs, text_input['attention_mask'])
            text_proj_out = self.text_proj(text_out)
        else:
            text_proj_out = None
        
        return dna_proj_out, text_proj_out

##############################################################################
#                   STEP 6: TRAINING & INFO-NCE CONTRASTIVE LOSS
##############################################################################

class DNAPhenotypeDataset(Dataset):
    """
    A basic PyTorch Dataset for (DNA_sequence, phenotype_text) pairs in JSONL.
    """

    def __init__(self, jsonl_path: str, dna_tokenizer, text_tokenizer, max_dna_length=128, max_text_length=32):
        self.records = []
        with open(jsonl_path, 'r', encoding='utf-8') as f:
            for line in f:
                self.records.append(json.loads(line.strip()))
        self.dna_tokenizer = dna_tokenizer
        self.text_tokenizer = text_tokenizer
        self.max_dna_length = max_dna_length
        self.max_text_length = max_text_length

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        rec = self.records[idx]
        dna_seq = rec["dna_sequence"]
        pheno_text = rec["phenotype_text"]

        dna_enc = self.dna_tokenizer(
            dna_seq,
            truncation=True,
            padding="max_length",
            max_length=self.max_dna_length,
            return_tensors="pt"
        )

        text_enc = self.text_tokenizer(
            pheno_text,
            truncation=True,
            padding="max_length",
            max_length=self.max_text_length,
            return_tensors="pt"
        )

        # Convert to simpler dict
        dna_enc = {k: v.squeeze(0) for k, v in dna_enc.items()}
        text_enc = {k: v.squeeze(0) for k, v in text_enc.items()}

        return dna_enc, text_enc


def info_nce_loss(dna_emb, text_emb, temperature=0.07):
    """
    Standard InfoNCE contrastive loss:
      - dna_emb, text_emb: [batch_size, proj_dim]
    """
    batch_size = dna_emb.size(0)

    # Normalize embeddings
    dna_emb = nn.functional.normalize(dna_emb, dim=-1)
    text_emb = nn.functional.normalize(text_emb, dim=-1)

    # Similarity matrix [batch_size, batch_size]
    logits = torch.matmul(dna_emb, text_emb.T) / temperature

    # Create labels for matching pairs
    labels = torch.arange(batch_size).to(dna_emb.device)

    loss_dna = nn.CrossEntropyLoss()(logits, labels)
    loss_text = nn.CrossEntropyLoss()(logits.T, labels)

    loss = (loss_dna + loss_text) / 2.0
    return loss


def train_model(args):
    """
    High-level training routine.
    """
    logging.info("Loading tokenizers and building dataset...")

    dna_tokenizer = AutoTokenizer.from_pretrained(DNABERT_CKPT, trust_remote_code=True)
    text_tokenizer = AutoTokenizer.from_pretrained(BIOBERT_CKPT)

    dataset = DNAPhenotypeDataset(
        jsonl_path=DATA_JSONL,
        dna_tokenizer=dna_tokenizer,
        text_tokenizer=text_tokenizer,
        max_dna_length=128,
        max_text_length=32
    )

    dataloader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        sampler=RandomSampler(dataset),
        num_workers=0
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = ContrastiveModel(dna_model_name=DNABERT_CKPT,
                             text_model_name=BIOBERT_CKPT,
                             proj_dim=args.proj_dim,
                             freeze_encoders=False)
    model.to(device)

    optimizer = optim.AdamW(model.parameters(), lr=args.lr)

    logging.info(f"Starting training for {args.num_epochs} epochs...")
    
    # Add progress bar for epochs
    pbar_epochs = tqdm(range(args.num_epochs), desc="Training epochs")
    for epoch in pbar_epochs:
        model.train()
        total_loss = 0.0
        
        # Add progress bar for batches
        pbar_batches = tqdm(enumerate(dataloader), desc=f"Epoch {epoch+1}", 
                           total=len(dataloader), leave=False)
        
        for step, (dna_enc, text_enc) in pbar_batches:
            # Move to device
            for k in dna_enc:
                dna_enc[k] = dna_enc[k].to(device)
            for k in text_enc:
                text_enc[k] = text_enc[k].to(device)

            dna_emb, text_emb = model(dna_enc, text_enc)
            loss = info_nce_loss(dna_emb, text_emb, temperature=args.temperature)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            avg_loss = total_loss / (step + 1)
            
            # Update batch progress bar
            pbar_batches.set_postfix({'loss': f'{avg_loss:.4f}'})

        # Update epoch progress bar
        pbar_epochs.set_postfix({'avg_loss': f'{avg_loss:.4f}'})

        # Save checkpoint each epoch
        os.makedirs(MODEL_DIR, exist_ok=True)
        ckpt_path = os.path.join(MODEL_DIR, f"checkpoint_epoch_{epoch+1}.pt")
        torch.save(model.state_dict(), ckpt_path)
        logging.info(f"Saved checkpoint => {ckpt_path}")


##############################################################################
#             STEP 7: BASIC EVALUATION (RETRIEVAL ACCURACY)
##############################################################################

def evaluate_model(args):
    """
    Simple retrieval test: given a DNA snippet, do we retrieve the correct phenotype
    text among a set of distractors?
    """
    logging.info("Starting evaluation...")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = ContrastiveModel(dna_model_name=DNABERT_CKPT,
                             text_model_name=BIOBERT_CKPT,
                             proj_dim=args.proj_dim,
                             freeze_encoders=True)

    best_ckpt = os.path.join(MODEL_DIR, f"checkpoint_epoch_{args.num_epochs}.pt")
    if not os.path.exists(best_ckpt):
        logging.warning(f"No checkpoint found at {best_ckpt}. Evaluation aborted.")
        return
    model.load_state_dict(torch.load(best_ckpt, map_location=device))
    model.to(device)
    model.eval()

    dna_tokenizer = AutoTokenizer.from_pretrained(DNABERT_CKPT, trust_remote_code=True)
    text_tokenizer = AutoTokenizer.from_pretrained(BIOBERT_CKPT)

    dataset = []
    with open(DATA_JSONL, 'r', encoding='utf-8') as f:
        for line in f:
            dataset.append(json.loads(line.strip()))

    total = len(dataset)
    correct = 0

    with torch.no_grad():
        # Add progress bar for evaluation
        pbar = tqdm(enumerate(dataset), desc="Evaluating", total=total)
        for idx, record in pbar:
            dna_seq = record["dna_sequence"]
            target_phenotype = record["phenotype_text"]

            dna_enc = dna_tokenizer(dna_seq, truncation=True, padding="max_length",
                                    max_length=128, return_tensors="pt")
            for k in dna_enc:
                dna_enc[k] = dna_enc[k].to(device)

            dna_emb, _ = model(dna_enc, text_input=None)
            dna_emb = nn.functional.normalize(dna_emb, dim=-1)

            distractors = [r["phenotype_text"] for i, r in enumerate(dataset) if i != idx]
            distractors = distractors[:4]
            candidates = [target_phenotype] + distractors

            candidate_embs = []
            for ctext in candidates:
                text_enc = text_tokenizer(ctext, truncation=True, padding="max_length",
                                          max_length=32, return_tensors="pt")
                for k in text_enc:
                    text_enc[k] = text_enc[k].to(device)
                _, text_e = model(dna_input=None, text_input=text_enc)
                text_e = nn.functional.normalize(text_e, dim=-1)
                candidate_embs.append(text_e)

            sims = [torch.matmul(dna_emb, t_emb.T).item() for t_emb in candidate_embs]
            max_idx = int(torch.argmax(torch.tensor(sims)))
            if max_idx == 0:
                correct += 1
                
            # Update progress bar with current accuracy
            accuracy = correct / (idx + 1)
            pbar.set_postfix({'accuracy': f'{accuracy*100:.2f}%'})

    final_accuracy = correct / total
    logging.info(f"Final Retrieval Accuracy: {final_accuracy*100:.2f}%")
    return final_accuracy


##############################################################################
#        STEP 8: DEPLOYMENT WITH FLASK (SIMPLE INFERENCE ENDPOINT)
##############################################################################

app = Flask(__name__)

# Global reference to model/tokenizers (loaded on first request)
MODEL_CACHE = {
    "model": None,
    "dna_tokenizer": None,
    "text_tokenizer": None
}


def load_inference_objects():
    """
    Lazy-load the model and tokenizers for inference.
    """
    if MODEL_CACHE["model"] is None:
        logging.info("Loading model and tokenizers for inference...")
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        dna_tokenizer = AutoTokenizer.from_pretrained(DNABERT_CKPT, trust_remote_code=True)
        text_tokenizer = AutoTokenizer.from_pretrained(BIOBERT_CKPT)
        MODEL_CACHE["dna_tokenizer"] = dna_tokenizer
        MODEL_CACHE["text_tokenizer"] = text_tokenizer

        model = ContrastiveModel(dna_model_name=DNABERT_CKPT,
                                 text_model_name=BIOBERT_CKPT,
                                 proj_dim=256,
                                 freeze_encoders=True)
        # Load final checkpoint (this is a placeholder—replace with your best checkpoint)
        best_ckpt = os.path.join(MODEL_DIR, "checkpoint_epoch_1.pt")
        if os.path.exists(best_ckpt):
            model.load_state_dict(torch.load(best_ckpt, map_location=device))
        model.to(device)
        model.eval()

        MODEL_CACHE["model"] = model

    return MODEL_CACHE["model"], MODEL_CACHE["dna_tokenizer"], MODEL_CACHE["text_tokenizer"]


@app.route("/predict", methods=["POST"])
def predict():
    """
    Example: 
      curl -X POST -H "Content-Type: application/json" \
           -d '{"dna_sequence": "ACGTACGTACG..."}' \
           http://localhost:5000/predict
    Returns:
      JSON with "top_phenotype" = "..."
    """
    data = request.get_json()
    if not data or "dna_sequence" not in data:
        return jsonify({"error": "Missing 'dna_sequence' in request"}), 400

    dna_seq = data["dna_sequence"]
    model, dna_tok, text_tok = load_inference_objects()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Encode DNA
    dna_enc = dna_tok(dna_seq, truncation=True, padding="max_length",
                      max_length=128, return_tensors="pt")
    for k in dna_enc:
        dna_enc[k] = dna_enc[k].to(device)

    with torch.no_grad():
        dna_emb, _ = model(dna_enc, text_input=None)
        dna_emb = nn.functional.normalize(dna_emb, dim=-1)

    # For demonstration, pick from dataset or a small dictionary of known phenotypes
    phenotypes = ["Breast Cancer", "Cystic Fibrosis", "Unknown Condition"]
    best_sim = None
    best_pheno = None

    with torch.no_grad():
        for p in phenotypes:
            text_enc = text_tok(p, truncation=True, padding="max_length",
                                max_length=32, return_tensors="pt")
            for k in text_enc:
                text_enc[k] = text_enc[k].to(device)
            _, text_emb = model(dna_input=None, text_input=text_enc)
            text_emb = nn.functional.normalize(text_emb, dim=-1)

            sim = torch.matmul(dna_emb, text_emb.T).item()
            if (best_sim is None) or (sim > best_sim):
                best_sim = sim
                best_pheno = p

    return jsonify({"top_phenotype": best_pheno})


##############################################################################
#                              MAIN SCRIPT
##############################################################################

class ModelConfig:
    def __init__(self):
        self.num_epochs = 1
        self.batch_size = 2
        self.lr = 1e-4
        self.proj_dim = 256
        self.temperature = 0.07
        self.log_interval = 1

def setup_environment():
    """Initialize directories and download mock data"""
    # Create directories
    os.makedirs(DATA_DIR, exist_ok=True)
    os.makedirs(MODEL_DIR, exist_ok=True)
    
    # Download mock data
    download_clinvar_data()
    
    # Create dataset
    create_dataset_from_clinvar_vcf(
        vcf_path="./data/clinvar_variants.vcf.gz",
        jsonl_path=DATA_JSONL
    )

def run_training():
    """Run the full training pipeline"""
    # Setup
    setup_environment()
    
    # Create config
    args = ModelConfig()
    
    # Train
    train_model(args)
    
    # Evaluate
    evaluate_model(args)
    
    return args

# For Jupyter notebook usage, you can now just run:
# args = run_training()


In [2]:
# In your notebook:
args = run_training()

Downloading ClinVar VCF: 100%|██████████| 100M/100M [00:03<00:00, 31.3MiB/s] 
Processing variants: 100%|██████████| 1000/1000 [09:41<00:00,  1.72it/s]


tokenizer_config.json:   0%|          | 0.00/158 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/168k [00:00<?, ?B/s]



tokenizer_config.json:   0%|          | 0.00/412 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/669k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/904 [00:00<?, ?B/s]

bert_layers.py:   0%|          | 0.00/40.7k [00:00<?, ?B/s]

bert_padding.py:   0%|          | 0.00/6.10k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/zhihan1996/DNABERT-2-117M:
- bert_padding.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


flash_attn_triton.py:   0%|          | 0.00/42.7k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/zhihan1996/DNABERT-2-117M:
- flash_attn_triton.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downloaded from https://huggingface.co/zhihan1996/DNABERT-2-117M:
- bert_layers.py
- bert_padding.py
- flash_attn_triton.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


pytorch_model.bin:   0%|          | 0.00/468M [00:00<?, ?B/s]

Some weights of BertModel were not initialized from the model checkpoint at zhihan1996/DNABERT-2-117M and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


config.json:   0%|          | 0.00/691 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/433M [00:00<?, ?B/s]

Training epochs:   0%|          | 0/1 [00:00<?, ?it/s]
Epoch 1:   0%|          | 0/500 [00:00<?, ?it/s][A
Epoch 1:   0%|          | 0/500 [00:00<?, ?it/s, loss=1.0624][A
Epoch 1:   0%|          | 1/500 [00:00<07:51,  1.06it/s, loss=1.0624][A
Epoch 1:   0%|          | 1/500 [00:01<07:51,  1.06it/s, loss=0.8997][A
Epoch 1:   0%|          | 2/500 [00:01<03:47,  2.19it/s, loss=0.8997][A
Epoch 1:   0%|          | 2/500 [00:01<03:47,  2.19it/s, loss=0.8762][A
Epoch 1:   1%|          | 3/500 [00:01<02:28,  3.35it/s, loss=0.8762][A
Epoch 1:   1%|          | 3/500 [00:01<02:28,  3.35it/s, loss=0.8855][A
Epoch 1:   1%|          | 4/500 [00:01<01:51,  4.44it/s, loss=0.8855][A
Epoch 1:   1%|          | 4/500 [00:01<01:51,  4.44it/s, loss=0.8365][A
Epoch 1:   1%|          | 5/500 [00:01<01:31,  5.44it/s, loss=0.8365][A
Epoch 1:   1%|          | 5/500 [00:01<01:31,  5.44it/s, loss=0.8926][A
Epoch 1:   1%|          | 6/500 [00:01<01:18,  6.28it/s, loss=0.8926][A
Epoch 1:   1%|          |