In [1]:
import argparse
import os
import pickle
import time
from pathlib import Path
from typing import Dict

import numpy as np
import pandas as pd
import tqdm
import json
import abc
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, WeightedRandomSampler
from transformers import PreTrainedTokenizer

from DatModule.datamodule import DataModule, get_transforms, MammoDataset_Mapper
from breastclip.model import BreastClip, MammoClassification, MammoEfficientNet, load_image_encoder



In [2]:
def mammo_factor_model(model_config: Dict, loss_config: Dict, tokenizer: PreTrainedTokenizer = None) -> nn.Module:
    name = model_config["name"].lower()
    if name == "clip_custom":
        return BreastClip(model_config, loss_config, tokenizer)
    elif name == "finetune_classification":
        model_type = model_config["image_encoder"].get("model_type", "vit")
        return MammoClassification(model_config, model_type)
    elif name == "pretrained_classifier":
        return MammoEfficientNet(model_config)
    else:
        raise KeyError(f"Not supported model: {model_config['name']}")

In [3]:
class Mapper_model(torch.nn.Module):
    def __init__(self, ckpt, lang_emb: int, emb_dim: int, one_proj: bool, adapter: bool, attr_embs):
        super(Mapper_model, self).__init__()
        self.image_encoder = load_image_encoder(ckpt["config"]["model"]["image_encoder"])
        image_encoder_weights = {}
        for k in ckpt["model"].keys():
            if k.startswith("image_encoder."):
                image_encoder_weights[".".join(k.split(".")[1:])] = ckpt["model"][k]
        self.image_encoder.load_state_dict(image_encoder_weights, strict=True)
        self.image_encoder_type = ckpt["config"]["model"]["image_encoder"]["model_type"]
        for param in self.image_encoder.parameters():
            param.requires_grad = False

        self.emb_dim = emb_dim
        self.lang_emb = lang_emb
        self.one_proj = one_proj
        self.adapter = adapter

        if self.one_proj:
            self.num_proj = 1
        else:
            self.num_proj = len(attr_embs)
        self.pool = torch.nn.ModuleList(
            [
                torch.nn.Sequential(
                    torch.nn.Linear(self.emb_dim, self.emb_dim),
                    torch.nn.ReLU(),
                    torch.nn.Linear(self.emb_dim, self.lang_emb),
                )
                for i in range(self.num_proj)
            ]
        )

    def encode_image(self, input):
        image_features, raw_features = self.image_encoder(input)
        return image_features, raw_features

    def forward(self, sample: dict):
        out_dict = {}

        img_vector = sample["img"].to(torch.float32).to("cuda")
        if len(img_vector.size()) == 5:
            img_vector = img_vector.squeeze(1).permute(0, 3, 1, 2)
        input = {"image": img_vector}

        image_features, raw_features = self.encode_image(input)
        bs = raw_features.size(0)
        channel_dim = raw_features.size(1)
        raw_features_flatten = raw_features.view(bs, channel_dim, -1)

        out_img_a = []
        for i in range(self.num_proj):
            pool = self.pool[i](raw_features_flatten)
            if self.adapter:
                pool = 0.2 * pool + 0.8 * raw_features_flatten
            out_img_a.append(pool)

        region_proj_embs = torch.cat(out_img_a, dim=1).view(
            -1, self.num_proj, self.lang_emb
        )
        out_dict["region_proj_embs"] = region_proj_embs
        out_dict["num_regions"] = torch.tensor(channel_dim)
        out_dict["image_features"] = image_features
        out_dict["raw_features"] = raw_features
        return out_dict


In [4]:
class BaseLoss(torch.nn.Module):
    __metaclass__ = abc.ABC

    def __init__(self):
        super().__init__()

        self.iteration = 0
        self.running_loss = 0
        self.mean_running_loss = 0

    def forward(self, input):
        return input

    def update_running_loss(self, loss):
        self.iteration += 1
        self.running_loss += loss.item()
        self.mean_running_loss = self.running_loss / self.iteration


