<a href="https://colab.research.google.com/github/Gemmamarshall/oxrse_unit_conv/blob/main/Group_project_Stats_and_AI.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# # Create environment
# conda create -n esmc_env python=3.10 -y
# conda activate esmc_env

# # Install PyTorch CPU version
# pip install torch --index-url https://download.pytorch.org/whl/cpu
# or conda install pytorch-cpu torchvision-cpu -c pytorch

# # Install ESM package (ESMC is part of esm)
# pip install esm

# # Install other utilities
# pip install numpy pandas scikit-learn matplotlib seaborn


import pandas as pd

# Replace with 50+ sequences and polarity measurements
data = {
    "protein_id": ["P1","P2","P3","P4","P5"],
    "sequence": [
        "MSTNPKPQRIT",
        "MLAIKAGFAT",
        "VDSDDQEQVI",
        "MNKQWERTYI",
        "MKLFIAGVLA"
    ],
    "polarity": [-40, 20, 5, 50, -10]  # -50 = back, 50 = front
}

df = pd.DataFrame(data)



##############################################################
###### load ESM and extract embeddings for sequences
import torch
import esm
import numpy as np

# Load pretrained ESM Cambrian model (CPU only)
model, alphabet = esm.pretrained.esmc_300M()
model.eval()  # inference mode
device = torch.device("cpu")
model = model.to(device)

batch_converter = alphabet.get_batch_converter()

mean_embeddings = []        # per-protein embedding
per_residue_embeddings = [] # per-residue embeddings
sequence_lengths = []

for seq in df["sequence"]:
    batch_labels, batch_strs, batch_tokens = batch_converter([("protein", seq)])
    batch_tokens = batch_tokens.to(device)

    with torch.no_grad():
        results = model(batch_tokens, repr_layers=[model.num_layers], return_contacts=False)

    # Per-residue embeddings: shape = (L, D)
    token_embeddings = results["representations"][model.num_layers][0, 1:len(seq)+1]  # remove start token
    per_residue_embeddings.append(token_embeddings.cpu().numpy())
    sequence_lengths.append(len(seq))

    # Per-protein embedding: mean over residues
    mean_embeddings.append(token_embeddings.mean(0).cpu().numpy())

X = np.vstack(mean_embeddings)  # shape = n_proteins × embedding_dim
y = df["polarity"].values
print("Embedding shape:", X.shape)




#### train regression model on per-protein embeddings

# You have per-protein embeddings: each protein is summarized by a single vector of numbers (mean_embedding from all residues).

# You train a regression model:
# Polarity=Regressor(mean_embedding)
# Each dimension in the embedding vector is a learned feature by the model (ESM‑2 or ESMC).
# The regression weight tells you:
# “If this embedding dimension increases, how much does the polarity score change?”
# Positive weight goes to increases polarity (toward +50)
# Negative weight goes to decreases polarity (toward -50)
# Larger absolute value → more influence on the polarity prediction


from sklearn.linear_model import Lasso
from sklearn.preprocessing import StandardScaler

# Scale embeddings
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# LASSO regression for interpretability
reg = Lasso(alpha=0.001)
reg.fit(X_scaled, y)

print("Top embedding dimensions influencing polarity:", np.argsort(np.abs(reg.coef_))[::-1][:10])




### compute each residues importance
# Now you want residue-level interpretation:
# You know which embedding dimensions are important (from reg.coef_).
# Each residue embedding also has the same dimensions.
# Multiply the per-residue embedding by the regression weights:
# Residue contribution=residue_embedding⋅regression_weights
# This gives a score per residue
# High absolute values → this amino acid strongly drives the polarity prediction

per_residue_contributions = []

for emb in per_residue_embeddings:
    # emb shape: (L, D)
    contribution = emb @ reg.coef_  # shape: (L,)
    per_residue_contributions.append(contribution)



### aggregate across all proteins
    # Option 1: average by residue position (if aligned lengths) or concatenate all residues
all_contributions = np.concatenate(per_residue_contributions)
print("All residues contribution shape:", all_contributions.shape)

# # Identify top residues (largest absolute contribution)
# top_indices = np.argsort(np.abs(all_contributions))[::-1][:20]
# print("Top 20 residues driving polarity across all proteins:", top_indices)


### Plot per residue contribution per protein on polarity score
import matplotlib.pyplot as plt

seq_idx = 0  # first protein
seq = df["sequence"][seq_idx]
contrib = per_residue_contributions[seq_idx]

plt.figure(figsize=(8,3))
plt.bar(range(len(seq)), contrib)
plt.xticks(range(len(seq)), list(seq))
plt.xlabel("Residue")
plt.ylabel("Contribution to polarity")
plt.title(f"Residue contributions for {df['protein_id'][seq_idx]}")
plt.show()


## plot heatmap of per residue contributions

import matplotlib.pyplot as plt
import seaborn as sns

plt.figure(figsize=(12,6))
for i, contrib in enumerate(per_residue_contributions):
    sns.heatmap([contrib], cmap="coolwarm", cbar=True, xticklabels=list(df['sequence'][i]))
