In [1]:
!pip install -U segmentation-models-pytorch

Collecting segmentation-models-pytorch
  Downloading segmentation_models_pytorch-0.5.0-py3-none-any.whl.metadata (17 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.8->segmentation-models-pytorch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.8->segmentation-models-pytorch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.8->segmentation-models-pytorch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.8->segmentation-models-pytorch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.8->segmentation-models-pytorch)
  Downloading nvidia_cublas_cu12-12.4.5.8-

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


METHODE A LA CHATGPT

In [3]:
import os, re, glob, json
import numpy as np, pandas as pd, cv2
import torch, random
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset, ConcatDataset
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
from tqdm import tqdm

# ------------------ Configuration ------------------
SEED = 26
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
PATH = "/content/drive/MyDrive/FewCTSeg/data/"
IMG_DIR = os.path.join(PATH, 'train-images/')
MASK_CSV = os.path.join(PATH, 'y_train.csv')
MODEL_OUT = PATH
BATCH = 8
NUM_CLASSES = 55
NUM_EPOCHS_PHASE1 = 10
NUM_EPOCHS_PHASE2 = 25
PSEUDO_LABEL_THRESHOLD = 0.6
LAMBDA = 0.5  # weight feature-perturb loss
MU = 0.5      # weight image-level loss
IGNORE_INDEX = 255  # for unlabeled pixels

# Fix randomness
def set_seed(seed=SEED):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
set_seed()

def alphanumeric_sort(name):
    parts = re.split(r'(\d+)', name)
    return [int(p) if p.isdigit() else p for p in parts]

