In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from training_utils import classification_training, contrastive_training
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score, confusion_matrix
from models import MambaPooled, MambaCLS, CrossAttentionTransformer, NomicEmbedder
from torch.utils.tensorboard import SummaryWriter
import pandas as pd
import numpy as np
import torch.nn as nn
from tqdm import tqdm
from datetime import datetime
import glob
import os

seed = 42

In [2]:
training_ratio = 0.7
validation_ratio = 0.2
test_ratio = 0.1

vectors_df = pd.read_pickle(r"C:\Users\ian\Desktop\Coding\ReviewClassification\model\data\processed\vectorised_locations.pkl")
vectors_map = dict(zip(vectors_df["gmap_id"], vectors_df["vector"]))

positive_reviews = pd.read_csv(r"C:\Users\ian\Desktop\Coding\ReviewClassification\model\data\processed\positive_reviews.csv")       # label 1
ads_reviews = pd.read_csv(r"C:\Users\ian\Desktop\Coding\ReviewClassification\model\data\processed\ads_reviews.csv")                 # label 0
indirect_reviews = pd.read_csv(r"C:\Users\ian\Desktop\Coding\ReviewClassification\model\data\processed\indirect_reviews.csv")       # label 0
irelevant_reviews = pd.read_csv(r"C:\Users\ian\Desktop\Coding\ReviewClassification\model\data\processed\irelevant_reviews.csv")     # label 0

positive_reviews["label"] = 1
ads_reviews["label"] = 0
indirect_reviews["label"] = 0
irelevant_reviews["label"] = 0

all_reviews = pd.concat([positive_reviews, ads_reviews, indirect_reviews, irelevant_reviews], ignore_index=True)
all_reviews["context_vector"] = all_reviews["gmap_id"].map(vectors_map)
all_reviews = all_reviews.dropna(subset=["context_vector"]).reset_index(drop=True)

gmap_id_labels = all_reviews.groupby('gmap_id')['label'].agg(lambda x: x.mode()[0])
gmap_id_labels = gmap_id_labels.reset_index(name='majority_label')

all_gmap_ids = gmap_id_labels['gmap_id'].values
majority_labels = gmap_id_labels['majority_label'].values

train_ids, temp_ids, _, temp_labels = train_test_split(
    all_gmap_ids, majority_labels,
    test_size=validation_ratio + test_ratio,
    stratify=majority_labels,
    random_state=seed
)

val_ids, test_ids, _, _ = train_test_split(
    temp_ids, temp_labels,
    test_size=test_ratio / (validation_ratio + test_ratio),
    stratify=temp_labels,
    random_state=seed
)

train_df = all_reviews[all_reviews["gmap_id"].isin(train_ids)]
val_df = all_reviews[all_reviews["gmap_id"].isin(val_ids)]
test_df = all_reviews[all_reviews["gmap_id"].isin(test_ids)]


In [3]:
# Verification prints
print("--- Final Dataset Sizes ---")
print(f"Training set size: {len(train_df)}")
print(f"Validation set size: {len(val_df)}")
print(f"Test set size: {len(test_df)}")

print("\n--- Final Class Distribution by Review (Correct) ---")
print(f"Training set distribution:\n{train_df['label'].value_counts(normalize=True)}")
print(f"Validation set distribution:\n{val_df['label'].value_counts(normalize=True)}")
print(f"Test set distribution:\n{test_df['label'].value_counts(normalize=True)}")

train_ids_set = set(train_ids)
val_ids_set = set(val_ids)
test_ids_set = set(test_ids)

# Check for overlaps between the ID sets
assert len(train_ids_set.intersection(val_ids_set)) == 0, "gmap_id leakage between training and validation sets!"
assert len(train_ids_set.intersection(test_ids_set)) == 0, "gmap_id leakage between training and test sets!"
assert len(val_ids_set.intersection(test_ids_set)) == 0, "gmap_id leakage between validation and test sets!"

print("Congratulations! The gmap_id lists are mutually exclusive.")

train_indices = set(train_df.index)
val_indices = set(val_df.index)
test_indices = set(test_df.index)

