In [None]:
!pip install lightning

In [None]:
!git clone https://github.com/state-spaces/mamba.git

In [None]:
!pip install Datasets

In [None]:
!pip show pytorch_lightning

In [None]:
!cd mamba && pip install .

In [1]:
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_lightning.demos import Transformer, WikiText2
from torch.utils.data import DataLoader
from pytorch_lightning.callbacks import RichProgressBar
from pytorch_lightning.callbacks import DeviceStatsMonitor

import pytorch_lightning as pl

# intialize model, optimizer and defines training step
class LanguageModel(pl.LightningModule):
    def __init__(self, vocab_size):
        super().__init__()
        self.model = MambaLMHeadModel(
            vocab_size=vocab_size,
            n_layer=16,
            d_model=1024,
           # nhead=16,
        )

    def training_step(self, batch):
        input, target = batch
        # Get the logits from the model. Do not apply argmax here.
        logits = self.model(input).logits

        # Flatten the logits and target for CrossEntropyLoss
        logits_flat = logits.view(-1, logits.shape[-1])  # Reshaping to [128 * 35, 33278]
        target_flat = target.view(-1)                    # Reshaping to [128 * 35]

        # Using CrossEntropyLoss
        loss_function = torch.nn.CrossEntropyLoss()
        loss = loss_function(logits_flat, target_flat)

        # Log the loss and return
        self.log("train_loss", loss, prog_bar=True)
        return loss


    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.1)

pl.seed_everything(42)





In [None]:
# Data
dataset = WikiText2() #Virus DNA in our instance
train_dataloader = DataLoader(dataset, batch_size=128)

# Model
model = LanguageModel(vocab_size=dataset.vocab_size)

# Trainer
trainer = pl.Trainer(accelerator="cuda", devices=1)
trainer.fit(model, train_dataloader)
trainer.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/utilities.py:73: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name  | Type             | Params
-------------------------------------------
0 | model | MambaLMHeadModel | 140 M 
-------------------------------------------
140 M     Trainable params
0         Non-trainable params
140 M     Total params
563.085   Total estimated model params size (MB)
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/train

Training: |          | 0/? [00:00<?, ?it/s]

In [None]:
# from datasets import load_dataset
# dataset = load_dataset('Hack90/virus_dna_dedup_minihash_0.9_kmer_7')
# virus = ''
# for x in list(dataset['train']['sequence_x']):
#   virus = virus +  x + '/n'
# # Using loop
# # Insert character after every character pair
# separator = " "
# res = ""
# for i in range(0, len(virus), 10):
#     res += virus[i:i + 2] + separator


In [None]:
# class WikiText2(Dataset):
#     """Mini version of WikiText2."""

#     def __init__(self,  block_size: int = 35, download: bool = False) -> None:
#         super().__init__()
#         self.path = Path("/content/virus.txt")
#         self.data, self.dictionary = tokenize(self.path)
#         self.block_size = block_size

#     @property
#     def vocab_size(self) -> int:
#         return len(self.dictionary)

#     def __len__(self) -> int:
#         return len(self.data) // self.block_size - 1

#     def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
#         start = index * self.block_size
#         end = start + self.block_size
#         inputs = self.data[start:end]
#         target = self.data[(start + 1) : (end + 1)]
#         return inputs, target

#     @staticmethod
#     def download(destination: Path) -> None:
#         os.makedirs(destination.parent, exist_ok=True)
#         url = "https://raw.githubusercontent.com/pytorch/examples/main/word_language_model/data/wikitext-2/train.txt"
#         if os.path.exists(destination):
#             return
#         with open(destination, "w") as f:
#             f.write(requests.get(url).text)

In [None]:
# dataset = WikiText2()