In [3]:
import sys
import os

sys.path.append('../src')

In [28]:
import argparse
import torch
import pytorch_warmup as warmup
import wandb
from tqdm import tqdm
import yaml
import sys
import os

import torch
import torch.optim as optim
import torch.nn as nn
import numpy as np

from mamba_ssm.models.config_mamba import MambaConfig

from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorWithPadding
from torch.utils.data import DataLoader

from utils import print_model_size, fix_seed
from models.MambaWithEmbeddings import MambaLMHeadModelWithEmbeddings
from train_yelp_reviews import add_special_token

In [45]:
dataset = load_dataset("yelp_polarity")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")

if tokenizer.pad_token is None:
    # tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})

def tokenize_function(examples):
    return tokenizer(examples["text"], padding=True, truncation=True, max_length=512)

tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=["text"])

data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="pt")

train_dataset = tokenized_datasets["train"]
test_dataset = tokenized_datasets["test"]

train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=data_collator)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False, collate_fn=data_collator)

In [46]:
seed = 42
model_name = "state-spaces/mamba-130m"

In [47]:
fix_seed(seed)
model = MambaLMHeadModelWithEmbeddings.from_pretrained(model_name, num_labels=2)

In [48]:
model.freeze_layers()

In [49]:
model.backbone.layers[-1].mixer.out_proj.weight

Parameter containing:
tensor([[ 0.1919, -0.0347,  0.0941,  ..., -0.1357,  0.1383, -0.0218],
        [ 0.0317,  0.0470, -0.2262,  ..., -0.1767,  0.0739,  0.0479],
        [-0.0528,  0.1756,  0.1555,  ...,  0.2199, -0.1916, -0.0371],
        ...,
        [-0.1702, -0.0906, -0.0830,  ..., -0.0135, -0.0842,  0.0020],
        [ 0.0432,  0.0286, -0.1470,  ...,  0.0472,  0.0670, -0.1609],
        [ 0.1678,  0.0724,  0.0571,  ..., -0.1032, -0.0364,  0.0401]])

In [50]:
model.classification_head.weight

Parameter containing:
tensor([[ 0.0009,  0.0087,  0.0269,  ...,  0.0269, -0.0181,  0.0178],
        [-0.0118,  0.0232,  0.0009,  ...,  0.0321, -0.0164, -0.0025]],
       requires_grad=True)

In [51]:
gpu_number = 0
device = torch.device(f'cuda:{gpu_number}' if torch.cuda.is_available() else 'cpu')

In [52]:
num_epochs = 3
learning_rate = 1
tokens_num = 100
period = 50
warmup_percent = 0.05

In [55]:
import time
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

special_token = torch.randn(1, tokens_num, model.config.d_model, requires_grad=True, device=device)
model.to(device)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()

optimizer = optim.AdamW(
    [{'params': [special_token], 'lr': learning_rate},
     {'params': model.parameters(), 'lr': learning_rate}],
)

total_steps = num_epochs * len(train_dataloader)
warmup_steps = int(warmup_percent * total_steps)  # Calculate warmup steps as a percentage of total steps
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_steps, eta_min=5e-6)
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_steps)

fix_seed(seed)

i = 0

for epoch in range(num_epochs):
    model.train()  # Set model to evaluation mode
    train_loss = 0.0
    correct_train = 0
    total_train = 0

    pbar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs} - Training")
    for batch in pbar:
        i += 1

        if i == 10:
            break
        batch_start_time = time.time()
        torch.cuda.synchronize()
        inputs, labels = batch['input_ids'], batch['labels']
        
        # Move data to the specified device
        start = time.time()
        torch.cuda.synchronize()
        inputs, labels = inputs.to(device), labels.to(device)
        torch.cuda.synchronize()
        print(f"Data Transfer: {time.time() - start:.4f}s")

        # Convert inputs to embeddings without tracking gradients
        start = time.time()
        torch.cuda.synchronize()
        with torch.no_grad():
            embedded_inputs = model.backbone.embedding(inputs)
        torch.cuda.synchronize()
        print(f"Embedding Conversion: {time.time() - start:.4f}s")

        # Add special token
        start = time.time()
        torch.cuda.synchronize()
        embedded_with_special = add_special_token(embedded_inputs, special_token, period, tokens_num)
        torch.cuda.synchronize()
        print(f"Add Special Token: {time.time() - start:.4f}s")
        
        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        start = time.time()
        torch.cuda.synchronize()
        outputs = model(embedded_with_special, is_embeds=True, num_last_tokens=1)
        torch.cuda.synchronize()
        print(f"Forward Pass: {time.time() - start:.4f}s")
        
        logits = outputs.logits[:, 0, :]
        loss = criterion(logits, labels)

        # Backward pass and optimization
        start = time.time()
        torch.cuda.synchronize()
        loss.backward()
        optimizer.step()
        torch.cuda.synchronize()
        print(f"Backward Pass & Optimization: {time.time() - start:.4f}s")
        
        # Apply warmup and scheduler updates
        start = time.time()
        torch.cuda.synchronize()
        with warmup_scheduler.dampening():
            scheduler.step()
        torch.cuda.synchronize()
        print(f"Scheduler Update: {time.time() - start:.4f}s")

        # Accumulate training loss and accuracy
        train_loss += loss.item()
        _, predicted = logits.max(1)
        total_train += labels.size(0)
        correct_train += predicted.eq(labels).sum().item()

        # Calculate local train accuracy
        train_accuracy_local = 100 * correct_train / total_train
        
        # Display the current loss for each training batch
        pbar.set_postfix({"Train Loss (batch)": loss.item()})

        torch.cuda.synchronize()
        print(f"Total Batch Time: {time.time() - batch_start_time:.4f}s")
    break

    # ---- Epoch-Based Logging ----
    train_accuracy_epoch = 100 * correct_train / total_train

    # ---- Validation phase after each epoch ----
    val_batch_losses_local, val_accuracy = inference(
        model, val_loader, device, criterion=criterion, num_last_tokens=1, special_token=special_token, period=period
    )
    val_loss = sum(val_batch_losses_local) / len(val_batch_losses_local)


