# AlexNet + BERT End-to-End Fusion (Using Trained BERT `.pt` Weights)

This notebook does **not** expect precomputed BERT features.

Instead it:

1. Rebuilds the same BERT classifier architecture you used in `DM_BERT.ipynb`.
2. **Loads your trained BERT weights** from the `.pt` file (state_dict).
3. Uses that frozen BERT model inside a fusion network together with AlexNet image features.

You get:
- An **AlexNet-only** image classifier baseline.
- A **fusion model**: AlexNet (images) + BERT (text) trained together.


In [2]:
# ====== Cell 1: Imports ======
import os

import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models

from sklearn.model_selection import train_test_split

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

from transformers import AutoTokenizer, AutoModel


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# ====== Cell 2: Paths and basic config ======
# TODO: update these paths to match your local setup
CSV_PATH = "fake_job_postings.csv"                 # your CSV
IMG_DIR = "images"                                 # folder containing 1.png, 2.png, ...
BERT_WEIGHTS_PATH = "bert_fake_job_classifier.pt"  # .pt file from DM_BERT.ipynb

MODEL_NAME = "bert-base-uncased"
MAX_LEN = 256

BATCH_SIZE = 32
VAL_FRACTION = 0.2
LR_ALEXNET = 1e-4
LR_FUSION = 1e-4
MAX_EPOCHS_ALEXNET = 8
MAX_EPOCHS_FUSION = 8

device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

In [4]:
# ====== Cell 3: Load CSV and construct text/labels ======
df = pd.read_csv(CSV_PATH)
print("Columns:", df.columns.tolist())
print("Rows:", len(df))

# Rebuild the 'text' column similar to DM_BERT notebook
df.fillna("", inplace=True)
df["text"] = (
    df["title"].astype(str) + " " +
    df["location"].astype(str) + " " +
    df["department"].astype(str) + " " +
    df["company_profile"].astype(str) + " " +
    df["description"].astype(str) + " " +
    df["requirements"].astype(str) + " " +
    df["benefits"].astype(str)
)

df["label"] = df["fraudulent"].astype(int)

df[["job_id", "label", "text"]].head()

Columns: ['job_id', 'title', 'location', 'department', 'salary_range', 'company_profile', 'description', 'requirements', 'benefits', 'telecommuting', 'has_company_logo', 'has_questions', 'employment_type', 'required_experience', 'required_education', 'industry', 'function', 'fraudulent']
Rows: 17880


Unnamed: 0,job_id,label,text
0,1,0,"Marketing Intern US, NY, New York Marketing We..."
1,2,0,"Customer Service - Cloud Video Production NZ, ..."
2,3,0,"Commissioning Machinery Assistant (CMA) US, IA..."
3,4,0,"Account Executive - Washington DC US, DC, Wash..."
4,5,0,"Bill Review Manager US, FL, Fort Worth SpotSo..."


In [5]:
from sklearn.utils.class_weight import compute_class_weight
import numpy as np

# class_weights[0] = weight for label 0, class_weights[1] = weight for label 1
class_weights = compute_class_weight(
    class_weight="balanced",
    classes=np.array([0, 1]),
    y=df["label"].values,
)
print("Class weights (0, 1):", class_weights)

Class weights (0, 1): [ 0.52544963 10.32332564]


In [6]:
# ====== Cell 4: Train/validation split ======
train_df, val_df = train_test_split(
    df,
    test_size=VAL_FRACTION,
    stratify=df["label"],
    random_state=42,
)

print("Train size:", len(train_df))
print("Val size:", len(val_df))

Train size: 14304
Val size: 3576


In [7]:
# ====== Cell 5: Tokenizer ======
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer

BertTokenizerFast(name_or_path='bert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=False, added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
)

In [8]:
# ====== Cell 6: Dataset classes ======
class JobImageDataset(Dataset):
    """Image-only dataset for AlexNet baseline."""
    def __init__(self, df, img_dir, transform=None):
        self.df = df.reset_index(drop=True)
        self.img_dir = img_dir
        self.transform = transform

        mask = self.df["job_id"].apply(
            lambda x: os.path.exists(os.path.join(img_dir, f"{x}.png"))
        )
        self.df = self.df[mask].reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        job_id = int(row["job_id"])
        label = int(row["label"])

        img_path = os.path.join(self.img_dir, f"{job_id}.png")
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)

        return image, label, job_id


