In [None]:
import os
import json
import random
from typing import List, Dict, Any

import numpy as np
import pandas as pd
from tqdm import tqdm

from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split

from transformers import CLIPProcessor, CLIPModel


In [None]:
IMAGE_DIR = "bin-images"   # change this
META_DIR  = "metadata" # change this

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)


In [None]:
rows = []

meta_files = [f for f in os.listdir(META_DIR) if f.endswith(".json")]
for meta_file in tqdm(meta_files):
    meta_path = os.path.join(META_DIR, meta_file)
    with open(meta_path, "r") as f:
        meta = json.load(f)
    
    image_id = meta_file.replace(".json", "")
    item_dict = meta.get("BIN_FCSKU_DATA", {})
    expected_total_qty = meta.get("EXPECTED_QUANTITY", None)
    
    for asin, item in item_dict.items():

        name = item.get("name")
        normalized = item.get("normalizedName")

        # ⭐ SKIP THIS ITEM if both names are missing or empty
        if (name is None or str(name).strip() == "") and (normalized is None or str(normalized).strip() == ""):
            # print(f"Skipping item with no name for ASIN {asin} in {image_id}")
            continue
        
        rows.append({
            "image_id": image_id,
            "asin": asin,
            "name": name,
            "normalized_name": normalized,
            "bin_quantity": item.get("quantity", 1),
            "expected_total_qty": expected_total_qty,
        })

df_items = pd.DataFrame(rows)
print("Rows:", len(df_items))
print("Unique images:", df_items["image_id"].nunique())
print("Unique ASINs:", df_items["asin"].nunique())
df_items.head()


In [None]:
def generate_training_samples(
    df_items: pd.DataFrame,
    max_required_qty_cap: int = 100
) -> pd.DataFrame:

    samples = []
    
    grouped = df_items.groupby("image_id")
    all_asins = df_items["asin"].unique().tolist()

    for image_id, group in tqdm(grouped, desc="Generating samples"):

        asins_in_bin = set(group["asin"].tolist())
        other_asins = list(set(all_asins) - asins_in_bin)

        # sample 3 random negative ASINs (without replacement)
        wrong_asin_samples = random.sample(other_asins, min(3, len(other_asins))) if len(other_asins) > 0 else []

        for _, row in group.iterrows():

            asin = row["asin"]
            name = row["name"] or row["normalized_name"]
            bin_qty = int(row["bin_quantity"])

            # -----------------------------
            # POSITIVE SAMPLES
            # -----------------------------
            if bin_qty <= 5:
                pos_quantities = list(range(1, bin_qty + 1))

            else:
                # sample 5 unique quantities from 1..bin_qty
                pos_quantities = random.sample(range(1, bin_qty + 1), 5)

            for q in pos_quantities:
                samples.append({
                    "image_id": image_id,
                    "item_name": name,
                    "asin": asin,
                    "required_quantity": q,
                    "label": 1
                })

            # -----------------------------
            # NEGATIVE QUANTITY SAMPLES
            # -----------------------------
            neg_range_start = bin_qty + 1
            neg_range_end = max_required_qty_cap

            if neg_range_start <= neg_range_end:
                possible_neg_quantities = list(range(neg_range_start, neg_range_end + 1))

                neg_quantities = (
                    random.sample(possible_neg_quantities, 
                    min(3, len(possible_neg_quantities)))
                )

                for q in neg_quantities:
                    samples.append({
                        "image_id": image_id,
                        "item_name": name,
                        "asin": asin,
                        "required_quantity": q,
                        "label": 0
                    })

        # -----------------------------
        # NEGATIVE WRONG-ASIN SAMPLES
        # -----------------------------
        for neg_asin in wrong_asin_samples:

            neg_row = df_items[df_items["asin"] == neg_asin].sample(1).iloc[0]
            neg_name = neg_row["name"] or neg_row["normalized_name"]

            # random quantity between 1 and 10
            q = random.randint(1, 10)

            samples.append({
                "image_id": image_id,
                "item_name": neg_name,
                "asin": neg_asin,
                "required_quantity": q,
                "label": 0
            })

    return pd.DataFrame(samples)