# ------------------ Custom CutMix ------------------
def cutmix(images, masks):
    # images, masks: tensors shape (B,C,H,W),(B,H,W)
    B, C, H, W = images.shape
    lam = np.random.beta(1.0, 1.0)
    rx = np.random.randint(W)
    ry = np.random.randint(H)
    rw = int(W * np.sqrt(1 - lam))
    rh = int(H * np.sqrt(1 - lam))
    x1 = np.clip(rx - rw // 2, 0, W)
    y1 = np.clip(ry - rh // 2, 0, H)
    x2 = np.clip(rx + rw // 2, 0, W)
    y2 = np.clip(ry + rh // 2, 0, H)
    perm = torch.randperm(B)
    mixed_images = images.clone()
    mixed_masks = masks.clone()
    mixed_images[:, :, y1:y2, x1:x2] = images[perm, :, y1:y2, x1:x2]
    mixed_masks[:, y1:y2, x1:x2] = masks[perm, y1:y2, x1:x2]
    return mixed_images, mixed_masks



In [4]:

# ------------------ Transforms ------------------
base_transform = A.Compose([
    A.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
    ToTensorV2()])

weak_transform = A.Compose([
    A.RandomResizedCrop((256,256), scale=(0.2,1.0), p=1.0),
    #A.RandomRotate90(p=0.5),
    A.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
    ToTensorV2()])



In [5]:
strong_transform = A.Compose([
    A.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
    ToTensorV2()])  # CutMix applied in batch

In [6]:

# ------------------ Datasets ------------------
class LabeledCTScanDataset(Dataset):
    def __init__(self, img_dir, mask_csv, transform):
        paths = sorted(glob.glob(os.path.join(img_dir, '*.png')), key=alphanumeric_sort)
        masks = pd.read_csv(mask_csv, index_col=0).T.values.reshape(-1,256,256).astype(np.uint8)
        valid = [m.sum()>0 for m in masks]
        self.image_paths = [p for p,v in zip(paths,valid) if v]
        self.masks = masks[np.array(valid)]
        self.transform = transform
    def __len__(self): return len(self.image_paths)
    def __getitem__(self, idx):
        img = cv2.cvtColor(cv2.imread(self.image_paths[idx]), cv2.COLOR_BGR2RGB)
        mask = self.masks[idx]
        aug = self.transform(image=img, mask=mask)
        return aug['image'], aug['mask']

class UnlabeledCTScanDataset(Dataset):
    def __init__(self, img_dir, mask_csv, weak_transform):
        paths = sorted(glob.glob(os.path.join(img_dir, '*.png')), key=alphanumeric_sort)
        masks = pd.read_csv(mask_csv, index_col=0).T.values.reshape(-1,256,256).astype(np.uint8)
        valid = [m.sum()==0 for m in masks]
        self.image_paths = [p for p,v in zip(paths,valid) if v]
        self.transform = weak_transform
    def __len__(self): return len(self.image_paths)
    def __getitem__(self, idx):
        img = cv2.cvtColor(cv2.imread(self.image_paths[idx]), cv2.COLOR_BGR2RGB)
        aug = self.transform(image=img)
        return aug['image'], self.image_paths[idx]

class PseudoCTScanDataset(Dataset):
    def __init__(self, image_paths, masks):
        self.image_paths, self.masks = image_paths, masks
    def __len__(self): return len(self.masks)
    def __getitem__(self, idx):
        img = cv2.cvtColor(cv2.imread(self.image_paths[idx]), cv2.COLOR_BGR2RGB)
        aug = base_transform(image=img, mask=self.masks[idx])
        #aug = strong_transform(image=img, mask=self.masks[idx])
        return aug['image'], aug['mask']


In [7]:

# ------------------ DataLoaders ------------------
# Phase1 : Split Train/val (reprodcuible)

full_lab = LabeledCTScanDataset(IMG_DIR, MASK_CSV, base_transform)
idxs = torch.randperm(len(full_lab), generator=torch.Generator().manual_seed(SEED)).tolist()
#idxs = np.random.default_rng(seed=SEED).permutation(len(full_lab))
split = int(0.8*len(full_lab))
train_ds = Subset(full_lab, idxs[:split])
val_ds   = Subset(full_lab, idxs[split:])
gen = torch.Generator(); gen.manual_seed(SEED)
train_loader = DataLoader(train_ds, batch_size=BATCH, shuffle=True, num_workers=4, generator=gen)
val_loader   = DataLoader(val_ds,   batch_size=BATCH, shuffle=False, num_workers=4)

unlab_ds = UnlabeledCTScanDataset(IMG_DIR, MASK_CSV, weak_transform)
unlab_loader = DataLoader(unlab_ds, batch_size=BATCH, shuffle=False, num_workers=4)




In [8]:
model = smp.Segformer(
    encoder_name='timm-efficientnet-b7', encoder_weights='imagenet',
    in_channels=3, classes=NUM_CLASSES
).to(DEVICE)
sup_loss = smp.losses.DiceLoss(mode='multiclass', from_logits=True)
opt1 = Adam(model.parameters(), lr=1e-3)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/106 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/267M [00:00<?, ?B/s]

In [9]:
# ------------------ Phase1: Warmup ------------------
best_val=1e6
for epoch in range(NUM_EPOCHS_PHASE1):
    model.train(); train_loss=0
    for imgs, masks in tqdm(train_loader):
        imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
        opt1.zero_grad()
        logits = model(imgs)
        loss = sup_loss(logits, masks.long())
        loss.backward(); opt1.step()
        train_loss += loss.item()*imgs.size(0)
    val_loss=0; model.eval()
    with torch.no_grad():
        for imgs, masks in val_loader:
            imgs,masks=imgs.to(DEVICE),masks.to(DEVICE)
            val_loss+=sup_loss(model(imgs), masks.long()).item()*imgs.size(0)
    val_loss /= len(val_ds)
    print(f"Epoch1 {epoch+1}/{NUM_EPOCHS_PHASE1} — val: {val_loss:.4f}")
    if val_loss<best_val:
        best_val=val_loss
        torch.save(model.state_dict(), os.path.join(MODEL_OUT,'best_phase1_unimatch_3006.pth'))



100%|██████████| 76/76 [00:40<00:00,  1.87it/s]


Epoch1 1/10 — val: 0.3573


100%|██████████| 76/76 [00:11<00:00,  6.51it/s]


Epoch1 2/10 — val: 0.3031


100%|██████████| 76/76 [00:11<00:00,  6.54it/s]


Epoch1 3/10 — val: 0.2905


100%|██████████| 76/76 [00:14<00:00,  5.14it/s]


Epoch1 4/10 — val: 0.2850


100%|██████████| 76/76 [00:11<00:00,  6.69it/s]


Epoch1 5/10 — val: 0.2683


100%|██████████| 76/76 [00:11<00:00,  6.70it/s]


Epoch1 6/10 — val: 0.2687


100%|██████████| 76/76 [00:11<00:00,  6.47it/s]


Epoch1 7/10 — val: 0.2716


100%|██████████| 76/76 [00:12<00:00,  5.90it/s]


Epoch1 8/10 — val: 0.2732


100%|██████████| 76/76 [00:11<00:00,  6.54it/s]


Epoch1 9/10 — val: 0.2651


100%|██████████| 76/76 [00:15<00:00,  4.90it/s]


Epoch1 10/10 — val: 0.2670


In [29]:

model.load_state_dict(torch.load(os.path.join(MODEL_OUT,'best_phase1_unimatch_3006.pth')))
model.eval()
pseudo_paths, pseudo_masks = [], []

In [12]:
# ------------------ Pseudo-labeling ------------------
with torch.no_grad():
    for imgs, paths in tqdm(unlab_loader):
        imgs = imgs.to(DEVICE)
        pw = torch.softmax(model(imgs), dim=1)
        conf, pred = pw.max(1)
        for i,pth in enumerate(paths):
            conf_i, pred_i = conf[i].cpu().numpy(), pred[i].cpu().numpy().astype(np.uint8)
            mask_i = np.where(conf_i>=PSEUDO_LABEL_THRESHOLD, pred_i, 0)
            if (conf_i>=PSEUDO_LABEL_THRESHOLD).mean()>=LAMBDA:
                pseudo_paths.append(pth); pseudo_masks.append(mask_i)


100%|██████████| 156/156 [00:08<00:00, 17.67it/s]


In [13]:
# Si vide --> skip phase 2
if len(pseudo_masks)==0:
    raise ValueError("No pseudo-labels found, adjust thresholds.")
pseudo_masks = np.stack(pseudo_masks,0)
pseudo_ds = PseudoCTScanDataset(pseudo_paths, pseudo_masks)


In [14]:
len(pseudo_ds)

1241

In [15]:

# ------------------ Phase2: UniMatch ------------------
joint_ds = ConcatDataset([train_ds, pseudo_ds])
joint_loader = DataLoader(joint_ds, batch_size=BATCH, shuffle=True, num_workers=4)

In [32]:
# ------------------ Hyperparameters of UniMatch ------------------
opt2 = Adam(model.parameters(), lr=1e-4)
scheduler = ReduceLROnPlateau(opt2, 'min', factor=0.5, patience=3)
drop2d = nn.Dropout2d(p=0.5)
dice = smp.losses.DiceLoss(mode='multiclass', from_logits=True)
ce = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX)
std, mean = np.array([0.485, 0.456, 0.406]), np.array([0.229, 0.224, 0.225])

In [33]:
#PSEUDO_LABEL_THRESHOLD = 0.6

In [34]:


# Création loader de validation pour phase2 identique à val_loader
for epoch in range(1, NUM_EPOCHS_PHASE2+1):

    model.train(); train_loss=0
    for imgs, masks in tqdm(joint_loader):
        imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
        # CutMix sur batch supervisé
        imgs, masks = cutmix(imgs, masks)
        # Forward weak
        feat_w = model.encoder(imgs)
        logits_w = model.segmentation_head(model.decoder(feat_w))
        # Feature perturbation
        last = feat_w[-1]; feat_fp = list(feat_w)
        feat_fp[-1] = drop2d(last)
        logits_fp = model.segmentation_head(model.decoder(feat_fp))

        # Two strong streams
        # (we reuse mix outputs as strong; for demo)
        logits_s1 = model(imgs); logits_s2 = model(imgs)
        # Pseudo mask from weak augmentation (detach)
        pseudo = logits_w.argmax(1).detach()

        # 2) Flux fort
        """
        Mettre dans une fonction
        # a) récupérer numpy non-normalisé
        imgs_np = imgs.cpu().permute(0,2,3,1).numpy()
        imgs_np = (imgs_np * std + mean) * 255.0  # repasser en 0–255

        # b) deux augmentations indépendantes
        xs1 = torch.stack([strong_transform(image=img.astype(np.uint8))['image']
                            for img in imgs_np]).to(DEVICE)
        xs2 = torch.stack([strong_transform(image=img.astype(np.uint8))['image']
                            for img in imgs_np]).to(DEVICE)
        """

        # Losses
        loss_sup = sup_loss(logits_w, masks.long())

        loss_fp = ce(logits_fp, pseudo)
        loss_s = 0.5*(ce(logits_s1, pseudo) + ce(logits_s2, pseudo))

        loss_u = LAMBDA*loss_fp + MU*loss_s
        loss = 0.5*(loss_sup + loss_u)
        opt2.zero_grad(); loss.backward(); opt2.step()
        train_loss += loss.item()*imgs.size(0)
    train_loss /= len(joint_ds)

    # Validation phase2
    model.eval(); val_loss=0
    with torch.no_grad():
        for imgs, masks in val_loader:
            imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
            val_loss += sup_loss(model(imgs), masks.long()).item()*imgs.size(0)
    val_loss /= len(val_ds)
    scheduler.step(val_loss)
    print(f"Phase2 E{epoch}/{NUM_EPOCHS_PHASE2} — train:{train_loss:.4f} val:{val_loss:.4f}")
    if epoch%5==0:
        torch.save(model.state_dict(), os.path.join(MODEL_OUT, f'phase2_e{epoch}_unimatch_drp05_th0.9_3006.pth'))
    if val_loss<best_val:
        best_val=val_loss
        torch.save(model.state_dict(), os.path.join(MODEL_OUT,'best_phase2_unimatch_drp05_th0.9_3006.pth'))




100%|██████████| 231/231 [01:32<00:00,  2.51it/s]


Phase2 E1/25 — train:0.2645 val:0.2534


100%|██████████| 231/231 [01:31<00:00,  2.51it/s]


Phase2 E2/25 — train:0.2568 val:0.2597


100%|██████████| 231/231 [01:32<00:00,  2.51it/s]


Phase2 E3/25 — train:0.2563 val:0.2542


100%|██████████| 231/231 [01:31<00:00,  2.51it/s]


Phase2 E4/25 — train:0.2555 val:0.2565


100%|██████████| 231/231 [01:32<00:00,  2.51it/s]


Phase2 E5/25 — train:0.2529 val:0.2612


100%|██████████| 231/231 [01:31<00:00,  2.53it/s]


Phase2 E6/25 — train:0.2498 val:0.2628


100%|██████████| 231/231 [01:32<00:00,  2.50it/s]


Phase2 E7/25 — train:0.2497 val:0.2627


  3%|▎         | 8/231 [00:03<01:48,  2.05it/s]


KeyboardInterrupt: 

In [None]:
torch.save(model.state_dict(), os.path.join(MODEL_OUT,'best_phase2_final_unimatch_3006.pth'))

# Version 27/06 - CHATGPT

In [None]:
import os, re, glob, json
import numpy as np
import pandas as pd
import cv2
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset, ConcatDataset
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
import albumentations as A
import segmentation_models_pytorch as smp
from tqdm import tqdm

# ------------------ Configuration ------------------
SEED = 26
np.random.seed(SEED); torch.manual_seed(SEED)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Paths
PATH = "/content/drive/MyDrive/FewCTSeg/data/"
IMG_DIR = os.path.join(PATH, 'train-images/')
MASK_CSV = os.path.join(PATH, 'y_train.csv')
MODEL_OUT = PATH

# Hyperparameters
BATCH = 8
NUM_CLASSES = 55
NUM_EPOCHS_PHASE1 = 20
NUM_EPOCHS_PHASE2 = 20
#PSEUDO_LABEL_THRESHOLD = 0.6
PSEUDO_LABEL_THRESHOLD = 0.9
LAMBDA = 0.5    # weight for feature perturbation loss
MU = 0.5        # weight for image-level loss




In [None]:

# ------------------ Utilities ------------------
def alphanumeric_sort(name):
    parts = re.split(r'(\d+)', name)
    return [int(p) if p.isdigit() else p for p in parts]

# ------------------ Datasets ------------------
class LabeledCTScanDataset(Dataset):
    def __init__(self, img_dir, mask_csv, transform):
        self.paths = sorted(glob.glob(os.path.join(img_dir, '*.png')), key=alphanumeric_sort)
        masks = pd.read_csv(mask_csv, index_col=0).T.values.reshape(-1,256,256).astype(np.uint8)
        valid = [m.sum()>0 for m in masks]
        self.image_paths = [p for p,v in zip(self.paths, valid) if v]
        self.masks = masks[np.array(valid)]
        self.transform = transform

    def __len__(self): return len(self.image_paths)
    def __getitem__(self, idx):
        img = cv2.cvtColor(cv2.imread(self.image_paths[idx]), cv2.COLOR_BGR2RGB)
        mask = self.masks[idx]
        aug = self.transform(image=img, mask=mask)
        return aug['image'], aug['mask']

class UnlabeledCTScanDataset(Dataset):
    def __init__(self, img_dir, mask_csv, weak_transform):
        self.paths = sorted(glob.glob(os.path.join(img_dir, '*.png')), key=alphanumeric_sort)
        masks = pd.read_csv(mask_csv, index_col=0).T.values.reshape(-1,256,256).astype(np.uint8)
        valid = [m.sum()==0 for m in masks]
        self.image_paths = [p for p,v in zip(self.paths, valid) if v]
        self.transform = weak_transform

    def __len__(self): return len(self.image_paths)
    def __getitem__(self, idx):
        path = self.image_paths[idx]
        img = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB)
        aug = self.transform(image=img)
        return aug['image'], os.path.basename(path)

class PseudoCTScanDataset(Dataset):
    def __init__(self, image_paths, masks, transform):
        self.image_paths = image_paths
        self.masks = masks
        self.transform = transform
    def __len__(self): return len(self.masks)
    def __getitem__(self, idx):
        img = cv2.cvtColor(cv2.imread(self.image_paths[idx]), cv2.COLOR_BGR2RGB)
        aug = self.transform(image=img, mask=self.masks[idx])
        return aug['image'], aug['mask']


In [None]:
# ------------------ Transforms ------------------
base_transform = A.Compose([
    A.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
    A.ToTensorV2()])

weak_transform = A.Compose([
    A.RandomResizedCrop((256,256),scale = (0.2,1), p=1.0),
    A.RandomRotate90(p=0.5),
    #A.OneOf([
        #A.ElasticTransform(alpha=300, sigma=10, p=0.5),
        #A.GridDistortion(distort_limit=0.2, num_steps=5, p=0.5)],p=1.0),
    A.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
    A.ToTensorV2()])

strong_transform = A.Compose([
    #A.RandomResizedCrop((256,256),scale = (0.2,1), p=1.0),
    A.CoarseDropout(num_holes_range=(4, 8), hole_height_range=(0.1, 0.25),hole_width_range=(0.1, 0.25), fill_value=0, p=1.0),
    A.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
    A.ToTensorV2()])




  A.CoarseDropout(num_holes_range=(4, 8), hole_height_range=(0.1, 0.25),hole_width_range=(0.1, 0.25), fill_value=0, p=1.0),


In [None]:

# ------------------ DataLoaders ------------------
# Phase 1: supervised
full_lab = LabeledCTScanDataset(IMG_DIR, MASK_CSV, base_transform)
idxs = np.random.permutation(len(full_lab))
#idxs = np.random.default_rng(seed=SEED).permutation(len(full_lab))
split = int(0.8*len(full_lab))
train_ds = Subset(full_lab, idxs[:split])
val_ds   = Subset(full_lab, idxs[split:])
train_loader = DataLoader(train_ds, batch_size=BATCH, shuffle=True, num_workers=4)
val_loader   = DataLoader(val_ds,   batch_size=BATCH, shuffle=False, num_workers=4)

# Unlabeled loader for pseudo-label generation
unlab_ds = UnlabeledCTScanDataset(IMG_DIR, MASK_CSV, weak_transform)
unlab_loader = DataLoader(unlab_ds, batch_size=BATCH, shuffle=False, num_workers=4)


In [None]:
idxs

In [None]:

# ------------------ Model & Loss ------------------
model = smp.Segformer(
    encoder_name='timm-efficientnet-b7', encoder_weights='imagenet',
    in_channels=3, classes=NUM_CLASSES, activation=None
).to(DEVICE)

loss_fn = smp.losses.DiceLoss(mode='multiclass', from_logits=True)
optimizer = Adam(model.parameters(), lr=1e-3)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/106 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/267M [00:00<?, ?B/s]

In [None]:

# ------------------ Phase 1: Warm-up ----------------
best_val = float('inf')
for epoch in range(NUM_EPOCHS_PHASE1):
    model.train()
    total_loss = 0
    for imgs, masks in tqdm(train_loader, desc=f"Phase1 Epoch {epoch+1}"):
        imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
        optimizer.zero_grad()
        logits = model(imgs)
        loss = loss_fn(logits, masks.long())
        loss.backward(); optimizer.step()
        total_loss += loss.item()*imgs.size(0)
    val_loss = 0
    model.eval()
    with torch.no_grad():
        for imgs, masks in val_loader:
            imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
            val_loss += loss_fn(model(imgs), masks.long()).item()*imgs.size(0)
    val_loss /= len(val_ds)
    print(f"Val Loss: {val_loss:.4f}")
    if val_loss < best_val:
        best_val = val_loss
        torch.save(model.state_dict(), os.path.join(MODEL_OUT,'best_phase1_2706.pth'))



Phase1 Epoch 1: 100%|██████████| 76/76 [00:34<00:00,  2.22it/s]


Val Loss: 0.3410


Phase1 Epoch 2: 100%|██████████| 76/76 [00:12<00:00,  6.21it/s]


Val Loss: 0.3139


Phase1 Epoch 3: 100%|██████████| 76/76 [00:12<00:00,  6.31it/s]


Val Loss: 0.2845


Phase1 Epoch 4: 100%|██████████| 76/76 [00:15<00:00,  5.01it/s]


Val Loss: 0.2718


Phase1 Epoch 5: 100%|██████████| 76/76 [00:11<00:00,  6.50it/s]


Val Loss: 0.2695


Phase1 Epoch 6: 100%|██████████| 76/76 [00:15<00:00,  5.00it/s]


Val Loss: 0.2707


Phase1 Epoch 7: 100%|██████████| 76/76 [00:12<00:00,  6.03it/s]


Val Loss: 0.2629


Phase1 Epoch 8: 100%|██████████| 76/76 [00:15<00:00,  4.79it/s]


Val Loss: 0.2588


Phase1 Epoch 9: 100%|██████████| 76/76 [00:15<00:00,  5.01it/s]


Val Loss: 0.2700


Phase1 Epoch 10: 100%|██████████| 76/76 [00:12<00:00,  6.07it/s]


Val Loss: 0.2627


Phase1 Epoch 11: 100%|██████████| 76/76 [00:11<00:00,  6.61it/s]


Val Loss: 0.2621


Phase1 Epoch 12: 100%|██████████| 76/76 [00:11<00:00,  6.56it/s]


Val Loss: 0.2711


Phase1 Epoch 13: 100%|██████████| 76/76 [00:11<00:00,  6.57it/s]


Val Loss: 0.2623


Phase1 Epoch 14: 100%|██████████| 76/76 [00:11<00:00,  6.54it/s]


Val Loss: 0.2751


Phase1 Epoch 15: 100%|██████████| 76/76 [00:11<00:00,  6.55it/s]


Val Loss: 0.2809


Phase1 Epoch 16: 100%|██████████| 76/76 [00:11<00:00,  6.54it/s]


Val Loss: 0.2807


Phase1 Epoch 17: 100%|██████████| 76/76 [00:12<00:00,  6.32it/s]


Val Loss: 0.2824


Phase1 Epoch 18: 100%|██████████| 76/76 [00:11<00:00,  6.45it/s]


Val Loss: 0.2875


Phase1 Epoch 19: 100%|██████████| 76/76 [00:11<00:00,  6.55it/s]


Val Loss: 0.2788


Phase1 Epoch 20: 100%|██████████| 76/76 [00:11<00:00,  6.59it/s]


Val Loss: 0.3005


In [None]:
# ------------------ Generate Pseudo-Labels ------------------
model.load_state_dict(torch.load(os.path.join(MODEL_OUT,'best_phase1_2706.pth')))
model.eval()
pass

In [None]:
pseudo_paths, pseudo_masks = [], []
with torch.no_grad():
    for imgs, names in tqdm(unlab_loader, desc="Pseudo-labeling"):
        imgs = imgs.to(DEVICE)
        # Dual strong views
        xs1 = torch.stack([strong_transform(image=img.cpu().permute(1,2,0).numpy())['image'] for img in imgs]).to(DEVICE)
        xs2 = torch.stack([strong_transform(image=img.cpu().permute(1,2,0).numpy())['image'] for img in imgs]).to(DEVICE)
        # Weak features
        pw_logits = model(imgs)
        pw = torch.softmax(pw_logits, dim=1)
        conf, mask_pred = pw.max(dim=1)
        for i,name in enumerate(names):
            conf_i = conf[i].cpu().numpy()
            pred_i = mask_pred[i].cpu().numpy().astype(np.uint8)
            mask_i = np.where(conf_i>=PSEUDO_LABEL_THRESHOLD, pred_i, 0)
            keep_ratio = (conf_i>=PSEUDO_LABEL_THRESHOLD).mean()

            if keep_ratio >= LAMBDA:
                pseudo_paths.append(unlab_ds.image_paths[i])
                pseudo_masks.append(mask_i)

            #pseudo_paths.append(unlab_ds.image_paths[i])
            #pseudo_paths.append(os.path.join(IMG_DIR,name))
            #pseudo_masks.append(mask_i)


Pseudo-labeling: 100%|██████████| 156/156 [00:16<00:00,  9.50it/s]


In [None]:
pseudo_masks = np.stack(pseudo_masks,axis=0)

In [None]:
pseudo_masks.shape

(1241, 256, 256)

In [None]:
len(pseudo_paths)

1241

In [None]:
# Build pseudo dataset
pseudo_ds = PseudoCTScanDataset(pseudo_paths, pseudo_masks, strong_transform)


In [None]:
# ------------------ Phase 2: UniMatch ------------------
# Combine datasets
joint_ds = ConcatDataset([train_ds, pseudo_ds])
joint_loader = DataLoader(joint_ds, batch_size=BATCH, shuffle=True, num_workers=4)
optimizer = Adam(model.parameters(), lr=5e-4)
#scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=3)


In [None]:
dropout2d =nn.Dropout2d(p=0.3)
dice_loss = smp.losses.DiceLoss(mode='multiclass', from_logits=True)

In [None]:

for epoch in range(1,11):
    model.train(); tot_loss=0
    for imgs, masks in tqdm(joint_loader, desc=f"Phase2 Epoch {epoch}"):
        imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
        optimizer.zero_grad()
        # Weak stream
        feat_w = model.encoder(imgs)
        #p_w = model.decoder(feat_w)
        p_w  = model.segmentation_head(model.decoder(feat_w))

        # Feature perturb stream
        #feat_fp = nn.Dropout2d(p=0.5)(feat_w)
        """

        last_feat = feat_w[-1]            # par ex. la feature la plus “profond”
        perturbed = drop(last_feat)
        p_fp = model.decoder([*feat_w[:-1], perturbed])
        """
        last_w  = feat_w[-1]                # map finale (B, C, h, w)
        last_fp = dropout2d(last_w)          # feature perturbée
        feat_fp = feat_w.copy()
        feat_fp[-1] = last_fp
        #p_fp = model.decoder(feat_fp)
        p_fp  = model.segmentation_head(model.decoder(feat_fp))

        # Dual strong image streams
        xs1 = torch.stack([strong_transform(image=img.cpu().permute(1,2,0).numpy())['image'] for img in imgs]).to(DEVICE)
        xs2 = torch.stack([strong_transform(image=img.cpu().permute(1,2,0).numpy())['image'] for img in imgs]).to(DEVICE)

        p_s1 = model(xs1); p_s2 = model(xs2)
        # Pseudo mask from p_w
        pseudo_mask = p_w.argmax(dim=1).detach()
        # Losses

        #ce = nn.CrossEntropyLoss()
        #Après tester avec la cross entropy

        loss_fp = dice_loss(p_fp, pseudo_mask.long())
        loss_s = (dice_loss(p_s1, pseudo_mask.long()) + dice_loss(p_s2, pseudo_mask.long())) * 0.5
        loss_u = LAMBDA * loss_fp + MU * loss_s
        loss = loss_u
        loss.backward(); optimizer.step()
        tot_loss += loss.item()*imgs.size(0)
    avg_loss = tot_loss/len(joint_ds)

    print(f"Phase2 Train Loss: {avg_loss:.4f}")
    if epoch %5 ==0:

      torch.save(model.state_dict(), os.path.join(MODEL_OUT, f'best_phase2_epoch_{epoch}_2706.pth'))





Phase2 Epoch 1: 100%|██████████| 231/231 [01:43<00:00,  2.23it/s]


Phase2 Train Loss: 0.2275


Phase2 Epoch 2: 100%|██████████| 231/231 [01:42<00:00,  2.25it/s]


Phase2 Train Loss: 0.1800


Phase2 Epoch 3: 100%|██████████| 231/231 [01:43<00:00,  2.24it/s]


Phase2 Train Loss: 0.1671


Phase2 Epoch 4: 100%|██████████| 231/231 [01:43<00:00,  2.24it/s]


Phase2 Train Loss: 0.1597


Phase2 Epoch 5: 100%|██████████| 231/231 [01:42<00:00,  2.25it/s]


Phase2 Train Loss: 0.1604


Phase2 Epoch 6: 100%|██████████| 231/231 [01:42<00:00,  2.25it/s]


Phase2 Train Loss: 0.1421


Phase2 Epoch 7: 100%|██████████| 231/231 [01:42<00:00,  2.25it/s]


Phase2 Train Loss: 0.1331


Phase2 Epoch 8: 100%|██████████| 231/231 [01:42<00:00,  2.25it/s]


Phase2 Train Loss: 0.1327


Phase2 Epoch 9: 100%|██████████| 231/231 [01:42<00:00,  2.24it/s]


Phase2 Train Loss: 0.1337


Phase2 Epoch 10: 100%|██████████| 231/231 [01:42<00:00,  2.25it/s]


Phase2 Train Loss: 0.1149


In [None]:
# Reesayer comme ça

In [None]:
from PIL import Image
from torchvision import transforms as T
import cv2
import torch, torch.nn as nn
from torch.utils.data import DataLoader, random_split, Subset, Dataset, ConcatDataset
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
import matplotlib.pyplot as plt
import json, os, glob, re
import numpy as np, pandas as pd
import albumentations as A
import segmentation_models_pytorch as smp
#cv2.setNumThreads(0)  - To avoid slower computation

In [None]:
SEED = 26
np.random.seed(SEED); torch.manual_seed(SEED)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
PATH = "/content/drive/MyDrive/FewCTSeg/data/"
IMG_DIR = "/content/drive/MyDrive/FewCTSeg/data/train-images/"
MASK_CSV = "/content/drive/MyDrive/FewCTSeg/data/y_train.csv"
BATCH = 8
NUM_CLASSES = 55
NUM_EPOCHS = 1
PSEUDO_LABEL_THRESHOLD = 0.9

LAMBDA = 0.5
MU = 0.5

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
def alphanumeric_sort(name):
    parts = re.split(r'(\d+)', name)
    return [int(p) if p.isdigit() else p for p in parts]

In [None]:
class LabeledCTScanDataset(Dataset):
    def __init__(self, img_dir, mask_csv, transform):
        all_paths = sorted(glob.glob(os.path.join(img_dir, "*.png")), key=alphanumeric_sort)
        masks = pd.read_csv(mask_csv, index_col=0).T.values.reshape(-1,256,256).astype(np.uint8)
        valid_idx = np.array([m.sum() > 0 for m in masks])
        # on filtre image_paths et masks
        filtered_paths = [p for p, valid in zip(all_paths, valid_idx) if valid]
        filtered_masks = masks[valid_idx]

        self.image_paths = filtered_paths
        self.masks = filtered_masks
        self.transform = transform

    def __len__(self): return len(self.image_paths)
    def __getitem__(self, i):
        img = cv2.cvtColor(cv2.imread(self.image_paths[i]), cv2.COLOR_BGR2RGB)
        mask = self.masks[i]
        aug = self.transform(image=img, mask=mask)
        return aug['image'], aug['mask']

        return img_t, torch.zeros((256,256), dtype=torch.long)

class UnlabeledCTScanDataset(Dataset):
    def __init__(self, img_dir, mask_csv, transform):
        all_paths = sorted(glob.glob(os.path.join(img_dir,"*.png")), key=alphanumeric_sort)


        masks = pd.read_csv(mask_csv, index_col=0).T.values.reshape(-1,256,256).astype(np.uint8)
        valid_idx = np.array([m.sum() == 0 for m in masks])
        # on filtre image_paths et masks
        filtered_paths = [p for p, valid in zip(all_paths, valid_idx) if valid]
        filtered_masks = masks[valid_idx]

        self.image_paths = filtered_paths
        self.transform = transform

    def __len__(self): return len(self.image_paths)

    def __getitem__(self, i):
        path = self.image_paths[i]
        img = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB)
        img_t = self.transform(image=img)['image']
        name = os.path.basename(path)
        return img_t, torch.zeros((256,256), dtype=torch.long), name

class PseudoCTScanDataset(Dataset):
    def __init__(self, image_paths, masks, transform):
        assert len(image_paths) == len(masks), "paths and masks should have the same length"
        self.image_paths = image_paths
        self.masks       = masks
        self.transform   = transform

    def __len__(self):
        return len(self.masks)

    def __getitem__(self, i):
        img = cv2.cvtColor(cv2.imread(self.image_paths[i]), cv2.COLOR_BGR2RGB)
        aug = self.transform(image=img, mask=self.masks[i])
        return aug['image'], aug['mask']



In [None]:
#  Define transformations

base_transform = A.Compose([
    A.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
    A.ToTensorV2(),
])


weak_augment = A.Compose([
    A.RandomCrop(height=192, width=192, p=1.0),
    #ou A.RandomRotate90 ou A.Transpose
    A.OneOf([
        A.ElasticTransform(alpha=300, sigma=10, p=0.5),
        A.GridDistortion(distort_limit=0.2, num_steps=5, p=0.5)],p=1.0)
    A.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)), A.ToTensorV2(),])