class Mapper_loss(BaseLoss):
    def __init__(self, temp: float, one_proj: bool, attr_to_embs):
        super().__init__()

        self.temperature = temp
        self.one_proj = one_proj

        self.attr_embs = []
        for a in attr_to_embs:
            self.attr_embs.append(attr_to_embs[a])
        self.attr_embs = torch.tensor(np.stack(self.attr_embs)).cuda().to(torch.float32)

    def forward(self, pred: dict, sample: dict):
        anchor_img = torch.nn.functional.normalize(
            pred["region_proj_embs"].float(), dim=2
        )
        labels = sample["labels"].to(int)

        attr_ids = labels.sum(0).nonzero().flatten().tolist()
        batch_size = labels.shape[0]

        loss = torch.tensor(0.0).cuda()
        num_loss_terms = 0

        if self.one_proj:
            txt_emb = self.attr_embs[attr_ids, :].T.unsqueeze(0)
            sim = (anchor_img @ txt_emb).squeeze() / self.temperature
        else:
            reg_emb = anchor_img[:, attr_ids, :]
            txt_emb = self.attr_embs[attr_ids, :].unsqueeze(0)
            sim = (reg_emb * txt_emb).sum(2) / self.temperature

        split = torch.split(sim, pred["num_regions"].tolist(), dim=0)
        vals, _ = zip(*map(torch.max, split, [0] * batch_size))

        sim = torch.stack(vals)  # batch_size x len(attr_ids)
        true_label = labels[:, attr_ids].cuda()
        inv_true_label = (~true_label.bool()).to(int)

        # Compute final contrastive loss
        denom = torch.exp(sim) + torch.exp(sim * inv_true_label).sum(0, keepdims=True)
        loss = ((-torch.log(torch.exp(sim) / denom)) * true_label).sum(1, keepdims=True)
        num_loss_terms = true_label.sum()
        loss = loss.sum() / num_loss_terms
        self.update_running_loss(loss)
        return loss


In [5]:
ATTRIBUTES = [
    "mass",
    "suspicious_calcification",
]

In [6]:
def generate_attribute_embs(out_dir, breast_clip_path, model_name):
    def get_prompts(attr):
        if attr == "mass":
            return [
                # your mass prompts here
            ]
        elif attr == "suspicious_calcification":
            return [
                # your calcification prompts here
            ]
        return []

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    breast_clip_path = Path(breast_clip_path)
    ckpt = torch.load(breast_clip_path, map_location=device)
    cfg = ckpt["config"]

    datamodule = DataModule(
        data_config=cfg["data_train"],
        dataloader_config=cfg["dataloader"],
        tokenizer_config=cfg.get("tokenizer"),
        loss_config=cfg["loss"],
        transform_config=cfg["transform"],
        mean=cfg["base"]["mean"],
        std=cfg["base"]["std"],
        image_encoder_type=cfg["model"]["image_encoder"]["model_type"],
        cur_fold=cfg["base"]["fold"],
    )

    clip = mammo_factor_model(cfg["model"], cfg["loss"], datamodule.tokenizer)
    clip = clip.to(device)
    clip.load_state_dict(ckpt["model"], strict=True)
    clip.eval()

    attr_embs = []
    with torch.no_grad():
        for attr in ATTRIBUTES:
            prompts = get_prompts(attr)
            tokens = datamodule.tokenizer(
                prompts,
                padding="longest",
                truncation=True,
                return_tensors="pt",
                max_length=256,
            ).to(device)
            txt_feats = clip.encode_text(tokens)
            if hasattr(clip, "projection") and clip.projection:
                txt_feats = clip.text_projection(txt_feats)
            avg_emb = txt_feats.mean(dim=0, keepdim=True)
            avg_emb = avg_emb / avg_emb.norm(dim=-1, keepdim=True)
            attr_embs.append(avg_emb)

    attr_embs = torch.cat(attr_embs, dim=0).cpu().numpy()
    attr_to_emb = dict(zip(ATTRIBUTES, attr_embs))

    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    save_path = out_dir / f"attr_embs_{model_name}.pth"
    torch.save(attr_to_emb, save_path)
    print(f"Saved {len(attr_to_emb)} attribute embeddings to {save_path}")