df_samples = generate_training_samples(df_items)
print("Total samples:", len(df_samples))
df_samples.head()

In [None]:
class BinOrderDataset(Dataset):
    def __init__(self, df_samples: pd.DataFrame, image_dir: str, clip_processor: CLIPProcessor):
        self.df = df_samples.reset_index(drop=True)
        self.image_dir = image_dir
        self.processor = clip_processor
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image_id = row["image_id"]
        item_name = row["item_name"]
        required_qty = row["required_quantity"]
        label = row["label"]
        
        img_path = os.path.join(self.image_dir, image_id + ".jpg")
        image = Image.open(img_path).convert("RGB")
        
        # We'll let CLIPProcessor handle the transform; here we just return raw image, text.
        text = item_name
        
        return {
            "image": image,
            "text": text,
            "required_qty": float(required_qty),
            "label": float(label)
        }


In [None]:
class CLIPQuantityMatcher(nn.Module):
    def __init__(
        self,
        clip_model_name: str = "openai/clip-vit-base-patch32",
        quantity_dim: int = 32,
        hidden_dim: int = 256,
        freeze_clip: bool = True,
    ):
        super().__init__()
        self.clip = CLIPModel.from_pretrained(clip_model_name)
        self.clip.eval()
        self.config = self.clip.config
        
        if freeze_clip:
            for p in self.clip.parameters():
                p.requires_grad = False
        
        embed_dim = self.config.projection_dim  # CLIP projection dim (e.g. 512)
        
        self.quantity_mlp = nn.Sequential(
            nn.Linear(1, quantity_dim),
            nn.ReLU(),
            nn.Linear(quantity_dim, quantity_dim),
            nn.ReLU(),
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(embed_dim * 2 + quantity_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)  # binary logit
        )
    
    def forward(self, pixel_values, input_ids, attention_mask, quantities):
        # CLIP forward
        outputs = self.clip(
            pixel_values=pixel_values,
            input_ids=input_ids,
            attention_mask=attention_mask,
        )
        image_embeds = outputs.image_embeds   # (B, D)
        text_embeds  = outputs.text_embeds    # (B, D)
        
        # quantities: shape (B,) → (B,1)
        q = quantities.unsqueeze(-1)  # normalize if you want, e.g. q/10
        q_emb = self.quantity_mlp(q)  # (B, quantity_dim)
        
        x = torch.cat([image_embeds, text_embeds, q_emb], dim=-1)
        logits = self.classifier(x).squeeze(-1)  # (B,)
        return logits


In [None]:
clip_model_name = "openai/clip-vit-base-patch32"
clip_processor = CLIPProcessor.from_pretrained(clip_model_name)

dataset = BinOrderDataset(df_samples, IMAGE_DIR, clip_processor)

# Train/val split
val_ratio = 0.1
test_ratio = 0.2
val_size = int(len(dataset) * val_ratio)
test_size = int(len(dataset) * test_ratio)
train_size = len(dataset) - val_size - test_size

train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])
print("Train size:", len(train_dataset), "Val size:", len(val_dataset), "Test size:", len(test_dataset))

def collate_fn(batch: List[Dict[str, Any]]):
    images = [item["image"] for item in batch]
    texts  = [item["text"] for item in batch]
    qtys   = torch.tensor([item["required_qty"] for item in batch], dtype=torch.float32)
    labels = torch.tensor([item["label"] for item in batch], dtype=torch.float32)

    for t in texts:
        if not isinstance(t, str):
            print("\n❌ BAD TEXT DETECTED:", t, "TYPE:", type(t))
    
    # Use CLIP processor
    encoding = clip_processor(
        text=texts,
        images=images,
        return_tensors="pt",
        padding=True,
        truncation=True,
    )
    
    batch_dict = {
        "pixel_values": encoding["pixel_values"],
        "input_ids": encoding["input_ids"],
        "attention_mask": encoding["attention_mask"],
        "quantities": qtys,
        "labels": labels,
    }
    return batch_dict