strong_augment = A.Compose([
        #A.CoarseDropout(num_holes_range=(1, 8), hole_height_range=(0.1, 0.4),hole_width_range=(0.1, 0.4), fill_value=0, p=1.0),
        A.CoarseDropout(num_holes_range=(4, 8), hole_height_range=(0.1, 0.25),hole_width_range=(0.1, 0.25), fill_value=0, p=1.0),
        A.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
        A.ToTensorV2(),])


#Define Dropout
dropout2d = nn.Dropout2d(0.3)

In [None]:
#Split train/val

full_lab = LabeledCTScanDataset(IMG_DIR, MASK_CSV, base_transform)
idxs = np.random.permutation(len(full_lab))
split = int(0.8*len(full_lab))
train_ds = Subset(full_lab, idxs[:split])
val_ds   = Subset(full_lab, idxs[split:])
train_loader = DataLoader(train_ds, batch_size=BATCH, shuffle=True, num_workers=4)
val_loader   = DataLoader(val_ds,   batch_size=BATCH, shuffle=False, num_workers=4)

unlab_ds = UnlabeledCTScanDataset(IMG_DIR, MASK_CSV, weak_augment)
unlab_loader = DataLoader(unlab_ds, batch_size=BATCH, shuffle=False, num_workers=4)