Epoch 1/3 - Training:   0%|                                                                                            | 0/35000 [00:00<?, ?it/s]

Data Transfer: 0.0002s
Embedding Conversion: 0.0004s
Add Special Token: 0.0007s
Forward Pass: 0.4747s


Epoch 1/3 - Training:   0%|                                                        | 1/35000 [00:01<10:50:33,  1.12s/it, Train Loss (batch)=1.17]

Backward Pass & Optimization: 0.5788s
Scheduler Update: 0.0001s
Total Batch Time: 1.0602s
Data Transfer: 0.0001s
Embedding Conversion: 0.0002s
Add Special Token: 0.0005s
Forward Pass: 0.3824s


Epoch 1/3 - Training:   0%|                                                        | 2/35000 [00:02<10:01:09,  1.03s/it, Train Loss (batch)=1.48]

Backward Pass & Optimization: 0.5732s
Scheduler Update: 0.0001s
Total Batch Time: 0.9592s
Data Transfer: 0.0001s
Embedding Conversion: 0.0002s
Add Special Token: 0.0004s
Forward Pass: 0.3825s


Epoch 1/3 - Training:   0%|                                                         | 3/35000 [00:03<9:45:24,  1.00s/it, Train Loss (batch)=1.32]

Backward Pass & Optimization: 0.5745s
Scheduler Update: 0.0001s
Total Batch Time: 0.9602s
Data Transfer: 0.0001s
Embedding Conversion: 0.0001s
Add Special Token: 0.0004s
Forward Pass: 0.3825s


Epoch 1/3 - Training:   0%|                                                         | 4/35000 [00:04<9:38:16,  1.01it/s, Train Loss (batch)=1.11]

Backward Pass & Optimization: 0.5757s
Scheduler Update: 0.0001s
Total Batch Time: 0.9612s
Data Transfer: 0.0001s
Embedding Conversion: 0.0001s
Add Special Token: 0.0004s
Forward Pass: 0.3811s


Epoch 1/3 - Training:   0%|                                                         | 5/35000 [00:05<9:34:43,  1.01it/s, Train Loss (batch)=1.04]

Backward Pass & Optimization: 0.5763s
Scheduler Update: 0.0001s
Total Batch Time: 0.9605s
Data Transfer: 0.0001s
Embedding Conversion: 0.0001s
Add Special Token: 0.0004s
Forward Pass: 0.3849s


Epoch 1/3 - Training:   0%|                                                         | 6/35000 [00:05<9:32:28,  1.02it/s, Train Loss (batch)=0.95]

Backward Pass & Optimization: 0.5748s
Scheduler Update: 0.0001s
Total Batch Time: 0.9628s
Data Transfer: 0.0001s
Embedding Conversion: 0.0001s
Add Special Token: 0.0004s
Forward Pass: 0.3831s


Epoch 1/3 - Training:   0%|                                                        | 7/35000 [00:06<9:30:58,  1.02it/s, Train Loss (batch)=0.892]

Backward Pass & Optimization: 0.5759s
Scheduler Update: 0.0001s
Total Batch Time: 0.9625s
Data Transfer: 0.0001s
Embedding Conversion: 0.0002s
Add Special Token: 0.0004s
Forward Pass: 0.3843s


Epoch 1/3 - Training:   0%|                                                        | 8/35000 [00:07<9:30:15,  1.02it/s, Train Loss (batch)=0.776]

Backward Pass & Optimization: 0.5760s
Scheduler Update: 0.0001s
Total Batch Time: 0.9635s
Data Transfer: 0.0001s
Embedding Conversion: 0.0001s
Add Special Token: 0.0004s
Forward Pass: 0.3839s


Epoch 1/3 - Training:   0%|                                                         | 9/35000 [00:08<9:38:08,  1.01it/s, Train Loss (batch)=1.27]

Backward Pass & Optimization: 0.5764s
Scheduler Update: 0.0001s
Total Batch Time: 0.9633s





In [57]:
special_token.size()

torch.Size([1, 100, 768])

In [58]:
print_model_size(model)

Model parameters number: 129136898
Model size: 492.62 MB
Model size: 0.48 GB


In [62]:
print_model_size(model.classification_head)

Model parameters number: 1538
Model size: 0.01 MB
Model size: 0.00 GB


In [66]:
len(train_dataset)

560000

In [67]:
len(test_dataset)

38000