In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
!git clone https://github.com/YasinSharifbeigy/Crane.git

Cloning into 'Crane'...
remote: Enumerating objects: 358, done.[K
remote: Counting objects: 100% (25/25), done.[K
remote: Compressing objects: 100% (21/21), done.[K
remote: Total 358 (delta 7), reused 14 (delta 4), pack-reused 333 (from 1)[K
Receiving objects: 100% (358/358), 9.14 MiB | 16.48 MiB/s, done.
Resolving deltas: 100% (156/156), done.


In [4]:
%cd Crane

/content/Crane


In [5]:
!pip install humanhash3
!pip install torchmetrics
!pip install ftfy

Collecting humanhash3
  Downloading humanhash3-0.0.6.tar.gz (5.4 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: humanhash3
  Building wheel for humanhash3 (setup.py) ... [?25l[?25hdone
  Created wheel for humanhash3: filename=humanhash3-0.0.6-py3-none-any.whl size=5739 sha256=37a1869326a5baf5e72aad226a11b4a79cb3efc41d4e3f4bea914b2043317931
  Stored in directory: /root/.cache/pip/wheels/24/9c/a6/944e159dadbcc308d4b0b12e536503849bf4b9bbcc4036fae6
Successfully built humanhash3
Installing collected packages: humanhash3
Successfully installed humanhash3-0.0.6
Collecting torchmetrics
  Downloading torchmetrics-1.8.2-py3-none-any.whl.metadata (22 kB)
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.15.2-py3-none-any.whl.metadata (5.7 kB)
Downloading torchmetrics-1.8.2-py3-none-any.whl (983 kB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ

In [6]:
!ls

assets	     environment.yml  models	    runtime.sh	      test.py	utils
checkpoints  __init__.py      README.md     segment_anything  test.sh
dataset      LICENSE	      reproduce.sh  setup.sh	      train.py


In [7]:
!mkdir '.cache'

In [8]:
!mkdir '.cache/sam'

In [9]:

# !wget -O .cache/sam/sam_vit_b_01ec64.pth \
# https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth

In [10]:
DATASETS_ROOT = '/content/drive/MyDrive/Colposcopy/Data/Patient records'

In [11]:
# !bash reproduce.sh Colpo_model 0 1

In [12]:
# ==========================================================
# üì¶ 1. Imports
# ==========================================================
import os, sys, subprocess
import torch
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.metrics import confusion_matrix

# your project-specific imports
import models
from models import Crane
from models.prompt_ensemble import PromptLearner
from dataset.dataset import Dataset
from __init__ import DATASETS_ROOT
from utils.transform import get_transform
from utils.logger import get_logger
from utils.similarity import calc_similarity_logits
from utils import (
    setup_seed,
    seed_worker,
    turn_gradient_off,
    str2bool,
)
from utils.loss import FocalLoss


In [13]:
# ==========================================================
# ‚öôÔ∏è 2. Argument setup (for interactive control)
# ==========================================================
class Args:
    datasets_root_dir = f"{DATASETS_ROOT}"
    dataset = ["Cropped Folder"]
    train_data_path = [f"{DATASETS_ROOT}/{ds}/" for ds in dataset]
    save_path = f"./checkpoints/trained_on_{dataset[0]}_Colpo_model_crane"
    model_name = "Colpo_model_crane"
    type = "train"
    seed = 111
    save_freq = 1
    device = 0
    epoch = 2
    learning_rate = 1e-3
    batch_size = 8
    aug_rate = 0.0
    k_shot = 0
    portion = 1
    image_size = 518
    features_list = [24]
    interpolation = "nearest"
    depth = 9
    n_ctx = 12
    t_n_ctx = 4
    train_with_img_cls_prob = 1
    train_with_img_cls_type = "pad_suffix"
    dino_model = "dinov2"
    both_eattn_dattn = True
    use_scorebase_pooling = True
    attn_type = "qq+kk+vv"
    why = "Jupyter experiment"

args = Args()
setup_seed(args.seed)


In [14]:
# ==========================================================
# üß† 3. Load Dataset and Model
# ==========================================================
preprocess, target_transform = get_transform(args)
train_data = Dataset(
    roots=args.train_data_path,
    transform=preprocess,
    target_transform=target_transform,
    dataset_name=args.dataset,
    kwargs=args
)

train_loader = DataLoader(
    train_data,
    batch_size=args.batch_size,
    shuffle=True,
    num_workers=8,
    pin_memory=True,
    prefetch_factor=2,
    generator=torch.Generator().manual_seed(args.seed),
    worker_init_fn=seed_worker
)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"‚úÖ Device: {device}, Dataset length: {len(train_data)}")

# -------------------- Load Model --------------------
crane_parameters = {
    "Prompt_length": args.n_ctx,
    "learnabel_text_embedding_depth": args.depth,
    "learnabel_text_embedding_length": args.t_n_ctx,
    "others": args
}
model, _ = models.load("ViT-L/14@336px", device=device, design_details=crane_parameters)
model = turn_gradient_off(model)
model.visual.replace_with_EAttn(to_layer=20, type=args.attn_type)
if args.dino_model != "none":
    model.use_DAttn(args.dino_model)

prompt_learner = PromptLearner(model.to("cpu"), crane_parameters)
sbp = Crane.ScoreBasePooling()
model.to(device)
prompt_learner.to(device)

optimizer = torch.optim.Adam(prompt_learner.parameters(), lr=args.learning_rate, betas=(0.6, 0.999))





number of samples: 318
‚úÖ Device: cuda, Dataset length: 318
name ViT-L/14@336px


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 934M/934M [00:51<00:00, 18.1MiB/s]


Turning off gradients in both the image and the text encoder
Parameters to be updated: set()
Downloading: "https://github.com/facebookresearch/dinov2/zipball/main" to /root/.cache/torch/hub/main.zip




Downloading: "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_pretrain.pth" to /root/.cache/torch/hub/checkpoints/dinov2_vitb14_reg4_pretrain.pth


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 330M/330M [00:01<00:00, 336MB/s]


In [17]:
class FocalLoss(torch.nn.Module):
    def __init__(self, weight=None, gamma=2.0, reduction="mean"):
        super(FocalLoss, self).__init__()
        self.weight = weight
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, logits, target):
        # target shape: [B] or [B,1]
        if target.dim() == 1:
            target = target.unsqueeze(1)

        target = target.view(-1)  # shape [B]
        logits = logits.view(-1, logits.size(-1))  # shape [B, num_classes]

        log_probs = F.log_softmax(logits, dim=-1)
        probs = torch.exp(log_probs)

        # gather logprobs of the correct class
        log_pt = log_probs[torch.arange(len(target)), target]
        pt = probs[torch.arange(len(target)), target]

        # apply class weights if provided
        if self.weight is not None:
            at = self.weight[target]
        else:
            at = 1.0

        loss = -at * ((1 - pt) ** self.gamma) * log_pt

        if self.reduction == "mean":
            return loss.mean()
        elif self.reduction == "sum":
            return loss.sum()
        return loss


In [30]:
# criterion = torch.nn.CrossEntropyLoss(weight=torch.tensor([1.0, 10.0], device=device))
criterion = FocalLoss(weight=torch.tensor([1, 100.0], device=device), gamma=2.0)
# criterion = FocalLoss(gamma=2.0)


In [31]:
# ==========================================================
# üöÄ 4. Training Loop (with confusion matrix logging)
# ==========================================================
log_interval = 10  # show confusion matrix every N batches
logger = get_logger(args.save_path)

model.eval()
prompt_learner.train()

for epoch in tqdm(range(args.epoch), desc="Epochs"):
    loss_list = []
    all_preds, all_labels = [], []

    for batch_idx, items in enumerate(train_loader):
        label = items["anomaly"].to(device)
        image = items["img"].to(device)

        # Encode
        image_features, patch_features = model.encode_image(image, args.features_list, self_cor_attn_layers=20)
        image_features = F.normalize(image_features, dim=-1)

        # Text prompts
        prompts, tokenized_prompts, compound_prompts_text, is_train_with_img_cls = prompt_learner(img_emb=image_features)
        if is_train_with_img_cls:
            text_features_nrm = model.encode_text_learn(prompts[0], tokenized_prompts[0], compound_prompts_text)
            text_features_anm = model.encode_text_learn(prompts[1], tokenized_prompts[1], compound_prompts_text)
            text_features = torch.stack([text_features_nrm, text_features_anm], dim=1)
        else:
            text_features = model.encode_text_learn(prompts, tokenized_prompts, compound_prompts_text).unsqueeze(dim=0)
        text_features = F.normalize(text_features, dim=-1).float()

        # Score-base pooling (optional)
        if args.use_scorebase_pooling:
            sms = [calc_similarity_logits(pf, text_features, temp=0.07) for pf in patch_features]
            patch_features = torch.stack(patch_features, dim=0)  # ü©π FIX
            clustered_feature = sbp.forward(patch_features, sms)
            image_features = 0.5 * clustered_feature + 0.5 * image_features
            image_features = F.normalize(image_features, dim=1)


        # Classification
        image_logits = calc_similarity_logits(image_features, text_features, temp=0.01)
        loss = criterion(image_logits, label.long())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Record loss and predictions
        loss_list.append(loss.item())
        preds = image_logits.argmax(dim=-1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(label.cpu().numpy())

        # Show confusion matrix every N batches
        if (batch_idx + 1) % log_interval == 0:
            cm = confusion_matrix(all_labels, all_preds, labels=[0, 1])
            acc = np.trace(cm) / np.sum(cm)
            print(f"\nEpoch {epoch+1}, Batch {batch_idx+1}: Loss={np.mean(loss_list):.4f}, Acc={acc:.3f}")
            print(cm)
            all_preds, all_labels = [], []  # reset window

    # End of epoch
    avg_loss = np.mean(loss_list)
    logger.info(f"Epoch [{epoch+1}/{args.epoch}] - loss: {avg_loss:.4f}")

    # Save model checkpoint
    if (epoch + 1) % args.save_freq == 0:
        os.makedirs(args.save_path, exist_ok=True)
        ckpt_path = os.path.join(args.save_path, f"epoch_{epoch+1}.pth")
        torch.save({"prompt_learner": prompt_learner.state_dict()}, ckpt_path)
        print(f"‚úÖ Saved: {ckpt_path}")





Epoch 1, Batch 10: Loss=0.7123, Acc=0.600
[[ 0 32]
 [ 0 48]]

Epoch 1, Batch 20: Loss=0.7224, Acc=0.650
[[ 0 28]
 [ 0 52]]

Epoch 1, Batch 30: Loss=0.7650, Acc=0.588
[[ 0 33]
 [ 0 47]]

Epoch 1, Batch 40: Loss=0.7211, Acc=0.692
[[ 0 24]
 [ 0 54]]


25-11-15 11:52:47.628 - INFO: Epoch [1/2] - loss: 0.7211
25-11-15 11:52:47.628 - INFO: Epoch [1/2] - loss: 0.7211
25-11-15 11:52:47.628 - INFO: Epoch [1/2] - loss: 0.7211
25-11-15 11:52:47.628 - INFO: Epoch [1/2] - loss: 0.7211
25-11-15 11:52:47.628 - INFO: Epoch [1/2] - loss: 0.7211
25-11-15 11:52:47.628 - INFO: Epoch [1/2] - loss: 0.7211
Epochs:  50%|‚ñà‚ñà‚ñà‚ñà‚ñà     | 1/2 [04:27<04:27, 267.85s/it]

‚úÖ Saved: ./checkpoints/trained_on_Cropped Folder_Colpo_model_crane/epoch_1.pth





Epoch 2, Batch 10: Loss=0.7362, Acc=0.625
[[ 0 30]
 [ 0 50]]

Epoch 2, Batch 20: Loss=0.6987, Acc=0.662
[[ 0 27]
 [ 0 53]]

Epoch 2, Batch 30: Loss=0.7151, Acc=0.613
[[ 0 31]
 [ 0 49]]

Epoch 2, Batch 40: Loss=0.6981, Acc=0.628
[[ 0 29]
 [ 0 49]]


25-11-15 11:57:15.399 - INFO: Epoch [2/2] - loss: 0.6981
25-11-15 11:57:15.399 - INFO: Epoch [2/2] - loss: 0.6981
25-11-15 11:57:15.399 - INFO: Epoch [2/2] - loss: 0.6981
25-11-15 11:57:15.399 - INFO: Epoch [2/2] - loss: 0.6981
25-11-15 11:57:15.399 - INFO: Epoch [2/2] - loss: 0.6981
25-11-15 11:57:15.399 - INFO: Epoch [2/2] - loss: 0.6981
Epochs: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2/2 [08:55<00:00, 267.81s/it]

‚úÖ Saved: ./checkpoints/trained_on_Cropped Folder_Colpo_model_crane/epoch_2.pth





In [32]:
!bash reproduce.sh Colpo_model 0 2

The name for base version (Crane) is: Colpo_model_crane
Testing on dataset from: ['/content/drive/MyDrive/Colposcopy/Data/Patient records/Cropped Folder/']
Results will be saved to: [32m./results/trained_on_Cropped Folder_Colpo_model_crane/test_on_Cropped Folder/salami-nuts[0m
name ViT-L/14@336px
Turning off gradients in both the image and the text encoder
Parameters to be updated: set()
number of samples: 106
process_dataset, Batch size: 8
Processing test samples: 100% 14/14 [01:26<00:00,  6.17s/it]
calculating metrics for coloposcopy
Best threshold=0.5350, F1=0.8235
Confusion Matrix (labels=[0,1], threshold=0.53):
[[TN=   6, FP=  27],
 [FN=   3, TP=  70]]
Best threshold=0.5350, F1=0.8235
Confusion Matrix (labels=[0,1], threshold=0.53):
[[TN=   6, FP=  27],
 [FN=   3, TP=  70]]
Best threshold=0.5350, F1=0.8235
Confusion Matrix (labels=[0,1], threshold=0.53):
[[TN=   6, FP=  27],
 [FN=   3, TP=  70]]
only one class present, can not calculate pixel metrics
only one class present, can 