# Check for overlaps between the sets
assert len(train_indices.intersection(val_indices)) == 0, "Data leakage found between training and validation sets!"
assert len(train_indices.intersection(test_indices)) == 0, "Data leakage found between training and test sets!"
assert len(val_indices.intersection(test_indices)) == 0, "Data leakage found between validation and test sets!"

print("Congratulations! All data leakage checks passed.")
print("Your datasets are now correctly partitioned by gmap_id.")

--- Final Dataset Sizes ---
Training set size: 36441
Validation set size: 11166
Test set size: 4765

--- Final Class Distribution by Review (Correct) ---
Training set distribution:
label
1    0.511731
0    0.488269
Name: proportion, dtype: float64
Validation set distribution:
label
1    0.549973
0    0.450027
Name: proportion, dtype: float64
Test set distribution:
label
0    0.573347
1    0.426653
Name: proportion, dtype: float64
Congratulations! The gmap_id lists are mutually exclusive.
Congratulations! All data leakage checks passed.
Your datasets are now correctly partitioned by gmap_id.


In [4]:
print("--- Data Consistency and Label Check ---")
print(f"Total number of reviews: {len(all_reviews)}")
pos_count = all_reviews[all_reviews['label'] == 1].shape[0]
neg_count = all_reviews[all_reviews['label'] == 0].shape[0]

print(f"Number of 'positive' reviews (label 1): {pos_count}")
print(f"Number of 'other' reviews (label 0): {neg_count}")
assert pos_count == len(positive_reviews), "Positive label count does not match original file size."
assert neg_count == len(ads_reviews) + len(indirect_reviews) + len(irelevant_reviews), "Negative label count does not match original files size."

print("\n--- Class Distribution Check ---")
label_counts = all_reviews['label'].value_counts()
print(f"Overall Label Distribution:\n{label_counts}")
print(f"Overall Label Ratio:\n{all_reviews['label'].value_counts(normalize=True)}")

unexpected_labels = all_reviews['label'].isin([0, 1])
if not unexpected_labels.all():
    print("Warning: Found labels other than 0 or 1.")

print("All initial data consistency and label checks passed.")

--- Data Consistency and Label Check ---
Total number of reviews: 52372
Number of 'positive' reviews (label 1): 26822
Number of 'other' reviews (label 0): 25550

--- Class Distribution Check ---
Overall Label Distribution:
label
1    26822
0    25550
Name: count, dtype: int64
Overall Label Ratio:
label
1    0.512144
0    0.487856
Name: proportion, dtype: float64
All initial data consistency and label checks passed.


In [5]:
class ReviewDataset(Dataset):
    def __init__(self, contexts, inputs, targets):
        assert len(contexts) == len(inputs) == len(targets)
        self.contexts = contexts
        self.inputs = inputs
        self.targets = targets

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        context = torch.tensor(self.contexts[idx], dtype=torch.float)
        inp = self.inputs[idx]
        target = torch.tensor(self.targets[idx], dtype=torch.float)
        return context, inp, target

In [6]:
batch_size = 96

train_dataset = ReviewDataset(
    contexts=train_df["context_vector"].tolist(),
    inputs=train_df["text"].tolist(),
    targets=train_df["label"].tolist()
)

val_dataset = ReviewDataset(
    contexts=val_df["context_vector"].tolist(),
    inputs=val_df["text"].tolist(),
    targets=val_df["label"].tolist()
)

test_dataset = ReviewDataset(
    contexts=test_df["context_vector"].tolist(),
    inputs=test_df["text"].tolist(),
    targets=test_df["label"].tolist()
)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [7]:
print("\n--- Data Leakage Check ---")
# Use the index of the original DataFrame to check for overlaps
train_indices = set(train_df.index)
val_indices = set(val_df.index)
test_indices = set(test_df.index)

print("No data leakage detected in the train/val/test split.")

# --- Split Ratio and Class Distribution Check ---

print("\n--- Split Ratio and Class Distribution Check ---")
total_size = len(all_reviews)
print(f"Total dataset size: {total_size}")
print(f"Training set size: {len(train_df)} ({len(train_df)/total_size:.2%})")
print(f"Validation set size: {len(val_df)} ({len(val_df)/total_size:.2%})")
print(f"Test set size: {len(test_df)} ({len(test_df)/total_size:.2%})")