class JobImageTextDataset(Dataset):
    """Image + text dataset for fusion model."""
    def __init__(self, df, img_dir, tokenizer, max_len, transform=None):
        self.df = df.reset_index(drop=True)
        self.img_dir = img_dir
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.transform = transform

        mask = self.df["job_id"].apply(
            lambda x: os.path.exists(os.path.join(img_dir, f"{x}.png"))
        )
        self.df = self.df[mask].reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        job_id = int(row["job_id"])
        label = int(row["label"])
        text = str(row["text"])

        img_path = os.path.join(self.img_dir, f"{job_id}.png")
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)

        encoding = self.tokenizer(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
        input_ids = encoding["input_ids"].squeeze(0)
        attention_mask = encoding["attention_mask"].squeeze(0)

        return image, input_ids, attention_mask, label, job_id

In [9]:
# ====== Cell 7: Transforms and dataloaders for AlexNet baseline ======
train_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])

val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])

train_img_dataset = JobImageDataset(train_df, IMG_DIR, transform=train_transform)
val_img_dataset   = JobImageDataset(val_df,   IMG_DIR, transform=val_transform)

print("AlexNet train samples:", len(train_img_dataset))
print("AlexNet val samples:", len(val_img_dataset))

train_img_loader = DataLoader(train_img_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_img_loader   = DataLoader(val_img_dataset,   batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

AlexNet train samples: 14238
AlexNet val samples: 3558


In [10]:
# ====== Cell 8: AlexNet LightningModule (image-only classifier) ======
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.metrics import (
    roc_auc_score,
    average_precision_score,
    precision_recall_fscore_support,
    balanced_accuracy_score,
)

class AlexNetClassifier(pl.LightningModule):
    def __init__(
        self,
        num_classes: int = 2,
        lr: float = 1e-4,
        class_weights: Optional[np.ndarray] = None,
    ):
        super().__init__()
        self.save_hyperparameters(ignore=["class_weights"])

        # --- Model ---
        self.alexnet = models.alexnet(weights=models.AlexNet_Weights.IMAGENET1K_V1)
        in_feats = self.alexnet.classifier[-1].in_features
        self.alexnet.classifier[-1] = nn.Linear(in_feats, num_classes)

        # --- Loss weights ---
        if class_weights is not None:
            cw = torch.as_tensor(class_weights, dtype=torch.float32)
            self.register_buffer("class_weights", cw)
        else:
            self.class_weights = None

        self.lr = lr

        # storage for validation epoch
        self.val_logits = []
        self.val_labels = []

    def forward(self, x):
        return self.alexnet(x)

    # ---------- TRAIN ----------
    def training_step(self, batch, batch_idx):
        images, labels, _ = batch
        logits = self(images)
        loss = F.cross_entropy(
            logits,
            labels,
            weight=self.class_weights if self.class_weights is not None else None,
        )

        preds = torch.argmax(logits, dim=1)
        acc = (preds == labels).float().mean()

        self.log("train_loss", loss, prog_bar=True, on_epoch=True, on_step=False)
        self.log("train_acc", acc, prog_bar=True, on_epoch=True, on_step=False)
        return loss

    # ---------- VALIDATION ----------
    def validation_step(self, batch, batch_idx):
        images, labels, _ = batch
        logits = self(images)
        loss = F.cross_entropy(
            logits,
            labels,
            weight=self.class_weights if self.class_weights is not None else None,
        )

        preds = torch.argmax(logits, dim=1)
        acc = (preds == labels).float().mean()

        self.log("val_loss", loss, prog_bar=True, on_epoch=True, on_step=False)
        self.log("val_acc",  acc,  prog_bar=True, on_epoch=True, on_step=False)

        # stash for epoch-level metrics
        self.val_logits.append(logits.detach().cpu())
        self.val_labels.append(labels.detach().cpu())

        return loss

    def on_validation_epoch_end(self):
        if not self.val_logits:
            return

        logits = torch.cat(self.val_logits, dim=0)
        labels = torch.cat(self.val_labels, dim=0)

        self.val_logits.clear()
        self.val_labels.clear()

        probs_pos = torch.softmax(logits, dim=1)[:, 1].numpy()
        y_true = labels.numpy()
        y_pred = (probs_pos >= 0.5).astype("int64")

        # --- Metrics focused on the fraud (positive) class ---
        try:
            roc_auc = roc_auc_score(y_true, probs_pos)
        except ValueError:
            roc_auc = float("nan")

        try:
            pr_auc = average_precision_score(y_true, probs_pos)
        except ValueError:
            pr_auc = float("nan")

        try:
            prec, rec, f1, _ = precision_recall_fscore_support(
                y_true, y_pred, average="binary", pos_label=1, zero_division=0
            )
            bal_acc = balanced_accuracy_score(y_true, y_pred)
        except Exception:
            prec = rec = f1 = bal_acc = float("nan")

        # log these per epoch
        self.log("PR_AUC",      pr_auc,  prog_bar=True,  on_epoch=True)
        self.log("ROC_AUC",     roc_auc, prog_bar=True, on_epoch=True)
        self.log("F1_fraud",    f1,      prog_bar=True,  on_epoch=True)
        self.log("Prec_fraud",  prec,    prog_bar=True, on_epoch=True)
        self.log("Rec_fraud",   rec,     prog_bar=True, on_epoch=True)
        self.log("BalancedAcc", bal_acc, prog_bar=False, on_epoch=True)

    # ---------- OPTIMIZER ----------
    def configure_optimizers(self):
        return optim.AdamW(self.parameters(), lr=self.lr)

In [11]:
# ====== Cell 9: Train AlexNet baseline ======
checkpoint_cb = ModelCheckpoint(
    monitor="val_loss",
    mode="min",
    save_top_k=1,
    filename="alexnet_image_only-{epoch:02d}-{val_loss:.4f}",
)
early_stop_cb = EarlyStopping(
    monitor="val_loss",
    mode="min",
    patience=3,
)

alexnet_model = AlexNetClassifier(
    num_classes=2,
    lr=LR_ALEXNET,
    class_weights=class_weights,
)

trainer = Trainer(
    max_epochs=MAX_EPOCHS_ALEXNET,
    accelerator="auto",
    callbacks=[checkpoint_cb, early_stop_cb],
    log_every_n_steps=10,
)

trainer.fit(alexnet_model, train_img_loader, val_img_loader)

best_alexnet_ckpt = checkpoint_cb.best_model_path
best_alexnet_ckpt

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
/opt/anaconda3/envs/dmenv/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default

  | Name    | Type    | Params | Mode  | FLOPs
----------------------------------------------------
0 | alexnet | AlexNet | 57.0 M | train | 0    
----------------------------------------------------
57.0 M    Trainable params
0         Non-trainable params
57.0 M    Total params
228.048   Total estimated model params size (MB)
24        Modules in train mode
0         Modules i

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/opt/anaconda3/envs/dmenv/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:434: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.


                                                                           

/opt/anaconda3/envs/dmenv/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:434: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.


Epoch 0: 100%|██████████| 445/445 [04:01<00:00,  1.85it/s, v_num=6]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/112 [00:00<?, ?it/s][A
Validation DataLoader 0:   1%|          | 1/112 [00:00<00:03, 34.45it/s][A
Validation DataLoader 0:   2%|▏         | 2/112 [00:00<00:31,  3.45it/s][A
Validation DataLoader 0:   3%|▎         | 3/112 [00:01<00:41,  2.62it/s][A
Validation DataLoader 0:   4%|▎         | 4/112 [00:01<00:45,  2.37it/s][A
Validation DataLoader 0:   4%|▍         | 5/112 [00:02<00:47,  2.24it/s][A
Validation DataLoader 0:   5%|▌         | 6/112 [00:02<00:48,  2.17it/s][A
Validation DataLoader 0:   6%|▋         | 7/112 [00:03<00:49,  2.11it/s][A
Validation DataLoader 0:   7%|▋         | 8/112 [00:03<00:50,  2.06it/s][A
Validation DataLoader 0:   8%|▊         | 9/112 [00:04<00:50,  2.02it/s][A
Validation DataLoader 0:   9%|▉         | 10/112 [00:04<00:50,  2.00it/s][A
Valid

`Trainer.fit` stopped: `max_epochs=8` reached.


Epoch 7: 100%|██████████| 445/445 [05:02<00:00,  1.47it/s, v_num=6, val_loss=0.481, val_acc=0.883, PR_AUC=0.406, ROC_AUC=0.842, F1_fraud=0.307, Prec_fraud=0.215, Rec_fraud=0.532, train_loss=0.568, train_acc=0.852]


'/Users/savitajkaur/Documents/DM_Project/lightning_logs/version_6/checkpoints/alexnet_image_only-epoch=07-val_loss=0.4807.ckpt'

In [19]:
# Load the best checkpoint
best_model = AlexNetClassifier.load_from_checkpoint(
    best_alexnet_ckpt,
    num_classes=2,
    lr=LR_ALEXNET,
    class_weights=class_weights
)

# Extract the pure PyTorch state_dict
alexnet_state_dict = best_model.state_dict()

# Save it as a .pt file
torch.save(alexnet_state_dict, "alexnet_best.pt")

print("Saved best model weights to alexnet_best.pt")

Saved best model weights to alexnet_best.pt


In [21]:
# ====== Cell 10: Rebuild BERT classifier and load your .pt weights ======
class BERTClassifier(nn.Module):
    """Matches the architecture you used in DM_BERT.ipynb."""
    def __init__(self, n_classes):
        super().__init__()
        self.bert = AutoModel.from_pretrained(MODEL_NAME)
        self.drop = nn.Dropout(p=0.3)
        self.out = nn.Linear(self.bert.config.hidden_size, n_classes)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled = outputs.pooler_output
        x = self.drop(pooled)
        return self.out(x)

    def encode(self, input_ids, attention_mask):
        """Return pooled BERT features (no dropout or final layer)."""
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        
        return outputs.pooler_output

# instantiate and load state_dict
bert_model = BERTClassifier(n_classes=2)
state_dict = torch.load(BERT_WEIGHTS_PATH, map_location="cpu")
bert_model.load_state_dict(state_dict)
bert_model.to(device)
bert_model.eval()

# freeze BERT for fusion (you can unfreeze later if you want)
for p in bert_model.parameters():
    p.requires_grad = False

print("Loaded BERT weights from", BERT_WEIGHTS_PATH)

Loaded BERT weights from bert_fake_job_classifier.pt


In [31]:
# ====== Cell 11: Fusion datasets and dataloaders ======
train_fusion_dataset = JobImageTextDataset(
    train_df, IMG_DIR, tokenizer=tokenizer, max_len=MAX_LEN, transform=train_transform
)
val_fusion_dataset = JobImageTextDataset(
    val_df, IMG_DIR, tokenizer=tokenizer, max_len=MAX_LEN, transform=val_transform
)

print("Fusion train samples:", len(train_fusion_dataset))
print("Fusion val samples:", len(val_fusion_dataset))

train_fusion_loader = DataLoader(train_fusion_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_fusion_loader   = DataLoader(val_fusion_dataset,   batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

Fusion train samples: 14238
Fusion val samples: 3558


In [37]:
# ====== Cell 12: AlexNet + BERT fusion LightningModule ======
class AlexNetBertFusion(pl.LightningModule):
    def __init__(
        self,
        bert_model,
        num_classes: int = 2,
        lr: float = 1e-4,
        class_weights: np.ndarray | None = None,
    ):
        super().__init__()
        self.save_hyperparameters(ignore=["bert_model", "class_weights"])

        # ---- AlexNet backbone as feature extractor ----
        self.alexnet = models.alexnet(weights=models.AlexNet_Weights.IMAGENET1K_V1)
        num_ftrs = self.alexnet.classifier[6].in_features  # usually 4096
        self.alexnet.classifier[6] = nn.Identity()        # output = feature vector

        # ---- BERT encoder (already loaded + frozen outside) ----
        self.bert_model = bert_model
        bert_dim = self.bert_model.bert.config.hidden_size

        fusion_dim = num_ftrs + bert_dim

        self.classifier = nn.Sequential(
            nn.Linear(fusion_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes),
        )

        # class weights (for fraud emphasis)
        if class_weights is not None:
            cw = torch.as_tensor(class_weights, dtype=torch.float32)
            self.register_buffer("class_weights", cw)
        else:
            self.class_weights = None

        self.lr = lr

        # store logits/labels per val epoch
        self.val_logits = []
        self.val_labels = []

    def forward(self, images, input_ids, attention_mask):
        alex_feat = self.alexnet(images)  # [B, 4096]
        bert_feat = self.bert_model.encode(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )                                  # [B, hidden]
        fusion = torch.cat([alex_feat, bert_feat], dim=1)
        logits = self.classifier(fusion)
        return logits

    # ---------- TRAIN ----------
    def training_step(self, batch, batch_idx):
        images, input_ids, attention_mask, labels, _ = batch
        images = images.to(self.device)
        input_ids = input_ids.to(self.device)
        attention_mask = attention_mask.to(self.device)
        labels = labels.to(self.device)

        logits = self(images, input_ids, attention_mask)
        loss = F.cross_entropy(
            logits,
            labels,
            weight=self.class_weights if self.class_weights is not None else None,
        )

        preds = torch.argmax(logits, dim=1)
        acc = (preds == labels).float().mean()

        self.log("train_loss", loss, prog_bar=True, on_epoch=True, on_step=False)
        self.log("train_acc",  acc,  prog_bar=True, on_epoch=True, on_step=False)

        return loss

    # ---------- VAL STEP ----------
    def validation_step(self, batch, batch_idx):
        images, input_ids, attention_mask, labels, _ = batch
        images = images.to(self.device)
        input_ids = input_ids.to(self.device)
        attention_mask = attention_mask.to(self.device)
        labels = labels.to(self.device)

        logits = self(images, input_ids, attention_mask)
        loss = F.cross_entropy(
            logits,
            labels,
            weight=self.class_weights if self.class_weights is not None else None,
        )

        preds = torch.argmax(logits, dim=1)
        acc = (preds == labels).float().mean()

        self.log("val_loss", loss, prog_bar=True, on_epoch=True, on_step=False)
        self.log("val_acc",  acc,  prog_bar=True, on_epoch=True, on_step=False)

        # stash for epoch-level metrics
        self.val_logits.append(logits.detach().cpu())
        self.val_labels.append(labels.detach().cpu())

        return loss

    def on_validation_epoch_end(self):
        if not self.val_logits:
            return

        logits = torch.cat(self.val_logits, dim=0)
        labels = torch.cat(self.val_labels, dim=0)

        self.val_logits.clear()
        self.val_labels.clear()

        probs_pos = torch.softmax(logits, dim=1)[:, 1].numpy()
        y_true = labels.numpy()
        y_pred = (probs_pos >= 0.5).astype("int64")

        # ---- metrics for fraud (positive) class ----
        try:
            roc_auc = roc_auc_score(y_true, probs_pos)
        except ValueError:
            roc_auc = float("nan")

        try:
            pr_auc = average_precision_score(y_true, probs_pos)
        except ValueError:
            pr_auc = float("nan")

        try:
            prec, rec, f1, _ = precision_recall_fscore_support(
                y_true, y_pred,
                average="binary",
                pos_label=1,
                zero_division=0,
            )
            bal_acc = balanced_accuracy_score(y_true, y_pred)
        except Exception:
            prec = rec = f1 = bal_acc = float("nan")

        # log for each epoch
        self.log("Fusion_PR_AUC",      pr_auc,  prog_bar=True,  on_epoch=True)
        self.log("Fusion_ROC_AUC",     roc_auc, prog_bar=False, on_epoch=True)
        self.log("Fusion_F1_fraud",    f1,      prog_bar=True,  on_epoch=True)
        self.log("Fusion_Prec_fraud",  prec,    prog_bar=True,  on_epoch=True)
        self.log("Fusion_Rec_fraud",   rec,     prog_bar=True,  on_epoch=True)
        self.log("Fusion_BalancedAcc", bal_acc, prog_bar=False, on_epoch=True)

    # ---------- OPTIMIZER ----------
    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.lr)

In [39]:
ALEXNET_PT_PATH = "alexnet_best.pt"

In [40]:
# ====== Cell 13: Train fusion model ======
# ====== Init fusion model and load pretrained AlexNet weights ======
fusion_model = AlexNetBertFusion(
    bert_model=bert_model,
    num_classes=2,
    lr=LR_FUSION,
    class_weights=class_weights, 
)

alexnet_state = torch.load("alexnet_best.pt", map_location="cpu")

backbone_state = {
    k.replace("alexnet.", ""): v
    for k, v in alexnet_state.items()
    if k.startswith("alexnet.")
}

missing, unexpected = fusion_model.alexnet.load_state_dict(backbone_state, strict=False)
print("Loaded AlexNet backbone from alexnet_best.pt")
print("Missing keys:", missing)
print("Unexpected keys:", unexpected)

fusion_trainer = Trainer(
    max_epochs=MAX_EPOCHS_FUSION,
    accelerator="auto",
    callbacks=[fusion_checkpoint_cb, fusion_early_stop_cb],
    log_every_n_steps=10,
)

fusion_trainer.fit(fusion_model, train_fusion_loader, val_fusion_loader)

best_fusion_ckpt = fusion_checkpoint_cb.best_model_path
best_fusion_ckpt

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores


Loaded AlexNet backbone from alexnet_best.pt
Missing keys: []
Unexpected keys: ['classifier.6.weight', 'classifier.6.bias']



  | Name       | Type           | Params | Mode  | FLOPs
--------------------------------------------------------------
0 | alexnet    | AlexNet        | 57.0 M | train | 0    
1 | bert_model | BERTClassifier | 109 M  | eval  | 0    
2 | classifier | Sequential     | 2.5 M  | train | 0    
--------------------------------------------------------------
59.5 M    Trainable params
109 M     Non-trainable params
168 M     Total params
675.918   Total estimated model params size (MB)
29        Modules in train mode
231       Modules in eval mode
0         Total Flops


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/opt/anaconda3/envs/dmenv/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:434: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.


                                                                           

/opt/anaconda3/envs/dmenv/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:434: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.


Epoch 0: 100%|██████████| 445/445 [05:05<00:00,  1.46it/s, v_num=10]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/112 [00:00<?, ?it/s][A
Validation DataLoader 0:   1%|          | 1/112 [00:00<01:00,  1.84it/s][A
Validation DataLoader 0:   2%|▏         | 2/112 [00:01<01:32,  1.19it/s][A
Validation DataLoader 0:   3%|▎         | 3/112 [00:02<01:42,  1.06it/s][A
Validation DataLoader 0:   4%|▎         | 4/112 [00:03<01:47,  1.00it/s][A
Validation DataLoader 0:   4%|▍         | 5/112 [00:05<01:50,  0.97it/s][A
Validation DataLoader 0:   5%|▌         | 6/112 [00:06<01:50,  0.96it/s][A
Validation DataLoader 0:   6%|▋         | 7/112 [00:07<01:51,  0.94it/s][A
Validation DataLoader 0:   7%|▋         | 8/112 [00:08<01:50,  0.94it/s][A
Validation DataLoader 0:   8%|▊         | 9/112 [00:09<01:50,  0.93it/s][A
Validation DataLoader 0:   9%|▉         | 10/112 [00:10<01:49,  0.93it/s][A
Vali

'/Users/savitajkaur/Documents/DM_Project/lightning_logs/version_7/checkpoints/fusion_alexnet_bert-epoch=02-val_loss=0.1851.ckpt'

In [45]:
best_fusion_model = AlexNetBertFusion.load_from_checkpoint(
    best_fusion_ckpt,
    bert_model=bert_model,
    num_classes=2,
    lr=LR_FUSION,
    class_weights=class_weights,
)
best_fusion_model.eval()
best_fusion_model.to(device)

AlexNetBertFusion(
  (alexnet): AlexNet(
    (features): Sequential(
      (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
      (1): ReLU(inplace=True)
      (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (4): ReLU(inplace=True)
      (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (7): ReLU(inplace=True)
      (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (9): ReLU(inplace=True)
      (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): ReLU(inplace=True)
      (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
    (classifier): Sequential(
      (0): Dropout(p=0.5, inplace=Fals

In [47]:
from tqdm import tqdm

all_logits = []
all_labels = []

best_fusion_model.eval()
with torch.no_grad():
    for batch in tqdm(val_fusion_loader):
        images, input_ids, attention_mask, labels, _ = batch
        images = images.to(device)
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        labels = labels.to(device)

        logits = best_fusion_model(images, input_ids, attention_mask)

        all_logits.append(logits.cpu())
        all_labels.append(labels.cpu())

logits = torch.cat(all_logits, dim=0)
labels = torch.cat(all_labels, dim=0)

probs = torch.softmax(logits, dim=1)[:, 1].numpy()
y_true = labels.numpy()
y_pred = (probs >= 0.5).astype(int)

100%|██████████| 112/112 [04:06<00:00,  2.20s/it]


In [49]:
from sklearn.metrics import (
    roc_auc_score,
    average_precision_score,
    precision_recall_fscore_support,
    balanced_accuracy_score,
    accuracy_score,
)

metrics = {}

try:
    metrics["ROC_AUC"] = roc_auc_score(y_true, probs)
except:
    metrics["ROC_AUC"] = float("nan")

try:
    metrics["PR_AUC"] = average_precision_score(y_true, probs)
except:
    metrics["PR_AUC"] = float("nan")

prec, rec, f1, _ = precision_recall_fscore_support(
    y_true, y_pred, average="binary", zero_division=0
)

metrics["Precision_fraud"] = prec
metrics["Recall_fraud"] = rec
metrics["F1_fraud"] = f1
metrics["Accuracy"] = accuracy_score(y_true, y_pred)
metrics["Balanced_Accuracy"] = balanced_accuracy_score(y_true, y_pred)

In [51]:
print("\n=== Fusion Best Checkpoint Metrics ===")
for k, v in metrics.items():
    print(f"{k:20s}: {v:.4f}")


=== Fusion Best Checkpoint Metrics ===
ROC_AUC             : 0.9890
PR_AUC              : 0.9378
Precision_fraud     : 0.7488
Recall_fraud        : 0.9133
F1_fraud            : 0.8229
Accuracy            : 0.9809
Balanced_Accuracy   : 0.9488


In [55]:
best_fusion_model = AlexNetBertFusion.load_from_checkpoint(
    best_fusion_ckpt,
    bert_model=bert_model,        
    num_classes=2,
    lr=LR_FUSION,
    class_weights=class_weights,
)
best_fusion_model.eval()

AlexNetBertFusion(
  (alexnet): AlexNet(
    (features): Sequential(
      (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
      (1): ReLU(inplace=True)
      (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (4): ReLU(inplace=True)
      (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (7): ReLU(inplace=True)
      (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (9): ReLU(inplace=True)
      (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): ReLU(inplace=True)
      (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
    (classifier): Sequential(
      (0): Dropout(p=0.5, inplace=Fals

In [57]:
fusion_state_dict = best_fusion_model.state_dict()

In [65]:
torch.save(fusion_state_dict, "fusion_best.pt")
print("Saved fusion model weights to fusion_best.pt")

Saved fusion model weights to fusion_best.pt


In [66]:
#for loading later
fusion_model = AlexNetBertFusion(
    bert_model=bert_model,
    num_classes=2,
    lr=LR_FUSION,
    class_weights=class_weights,
)

fusion_model.load_state_dict(torch.load("fusion_best.pt", map_location="cpu"))
fusion_model.eval()

AlexNetBertFusion(
  (alexnet): AlexNet(
    (features): Sequential(
      (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
      (1): ReLU(inplace=True)
      (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (4): ReLU(inplace=True)
      (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (7): ReLU(inplace=True)
      (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (9): ReLU(inplace=True)
      (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): ReLU(inplace=True)
      (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
    (classifier): Sequential(
      (0): Dropout(p=0.5, inplace=Fals