In [None]:
model = smp.Segformer(
    encoder_name="timm-efficientnet-b7",
    encoder_weights="imagenet",
    in_channels=3, classes=NUM_CLASSES, activation=None
).to(DEVICE)

loss_fn = smp.losses.DiceLoss(mode='multiclass', from_logits=True)
optimizer = Adam(model.parameters(), lr=1e-3)

## First phase of Unimatch's method :

In [None]:
best_val = float('inf')
for epoch in range(1, NUM_EPOCHS + 1):
    # train
    model.train(); total=0.;total_dice_loss=0.0
    for imgs, msks in tqdm(train_loader):
        imgs, msks = imgs.to(DEVICE), msks.to(DEVICE)
        optimizer.zero_grad()
        logits = model(imgs)
        dice_loss = loss_fn(logits, msks.long())
        dice_loss.backward()
        optimizer.step()
        total_dice_loss +=dice_loss.item()*imgs.size(0)


    train_dice_loss = total_dice_loss/ len(train_ds)
    # val
    model.eval(); vtot=0.;vtot_dice_loss=0.0
    with torch.no_grad():
        for imgs, msks in val_loader:
            imgs, msks = imgs.to(DEVICE), msks.to(DEVICE)
            vtot_dice_loss+=(loss_fn(model(imgs), msks.long())).item()*imgs.size(0)

    val_dice_loss = vtot_dice_loss/len(val_ds)
    print(f"[Epoch {epoch}] Train Loss: {train_dice_loss:.4f}  - Val Loss : {val_dice_loss:4f}")
    if val_dice_loss < best_val:
        best_val = val_dice_loss
        print("→ Meilleur modèle Phase 1 sauvegardé")
        torch.save(model.state_dict(), PATH + "best_phase_unimatch_threshold06_dropout05_2506.pth")