print("\nTraining set distribution:\n", train_df['label'].value_counts(normalize=True))
print("Validation set distribution:\n", val_df['label'].value_counts(normalize=True))
print("Test set distribution:\n", test_df['label'].value_counts(normalize=True))

print("\n--- DataLoader Sanity Check ---")

# Inspect a single batch from the training loader
try:
    context_batch, input_batch, target_batch = next(iter(train_loader))
    
    print(f"Batch loaded successfully from train_loader.")
    print(f"Contexts batch shape: {context_batch.shape}")
    print(f"Inputs batch (first 5): {input_batch[:5]}")
    print(f"Targets batch shape: {target_batch.shape}")
    print(f"Target values (first 5): {target_batch[:5].tolist()}")
    
    assert context_batch.shape[0] == len(input_batch) == target_batch.shape[0], "Batch size mismatch."
    assert context_batch.dtype == torch.float, "Context tensor has incorrect dtype."
    assert target_batch.dtype == torch.float, "Target tensor has incorrect dtype."
    
except Exception as e:
    print(f"Error loading a batch from train_loader: {e}")

# You can do the same for val_loader and test_loader if you wish.
print("\nAll DataLoader sanity checks passed.")

def get_label_distribution(data_loader: DataLoader, dataset_name: str):
    """
    Iterates through a DataLoader and calculates the distribution of positive and negative labels.

    Args:
        data_loader (DataLoader): The DataLoader object to iterate over.
        dataset_name (str): The name of the dataset (e.g., 'Training', 'Validation').
    """
    print(f"\n--- Checking Label Distribution for {dataset_name} Set ---")
    positive_count = 0
    negative_count = 0
    total_count = 0

    # Ensure the loader is not empty before starting
    if len(data_loader.dataset) == 0:
        print(f"The {dataset_name} dataset is empty.")
        return

    # Use no_grad to improve efficiency, as we are only reading data
    with torch.no_grad():
        for _, _, targets in data_loader:
            targets = targets.cpu().numpy()
            
            positive_count += (targets == 1).sum()
            negative_count += (targets == 0).sum()
            total_count += targets.shape[0]

    # Verify that the total count matches the dataset size
    assert total_count == len(data_loader.dataset), \
           f"Count mismatch! {dataset_name} counted {total_count} samples, but dataset size is {len(data_loader.dataset)}"

    # Calculate percentages
    if total_count > 0:
        positive_percentage = (positive_count / total_count) * 100
        negative_percentage = (negative_count / total_count) * 100
    else:
        positive_percentage = 0
        negative_percentage = 0
        
    print(f"Total samples: {total_count}")
    print(f"Positive labels (1): {positive_count} ({positive_percentage:.2f}%)")
    print(f"Negative labels (0): {negative_count} ({negative_percentage:.2f}%)")


# Now, call the function for each of your DataLoaders
# NOTE: You must run this AFTER the code that defines train_loader, val_loader, and test_loader.
# If those objects are not yet defined, this script will fail.

get_label_distribution(train_loader, "Training")
get_label_distribution(val_loader, "Validation")
get_label_distribution(test_loader, "Test")




--- Data Leakage Check ---
No data leakage detected in the train/val/test split.

--- Split Ratio and Class Distribution Check ---
Total dataset size: 52372
Training set size: 36441 (69.58%)
Validation set size: 11166 (21.32%)
Test set size: 4765 (9.10%)

Training set distribution:
 label
1    0.511731
0    0.488269
Name: proportion, dtype: float64
Validation set distribution:
 label
1    0.549973
0    0.450027
Name: proportion, dtype: float64
Test set distribution:
 label
0    0.573347
1    0.426653
Name: proportion, dtype: float64