BATCH_SIZE = 256

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=16,
    collate_fn=collate_fn
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=16,
    collate_fn=collate_fn
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=16,
    collate_fn=collate_fn
)



In [None]:
model = CLIPQuantityMatcher(
    clip_model_name=clip_model_name,
    quantity_dim=32,
    hidden_dim=256,
    freeze_clip=True,   # you can set to False later to fine-tune
).to(DEVICE)

criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=1e-4
)

def run_epoch(loader, model, optimizer=None):
    is_train = optimizer is not None
    model.train() if is_train else model.eval()
    
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    
    for batch in tqdm(loader, desc="Train" if is_train else "Val"):
        pixel_values = batch["pixel_values"].to(DEVICE)
        input_ids = batch["input_ids"].to(DEVICE)
        attention_mask = batch["attention_mask"].to(DEVICE)
        quantities = batch["quantities"].to(DEVICE)
        labels = batch["labels"].to(DEVICE)
        
        with torch.set_grad_enabled(is_train):
            logits = model(pixel_values, input_ids, attention_mask, quantities)
            loss = criterion(logits, labels)
            
            preds = (torch.sigmoid(logits) > 0.5).float()
            correct = (preds == labels).sum().item()
            
            if is_train:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
        
        total_loss += loss.item() * labels.size(0)
        total_correct += correct
        total_samples += labels.size(0)
    
    avg_loss = total_loss / total_samples
    accuracy = total_correct / total_samples
    return avg_loss, accuracy

EPOCHS = 20
best_val_loss = float("inf")

for epoch in range(1, EPOCHS + 1):
    print(f"\nEpoch {epoch}/{EPOCHS}")
    train_loss, train_acc = run_epoch(train_loader, model, optimizer)
    print(f"  Train loss: {train_loss:.4f}, acc: {train_acc:.4f}")
    
    val_loss, val_acc = run_epoch(val_loader, model, optimizer=None)
    print(f"  Val   loss: {val_loss:.4f}, acc: {val_acc:.4f}")
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            "quantity_mlp": model.quantity_mlp.state_dict(),
            "classifier": model.classifier.state_dict()},
            "head_weights.pt")

        print("  ✅ Saved new best model.")


In [None]:
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

def evaluate_test_set(model, test_loader, device):
    model.eval()
    total_loss = 0
    total_samples = 0
    
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Testing"):
            pixel_values = batch["pixel_values"].to(device)
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            quantities = batch["quantities"].to(device)
            labels = batch["labels"].to(device)

            logits = model(pixel_values, input_ids, attention_mask, quantities)
            loss = criterion(logits, labels)

            total_loss += loss.item() * labels.size(0)
            total_samples += labels.size(0)

            preds = torch.sigmoid(logits).cpu().numpy()
            preds = (preds >= 0.5).astype(int)

            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())

    # Compute metrics
    avg_loss = total_loss / total_samples
    acc = accuracy_score(all_labels, all_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(
        all_labels, all_preds, average="binary"
    )

    return avg_loss, acc, precision, recall, f1

test_loss, test_acc, test_prec, test_rec, test_f1 = evaluate_test_set(
    model, test_loader, DEVICE
)

print("\n====== TEST RESULTS ======")
print(f"Test Loss:      {test_loss:.4f}")
print(f"Test Accuracy:  {test_acc:.4f}")
print(f"Precision:      {test_prec:.4f}")
print(f"Recall:         {test_rec:.4f}")
print(f"F1 Score:       {test_f1:.4f}")
print("==========================\n")