In [7]:
class Mapper_model(nn.Module):
    def __init__(self, ckpt, lang_emb: int, emb_dim: int, one_proj: bool, adapter: bool, attr_embs):
        super().__init__()
        # load frozen image encoder
        self.image_encoder = load_image_encoder(ckpt["config"]["model"]["image_encoder"])
        enc_w = {
            ".".join(k.split(".")[1:]): v
            for k, v in ckpt["model"].items()
            if k.startswith("image_encoder.")
        }
        self.image_encoder.load_state_dict(enc_w, strict=True)
        for p in self.image_encoder.parameters():
            p.requires_grad = False

        self.lang_emb = lang_emb
        self.emb_dim = emb_dim
        self.one_proj = one_proj
        self.adapter = adapter
        self.num_proj = 1 if one_proj else len(attr_embs)

        # projection heads
        self.pool = nn.ModuleList([
            nn.Sequential(
                nn.Linear(self.emb_dim, self.emb_dim),
                nn.ReLU(),
                nn.Linear(self.emb_dim, self.lang_emb),
            )
            for _ in range(self.num_proj)
        ])

    def encode_image(self, x):
        return self.image_encoder(x)

    def forward(self, sample: dict):
        img = sample["img"].to(torch.float32).to(next(self.parameters()).device)
        if img.ndim == 5:
            img = img.squeeze(1).permute(0, 3, 1, 2)

        image_feats, raw = self.encode_image({"image": img})
        bs, c, h, w = raw.shape
        regions = raw.view(bs, c, -1)  # (B, C, N)

        proj_outs = []
        for i in range(self.num_proj):
            p = self.pool[i](regions)  # (B, C, lang_emb)
            if self.adapter:
                p = 0.2 * p + 0.8 * regions
            proj_outs.append(p)

        region_proj_embs = torch.stack(proj_outs, dim=1)  # (B, num_proj, lang_emb)
        return {
            "region_proj_embs": region_proj_embs,
            "num_regions": torch.tensor(c),
            "image_features": image_feats,
            "raw_features": raw,
        }


In [8]:
class Mapper_loss(nn.Module):
    def __init__(self, temp: float, one_proj: bool, attr_to_embs):
        super().__init__()
        self.temperature = temp
        self.one_proj = one_proj
        embs = [attr_to_embs[a] for a in ATTRIBUTES]
        self.attr_embs = torch.tensor(np.stack(embs), dtype=torch.float32).cuda()

    def forward(self, pred: dict, sample: dict):
        anchor = nn.functional.normalize(pred["region_proj_embs"], dim=2)
        labels = sample["labels"].to(torch.int)
        attr_ids = labels.sum(0).nonzero().flatten().tolist()
        batch = labels.size(0)

        if self.one_proj:
            txt = self.attr_embs[attr_ids].T.unsqueeze(0)
            sim = (anchor @ txt).squeeze() / self.temperature
        else:
            reg = anchor[:, attr_ids, :]
            txt = self.attr_embs[attr_ids].unsqueeze(0)
            sim = (reg * txt).sum(2) / self.temperature

        # max over regions
        sims = []
        for i in range(batch):
            sims.append(sim[i].max(0)[0])
        sim = torch.stack(sims)

        true = labels[:, attr_ids].cuda()
        inv_true = (~true.bool()).int()
        denom = torch.exp(sim) + torch.exp(sim * inv_true).sum(0, keepdim=True)
        loss = (-torch.log(torch.exp(sim) / denom) * true).sum() / true.sum()
        return loss


In [9]:
def get_dataloaders(args):
    df = pd.read_csv(Path(args.data_dir) / args.csv_file).fillna(0)
    df = df[(df["Mass"] == 1) | (df["Suspicious_Calcification"] == 1)]
    train_df = df[df["split"] == "training"].reset_index(drop=True)
    valid_df = df[df["split"] == "test"].reset_index(drop=True)

    train_ds = MammoDataset_Mapper(args, train_df, transform=get_transforms(args))
    valid_ds = MammoDataset_Mapper(args, valid_df, transform=None)

    if args.balanced_dataloader.lower() == "y":
        w_path = Path(args.output_path) / f"weights_fold{args.cur_fold}.pkl"
        if w_path.exists():
            weights = pickle.load(open(w_path, "rb"))
        else:
            pos_w = args.sampler_weights[f"fold{args.cur_fold}"]["pos_wt"]
            neg_w = args.sampler_weights[f"fold{args.cur_fold}"]["neg_wt"]
            train_df["w"] = train_df["cancer"].map({1: pos_w, 0: neg_w})
            weights = train_df["w"].values
            pickle.dump(weights, open(w_path, "wb"))

        sampler = WeightedRandomSampler(weights.tolist(), len(weights), replacement=True)
        train_loader = DataLoader(train_ds, batch_size=args.batch_size, sampler=sampler,
                                  num_workers=args.num_workers, pin_memory=True, drop_last=True)
    else:
        train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True,
                                  num_workers=args.num_workers, pin_memory=True, drop_last=True)

    valid_loader = DataLoader(valid_ds, batch_size=args.batch_size, shuffle=False,
                              num_workers=args.num_workers, pin_memory=True)
    return train_loader, valid_loader