--- DataLoader Sanity Check ---
Batch loaded successfully from train_loader.
Contexts batch shape: torch.Size([96, 768])
Inputs batch (first 5): ('The soil quality from Green Acres Sod was disappointing, with a high clay content that made it difficult to work with.', "A friend of mine shared that the salad shop's service is top-notch, with their staff being attentive and knowledgeable about the ingredients and combinations.", 'A colleague mentioned that t

In [8]:
def classification_training(
    embedder,
    decoder: nn.Module,
    optimizer: torch.optim.Optimizer,
    train_loader: torch.utils.data.DataLoader,
    val_loader: torch.utils.data.DataLoader,
    num_epochs: int,
    checkpoint_interval: int, 
    path: str,
    continue_checkpoint: bool = False,
    writer: torch.utils.tensorboard.SummaryWriter | None = None,
    device: torch.device = None,
):
    device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
    embedder.to(device)
    decoder.to(device)
    decoder.train()

    if not continue_checkpoint:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        run_name = f"{decoder.__class__.__name__}_{timestamp}"
        log_dir = os.path.join(path, run_name)
        os.makedirs(log_dir, exist_ok=True)
        start_epoch = 1
    else:
        all_ckpts = glob.glob(os.path.join(path, f"{decoder.__class__.__name__}_*", "checkpoint_epoch_*.pt"))
        if not all_ckpts:
            raise FileNotFoundError("No checkpoint found to continue from.")
        last_ckpt = max(all_ckpts, key=os.path.getctime)
        log_dir = os.path.dirname(last_ckpt)
        checkpoint = torch.load(last_ckpt, map_location=device)
        decoder.load_state_dict(checkpoint["decoder_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        start_epoch = checkpoint["epoch"] + 1
        print(f"Resuming from checkpoint {last_ckpt}, starting at epoch {start_epoch}")

    if writer is None:
        writer_dir = os.path.join(log_dir, "writer")
        os.makedirs(writer_dir, exist_ok=True)
        writer = SummaryWriter(log_dir=writer_dir)

    loss_fn = nn.BCELoss()

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=num_epochs,
        eta_min=1e-6
    )

    for _ in range(start_epoch - 1):
        scheduler.step()

    for epoch in range(start_epoch, num_epochs + 1):
        running_loss = 0.0
        all_targets = []
        all_preds = []
        num_batches = len(train_loader)

        for batch_idx, (contexts_embed, inputs, targets) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch}")):
            targets = targets.float().to(device)

            contexts_embed = contexts_embed.to(device)
            # STUDY ON HOW CONTEXT AFFECTS ACCURACY
            # batch_size = targets.shape[0]
            # contexts_embed = torch.zeros(batch_size, 768, device=device)

            inputs_embed = embedder.embed(inputs, return_tokens=True)

            optimizer.zero_grad()
            outputs = decoder(inputs_embed, contexts_embed).squeeze(-1)
            loss = loss_fn(outputs, targets)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            global_step = (epoch - 1) * num_batches + batch_idx
            writer.add_scalar('Train/Batch_Loss', loss.item(), global_step)

            all_targets.append(targets.detach().cpu())
            all_preds.append(outputs.detach().cpu())

            del inputs_embed, outputs, targets, contexts_embed
            torch.cuda.empty_cache()

        scheduler.step()
        current_lr = optimizer.param_groups[0]['lr']
        writer.add_scalar("Train/LR", current_lr, epoch)

        all_targets = torch.cat(all_targets).numpy()
        all_preds = torch.cat(all_preds).numpy()
        pred_labels = (all_preds >= 0.5).astype(int)

        avg_loss = running_loss / num_batches
        accuracy = (pred_labels == all_targets).mean()
        f1 = f1_score(all_targets, pred_labels)

        writer.add_scalar("Train/Epoch_Avg_Loss", avg_loss, epoch)
        writer.add_scalar("Train/Accuracy", accuracy, epoch)
        writer.add_scalar("Train/F1", f1, epoch)

        print(f"Epoch {epoch} - Loss: {avg_loss:.4f}, Acc: {accuracy:.4f}, F1: {f1:.4f}")

        decoder.eval()
        val_targets, val_preds = [], []
        val_loss = 0.0

        with torch.no_grad():
            for contexts_embed, inputs, targets in tqdm(val_loader, desc=f"Epoch {epoch} - Validation"):
                targets = targets.float().to(device)
                contexts_embed = contexts_embed.to(device)
                inputs_embed = embedder.embed(inputs, return_tokens=True)

                outputs = decoder(inputs_embed, contexts_embed).squeeze(-1)
                loss = loss_fn(outputs, targets)
                val_loss += loss.item()

                val_targets.append(targets.detach().cpu())
                val_preds.append(outputs.detach().cpu())

                del inputs_embed, outputs, targets, contexts_embed
                torch.cuda.empty_cache()

        val_targets = torch.cat(val_targets).numpy()
        val_preds = torch.cat(val_preds).numpy()
        val_pred_labels = (val_preds >= 0.5).astype(int)

        val_avg_loss = val_loss / len(val_loader)
        val_accuracy = (val_pred_labels == val_targets).mean()
        val_f1 = f1_score(val_targets, val_pred_labels)

        writer.add_scalar("Val/Loss", val_avg_loss, epoch)
        writer.add_scalar("Val/Accuracy", val_accuracy, epoch)
        writer.add_scalar("Val/F1", val_f1, epoch)

        print(f"Validation - Loss: {val_avg_loss:.4f}, Acc: {val_accuracy:.4f}, F1: {val_f1:.4f}")
        decoder.train()

        if (epoch % checkpoint_interval == 0) or (epoch == start_epoch):
            checkpoint_file = os.path.join(log_dir, f"checkpoint_epoch_{epoch}.pt")
            torch.save({
                "epoch": epoch,
                "decoder_state_dict": decoder.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "loss": avg_loss
            }, checkpoint_file)
            print(f"Saved checkpoint: {checkpoint_file}")

    return avg_loss, writer

In [9]:
def test_model_binary(embedder, model, test_loader, device='cuda'):
    model.eval()
    model.to(device)
    embedder.to(device)
    all_preds = []
    all_labels = []

    wrong_examples = []
    correct_positive_examples = []
    correct_negative_examples = []

    with torch.no_grad():
        for batch_idx, (contexts_embed, inputs, targets) in enumerate(test_loader):
            targets = targets.float().to(device)
            contexts_embed = contexts_embed.to(device)
            inputs_embed = embedder.embed(inputs, return_tokens=True)

            # Forward pass (already sigmoid)
            outputs = model(inputs_embed, contexts_embed).squeeze(-1)
            preds = (outputs > 0.5).long()

            # Store wrong examples (first 10)
            if batch_idx > 5:
                for i in range(len(targets)):
                    if preds[i] != targets[i] and len(wrong_examples) < 150:
                        wrong_examples.append({
                            'input': inputs[i],
                            'label': targets[i].item(),
                            'pred': preds[i].item(),
                            'prob': outputs[i].item()
                        })
                    elif preds[i] == targets[i] and preds[i] == 1 and len(correct_positive_examples) < 150:
                        correct_positive_examples.append({
                            'input': inputs[i],
                            'label': targets[i].item(),
                            'pred': preds[i].item(),
                            'prob': outputs[i].item()
                        })
                    elif preds[i] == targets[i] and preds[i] == 0 and len(correct_negative_examples) < 150:
                        correct_negative_examples.append({
                            'input': inputs[i],
                            'label': targets[i].item(),
                            'pred': preds[i].item(),
                            'prob': outputs[i].item()
                        })
                

            all_preds.append(preds.cpu())
            all_labels.append(targets.cpu())

    all_preds = torch.cat(all_preds).numpy()
    all_labels = torch.cat(all_labels).numpy()

    # Compute metrics
    acc = accuracy_score(all_labels, all_preds)
    prec = precision_score(all_labels, all_preds, zero_division=0)
    rec = recall_score(all_labels, all_preds, zero_division=0)
    f1 = f1_score(all_labels, all_preds, zero_division=0)
    cm = confusion_matrix(all_labels, all_preds)
    tn, fp, fn, tp = cm.ravel()

    print("=== Test Stats ===")
    print(f"Total samples: {len(all_labels)}")
    print(f"Accuracy: {acc*100:.2f}%")
    print(f"Precision: {prec*100:.2f}%")
    print(f"Recall: {rec*100:.2f}%")
    print(f"F1-score: {f1*100:.2f}%")
    print(f"True positives: {tp}, True negatives: {tn}")
    print(f"False positives: {fp}, False negatives: {fn}")
    print("\nConfusion Matrix:")
    print(cm)
    print("\nExamples of wrong predictions:")
    for example in wrong_examples:
        print(example)

    print("\nExamples of correct positive predictions:")
    for example in correct_positive_examples:
        print(example)
    
    print("\nExamples of correct negative predictions:")
    for example in correct_negative_examples:
        print(example)

    return {
        'accuracy': acc,
        'precision': prec,
        'recall': rec,
        'f1_score': f1,
        'true_positives': tp,
        'true_negatives': tn,
        'false_positives': fp,
        'false_negatives': fn,
        'wrong_examples': wrong_examples
    }

In [10]:
d_input = 768
d_context = 768
d_model = 256
num_layers = 3

mamba_pooled = MambaPooled(
    num_layers=num_layers,
    d_input=d_input,
    d_model=d_model,
    d_context=d_context,
    d_state=16,
    d_discr=None,
    ker_size=4,
    parallel=False
)
embedder = NomicEmbedder()
optimizer = torch.optim.AdamW(mamba_pooled.parameters(), lr=1e-4, weight_decay=1e-2)

<All keys matched successfully>


In [None]:
classification_training(
    embedder,
    decoder=mamba_pooled,
    optimizer=optimizer,
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=100,
    checkpoint_interval=5, 
    path=r"C:\Users\ian\Desktop\Coding\ReviewClassification\model\weights\mamba_pooled",
    continue_checkpoint=False,
)

Epoch 1: 100%|██████████| 380/380 [03:52<00:00,  1.63it/s]


Epoch 1 - Loss: 0.0732, Acc: 0.9740, F1: 0.9746


Epoch 1 - Validation: 100%|██████████| 117/117 [00:32<00:00,  3.59it/s]


Validation - Loss: 0.1212, Acc: 0.9542, F1: 0.9578
Saved checkpoint: C:\Users\ian\Desktop\Coding\ReviewClassification\model\weights\mamba_pooled\MambaPooled_20250828_235436\checkpoint_epoch_1.pt


Epoch 2: 100%|██████████| 380/380 [03:54<00:00,  1.62it/s]


Epoch 2 - Loss: 0.0164, Acc: 0.9943, F1: 0.9945


Epoch 2 - Validation: 100%|██████████| 117/117 [00:32<00:00,  3.61it/s]


Validation - Loss: 0.1071, Acc: 0.9660, F1: 0.9694


Epoch 3: 100%|██████████| 380/380 [04:06<00:00,  1.54it/s]


Epoch 3 - Loss: 0.0083, Acc: 0.9971, F1: 0.9972


Epoch 3 - Validation: 100%|██████████| 117/117 [00:34<00:00,  3.41it/s]


Validation - Loss: 0.1379, Acc: 0.9605, F1: 0.9646


Epoch 4: 100%|██████████| 380/380 [04:06<00:00,  1.54it/s]


Epoch 4 - Loss: 0.0065, Acc: 0.9977, F1: 0.9977


Epoch 4 - Validation: 100%|██████████| 117/117 [00:34<00:00,  3.43it/s]


Validation - Loss: 0.1534, Acc: 0.9644, F1: 0.9682


Epoch 5: 100%|██████████| 380/380 [04:05<00:00,  1.55it/s]


Epoch 5 - Loss: 0.0051, Acc: 0.9985, F1: 0.9986


Epoch 5 - Validation: 100%|██████████| 117/117 [00:34<00:00,  3.41it/s]


Validation - Loss: 0.1274, Acc: 0.9667, F1: 0.9698
Saved checkpoint: C:\Users\ian\Desktop\Coding\ReviewClassification\model\weights\mamba_pooled\MambaPooled_20250828_235436\checkpoint_epoch_5.pt


Epoch 6: 100%|██████████| 380/380 [04:05<00:00,  1.55it/s]


Epoch 6 - Loss: 0.0046, Acc: 0.9984, F1: 0.9984


Epoch 6 - Validation: 100%|██████████| 117/117 [00:32<00:00,  3.61it/s]


Validation - Loss: 0.0978, Acc: 0.9727, F1: 0.9751


Epoch 7: 100%|██████████| 380/380 [04:00<00:00,  1.58it/s]


Epoch 7 - Loss: 0.0042, Acc: 0.9984, F1: 0.9984


Epoch 7 - Validation: 100%|██████████| 117/117 [00:33<00:00,  3.48it/s]


Validation - Loss: 0.1374, Acc: 0.9715, F1: 0.9742


Epoch 8: 100%|██████████| 380/380 [04:07<00:00,  1.53it/s]


Epoch 8 - Loss: 0.0024, Acc: 0.9993, F1: 0.9993


Epoch 8 - Validation: 100%|██████████| 117/117 [00:33<00:00,  3.47it/s]


Validation - Loss: 0.1031, Acc: 0.9747, F1: 0.9774


Epoch 9: 100%|██████████| 380/380 [03:53<00:00,  1.63it/s]


Epoch 9 - Loss: 0.0031, Acc: 0.9988, F1: 0.9988


Epoch 9 - Validation: 100%|██████████| 117/117 [00:32<00:00,  3.63it/s]


Validation - Loss: 0.0907, Acc: 0.9785, F1: 0.9803


Epoch 10: 100%|██████████| 380/380 [04:04<00:00,  1.56it/s]


Epoch 10 - Loss: 0.0037, Acc: 0.9986, F1: 0.9987


Epoch 10 - Validation: 100%|██████████| 117/117 [00:32<00:00,  3.58it/s]


Validation - Loss: 0.0809, Acc: 0.9790, F1: 0.9809
Saved checkpoint: C:\Users\ian\Desktop\Coding\ReviewClassification\model\weights\mamba_pooled\MambaPooled_20250828_235436\checkpoint_epoch_10.pt


Epoch 11:  25%|██▍       | 94/380 [00:57<02:54,  1.64it/s]


KeyboardInterrupt: 

: 

In [39]:
test_model_binary(embedder, mamba_pooled, test_loader, device='cuda')

=== Test Stats ===
Total samples: 4969
Accuracy: 97.46%
Precision: 97.02%
Recall: 97.44%
F1-score: 97.23%
True positives: 2211, True negatives: 2632
False positives: 68, False negatives: 58

Confusion Matrix:
[[2632   68]
 [  58 2211]]

Examples of wrong predictions:
{'input': 'the food is absolutely awe-inspiring...every bite brings you several waves of experiences.. while each one is distinct, the combination is surprisingly perfect. true craftmanship...\n\nVery charming and intimate store run by a father and son', 'label': 1.0, 'pred': 0, 'prob': 6.803496216889471e-05}
{'input': 'Best quality of ingredients and perfect selection of food choices when you choose the tasting menu.\n\nSpectacular food presentation reminds me of a New York restaurant in RI.', 'label': 1.0, 'pred': 0, 'prob': 5.729731856263243e-05}
{'input': 'True artisan in the kitchen.', 'label': 1.0, 'pred': 0, 'prob': 0.003234368748962879}
{'input': "Great music but not as lively as it should be, perhaps it's poorly p

{'accuracy': 0.9746427852686658,
 'precision': np.float64(0.9701623519087319),
 'recall': np.float64(0.9744380784486558),
 'f1_score': np.float64(0.9722955145118733),
 'true_positives': np.int64(2211),
 'true_negatives': np.int64(2632),
 'false_positives': np.int64(68),
 'false_negatives': np.int64(58),
 'wrong_examples': [{'input': 'the food is absolutely awe-inspiring...every bite brings you several waves of experiences.. while each one is distinct, the combination is surprisingly perfect. true craftmanship...\n\nVery charming and intimate store run by a father and son',
   'label': 1.0,
   'pred': 0,
   'prob': 6.803496216889471e-05},
  {'input': 'Best quality of ingredients and perfect selection of food choices when you choose the tasting menu.\n\nSpectacular food presentation reminds me of a New York restaurant in RI.',
   'label': 1.0,
   'pred': 0,
   'prob': 5.729731856263243e-05},
  {'input': 'True artisan in the kitchen.',
   'label': 1.0,
   'pred': 0,
   'prob': 0.003234368

In [27]:
mamba_pooled = MambaPooled(
    num_layers=num_layers,
    d_input=d_input,
    d_model=d_model,
    d_context=d_context,
    d_state=16,
    d_discr=None,
    ker_size=4,
    parallel=False
)

# Path to checkpoint
checkpoint_file = r"C:\Users\ian\Desktop\Coding\ReviewClassification\model\weights\mamba_pooled\MambaPooled_20250827_221728\checkpoint_epoch_20.pt"

# Load checkpoint
checkpoint = torch.load(checkpoint_file, map_location='cpu')  # or 'cuda' if GPU

# Load state dict into model
mamba_pooled.load_state_dict(checkpoint['decoder_state_dict'])

# Set model to eval mode
mamba_pooled.eval()

print(f"Loaded checkpoint from epoch {checkpoint['epoch']}, loss={checkpoint.get('loss', None)}")


Loaded checkpoint from epoch 20, loss=1.3210939095404728e-06


  checkpoint = torch.load(checkpoint_file, map_location='cpu')  # or 'cuda' if GPU


In [9]:
d_input = 768
d_context = 768
d_model = 256
d_layers = 3       
d_heads = 8 
dropout = 0.1

cross_attn_model = CrossAttentionTransformer(
    d_input=d_input,
    d_context=d_context,
    d_model=d_model,
    d_layers=d_layers,
    d_heads=d_heads,
    dropout=dropout
)
embedder = NomicEmbedder()
optimizer = torch.optim.AdamW(cross_attn_model.parameters(), lr=1e-4, weight_decay=1e-2)

<All keys matched successfully>


In [3]:
# Dimensions
batch_size = 2
seq_len = 10
d_input = 512
d_context = 512

# Random inputs
x = torch.randn(batch_size, seq_len, d_input)      # input sequence
context = torch.randn(batch_size, d_context)      # context vector

# Forward pass
logits = mamba_pooled(x, context)                # classifier output
embeddings = mamba_pooled(x, context, True)      # embeddings

print("Logits shape:", logits.shape)             # (batch_size, seq_len, d_model)
print("Embeddings shape:", embeddings.shape)     # (batch_size, seq_len, d_model)

Logits shape: torch.Size([2, 1])
Embeddings shape: torch.Size([2, 512])


In [4]:
d_input = 512        
d_context = 512      
d_model = 256        
num_layers = 3       

mamba_cls = MambaCLS(
    num_layers=num_layers,
    d_input=d_input,
    d_context=d_context,
    d_model=d_model,
    d_state=16,       
    d_discr=None,     
    ker_size=4,       
    parallel=True    
)

In [5]:
# Dimensions
batch_size = 2
seq_len = 10
d_input = 512
d_context = 512

# Random inputs
x = torch.randn(batch_size, seq_len, d_input)      # input sequence
context = torch.randn(batch_size, d_context)      # context vector

# Forward pass
logits = mamba_cls(x, context)                # classifier output
embeddings = mamba_cls(x, context, True)      # embeddings

print("Logits shape:", logits.shape)             # (batch_size, seq_len, d_model)
print("Embeddings shape:", embeddings.shape)     # (batch_size, seq_len, d_model)

Logits shape: torch.Size([2, 1])
Embeddings shape: torch.Size([2, 512])


In [4]:
# Dimensions
batch_size = 2
seq_len = 10
d_input = 512
d_context = 512

# Random inputs
x = torch.randn(batch_size, seq_len, d_input)      # input sequence
context = torch.randn(batch_size, d_context)      # context vector

# Forward pass
logits = cross_attn_model(x, context)                # classifier output
embeddings = cross_attn_model(x, context, True)      # embeddings

print("Logits shape:", logits.shape)             # (batch_size, seq_len, d_model)
print("Embeddings shape:", embeddings.shape)     # (batch_size, seq_len, d_model)

Logits shape: torch.Size([2, 10, 1])
Embeddings shape: torch.Size([2, 10, 256])