plt.xlabel("Residue position")
plt.ylabel("Proteins")
plt.title("Per-residue contributions to polarity across all proteins")
plt.show()




# Other experiments: Clustering / dimensionality reduction of embeddings
# Use per-protein embeddings to:
# Reduce dimensionality (PCA, t-SNE, UMAP)
# Cluster proteins by similarity in embedding space

# Goal: see whether proteins with similar organization, polarity, or subcellular localization naturally cluster

# Insight: groups of sequences may share structural motifs or charged patterns influencing organization.

# Other experiments 2: Sequence motif discovery
# From high-contribution residues, find contiguous motifs or patterns enriched in top residues.
# Methods:
# Scan for charged clusters, hydrophobic stretches, or conserved motifs
# Use alignment tools to find motifs conserved across similar proteins
# Outcome: candidate functional or structural elements responsible for cellular localization or polarity.

# Follow up: Motif enrichment / functional annotation
# Take top motifs identified from per-residue contributions and check:
# Known functional domains (Pfam, InterPro)
# Post-translational modifications (phosphorylation, lipidation)
# Could link sequence patterns to mechanistic drivers of polarity

# Follow up: Map important residues
# PDB structures (or AlphaFold models)
# map sequences flagged up as important to polarity onto structure
# Goal: see whether key residues cluster in certain domains (e.g., membrane-binding face, coiled regions, hydrophobic, charged, IDRs).

# Follow up: Predict organisation of other surface proteins
# With the regression model - you can use this predict where non quantified proteins are e.g., If your model works well - it should say that the receptor ICAM-2 goes to uropod too.

ModuleNotFoundError: No module named 'esm'

In [None]:
# do this firs before anything
# Import the package
import sys
!{sys.executable} -m pip install git+https://github.com/David-Araripe/UniProtMapper.git

Collecting git+https://github.com/David-Araripe/UniProtMapper.git
  Cloning https://github.com/David-Araripe/UniProtMapper.git to /tmp/pip-req-build-04bimo2j
  Running command git clone --filter=blob:none --quiet https://github.com/David-Araripe/UniProtMapper.git /tmp/pip-req-build-04bimo2j
  Resolved https://github.com/David-Araripe/UniProtMapper.git to commit d7054f6df37423cbcd4bb24bf91385dbba2528df
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [None]:
# this is to understand the basics behind the uniprotmapper whihc might be helpful
from UniProtMapper import ProtMapper

mapper = ProtMapper()
result, failed = mapper.get(
    ids=["P30542", "Q16678", "Q02880"],
    from_db="UniProtKB_AC-ID",
    to_db="Ensembl",
)
print(result.head())
print("failed:", failed)


Fetched: 3 / 3
     From                  To
0  P30542  ENSG00000163485.18
1  Q16678  ENSG00000138061.13
2  Q02880  ENSG00000077097.17
failed: []


In [None]:
#!/usr/bin/env python3
"""
get_uniprot_by_gene.py

Ask the user for a gene name (e.g. "CD45" or "PTPRC"), treat that input as a gene name,
query UniProt for the top human (9606) reviewed hit, and print:
  - gene name (what user entered)
  - UniProt accession (primary accession)
  - entry name (e.g. PTPRC_HUMAN)
  - sequence (FASTA-style wrapped at 60 chars)

Requires: requests
Install: pip install requests
"""

from typing import Optional, Dict
import requests
import sys

UNIPROT_SEARCH_URL = "https://rest.uniprot.org/uniprotkb/search"
UNIPROT_ENTRY_URL = "https://rest.uniprot.org/uniprotkb/"

def build_gene_query(gene: str) -> str:
    """Build a UniProt query that treats the input as a gene name and restricts to human reviewed entries."""
    g = gene.strip()
    # Use gene_exact and gene to increase chance of matching official gene names and synonyms.
    # Require human (9606) and reviewed (Swiss-Prot).
    return f'(gene_exact:{g} OR gene:{g}) AND organism_id:9606 AND reviewed:true'

def query_uniprot_first_hit_by_gene(gene: str, size: int = 1) -> Optional[Dict]:
    """Query UniProtKB for the given gene (human, reviewed) and return the first result JSON dict or None."""
    query = build_gene_query(gene)
    params = {
        "query": query,
        "format": "json",
        "size": size
    }
    resp = requests.get(UNIPROT_SEARCH_URL, params=params, timeout=15)
    resp.raise_for_status()
    data = resp.json()
    results = data.get("results", [])
    if not results:
        return None
    return results[0]

def extract_accession_and_entry_name(result: Dict) -> (Optional[str], Optional[str]):
    """Extract accession and entry name from a UniProt search result JSON."""
    accession = result.get("primaryAccession") or result.get("accession") or None
    entry_name = result.get("uniProtkbId") or result.get("id") or result.get("entryName") or None
    return accession, entry_name

def get_sequence_from_result(result: Dict) -> Optional[str]:
    """Try to read sequence from search result JSON if present."""
    seq = None
    if "sequence" in result and isinstance(result["sequence"], dict):
        seq = result["sequence"].get("value")
    return seq