In [None]:
#Load the model
model.load_state_dict(torch.load(PATH + "best_phase_unimatch_threshold06_dropout05_2506.pth"))
model.eval()
pass

In [None]:
"""
#Pseudo-labeling phase

def pseudo_labeling(unlab_loader, strong_transform, model):

  pseudo_masks = []
  pseudo_paths = []

  for imgs, _,names in tqdm(unlab_loader, desc="Pseudo-labeling"):  # imgs: (B,3,256,256)
      imgs = imgs.to(DEVICE)

      # 1) weak view
      x_w = imgs

      # 2) strong views, image-by-image
      x_s1_list, x_s2_list = [], []
      for img_tensor in x_w:
          # passe de (3,256,256) à (256,256,3) numpy
          img_np = img_tensor.permute(1,2,0).cpu().numpy()
          s1 = strong_transform(image=img_np)['image']
          s2 = strong_transform(image=img_np)['image']
          x_s1_list.append(s1)
          x_s2_list.append(s2)
      x_s1 = torch.stack(x_s1_list).to(DEVICE)  # (B,3,256,256)
      x_s2 = torch.stack(x_s2_list).to(DEVICE)  # idem

      # 3) encodeur weak → features
      feats_w = model.encoder(x_w)         # liste de N maps
      last_w  = feats_w[-1]                # map finale (B, C, h, w)
      last_fp = dropout2d(last_w)          # feature perturbée
      feats_fp = feats_w.copy()
      feats_fp[-1] = last_fp

      # 4) head segmentation
      p_w  = model.segmentation_head(model.decoder(feats_w))   # (B,55,256,256)
      p_fp = model.segmentation_head(model.decoder(feats_fp))

      # 5) strong views through full model
      #    (concat + chunk ou deux appels séparés)
      ps   = model(torch.cat([x_s1, x_s2], dim=0))             # (2B,55,256,256)
      ps1, ps2 = ps.chunk(2, dim=0)                             # deux views

      # 6) calcul du pseudo-mask
      conf_w, mask_w = torch.max(torch.softmax(p_w, dim=1), dim=1)  # (B,256,256)
      mask_w = mask_w.cpu().numpy()

      # 7) seuillage
      for c, m in zip(conf_w.cpu(), mask_w):
          m[c < 0.6] = 0
          pseudo_masks.append(m)

      for name in names:
            pseudo_paths.append(os.path.join(train_path,name))


  pseudo_masks = np.stack(pseudo_masks, axis=0)#concatenate

  return pseudo_masks, p_w, p_fp, ps1, ps2

"""

In [None]:
optimizer = Adam(model.parameters(), lr=1e-3)
best_val = float('inf')
no_improve,patience = 0,10

In [None]:
pseudo_masks = []
pseudo_paths = []

In [None]:

for epoch in range(1, NUM_EPOCHS+1):
    model.train()
    sup_loss = 0.0
    for imgs, masks in train_loader:
        imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
        optimizer.zero_grad()
        preds = model(imgs)
        l_sup = loss_fn(preds, masks.long())
        (l_sup).backward()
        optimizer.step()
        sup_loss += l_sup.item() * imgs.size(0)
    sup_loss /= len(train_loader.dataset)

    unsup_loss = 0.0

    for imgs, _,names in unlab_loader:  # imgs: (B,3,256,256)
      with torch.no_grad():

        imgs = imgs.to(DEVICE)

        # 1) weak view
        x_w = imgs.float()

        # 2) strong views, image-by-image
        x_s1_list, x_s2_list = [], []
        for img_tensor in x_w:
            # passe de (3,256,256) à (256,256,3) numpy
            img_np = img_tensor.permute(1,2,0).cpu().numpy()
            s1 = strong_augment(image=img_np)['image']
            s2 = strong_augment(image=img_np)['image']
            x_s1_list.append(s1)
            x_s2_list.append(s2)
        x_s1 = torch.stack(x_s1_list).to(DEVICE)  # (B,3,256,256)
        x_s2 = torch.stack(x_s2_list).to(DEVICE)  # idem

        # 3) encodeur weak → features
        feats_w = model.encoder(x_w)         # liste de N maps
        last_w  = feats_w[-1]                # map finale (B, C, h, w)
        last_fp = dropout2d(last_w)          # feature perturbée
        feats_fp = feats_w.copy()
        feats_fp[-1] = last_fp

        # 4) head segmentation
        p_w  = model.segmentation_head(model.decoder(feats_w))   # (B,55,256,256)
        p_fp = model.segmentation_head(model.decoder(feats_fp))

        # 5) strong views through full model
        #    (concat + chunk ou deux appels séparés)
        ps  = model(torch.cat([x_s1, x_s2], dim=0))             # (2B,55,256,256)
        ps1, ps2 = ps.chunk(2, dim=0)                             # deux views

        # 6) calcul du pseudo-mask
        conf_w, masks_w = torch.max(torch.softmax(p_w, dim=1), dim=1)  # (B,256,256)
        #mask_w = mask_w.cpu().numpy()

        # 7) seuillage
        for c, m in zip(conf_w.cpu(), masks_w.cpu().numpy()):
            m[c < 0.6] = 0
            pseudo_masks.append(m)

        for name in names:
              pseudo_paths.append(os.path.join(IMG_DIR,name))

      #pseudo_masks=np.stack(pseudo_masks,axis=0)


      masks_w = masks_w.to(DEVICE)

      optimizer.zero_grad()
      loss_u = LAMBDA * loss_fn(p_fp, masks_w.long()) + 0.5 * MU * (loss_fn(ps1, masks_w.long()) + loss_fn(ps2, masks_w.long()))
      loss_u.backward()
      optimizer.step()
      unsup_loss += loss_u.item() * imgs.size(0)
    unsup_loss /= len(unlab_loader.dataset)

    total_loss =0.5*(sup_loss +  unsup_loss)
    print(f"[Epoch {epoch}] sup={sup_loss:.4f} pseudo={unsup_loss:.4f} tot={total_loss:.4f}")
   #print(f"[Finetune {epoch}] Train: {train_l:.4f} – Val: {val_l:.4f}")


    if epoch %5 ==0:
      torch.save(model.state_dict(), PATH +  f"best_final_phase_unimatch_threshold06_dropout05_2506.pth_{epoch}.pth")

      print("→ Meilleur modèle Final sauvegardé")

print("Pipeline semi-supervisé terminée.")

pseudo_masks = np.stack(pseudo_masks, axis=0)


In [None]:
optimizer = Adam(model.parameters(), lr=1e-3)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)

best_val = float('inf')
no_improve,patience = 0,10

In [None]:
len(pseudo_ds)

In [None]:
len(full_lab)

In [None]:
len(mixed_loader)

In [None]:
#Final phase of UniMatch

for epoch in range(1,NUM_EPOCHS + 1):
    model.train(); tot=0.;tot_dice=0.0
    for imgs, msks in tqdm(mixed_loader):
        imgs, msks = imgs.to(DEVICE), msks.to(DEVICE)
        optimizer.zero_grad()
        logits = model(imgs)
        #loss = loss_fn(logits, msks.long()) + criterion_sup(logits, msks.long())
        dice_loss = loss_fn(logits, msks.long())
        #loss.backward(); optimizer.step()
        dice_loss.backward(); optimizer.step()
        #tot += loss.item()*imgs.size(0)
        tot_dice += dice_loss.item()*imgs.size(0)
    #scheduler.step(tot/len(mixed_ds))
    #train_loss = tot / len(mixed_ds)
    train_dice_loss = tot_dice / len(mixed_ds)
    # évaluation sur val_loader identique à Phase 1
    model.eval(); vtot=0.;vtot_dice=0.0
    with torch.no_grad():
        for imgs, msks in val_loader:
            imgs, msks = imgs.to(DEVICE), msks.to(DEVICE)
            #vtot += (loss_fn(model(imgs), msks.long())+criterion_sup(model(imgs), msks.long())).item()*imgs.size(0)
            vtot_dice += (loss_fn(model(imgs), msks.long())).item()*imgs.size(0)

    #val_loss = vtot / len(val_ds); scheduler.step(val_loss)
    val_dice_loss = vtot_dice / len(val_ds); scheduler.step(val_dice_loss)
    #print(f"[Epoch {epoch}] Train: {train_loss:.4f} – Val: {val_loss:.4f}")
    print(f"[Epoch {epoch}] Train: {train_dice_loss:.4f} – Val: {val_dice_loss:.4f}")
    if val_dice_loss < best_val:
        best_val = val_dice_loss
        print("→ Meilleur modèle phase finale sauvegardé")
        torch.save(model.state_dict(), PATH + "best_final_unimatch_threshold06_dropout03_1306.pth")
        no_improve = 0
    else:
        no_improve +=1
        if no_improve>=patience:
            print(f"--> Early stopping at epoch {epoch}")
            break


In [None]:
#0.0782 à epoch 29

## Prediction phase

In [None]:
#Save best_simple_segformer_0105.pth
#torch.save(model.state_dict(), PATH+ "best_phase_final_lastepoch_unimatch.pth")

In [35]:
#model.load_state_dict(torch.load(PATH +'best_phase2_epoch_1_2706.pth', map_location=DEVICE,weights_only=True))
model.load_state_dict(torch.load(PATH +'best_phase2_unimatch_drp05_th0.9_3006.pth', map_location=DEVICE,weights_only=True))

<All keys matched successfully>

In [36]:
# 5. Inférence sur test set
PATH = "/content/drive/MyDrive/FewCTSeg/data/"
# -----------------------------------------------------------------------------
# Charger le meilleur modèle\ n
model.eval()

# Dataset test (sans masques)
class CTTestDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_paths = sorted(glob.glob(os.path.join(image_dir, "*.png")), key=alphanumeric_sort)
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = cv2.imread(self.image_paths[idx])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        name = os.path.basename(self.image_paths[idx])
        if self.transform:
            img = self.transform(image=img)['image']
        return img, name

# Pas d'augmentation, juste conversion
test_transform = A.Compose([
    A.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
    A.ToTensorV2(),
])

test_ds = CTTestDataset(image_dir=os.path.join(PATH, "test-images"), transform=test_transform)
test_loader = DataLoader(test_ds, batch_size=8, shuffle=False, num_workers=4)

# Inférence et sauvegarde CSV
all_preds = []
filenames = []


In [37]:
labels_train = pd.read_csv("/content/drive/MyDrive/FewCTSeg/data/y_train.csv", index_col=0, header=0).T

In [38]:
with torch.no_grad():
    for imgs, names in tqdm(test_loader):
        imgs = imgs.to(DEVICE)
        logits = model(imgs)
        preds = torch.argmax(logits, dim=1).cpu().numpy()  # (B,H,W)
        for p, n in zip(preds, names):
            all_preds.append(p.flatten())
            filenames.append(n)

# Créer DataFrame transposé
df = pd.DataFrame(np.stack(all_preds, axis=0), columns=labels_train.columns) #
df = df.T
# nom des colonnes
df.columns = filenames




100%|██████████| 63/63 [00:04<00:00, 13.72it/s]


In [39]:
# Sauvegarder CSV
#output_csv = os.path.join(PATH,  'best_phase2_epoch_10_2606.csv')
output_csv = os.path.join(PATH,  'best_phase2_unimatch_drp05_th0.9_3006.csv')
df.to_csv(output_csv, index=True)
print(f"Test predictions saved to {output_csv}")

Test predictions saved to /content/drive/MyDrive/FewCTSeg/data/best_phase2_unimatch_drp05_th0.9_3006.csv


