Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Using an official Python runtime with CUDA support as a parent image (https://hub.docker.com/r/nvidia/cuda/)
FROM pytorch/pytorch:2.3.1-cuda11.8-cudnn8-devel as base

ENV HOST docker
ENV LANG=C.UTF-8 LC_ALL=C.UTF-8
ENV TZ Europe/Paris
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone

# TO CHANGE IF NEEDED
RUN useradd -m -U -s /bin/bash user
RUN apt-get update && apt-get install -y git

# Set the working directory: TO CHANGE IF NEEDED
USER user

# Disable pip cache: https://stackoverflow.com/questions/45594707/what-is-pips-no-cache-dir-good-for
ENV PIP_NO_CACHE_DIR=1
#install from requirements.txt
COPY requirements.txt /home/user/requirements.txt

#apt get git
RUN pip install --no-cache-dir -r /home/user/requirements.txt

WORKDIR /home/user
1 change: 0 additions & 1 deletion ProtMamba_ssm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
__version__ = "0.0.1"
from .core import *
from .dataloaders import *
from .fim import *
from .modules import *
Expand Down
173 changes: 0 additions & 173 deletions ProtMamba_ssm/_modidx.py

This file was deleted.

25 changes: 8 additions & 17 deletions ProtMamba_ssm/dataloaders.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/03_dataloaders.ipynb.

# %% auto 0
__all__ = ['Uniclust30_Dataset', 'make_dataloader', 'DataCollatorForUniclust30Dataset']

# %% ../nbs/03_dataloaders.ipynb 3
from .utils import AA_TO_ID
from .fim import NoFIM, SingleSpanFIM, MultipleSpanFIM
import pickle
Expand All @@ -14,8 +8,6 @@
from dataclasses import dataclass
from typing import Dict, Sequence


# %% ../nbs/03_dataloaders.ipynb 4
# Make dataset
class Uniclust30_Dataset(Dataset):
"""
Expand Down Expand Up @@ -78,14 +70,14 @@ def __init__(self, filename="encoded_MSAs_train.pkl",
def __len__(self):
return len(self.cluster_names)

def __getitem__(self, idx):
def __getitem__(self, idx, shuffle=True):
# get all the sequences in the cluster
sequences = self.get_sequences(idx)
# get total number of sequences in the cluster and choose how many to sample
orig_num_sequences = len(self.get_index_start_of_sequences(sequences))
num_sequences = np.random.randint(1, orig_num_sequences + 1) if self.sample else orig_num_sequences
# sample the sequences
sequences, position_ids = self.sample_sequences(sequences, num_sequences)
sequences, position_ids = self.sample_sequences(sequences, num_sequences, shuffle=shuffle)
# with probability 0.5, reverse the sequences and move the last token to the front
sequences, position_ids = self.reverse_sequences(sequences, position_ids) if (
self.reverse and np.random.rand() > 0.5) else sequences, position_ids
Expand Down Expand Up @@ -128,7 +120,7 @@ def reverse_sequences(self, sequence, position_ids=None):
return np.concatenate([sequence[-1:], sequence[:-1]]), np.concatenate(
[position_ids[-1:], position_ids[:-1]]) if position_ids is not None else None

def sample_sequences(self, sequences, num_sequences, shuffle=True):
def sample_sequences(self, sequences, num_sequences, shuffle=True, which_seqs=None):
"""Sample `num_sequences` from the sequences in the cluster."""
L = len(sequences)
# get the indexes of the start of each sequence
Expand All @@ -137,10 +129,11 @@ def sample_sequences(self, sequences, num_sequences, shuffle=True):
assert len(inds) > 0, "No sequences found in cluster."
assert len(inds) >= num_sequences, "Not enough sequences in cluster."
# sample n_sequences randomly from the sequences
if shuffle:
which_seqs = np.random.choice(np.arange(len(inds)), num_sequences, replace=False)
else:
which_seqs = np.arange(len(inds))[-num_sequences:]
if which_seqs is None:
if shuffle:
which_seqs = np.random.choice(np.arange(len(inds)), num_sequences, replace=False)
else:
which_seqs = np.arange(len(inds))[-num_sequences:]
# get the tuples of start and end indexes of the sequences
tuples = [(inds[i], inds[i + 1]) if i < len(inds) - 1 else (inds[i], L) for i in which_seqs]
if self.troubleshoot:
Expand All @@ -154,9 +147,7 @@ def make_dataloader(dataset):
"""Basic function to make a dataloader.
"""
dataloader = DataLoader(dataset)
return dataloader

# %% ../nbs/03_dataloaders.ipynb 5
@dataclass
class DataCollatorForUniclust30Dataset(object):
"""
Expand Down
7 changes: 0 additions & 7 deletions ProtMamba_ssm/fim.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,6 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/04_fim.ipynb.

# %% auto 0
__all__ = ['AbstractFIM', 'NoFIM', 'SingleSpanFIM', 'MultipleSpanFIM']

# %% ../nbs/04_fim.ipynb 3
from .utils import MASK_TO_ID, AA_TO_ID
import numpy as np

# %% ../nbs/04_fim.ipynb 4
class AbstractFIM(object):
def __init__(self,
max_patches=5,
Expand Down
19 changes: 4 additions & 15 deletions ProtMamba_ssm/modules.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,7 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/01_modules.ipynb.

# %% auto 0
__all__ = ['MambaConfig', 'sample_safe', 'decode_safe', 'GenerationMixinSafe', 'CheckpointedModule', 'create_block',
'MixerModelSafe', 'MambaLMHeadModelSafe', 'MixerModelWithPosids', 'MixerModelWith2DPosids',
'MambaLMHeadModelwithPosids', 'MambaLMHeadModelwith2DPosids', 'load_model']

# %% ../nbs/01_modules.ipynb 3
import torch
import torch.nn as nn
import torch.nn as nn

import json
import os
from collections import namedtuple
from dataclasses import field, dataclass
from functools import partial

from mamba_ssm.models.config_mamba import MambaConfig
from mamba_ssm.modules.block import Block
Expand Down Expand Up @@ -722,10 +710,11 @@ def protected_forward(self, input_ids, position_ids=None, inference_params=None,
if num_last_tokens > 0:
hidden_states = hidden_states[:, -num_last_tokens:]
lm_logits = self.lm_head(hidden_states)
CausalLMOutput = namedtuple("CausalLMOutput", ["loss", "logits", "hidden_states"])
if len(save_layer) > 0:
CausalLMOutput = namedtuple("CausalLMOutput", ["loss", "logits", "hidden_states"])
return CausalLMOutput(loss=None, logits=lm_logits, hidden_states=hidden_states)
return CausalLMOutput(loss=None, logits=lm_logits, hidden_states=None)
CausalLMOutput = namedtuple("CausalLMOutput", ["loss", "logits"])
return CausalLMOutput(loss=None, logits=lm_logits)

@classmethod
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, checkpoint_mixer=False, **kwargs):
Expand Down
8 changes: 1 addition & 7 deletions ProtMamba_ssm/trainer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,4 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/02_trainer.ipynb.

# %% auto 0
__all__ = ['PREFIX_CHECKPOINT_DIR', 'MambaTrainer', 'get_last_checkpoint', 'EarlyStoppingCallback']

# %% ../nbs/02_trainer.ipynb 3
from transformers import Trainer, TrainerCallback, TrainerState, TrainerControl
from transformers import Trainer, TrainerCallback
from .utils import *
import re
import torch
Expand Down
6 changes: 0 additions & 6 deletions ProtMamba_ssm/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/99_utils.ipynb.

# %% auto 0
__all__ = ['AA_TO_ID', 'MASK_TO_ID', 'ID_TO_AA', 'encode_sequence', 'decode_sequence', 'clean_sequence', 'tokenizer',
'reorder_masked_sequence', 'load_from_file', 'generate_sequence', 'prepare_dataset_for_fim_generation',
'prepare_tokens', 'prepare_target', 'load_tensorboard_data', 'filter_datapoints', 'save_to_tensorboard',
'merge_loggings', 'concatenate_loggings', 'print_number_of_parameters', 'find_fim_indices',
'compute_metrics']

# %% ../nbs/99_utils.ipynb 3
# Constants
AA_TO_ID = {'<cls>': 0,
'<pad>': 1,
Expand Down Expand Up @@ -53,7 +49,6 @@

ID_TO_AA = {v: k for k, v in AA_TO_ID.items()}

# %% ../nbs/99_utils.ipynb 4
import numpy as np
import torch
from Bio import SeqIO
Expand Down Expand Up @@ -271,7 +266,6 @@ def prepare_target(target, use_fim=None):
assert new_target.shape[1] == new_pos_ids.shape[1]
return new_target, new_pos_ids, is_fim_dict

# %% ../nbs/99_utils.ipynb 5
from tensorboard.backend.event_processing import event_accumulator
from tensorboard.backend.event_processing.event_accumulator import ScalarEvent
from torch.utils.tensorboard import SummaryWriter
Expand Down
16 changes: 8 additions & 8 deletions configs/default_config.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
data_dir: ""
output_dir: "results/"
data_dir: "/home/user/data/" #"/home/malbrank/nas/protmamba2/data/" and "/nvme1/common/OpenProteinSet/"
output_dir: "/home/user/results/"
namedir: "test0"
train_dataset_path: "encoded_MSAs_train.pkl"
train_dataset_path: "encoded_MSAs_subset-100.pkl"
eval_dataset_path: '...'
batch_size: 4 # mamba trained with total 0.5M tokens per batch
batch_size: 2 # mamba trained with total 0.5M tokens per batch
d_model: 1024
gradient_accumulation_steps: 8
checkpoint_mixer: True
gradient_accumulation_steps: 1
checkpoint_mixer: False
learning_rate: 0.0006 # mamba default (x5), decrease to 0.0006 for sizes > 100M params and 0.0002 for sizes > 1B params
weight_decay: 0.1 # mamba default
beta1: 0.9 # mamba default
Expand All @@ -25,9 +25,9 @@ seed_sequence_sampling: 42
seed_datasets: 0
save_steps: 250
eval_steps: 50
size_validation: 192
size_validation: 4
logging_steps: 10
eval_accumulation_steps: 200
eval_accumulation_steps: 10
save_total_limit: 50
dtype: "bfloat16"
fim_strategy: "multiple_span" #["no-scramble", "one_span", "multiple_span"]
Expand Down
Loading
Loading