In [None]:
!pip install lightning torch  ninja einops triton transformers causal_conv1d>=1.1.0 wandb Datasets mamba-ssm

In [None]:
!wandb login
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer

wandb_logger = WandbLogger(log_model="all")

In [None]:
from typing import Dict, List, Tuple
from torch import Tensor
from torch.utils.data import DataLoader, Dataset
import torch
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from pytorch_lightning.callbacks import RichProgressBar
from pytorch_lightning.callbacks import DeviceStatsMonitor
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel



In [None]:
tiny_model = {
    "d_model": 768,
    "n_layer": 24,
    "vocab_size": 5,
    "ssm_cfg": {},
    "rms_norm": true,
    "residual_in_fp32": true,
    "fused_add_norm": true,
    "pad_vocab_size_multiple": 8
}

In [None]:
def tokenize(dataset: Dataset) -> Tuple[Tensor, Dict]:
    dictionary = Dictionary()


    for sequence in dataset['train']['sequence']:
            words = list(sequence) + ["<eos>"]
            for word in words:
                dictionary.add_word(word)
    idss: List[Tensor] = []
    # Tokenize file content
    for sequence in dataset['train']['sequence']:
            words = list(sequence) + ["<eos>"]
            ids: List[int] = []
            for word in words:
                ids.append(dictionary.word2idx[word])
            idss.append(torch.tensor(ids).type(torch.int64))
    return torch.cat(idss), dictionary

class Dictionary:
    def __init__(self) -> None:
        self.word2idx: Dict[str, int] = {}
        self.idx2word: List[str] = []

    def add_word(self, word: str) -> int:
        if word not in self.word2idx:
            self.idx2word.append(word)
            self.word2idx[word] = len(self.idx2word) - 1
        return self.word2idx[word]

    def __len__(self) -> int:
        return len(self.idx2word)
    
class VirusDataset(Dataset):
    """Mini version of WikiText2."""

    def __init__(self, dataset_name:str,  block_size: int = 1000) -> None:
        super().__init__()
        self.dataset = load_dataset(dataset_name)
        self.data, self.dictionary = tokenize(self.dataset)
        self.block_size = block_size

    @property
    def vocab_size(self) -> int:
        return len(self.dictionary)

    def __len__(self) -> int:
        return len(self.data) // self.block_size - 1

    def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
        start = index * self.block_size
        end = start + self.block_size
        inputs = self.data[start:end]
        return inputs



In [None]:


class LanguageModel(pl.LightningModule):
    def __init__(self, vocab_size):
        super().__init__()
        self.model = MambaLMHeadModel(**tiny_model
        )

    def training_step(self, batch):
        input = batch
        lm_logits = self.model(input).logits
        labels = input.to(lm_logits.device)
        shift_logits = lm_logits[:, :-1, :].contiguous()
        labels = labels[:, 1:].contiguous()

        loss_fct = torch.nn.CrossEntropyLoss()
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))

        self.log("train_loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.1)

pl.seed_everything(42)


In [None]:
# Data
dataset  = VirusDataset(dataset_name= 'Hack90/chikungunya')
train_dataloader = DataLoader(dataset, batch_size=512, num_workers=7)

# Model
model = LanguageModel()

# Trainer
trainer = pl.Trainer(accelerator="cuda", devices=1, logger=wandb_logger)
trainer.fit(model, train_dataloader)
trainer.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")