In [None]:
model.load_state_dict(torch.load(PATH +'best_phase2_final_unimatch_3006.pth', map_location=DEVICE,weights_only=True))

In [None]:
# 5. Inférence sur test set
PATH = "/content/drive/MyDrive/FewCTSeg/data/"
# -----------------------------------------------------------------------------
# Charger le meilleur modèle\ n
model.eval()

# Dataset test (sans masques)
class CTTestDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_paths = sorted(glob.glob(os.path.join(image_dir, "*.png")), key=alphanumeric_sort)
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = cv2.imread(self.image_paths[idx])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        name = os.path.basename(self.image_paths[idx])
        if self.transform:
            img = self.transform(image=img)['image']
        return img, name

# Pas d'augmentation, juste conversion
test_transform = A.Compose([
    A.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
    A.ToTensorV2(),
])

test_ds = CTTestDataset(image_dir=os.path.join(PATH, "test-images"), transform=test_transform)
test_loader = DataLoader(test_ds, batch_size=8, shuffle=False, num_workers=4)

# Inférence et sauvegarde CSV
all_preds = []
filenames = []


In [None]:
with torch.no_grad():
    for imgs, names in tqdm(test_loader):
        imgs = imgs.to(DEVICE)
        logits = model(imgs)
        preds = torch.argmax(logits, dim=1).cpu().numpy()  # (B,H,W)
        for p, n in zip(preds, names):
            all_preds.append(p.flatten())
            filenames.append(n)

# Créer DataFrame transposé
df = pd.DataFrame(np.stack(all_preds, axis=0), columns=labels_train.columns) #
df = df.T
# nom des colonnes
df.columns = filenames


In [None]:
# Sauvegarder CSV
#output_csv = os.path.join(PATH,  'best_phase2_epoch_10_2606.csv')
output_csv = os.path.join(PATH,  'best_phase2_final_unimatch_3006.csv')
df.to_csv(output_csv, index=True)
print(f"Test predictions saved to {output_csv}")

# CRF Improvement Attempt

## OTHERS VERSIONS

In [None]:
#LA PREMIERE VERSION
# Chargement Phase1
model.load_state_dict(torch.load(PATH + "best_phase_unimatch.pth"))
model.train()

# Génération pseudo-masks
pseudo_masks = []
with torch.no_grad():
    for imgs, _ in tqdm(unlab_loader):
        imgs = imgs.to(DEVICE)
        # weak / strong augmentations
        x_w = imgs
        x_s1, x_s2 = x_w.clone(), x_w.clone()
        x_s1 = strong_transform(image=x_s1.permute(0,2,3,1).cpu().numpy())['image'].to(DEVICE)
        x_s2 = strong_transform(image=x_s2.permute(0,2,3,1).cpu().numpy())['image'].to(DEVICE)
        # forward
        feat_w = model.encoder(x_w)

        #feat_fp = nn.Dropout2d(0.5)(feat_w)
        feat_fp = feat_w
        p_w, p_fp = model.decoder(torch.cat([feat_w, feat_fp],1)).chunk(2)
        ps1, ps2 = model(torch.cat([x_s1, x_s2],0)).chunk(2)
        mask = p_w.argmax(1).cpu().numpy()
        pseudo_masks.append(mask)
pseudo_masks = np.concatenate(pseudo_masks,0)



In [None]:
#LA SECONDE
import torch.nn as nn

dropout2d = nn.Dropout2d(0.5)

pseudo_masks = []
with torch.no_grad():
    for imgs, _ in tqdm(unlab_loader, desc="Pseudo-labeling"):
        imgs = imgs.to(DEVICE)

        # weak + deux vues strong (ici on clone, mais tu peux appliquer des strong transforms)
        x_w = imgs
        x_s1 = imgs.clone()
        x_s2 = imgs.clone()
        # … ici tu peux appliquer des strong augmentations à x_s1 / x_s2 …
        x_s1 = strong_transform(image=x_s1.permute(0,2,3,1).cpu().numpy())['image'].to(DEVICE)
        x_s2 = strong_transform(image=x_s2.permute(0,2,3,1).cpu().numpy())['image'].to(DEVICE)

        # 1) Passage encodeur sur la vue weak
        feats_w = model.encoder(x_w)        # liste de [f1, f2, ..., fN]
        last_w  = feats_w[-1]               # on prend la dernière map
        last_fp = dropout2d(last_w)         # feature perturbée
        # on remplace la dernière map par la version dropoutée
        feats_fp = feats_w.copy()
        feats_fp[-1] = last_fp

        # 2) Prédictions p_w et p_fp via le décodeur + tête de segmentation
        #    (chaque entrée doit être une liste de features)
        p_w  = model.segmentation_head( model.decoder(feats_w) )
        p_fp = model.segmentation_head( model.decoder(feats_fp) )

        # 3) Prédictions sur les deux vues fortes
        #    (si ton modèle ne sépare pas facilement, tu peux aussi faire:
        #     ps = model(torch.cat([x_s1, x_s2], dim=0)) puis chunk())
        ps1 = model(x_s1)
        ps2 = model(x_s2)

        # 4) Construction du pseudo‐mask (view weak)
        conf_w, mask_w = torch.max(torch.softmax(p_w, dim=1), dim=1)
        mask_w = mask_w.cpu().numpy()

        # 5) Seuillage et collecte
        #    (ici on ne garde que sur confiance>0.9, sinon label=0)
        for cw, mw in zip(conf_w.cpu(), mask_w):
            mw[cw<0.9] = 0
            pseudo_masks.append(mw)



In [None]:
#Avec corrections du pseudo labeling
"""
# on part du principe que :
# - Vous avez déjà chargé `filename_to_annotated` depuis annotated_labels.json
# - Vous avez un dict `true_masks` qui mappe nom de fichier → masque numpy (uniquement pour les images annotées)
# - `unlab_loader` renvoie (img_tensor, _, name)
for imgs, _, names in tqdm(unlab_loader, desc="Pseudo-labeling"):
    imgs = imgs.to(DEVICE)

    # -- weak + perturbed views as before --
    x_w = imgs
    x_s1_list, x_s2_list = [], []
    for img_tensor in x_w:
        img_np = img_tensor.permute(1,2,0).cpu().numpy()
        x_s1_list.append(strong_transform(image=img_np)['image'])
        x_s2_list.append(strong_transform(image=img_np)['image'])
    x_s1 = torch.stack(x_s1_list).to(DEVICE)
    x_s2 = torch.stack(x_s2_list).to(DEVICE)

    feats_w = model.encoder(x_w)
    last_w  = feats_w[-1]
    feats_fp = feats_w.copy(); feats_fp[-1] = dropout2d(last_w)

    p_w = model.segmentation_head(model.decoder(feats_w))  # (B, C, H, W)

    # softmax + top-5
    probs  = torch.softmax(p_w, dim=1)
    top1   = probs.topk(k=1, dim=1)
    confs, labels = top1.values.cpu().numpy(), top1.indices.cpu().numpy()
    B, _, H, W = labels.shape

    for b in range(B):
        name = names[b]
        # 1) déterminer les labels réellement ABSENTS
        #    (parmi ceux annotés pour cette image)
        absent_labels = {
            L for L in filename_to_annotated[name]
            if np.sum(true_masks[name] == L) == 0
        }
        # 2) on autorise tout autre label
        allowed = set(range(NUM_CLASSES)) - absent_labels

        # 3) parcours du top-5 pour choisir le 1er candidat autorisé
        m_img = np.zeros((H, W), dtype=np.uint8)
        for k in range(1):
            cand = labels[b, k]    # shape (H,W)
            take = (m_img == 0) & np.isin(cand, list(allowed))
            m_img[take] = cand[take]
            if np.all(m_img != 0):
                break

        # 4) appliquer votre seuil de confiance sur top-1
        low_conf = confs[b,0] < 0.6
        m_img[low_conf] = 0

        pseudo_masks.append(m_img)
        pseudo_paths.append(os.path.join(IMG_DIR, name))
""""