def fetch_fasta_by_accession(accession: str) -> Optional[str]:
    """Fetch FASTA for a given accession from UniProt and return the plain sequence (no header)."""
    url = f"{UNIPROT_ENTRY_URL}{accession}.fasta"
    resp = requests.get(url, timeout=15)
    if resp.status_code != 200:
        return None
    fasta = resp.text.strip()
    lines = fasta.splitlines()
    if not lines:
        return None
    # remove header (first line) and join the rest
    seq = "".join(lines[1:])
    return seq

def wrap_seq(seq: str, width: int = 60) -> str:
    return "\n".join(seq[i:i+width] for i in range(0, len(seq), width))

def pretty_print(gene: str, accession: str, entry_name: str, sequence: str):
    print("\n--- UniProt (human, reviewed) lookup result ---")
    print(f"Query (treated as gene name): {gene}")
    if entry_name:
        print(f"Entry name:  {entry_name}")
    print(f"Accession:   {accession}")
    print(f"Sequence length: {len(sequence)} aa\n")
    print(f">{accession} | {entry_name} | gene:{gene}")
    print(wrap_seq(sequence, 60))
    print("\n-----------------------------------------------\n")

def main():
    try:
        gene_input = input("Enter gene name (treated as human gene name, e.g. CD45 or PTPRC): ").strip()
        if not gene_input:
            print("No gene name entered. Exiting.")
            return

        # Query UniProt for the first human reviewed hit for this gene
        result = query_uniprot_first_hit_by_gene(gene_input)
        if not result:
            # Fallback: try without reviewed:true (some rare genes might be unreviewed)
            print("No reviewed human entries found. Trying without reviewed filter...")
            params = {
                "query": f'(gene_exact:{gene_input} OR gene:{gene_input}) AND organism_id:9606',
                "format": "json",
                "size": 1
            }
            resp = requests.get(UNIPROT_SEARCH_URL, params=params, timeout=15)
            resp.raise_for_status()
            data = resp.json()
            results = data.get("results", [])
            if not results:
                print(f"No UniProt hits found for gene '{gene_input}' in human (9606).")
                return
            result = results[0]

        accession, entry_name = extract_accession_and_entry_name(result)
        if not accession:
            print("Couldn't extract accession from UniProt result. Raw result:")
            print(result)
            return

        # Get sequence from result if present, otherwise fetch FASTA
        sequence = get_sequence_from_result(result)
        if not sequence:
            sequence = fetch_fasta_by_accession(accession)
        if not sequence:
            print(f"Could not retrieve sequence for accession {accession}.")
            return

        pretty_print(gene_input, accession, entry_name or "", sequence)

    except requests.HTTPError as e:
        print("HTTP error while querying UniProt:", e)
    except requests.RequestException as e:
        print("Network error while querying UniProt:", e)
    except KeyboardInterrupt:
        print("\nCancelled.")
    except Exception as e:
        print("An unexpected error occurred:", str(e))
        raise

if __name__ == "__main__":
    main()



Enter gene name (treated as human gene name, e.g. CD45 or PTPRC): CD45

--- UniProt (human, reviewed) lookup result ---
Query (treated as gene name): CD45
Entry name:  PTPRC_HUMAN
Accession:   P08575
Sequence length: 1306 aa

>P08575 | PTPRC_HUMAN | gene:CD45
MTMYLWLKLLAFGFAFLDTEVFVTGQSPTPSPTGLTTAKMPSVPLSSDPLPTHTTAFSPA
STFERENDFSETTTSLSPDNTSTQVSPDSLDNASAFNTTGVSSVQTPHLPTHADSQTPSA
GTDTQTFSGSAANAKLNPTPGSNAISDVPGERSTASTFPTDPVSPLTTTLSLAHHSSAAL
PARTSNTTITANTSDAYLNASETTTLSPSGSAVISTTTIATTPSKPTCDEKYANITVDYL
YNKETKLFTAKLNVNENVECGNNTCTNNEVHNLTECKNASVSISHNSCTAPDKTLILDVP
PGVEKFQLHDCTQVEKADTTICLKWKNIETFTCDTQNITYRFQCGNMIFDNKEIKLENLE
PEHEYKCDSEILYNNHKFTNASKIIKTDFGSPGEPQIIFCRSEAAHQGVITWNPPQRSFH
NFTLCYIKETEKDCLNLDKNLIKYDLQNLKPYTKYVLSLHAYIIAKVQRNGSAAMCHFTT
KSAPPSQVWNMTVSMTSDNSMHVKCRPPRDRNGPHERYHLEVEAGNTLVRNESHKNCDFR
VKDLQYSTDYTFKAYFHNGDYPGEPFILHHSTSYNSKALIAFLAFLIIVTSIALLVVLYK
IYDLHKKRSCNLDEQQELVERDDEKQLMNVEPIHADILLETYKRKIADEGRLFLAEFQSI
PRVFSKFPIKEARKPFNQNKNRYVDILPYDYNRVELSEINGDAGSNYINASYIDGFKEPR
KYIAAQGP