## Imports


In [1]:
import os
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from torchvision import transforms
from torchvision.models import mobilenet_v2
from PIL import Image

from transformers import DistilBertTokenizerFast, DistilBertModel
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder

  from .autonotebook import tqdm as notebook_tqdm


## Configuration


In [2]:
from config import CHECKPOINT_DIR, DATA_ROOT

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
EPOCHS = 5

os.makedirs(CHECKPOINT_DIR, exist_ok=True)

## Load Dataset


In [3]:
df = pd.read_csv(f"{DATA_ROOT}/ocr_ads_cleaned.csv")

# Drop rows with missing or empty required fields
df = df.dropna(subset=["image_path", "filtered_text", "label"])

# Also remove rows where filtered_text is empty string
df = df[df["filtered_text"].str.strip() != ""].reset_index(drop=True)

df.head()

Unnamed: 0,image_path,slogan_text,label,clean_text,filtered_text
0,madverse_data/OnlineAds/baby_products/baby_ess...,chicco baby moments F0 R EVERYDAY MOMENTS 0 F ...,baby_products,chicco moments everyday moments parabens free ...,chicco moments everyday moments parabens free
1,madverse_data/OnlineAds/baby_products/baby_ess...,"""No language can express the power and beauty,...",baby_products,language can express the power and beauty and ...,language can express the power and beauty and ...
2,madverse_data/OnlineAds/baby_products/baby_ess...,(chicco) Baby Care for New-age Parents like Yo...,baby_products,chicco care for new age parents like you momen...,chicco care for new age parents like you momen...
3,madverse_data/OnlineAds/baby_products/baby_ess...,#PARTNER iN PARENting Complete Protection for ...,baby_products,partner parenting complete protection for your...,partner complete protection for your minutes e...
4,madverse_data/OnlineAds/baby_products/baby_ess...,C (chicco) MOMENT OF DeeP CLEANSING AND NOURIS...,baby_products,chicco moment deep cleansing and nourishment c...,chicco moment deep cleansing and nourishment c...


## Encode Labels


In [4]:
le = LabelEncoder()
df["label_id"] = le.fit_transform(df["label"])

NUM_CLASSES = len(le.classes_)
print("Classes:", le.classes_)

Classes: ['baby_products' 'body_wear' 'cosmetics' 'drinks' 'electronics'
 'financial_institutions' 'food' 'home_essentials' 'sports' 'travel'
 'vehicles']


## Train Test Split

In [5]:
train_df, val_df = train_test_split(
    df, test_size=0.2, stratify=df["label_id"], random_state=42
)

## Image Transformations


In [6]:
img_transform = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

## Save Model Utility

In [7]:
def save_model(
    model,
    optimizer,
    epoch,
    num_classes,
    le,
    filename,
    text_model_str=None,
    image_encoder_str=None,
    text_encoder_str=None,
):
    data = {
        "model_state_dict": model.state_dict(),
        "num_classes": num_classes,
        "label_classes": le.classes_.tolist(),
        "epoch": epoch,
        "optimizer_state_dict": optimizer.state_dict(),
    }

    if text_model_str:
        data["text_model"] = text_model_str

    if image_encoder_str:
        data["image_encoder"] = image_encoder_str

    if text_encoder_str:
        data["text_encoder"] = text_encoder_str

    torch.save(data, filename)

    print(f"Saved model to {filename}")

## Training Utilities


In [8]:
def train_epoch(model, loader, optimizer, criterion, device):
    model.train()
    correct, total_loss = 0, 0

    for x, y in loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        correct += (out.argmax(1) == y).sum().item()

    return total_loss / len(loader), correct / len(loader.dataset)


def eval_epoch(model, loader, device):
    model.eval()
    correct = 0

    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            correct += (out.argmax(1) == y).sum().item()

    return correct / len(loader.dataset)

# Image-Only Model (MobileNetV2)


## Image Dataset


In [9]:
class ImageDataset(Dataset):
    def __init__(self, df):
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = Image.open(row["image_path"]).convert("RGB")
        img = img_transform(img)
        label = row["label_id"]
        return img, label

## MobileNetV2 Image Model


In [10]:
class ImageModel(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.backbone = mobilenet_v2(weights="IMAGENET1K_V1")
        self.backbone.classifier = nn.Identity()  # 1280-d
        self.classifier = nn.Linear(1280, num_classes)

    def forward(self, x):
        feat = self.backbone(x)
        return self.classifier(feat)

## Train Image-Only Model


In [11]:
CHECKPOINT_PATH = os.path.join(CHECKPOINT_DIR, "mobilenetv2_image_model.pth")

train_loader = DataLoader(ImageDataset(train_df), batch_size=32, shuffle=True)
val_loader = DataLoader(ImageDataset(val_df), batch_size=32)

image_model = ImageModel(NUM_CLASSES).to(DEVICE)
optimizer = torch.optim.Adam(image_model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
start_epoch = 0

if os.path.exists(CHECKPOINT_PATH):
    print("Image model checkpoint found. Loading weights...")

    ckpt = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
    image_model.load_state_dict(ckpt["model_state_dict"])
    optimizer.load_state_dict(ckpt["optimizer_state_dict"])
    start_epoch = ckpt["epoch"]
else:
    print("No checkpoint found. Starting training...")

for epoch in range(start_epoch, EPOCHS):
    train_loss, train_acc = train_epoch(
        image_model, train_loader, optimizer, criterion, DEVICE
    )

    val_acc = eval_epoch(image_model, val_loader, DEVICE)

    print(
        f"[Image] Epoch {epoch + 1}/{5} | "
        f"Loss: {train_loss:.3f} | "
        f"Train Acc: {train_acc:.3f} | "
        f"Val Acc: {val_acc:.3f}"
    )

    # Save trained model
    save_model(image_model, optimizer, epoch + 1, NUM_CLASSES, le, CHECKPOINT_PATH)

No checkpoint found. Starting training...
[Image] Epoch 1/5 | Loss: 0.907 | Train Acc: 0.708 | Val Acc: 0.793
Saved model to checkpoints/mobilenetv2_image_model.pth
[Image] Epoch 2/5 | Loss: 0.466 | Train Acc: 0.848 | Val Acc: 0.817
Saved model to checkpoints/mobilenetv2_image_model.pth
[Image] Epoch 3/5 | Loss: 0.262 | Train Acc: 0.919 | Val Acc: 0.836
Saved model to checkpoints/mobilenetv2_image_model.pth
[Image] Epoch 4/5 | Loss: 0.144 | Train Acc: 0.955 | Val Acc: 0.831
Saved model to checkpoints/mobilenetv2_image_model.pth
[Image] Epoch 5/5 | Loss: 0.084 | Train Acc: 0.977 | Val Acc: 0.849
Saved model to checkpoints/mobilenetv2_image_model.pth


# Text-Only Model (DistilBERT)


## Load DistilBERT


In [12]:
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
distilbert = DistilBertModel.from_pretrained("distilbert-base-uncased")
distilbert.eval()
distilbert.to(DEVICE)

TEXT_DIM = 768

## Text Dataset


In [13]:
class TextDataset(Dataset):
    def __init__(self, df):
        self.texts = df["filtered_text"].tolist()
        self.labels = df["label_id"].tolist()

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        enc = tokenizer(
            self.texts[idx],
            truncation=True,
            padding="max_length",
            max_length=32,
            return_tensors="pt",
        )

        with torch.no_grad():
            outputs = distilbert(
                input_ids=enc["input_ids"].to(DEVICE),
                attention_mask=enc["attention_mask"].to(DEVICE),
            )
            emb = outputs.last_hidden_state[:, 0, :]  # CLS

        return emb.squeeze(0), self.labels[idx]

## Text Classifier


In [14]:
class TextModel(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(TEXT_DIM, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes),
        )

    def forward(self, x):
        return self.fc(x)

## Train Text-Only Model


In [15]:
CHECKPOINT_PATH = os.path.join(CHECKPOINT_DIR, "distilbert_text_model.pth")

train_loader = DataLoader(TextDataset(train_df), batch_size=64, shuffle=True)
val_loader = DataLoader(TextDataset(val_df), batch_size=64)

text_model = TextModel(NUM_CLASSES).to(DEVICE)
optimizer = torch.optim.Adam(text_model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
start_epoch = 0

if os.path.exists(CHECKPOINT_PATH):
    print("Text model checkpoint found. Loading weights...")

    ckpt = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
    text_model.load_state_dict(ckpt["model_state_dict"])
    optimizer.load_state_dict(ckpt["optimizer_state_dict"])
    start_epoch = ckpt["epoch"]
else:
    print("No checkpoint found. Starting training...")

for epoch in range(start_epoch, EPOCHS):
    loss, acc = train_epoch(text_model, train_loader, optimizer, criterion, DEVICE)
    val_acc = eval_epoch(text_model, val_loader, DEVICE)
    print(
        f"[Text] Epoch {epoch+1} | Loss {loss:.3f} | Train {acc:.3f} | Val {val_acc:.3f}"
    )

    save_model(
        text_model,
        optimizer,
        epoch + 1,
        NUM_CLASSES,
        le,
        CHECKPOINT_PATH,
        text_model_str="distilbert-base-uncased",
    )

No checkpoint found. Starting training...
[Text] Epoch 1 | Loss 1.446 | Train 0.526 | Val 0.609
Saved model to checkpoints/distilbert_text_model.pth
[Text] Epoch 2 | Loss 1.155 | Train 0.618 | Val 0.640
Saved model to checkpoints/distilbert_text_model.pth
[Text] Epoch 3 | Loss 1.076 | Train 0.643 | Val 0.658
Saved model to checkpoints/distilbert_text_model.pth
[Text] Epoch 4 | Loss 1.021 | Train 0.661 | Val 0.664
Saved model to checkpoints/distilbert_text_model.pth
[Text] Epoch 5 | Loss 0.979 | Train 0.672 | Val 0.677
Saved model to checkpoints/distilbert_text_model.pth


# Late Fusion (MobileNetV2 + DistilBERT)


## Multimodal Dataset


In [16]:
class MultiModalDataset(Dataset):
    def __init__(self, df):
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        img = Image.open(row["image_path"]).convert("RGB")
        img = img_transform(img)

        enc = tokenizer(
            row["filtered_text"],
            truncation=True,
            padding="max_length",
            max_length=32,
            return_tensors="pt",
        )

        with torch.no_grad():
            text_emb = distilbert(
                input_ids=enc["input_ids"].to(DEVICE),
                attention_mask=enc["attention_mask"].to(DEVICE),
            ).last_hidden_state[:, 0, :]

        label = row["label_id"]
        return img, text_emb.squeeze(0), label

## Fusion Model


In [None]:
class FusionModel(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        self.image_encoder = mobilenet_v2(weights="IMAGENET1K_V1")
        self.image_encoder.classifier = nn.Identity()  # 1280

        self.text_proj = nn.Linear(TEXT_DIM, 256)

        self.classifier = nn.Sequential(
            nn.Linear(1280 + 256, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes),
        )

    def forward(self, img, text):
        img_feat = self.image_encoder(img)
        text_feat = self.text_proj(text)
        fused = torch.cat([img_feat, text_feat], dim=1)
        return self.classifier(fused)

## Train Fusion Model


In [None]:
CHECKPOINT_PATH = os.path.join(CHECKPOINT_DIR, "multimodal_fusion_model.pth")

train_loader = DataLoader(MultiModalDataset(train_df), batch_size=32, shuffle=True)
val_loader = DataLoader(MultiModalDataset(val_df), batch_size=32)

fusion_model = FusionModel(NUM_CLASSES).to(DEVICE)
optimizer = torch.optim.Adam(fusion_model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
start_epoch = 0

if os.path.exists(CHECKPOINT_PATH):
    print("Fusion model checkpoint found. Loading weights...")
    ckpt = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
    fusion_model.load_state_dict(ckpt["model_state_dict"])
    optimizer.load_state_dict(ckpt["optimizer_state_dict"])
    start_epoch = ckpt["epoch"]
else:
    print("No checkpoint found. Starting training...")

for epoch in range(start_epoch, EPOCHS):
    fusion_model.train()
    correct = 0

    for img, text, y in train_loader:
        img, text, y = img.to(DEVICE), text.to(DEVICE), y.to(DEVICE)
        optimizer.zero_grad()
        out = fusion_model(img, text)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        correct += (out.argmax(1) == y).sum().item()

    fusion_model.eval()
    val_correct = 0
    with torch.no_grad():
        for img, text, y in val_loader:
            img, text, y = img.to(DEVICE), text.to(DEVICE), y.to(DEVICE)
            val_correct += (fusion_model(img, text).argmax(1) == y).sum().item()

    print(
        f"[Fusion] Epoch {epoch+1} | Train {correct/len(train_df):.3f} | Val {val_correct/len(val_df):.3f}"
    )

    save_model(
        fusion_model,
        optimizer,
        epoch + 1,
        NUM_CLASSES,
        le,
        CHECKPOINT_PATH,
        image_encoder_str="mobilenet_v2",
        text_encoder_str="distilbert-base-uncased",
    )