In [1]:
# Cloner le dépôt
!git clone https://github.com/mahmoodlab/UNI.git
# Entrez dans le répertoire de l'entrepôt
%cd UNI

Cloning into 'UNI'...
remote: Enumerating objects: 158, done.[K
remote: Counting objects: 100% (65/65), done.[K
remote: Compressing objects: 100% (38/38), done.[K
remote: Total 158 (delta 36), reused 37 (delta 27), pack-reused 93 (from 3)[K
Receiving objects: 100% (158/158), 7.09 MiB | 10.48 MiB/s, done.
Resolving deltas: 100% (60/60), done.
/content/UNI


In [2]:
!pip install -e .

Obtaining file:///content/UNI
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting timm==0.9.8 (from uni==0.1.0)
  Downloading timm-0.9.8-py3-none-any.whl.metadata (59 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.3/59.3 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
Downloading timm-0.9.8-py3-none-any.whl (2.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m24.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: timm, uni
  Attempting uninstall: timm
    Found existing installation: timm 1.0.13
    Uninstalling timm-1.0.13:
      Successfully uninstalled timm-1.0.13
  Running setup.py develop for uni
Successfully installed timm-0.9.8 uni-0.1.0


In [3]:
import torch
import torchvision
import os
from os.path import join as j_
from PIL import Image
import pandas as pd
import numpy as np

# loading all packages here to start
from uni import get_encoder
from uni.downstream.extract_patch_features import extract_patch_features_from_dataloader
from uni.downstream.eval_patch_features.linear_probe import eval_linear_probe
from uni.downstream.eval_patch_features.fewshot import eval_knn, eval_fewshot
from uni.downstream.eval_patch_features.protonet import ProtoNet, prototype_topk_vote
from uni.downstream.eval_patch_features.metrics import get_eval_metrics, print_metrics
from uni.downstream.utils import concat_images
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Downloading UNI weights + Creating Model

Nous pouvons également télécharger les poids du modèle vers un emplacement de point de contrôle spécifié dans votre répertoire local. La bibliothèque timm est toujours utilisée pour définir l'architecture du modèle ViT-L/16. Les poids pré-entraînés et les transformations d'image pour UNI doivent être chargés et définis manuellement.

In [6]:
import os
import torch
from torchvision import transforms
import timm
from huggingface_hub import login, hf_hub_download

# Connexion au compte Hugging Face
login()

# Définir le chemin local pour stocker les poids
local_dir = "../assets/ckpts/uni2-h/"
os.makedirs(local_dir, exist_ok=True)

# Télécharger le fichier des poids
hf_hub_download("MahmoodLab/UNI2-h", filename="pytorch_model.bin", local_dir=local_dir, force_download=True)

# Créer le modèle
model = timm.create_model(
    "hf-hub:MahmoodLab/UNI2-h",  # Nom du modèle
    img_size=224,
    patch_size=14,
    depth=24,
    num_heads=24,
    init_values=1e-5,
    embed_dim=1536,
    mlp_ratio=2.66667 * 2,
    num_classes=0,
    no_embed_class=True,
    mlp_layer=timm.layers.SwiGLUPacked,
    act_layer=torch.nn.SiLU,
    reg_tokens=8,
    dynamic_img_size=True,
)

# Charger les poids
model.load_state_dict(
    torch.load(os.path.join(local_dir, "pytorch_model.bin"), map_location="cpu"),
    strict=True,
)

# Mettre le modèle en mode évaluation
model.eval()
model.to("cpu")  # Ou spécifier un appareil spécifique, par exemple "cuda"

# Définir le prétraitement des données
transform = transforms.Compose(
    [
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ]
)


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

pytorch_model.bin:   0%|          | 0.00/2.73G [00:00<?, ?B/s]

config.json:   0%|          | 0.00/587 [00:00<?, ?B/s]

  torch.load(os.path.join(local_dir, "pytorch_model.bin"), map_location="cpu"),


In [7]:
from uni import get_encoder
model, transform = get_encoder(enc_name='uni2-h', device=device)

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

pytorch_model.bin:   0%|          | 0.00/2.73G [00:00<?, ?B/s]

  state_dict = torch.load(ckpt_path, map_location="cpu")


### ROI Feature Extraction

Remarque : En raison de la grande quantité de données et du temps de traitement long, seule une petite partie des données est utilisée ici.

In [9]:
from uni.downstream.extract_patch_features import extract_patch_features_from_dataloader

# get path to example data
dataroot = '/content/drive/MyDrive/UNI/tcga_luadlusc'

# create some image folder datasets for train/test and their data laoders
train_dataset = torchvision.datasets.ImageFolder(j_(dataroot, 'train'), transform=transform)
test_dataset = torchvision.datasets.ImageFolder(j_(dataroot, 'test'), transform=transform)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=False)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=4, shuffle=False)

# extract patch features from the train and test datasets (returns dictionary of embeddings and labels)
train_features = extract_patch_features_from_dataloader(model, train_dataloader)
test_features = extract_patch_features_from_dataloader(model, test_dataloader)

# convert these to torch
train_feats = torch.Tensor(train_features['embeddings'])
train_labels = torch.Tensor(train_features['labels']).type(torch.long)
test_feats = torch.Tensor(test_features['embeddings'])
test_labels = torch.Tensor(test_features['labels']).type(torch.long)

100%|██████████| 33/33 [13:59<00:00, 25.45s/it]
100%|██████████| 15/15 [06:12<00:00, 24.84s/it]


### ROI Linear Probe Evaluation.

In [10]:
from uni.downstream.eval_patch_features.linear_probe import eval_linear_probe

linprobe_eval_metrics, linprobe_dump = eval_linear_probe(
    train_feats = train_feats,
    train_labels = train_labels,
    valid_feats = None ,
    valid_labels = None,
    test_feats = test_feats,
    test_labels = test_labels,
    max_iter = 1000,
    verbose= True,
)

print_metrics(linprobe_eval_metrics)

Linear Probe Evaluation: Train shape torch.Size([130, 1536])
Linear Probe Evaluation: Test shape torch.Size([60, 1536])
Linear Probe Evaluation (Train Time): Best cost = 30.720
Linear Probe Evaluation (Train Time): Using only train set for evaluation. Train Shape:  torch.Size([130, 1536])
(Before Training) Loss: 0.693
(After Training) Loss: 0.016
Linear Probe Evaluation (Test Time): Test Shape torch.Size([60, 1536])
Linear Probe Evaluation: Time taken 0.24
Test lin_acc: 1.000
Test lin_bacc: 1.000
Test lin_kappa: 1.000
Test lin_weighted_f1: 1.000
Test lin_auroc: 1.000


### ROI Few-Shot Evaluation (based on ProtoNet)

In [11]:
from uni.downstream.eval_patch_features.fewshot import eval_fewshot

fewshot_episodes, fewshot_dump = eval_fewshot(
    train_feats = train_feats,
    train_labels = train_labels,
    test_feats = test_feats,
    test_labels = test_labels,
    n_iter = 500, # draw 500 few-shot episodes
    n_way = 2, # use all class examples
    n_shot = 4, # 4 examples per class (as we don't have that many)
    n_query = test_feats.shape[0], # evaluate on all test samples
    center_feats = True,
    normalize_feats = True,
    average_feats = True,
)

# how well we did picking 4 random examples per class
display(fewshot_episodes)

# summary
display(fewshot_dump)

100%|██████████| 500/500 [00:06<00:00, 73.73it/s]


Unnamed: 0,Kw4s_acc,Kw4s_bacc,Kw4s_kappa,Kw4s_weighted_f1
0,0.916667,0.916667,0.833333,0.916084
1,0.683333,0.683333,0.366667,0.648040
2,1.000000,1.000000,1.000000,1.000000
3,1.000000,1.000000,1.000000,1.000000
4,1.000000,1.000000,1.000000,1.000000
...,...,...,...,...
495,1.000000,1.000000,1.000000,1.000000
496,1.000000,1.000000,1.000000,1.000000
497,1.000000,1.000000,1.000000,1.000000
498,1.000000,1.000000,1.000000,1.000000


{'Kw4s_acc_avg': 0.9361333333333334,
 'Kw4s_bacc_avg': 0.9361333333333334,
 'Kw4s_kappa_avg': 0.8722666666666666,
 'Kw4s_weighted_f1_avg': 0.9289430547276676,
 'Kw4s_acc_std': 0.1157779106944201,
 'Kw4s_bacc_std': 0.1157779106944201,
 'Kw4s_kappa_std': 0.2315558213888402,
 'Kw4s_weighted_f1_std': 0.1362420163223126}