In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from training_utils import classification_training, contrastive_training
from models import MambaPooled, MambaCLS, CrossAttentionTransformer

In [None]:
class ReviewDataset(Dataset):
    def __init__(self, contexts, inputs, targets):
        assert len(contexts) == len(inputs) == len(targets)
        self.contexts = contexts
        self.inputs = inputs
        self.targets = targets

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        context = self.contexts[idx]
        inp = self.inputs[idx]
        target = self.targets[idx]

        return context, inp, target

In [None]:
contexts = ["context 1", "context 2", "context 3"]
inputs = ["input 1", "input 2", "input 3"]
targets = [0, 1, 0]

dataset = ReviewDataset(contexts, inputs, targets)
train_loader = DataLoader(dataset, batch_size=2, shuffle=True)

In [2]:
d_input = 512
d_context = 512
d_model = 256
num_layers = 3

mamba_pooled = MambaPooled(
    num_layers=num_layers,
    d_input=d_input,
    d_model=d_model,
    d_context=d_context,
    d_state=16,
    d_discr=None,
    ker_size=4,
    parallel=True
)

In [3]:
# Dimensions
batch_size = 2
seq_len = 10
d_input = 512
d_context = 512

# Random inputs
x = torch.randn(batch_size, seq_len, d_input)      # input sequence
context = torch.randn(batch_size, d_context)      # context vector

# Forward pass
logits = mamba_pooled(x, context)                # classifier output
embeddings = mamba_pooled(x, context, True)      # embeddings

print("Logits shape:", logits.shape)             # (batch_size, seq_len, d_model)
print("Embeddings shape:", embeddings.shape)     # (batch_size, seq_len, d_model)

Logits shape: torch.Size([2, 1])
Embeddings shape: torch.Size([2, 512])


In [4]:
d_input = 512        
d_context = 512      
d_model = 256        
num_layers = 3       

mamba_cls = MambaCLS(
    num_layers=num_layers,
    d_input=d_input,
    d_context=d_context,
    d_model=d_model,
    d_state=16,       
    d_discr=None,     
    ker_size=4,       
    parallel=True    
)

In [5]:
# Dimensions
batch_size = 2
seq_len = 10
d_input = 512
d_context = 512

# Random inputs
x = torch.randn(batch_size, seq_len, d_input)      # input sequence
context = torch.randn(batch_size, d_context)      # context vector

# Forward pass
logits = mamba_cls(x, context)                # classifier output
embeddings = mamba_cls(x, context, True)      # embeddings

print("Logits shape:", logits.shape)             # (batch_size, seq_len, d_model)
print("Embeddings shape:", embeddings.shape)     # (batch_size, seq_len, d_model)

Logits shape: torch.Size([2, 1])
Embeddings shape: torch.Size([2, 512])


In [2]:
d_input = 512
d_context = 512
d_model = 256
d_layers = 3       
d_heads = 8 
dropout = 0.1

cross_attn_model = CrossAttentionTransformer(
    d_input=d_input,
    d_context=d_context,
    d_model=d_model,
    d_layers=d_layers,
    d_heads=d_heads,
    dropout=dropout
)

In [4]:
# Dimensions
batch_size = 2
seq_len = 10
d_input = 512
d_context = 512

# Random inputs
x = torch.randn(batch_size, seq_len, d_input)      # input sequence
context = torch.randn(batch_size, d_context)      # context vector

# Forward pass
logits = cross_attn_model(x, context)                # classifier output
embeddings = cross_attn_model(x, context, True)      # embeddings

print("Logits shape:", logits.shape)             # (batch_size, seq_len, d_model)
print("Embeddings shape:", embeddings.shape)     # (batch_size, seq_len, d_model)

Logits shape: torch.Size([2, 10, 1])
Embeddings shape: torch.Size([2, 10, 256])