In [None]:
#Better leverage of json
"""
import os
import numpy as np
import torch
from tqdm import tqdm

# on part du principe que :
# - Vous avez déjà chargé `filename_to_annotated` depuis annotated_labels.json
# - Vous avez un dict `true_masks` qui mappe nom de fichier → masque numpy (uniquement pour les images annotées)
# - `unlab_loader` renvoie (img_tensor, _, name)
# - `strong_transform`, `dropout2d`, `model`, `DEVICE`, `IMG_DIR`, `NUM_CLASSES` sont définis

pseudo_masks = []
pseudo_paths = []

for imgs, _, names in tqdm(unlab_loader, desc="Pseudo-labeling"):
    imgs = imgs.to(DEVICE)

    # -- weak + perturbed views as before --
    x_w = imgs
    x_s1_list, x_s2_list = [], []
    for img_tensor in x_w:
        img_np = img_tensor.permute(1,2,0).cpu().numpy()
        x_s1_list.append(strong_transform(image=img_np)['image'])
        x_s2_list.append(strong_transform(image=img_np)['image'])
    x_s1 = torch.stack(x_s1_list).to(DEVICE)
    x_s2 = torch.stack(x_s2_list).to(DEVICE)

    feats_w = model.encoder(x_w)
    last_w  = feats_w[-1]
    feats_fp = feats_w.copy(); feats_fp[-1] = dropout2d(last_w)

    p_w = model.segmentation_head(model.decoder(feats_w))  # (B, C, H, W)

    # softmax + top-5
    probs  = torch.softmax(p_w, dim=1)
    top5   = probs.topk(k=5, dim=1)
    confs, labels = top5.values.cpu().numpy(), top5.indices.cpu().numpy()
    B, _, H, W = labels.shape

    for b in range(B):
        name = names[b]
        # 1) déterminer les labels réellement ABSENTS
        #    (parmi ceux annotés pour cette image)
        absent_labels = {
            L for L in filename_to_annotated[name]
            if np.sum(true_masks[name] == L) == 0
        }
        # 2) on autorise tout autre label
        allowed = set(range(NUM_CLASSES)) - absent_labels

        # 3) parcours du top-5 pour choisir le 1er candidat autorisé
        m_img = np.zeros((H, W), dtype=np.uint8)
        for k in range(5):
            cand = labels[b, k]    # shape (H,W)
            take = (m_img == 0) & np.isin(cand, list(allowed))
            m_img[take] = cand[take]
            if np.all(m_img != 0):
                break

        # 4) appliquer votre seuil de confiance sur top-1
        low_conf = confs[b,0] < 0.6
        m_img[low_conf] = 0

        pseudo_masks.append(m_img)
        pseudo_paths.append(os.path.join(IMG_DIR, name))

# Empiler et créer le dataset
pseudo_masks = np.stack(pseudo_masks, axis=0)
class PseudoCTScanDataset(torch.utils.data.Dataset):
    def __init__(self, image_paths, masks, transform):
        assert len(image_paths)==len(masks)
        self.image_paths = image_paths
        self.masks       = masks
        self.transform   = transform
    def __len__(self):
        return len(self.masks)
    def __getitem__(self, i):
        img = cv2.cvtColor(cv2.imread(self.image_paths[i]), cv2.COLOR_BGR2RGB)
        aug = self.transform(image=img, mask=self.masks[i])
        return aug['image'], aug['mask']

pseudo_ds    = PseudoCTScanDataset(pseudo_paths, pseudo_masks, base_transform)
mixed_ds     = ConcatDataset([full_lab, pseudo_ds])
mixed_loader = DataLoader(mixed_ds, batch_size=BATCH, shuffle=True, num_workers=4)
"""

In [None]:
#CRF
"""
!pip install -U segmentation-models-pytorch pydensecrf -q

import os
import glob
import re
import json
import cv2
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import segmentation_models_pytorch as smp
import pydensecrf.densecrf as dcrf
from pydensecrf.utils import unary_from_softmax, create_pairwise_gaussian, create_pairwise_bilateral
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import albumentations as A

# ---------------------------
#  Réglages et chemins
# ---------------------------
SEED        = 26
DEVICE      = torch.device("cuda" if torch.cuda.is_available() else "cpu")
PATH_DATA   = "/content/drive/MyDrive/FewCTSeg/data/"
IMG_DIR     = os.path.join(PATH_DATA, "train-images")
TEST_DIR    = os.path.join(PATH_DATA, "test-images")
MASK_CSV    = os.path.join(PATH_DATA, "y_train.csv")
MODEL_PATH  = os.path.join(PATH_DATA, "best_final_unimatch_threshold06_dropout03_1306.pth")
BATCH       = 8
NUM_CLASSES = 55

# ---------------------------
#  Helpers
# ---------------------------
def alphanumeric_sort(name):
    parts = re.split(r"(\d+)", name)
    return [int(p) if p.isdigit() else p for p in parts]

def apply_crf(image_np, prob_map_np,
              gaussian_sxy=(3,3), bilateral_sxy=(80,80), bilateral_srgb=(13,13,13),
              num_iters=5):
    """
    image_np : HxWx3 uint8
    prob_map_np : CxHxW float32 (softmax probabilities)
    """
    C, H, W = prob_map_np.shape
    d = dcrf.DenseCRF2D(W, H, C)
    unary = unary_from_softmax(prob_map_np)
    d.setUnaryEnergy(unary)
    # spatial term
    feats = create_pairwise_gaussian(sdims=gaussian_sxy, shape=(H,W))
    d.addPairwiseEnergy(feats, compat=3)
    # bilateral term
    feats = create_pairwise_bilateral(sdims=bilateral_sxy,
                                      schan=bilateral_srgb,
                                      img=image_np,
                                      chdim=2)
    d.addPairwiseEnergy(feats, compat=10)
    Q = d.inference(num_iters)
    return np.array(Q).reshape((C, H, W))

# ---------------------------
#  Dataset test
# ---------------------------
class CTTestDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir   = image_dir
        self.image_paths = sorted(glob.glob(os.path.join(image_dir, "*.png")),
                                   key=alphanumeric_sort)
        self.transform   = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        path = self.image_paths[idx]
        img  = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB)
        name = os.path.basename(path)
        img_t = self.transform(image=img)["image"] if self.transform else img
        return img_t, name

# ---------------------------
#  Pré‑traitements
# ---------------------------
test_transform = A.Compose([
    A.Normalize(mean=(0.485,0.456,0.406),
                std=(0.229,0.224,0.225)),
    A.ToTensorV2(),
], seed=SEED)

# ---------------------------
#  Chargement du modèle
# ---------------------------
model = smp.Segformer(
    encoder_name="timm-efficientnet-b7",
    encoder_weights=None,
    in_channels=3, classes=NUM_CLASSES, activation=None
).to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE, weights_only=True))
model.eval()

# ---------------------------
#  DataLoader test
# ---------------------------
test_ds     = CTTestDataset(TEST_DIR, transform=test_transform)
test_loader = DataLoader(test_ds, batch_size=BATCH, shuffle=False, num_workers=4)

# ---------------------------
#  Inférence + CRF + CSV
# ---------------------------
all_preds = []
filenames = []

# On récupère les colonnes de sortie (la même disposition que y_train.csv)
labels_train = pd.read_csv(MASK_CSV, index_col=0).T
columns = labels_train.columns

with torch.no_grad():
    for imgs, names in tqdm(test_loader, desc="Test w/ CRF"):
        imgs = imgs.to(DEVICE)                       # (B,3,H,W)
        logits = model(imgs)                         # (B,C,H,W)
        probs  = F.softmax(logits, dim=1).cpu().numpy()  # (B,C,H,W)

        B, C, H, W = probs.shape
        for b in range(B):
            # 1) lire l'image originale (uint8) pour le CRF
            orig_path  = os.path.join(TEST_DIR, names[b])
            img_np     = cv2.cvtColor(cv2.imread(orig_path), cv2.COLOR_BGR2RGB)
            # 2) appliquer CRF
            crf_refined = apply_crf(img_np, probs[b])
            mask_crf    = np.argmax(crf_refined, axis=0)  # (H,W)
            # 3) aplatir et stocker
            all_preds.append(mask_crf.flatten())
            filenames.append(names[b])

# Assemblage du DataFrame et sauvegarde
df = pd.DataFrame(np.stack(all_preds, axis=0), columns=columns)
df = df.T
df.columns = filenames

output_csv = os.path.join(PATH_DATA, "best_final_unimatch_threshold06_crf.csv")
df.to_csv(output_csv, index=True)
print(f"Test predictions with CRF saved to {output_csv}")
"""