In [12]:
def train_loop(model, loss_fn, opt, train_loader, valid_loader, epochs, chk_pt_path, device):
    scaler = torch.cuda.amp.GradScaler()
    best_val = float("inf")
    no_improve = 0

    for epoch in range(1, epochs + 1):
        model.train()
        t0 = time.time()
        for sample in tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}"):
            opt.zero_grad()
            with torch.cuda.amp.autocast():
                pred = model(sample)
                loss = loss_fn(pred, sample)
            scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()

        print(f"Train epoch {epoch} done in {int(time.time() - t0)}s")
        torch.save(model.state_dict(), Path(chk_pt_path) / f"epoch_{epoch}.pth")

        # validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for sample in tqdm(valid_loader, desc="Validating"):
                pred = model(sample)
                val_loss += loss_fn(pred, sample).item()
        val_loss /= len(valid_loader)
        print(f"Validation loss: {val_loss:.4f}")

        if val_loss < best_val:
            best_val = val_loss
            no_improve = 0
            torch.save(model.state_dict(), Path(chk_pt_path) / "best.pth")
        else:
            no_improve += 1

        if no_improve >= 5:
            print("Early stopping.")
            break



In [None]:
data_dir           = "/path/to/data"
csv_file           = "./dataset/metadata.csv"
output_dir         = "Outputs"
clip_chk_pt_path   = ""
attr_embs_path     = ""

lang_emb           = 512
img_emb            = 1024
batch_size         = 16
num_workers        = 4
lr                 = 5e-5
epochs             = 20

balanced_dataloader = False  # Set to True to use WeightedRandomSampler
sampler_weights     = None
cur_fold            = 0

if balanced_dataloader:
    with open("[Data Here]", "r") as f:
        sampler_weights = json.load(f)

def get_Paths(output_dir):
    base        = Path(output_dir)
    chkpt_dir   = base / "checkpoints"
    results_dir = base / "results"
    tb_logs_dir = base / "tb_logs"
    for d in (chkpt_dir, results_dir, tb_logs_dir):
        d.mkdir(parents=True, exist_ok=True)
    return str(chkpt_dir), str(results_dir), str(tb_logs_dir)

chkpt_path, results_path, tb_logs_path = get_Paths(output_dir)

device    = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}, torch version: {torch.__version__}")

ckpt      = torch.load(clip_chk_pt_path, map_location="cpu")
attr_embs = torch.load(attr_embs_path)


model = Mapper_model(ckpt, lang_emb=lang_emb, emb_dim=img_emb, one_proj=False, adapter=False, attr_embs=attr_embs).to(device)

loss_fn = Mapper_loss(temp=0.07, one_proj=False, attr_to_embs=attr_embs).to(device)


class Args: pass
args = Args()
args.data_dir            = data_dir
args.csv_file            = csv_file
args.batch_size          = batch_size
args.num_workers         = num_workers
args.balanced_dataloader = "y" if balanced_dataloader else "n"
args.sampler_weights     = sampler_weights
args.cur_fold            = cur_fold
args.output_path         = results_path

train_loader, valid_loader = get_dataloaders(args)

train_loop(
    model,
    loss_fn,
    opt,
    train_loader,
    valid_loader,
    epochs=epochs,
    chk_pt_path=chkpt_path,
    device=device
)

Using device: cpu, torch version: 2.6.0+cpu


FileNotFoundError: [Errno 2] No such file or directory: './Mammo_CLIP/checkpoints/'