In [None]:
#!pip install -U segmentation-models-pytorch

In [None]:
import os, re, glob, json
import numpy as np, pandas as pd, cv2
import torch, random
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset, ConcatDataset
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
from tqdm import tqdm

#from src.dataset import LabeledCTScanDataset, UnlabeledCTScanDataset, PseudoCTScanDataset, CTTestDataset
#from utils_functions.sort_files import alphanumeric_sort
#from utils_functions.seed import set_seed
#from utils_functions.submit import pred_and_save


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:

# ------------------ Configuration ------------------
SEED = 26
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
PATH = "/content/drive/MyDrive/FewCTSeg/data/"
IMG_DIR = os.path.join(PATH, 'train-images/')
MASK_CSV = os.path.join(PATH, 'y_train.csv')
MODEL_OUT = PATH
BATCH = 8
NUM_CLASSES = 55
NUM_EPOCHS_PHASE1 = 10
NUM_EPOCHS_PHASE2 = 25
PSEUDO_LABEL_THRESHOLD = 0.6
LAMBDA = 0.5  # weight feature-perturb loss
MU = 0.5      # weight image-level loss
IGNORE_INDEX = 255  # for unlabeled pixels

# Fix randomness to get some reproducibility in results (mettre le package random dans ce code et enlever l'autre)
###utils_functions.seed
def set_seed(seed=SEED):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
set_seed(seed = SEED)
###utils_functions.sort_files
def alphanumeric_sort(name):
    parts = re.split(r'(\d+)', name)
    return [int(p) if p.isdigit() else p for p in parts]



In [4]:
###src.custom_augment
# ------------------ Custom CutMix ------------------
def cutmix(images, masks):
    # images, masks: tensors shape (B,C,H,W),(B,H,W)
    B, C, H, W = images.shape
    lam = np.random.beta(1.0, 1.0)
    rx = np.random.randint(W)
    ry = np.random.randint(H)
    rw = int(W * np.sqrt(1 - lam))
    rh = int(H * np.sqrt(1 - lam))
    x1 = np.clip(rx - rw // 2, 0, W)
    y1 = np.clip(ry - rh // 2, 0, H)
    x2 = np.clip(rx + rw // 2, 0, W)
    y2 = np.clip(ry + rh // 2, 0, H)
    perm = torch.randperm(B)
    mixed_images = images.clone()
    mixed_masks = masks.clone()
    mixed_images[:, :, y1:y2, x1:x2] = images[perm, :, y1:y2, x1:x2]
    mixed_masks[:, y1:y2, x1:x2] = masks[perm, y1:y2, x1:x2]
    return mixed_images, mixed_masks



# ------------------ Transforms ------------------
base_transform = A.Compose([
    A.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
    ToTensorV2()])

weak_transform = A.Compose([
    A.RandomResizedCrop((256,256), scale=(0.2,1.0), p=1.0),
    #A.RandomRotate90(p=0.5),
    A.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
    ToTensorV2()])

strong_transform = A.Compose([
    A.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
    ToTensorV2()])  # CutMix applied in batch

In [6]:

# ------------------ Datasets ------------------
class LabeledCTScanDataset(Dataset):
    def __init__(self, img_dir, mask_csv, transform):
        paths = sorted(glob.glob(os.path.join(img_dir, '*.png')), key=alphanumeric_sort)
        masks = pd.read_csv(mask_csv, index_col=0).T.values.reshape(-1,256,256).astype(np.uint8)
        valid = [m.sum()>0 for m in masks]
        self.image_paths = [p for p,v in zip(paths,valid) if v]
        self.masks = masks[np.array(valid)]
        self.transform = transform
    def __len__(self): return len(self.image_paths)
    def __getitem__(self, idx):
        img = cv2.cvtColor(cv2.imread(self.image_paths[idx]), cv2.COLOR_BGR2RGB)
        mask = self.masks[idx]
        aug = self.transform(image=img, mask=mask)
        return aug['image'], aug['mask']

class UnlabeledCTScanDataset(Dataset):
    def __init__(self, img_dir, mask_csv, weak_transform):
        paths = sorted(glob.glob(os.path.join(img_dir, '*.png')), key=alphanumeric_sort)
        masks = pd.read_csv(mask_csv, index_col=0).T.values.reshape(-1,256,256).astype(np.uint8)
        valid = [m.sum()==0 for m in masks]
        self.image_paths = [p for p,v in zip(paths,valid) if v]
        self.transform = weak_transform
    def __len__(self): return len(self.image_paths)
    def __getitem__(self, idx):
        img = cv2.cvtColor(cv2.imread(self.image_paths[idx]), cv2.COLOR_BGR2RGB)
        aug = self.transform(image=img)
        return aug['image'], self.image_paths[idx]

class PseudoCTScanDataset(Dataset):
    def __init__(self, image_paths, masks):
        self.image_paths, self.masks = image_paths, masks
    def __len__(self): return len(self.masks)
    def __getitem__(self, idx):
        img = cv2.cvtColor(cv2.imread(self.image_paths[idx]), cv2.COLOR_BGR2RGB)
        aug = base_transform(image=img, mask=self.masks[idx])
        #aug = strong_transform(image=img, mask=self.masks[idx])
        return aug['image'], aug['mask']



# Test Dataset
class CTTestDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_paths = sorted(glob.glob(os.path.join(image_dir, "*.png")), key=alphanumeric_sort)
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = cv2.imread(self.image_paths[idx])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        name = os.path.basename(self.image_paths[idx])
        if self.transform:
            img = self.transform(image=img)['image']
        return img, name



In [7]:

# ------------------ DataLoaders ------------------
# Phase1 : Split Train/val (reprodcuible)

full_lab = LabeledCTScanDataset(IMG_DIR, MASK_CSV, base_transform)
idxs = torch.randperm(len(full_lab), generator=torch.Generator().manual_seed(SEED)).tolist()
#idxs = np.random.default_rng(seed=SEED).permutation(len(full_lab))
split = int(0.8*len(full_lab))
train_ds = Subset(full_lab, idxs[:split])
val_ds   = Subset(full_lab, idxs[split:])
gen = torch.Generator(); gen.manual_seed(SEED)
train_loader = DataLoader(train_ds, batch_size=BATCH, shuffle=True, num_workers=4, generator=gen)
val_loader   = DataLoader(val_ds,   batch_size=BATCH, shuffle=False, num_workers=4)

unlab_ds = UnlabeledCTScanDataset(IMG_DIR, MASK_CSV, weak_transform)
unlab_loader = DataLoader(unlab_ds, batch_size=BATCH, shuffle=False, num_workers=4)




In [8]:
model = smp.Segformer(
    encoder_name='timm-efficientnet-b7', encoder_weights='imagenet',
    in_channels=3, classes=NUM_CLASSES
).to(DEVICE)
sup_loss = smp.losses.DiceLoss(mode='multiclass', from_logits=True)
opt1 = Adam(model.parameters(), lr=1e-3)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/106 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/267M [00:00<?, ?B/s]

In [9]:
# ------------------ Phase1: Warmup ------------------
best_val=1e6
for epoch in range(NUM_EPOCHS_PHASE1):
    model.train(); train_loss=0
    for imgs, masks in tqdm(train_loader):
        imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
        opt1.zero_grad()
        logits = model(imgs)
        loss = sup_loss(logits, masks.long())
        loss.backward(); opt1.step()
        train_loss += loss.item()*imgs.size(0)
    val_loss=0; model.eval()
    with torch.no_grad():
        for imgs, masks in val_loader:
            imgs,masks=imgs.to(DEVICE),masks.to(DEVICE)
            val_loss+=sup_loss(model(imgs), masks.long()).item()*imgs.size(0)
    val_loss /= len(val_ds)
    print(f"Epoch1 {epoch+1}/{NUM_EPOCHS_PHASE1} — val: {val_loss:.4f}")
    if val_loss<best_val:
        best_val=val_loss
        torch.save(model.state_dict(), os.path.join(MODEL_OUT,'best_phase1_unimatch_3006.pth'))



100%|██████████| 76/76 [00:40<00:00,  1.87it/s]


Epoch1 1/10 — val: 0.3573


100%|██████████| 76/76 [00:11<00:00,  6.51it/s]


Epoch1 2/10 — val: 0.3031


100%|██████████| 76/76 [00:11<00:00,  6.54it/s]


Epoch1 3/10 — val: 0.2905


100%|██████████| 76/76 [00:14<00:00,  5.14it/s]


Epoch1 4/10 — val: 0.2850


100%|██████████| 76/76 [00:11<00:00,  6.69it/s]


Epoch1 5/10 — val: 0.2683


100%|██████████| 76/76 [00:11<00:00,  6.70it/s]


Epoch1 6/10 — val: 0.2687


100%|██████████| 76/76 [00:11<00:00,  6.47it/s]


Epoch1 7/10 — val: 0.2716


100%|██████████| 76/76 [00:12<00:00,  5.90it/s]


Epoch1 8/10 — val: 0.2732


100%|██████████| 76/76 [00:11<00:00,  6.54it/s]


Epoch1 9/10 — val: 0.2651


100%|██████████| 76/76 [00:15<00:00,  4.90it/s]


Epoch1 10/10 — val: 0.2670


In [29]:

model.load_state_dict(torch.load(os.path.join(MODEL_OUT,'best_phase1_unimatch_3006.pth')))
model.eval()
pseudo_paths, pseudo_masks = [], []

In [12]:
# ------------------ Pseudo-labeling ------------------
with torch.no_grad():
    for imgs, paths in tqdm(unlab_loader):
        imgs = imgs.to(DEVICE)
        pw = torch.softmax(model(imgs), dim=1)
        conf, pred = pw.max(1)
        for i,pth in enumerate(paths):
            conf_i, pred_i = conf[i].cpu().numpy(), pred[i].cpu().numpy().astype(np.uint8)
            mask_i = np.where(conf_i>=PSEUDO_LABEL_THRESHOLD, pred_i, 0)
            if (conf_i>=PSEUDO_LABEL_THRESHOLD).mean()>=LAMBDA:
                pseudo_paths.append(pth); pseudo_masks.append(mask_i)


100%|██████████| 156/156 [00:08<00:00, 17.67it/s]


In [13]:

if len(pseudo_masks)==0:
    raise ValueError("No pseudo-labels found, adjust thresholds.")
    
pseudo_masks = np.stack(pseudo_masks,0)
pseudo_ds = PseudoCTScanDataset(pseudo_paths, pseudo_masks)


In [15]:

# ------------------ Phase2: UniMatch ------------------
joint_ds = ConcatDataset([train_ds, pseudo_ds])
joint_loader = DataLoader(joint_ds, batch_size=BATCH, shuffle=True, num_workers=4)

In [32]:
# ------------------ Hyperparameters of UniMatch ------------------
opt2 = Adam(model.parameters(), lr=1e-4)
scheduler = ReduceLROnPlateau(opt2, 'min', factor=0.5, patience=3)
drop2d = nn.Dropout2d(p=0.5)
dice = smp.losses.DiceLoss(mode='multiclass', from_logits=True)
ce = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX)
std, mean = np.array([0.485, 0.456, 0.406]), np.array([0.229, 0.224, 0.225])

In [34]:
# ------------------ Phase2: UniMatch ------------------

for epoch in range(1, NUM_EPOCHS_PHASE2+1):

    model.train(); train_loss=0
    for imgs, masks in tqdm(joint_loader):
        imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
        
        # CutMix per batch
        imgs, masks = cutmix(imgs, masks)
        
        # Forward weak
        feat_w = model.encoder(imgs)
        logits_w = model.segmentation_head(model.decoder(feat_w))
        # Feature perturbation
        last = feat_w[-1]; feat_fp = list(feat_w)
        feat_fp[-1] = drop2d(last)
        logits_fp = model.segmentation_head(model.decoder(feat_fp))

        # Two strong streams
        # (we reuse mix outputs as strong; for demo)
        logits_s1 = model(imgs); logits_s2 = model(imgs)
        # Pseudo mask from weak augmentation (detach)
        pseudo = logits_w.argmax(1).detach()

        #
        """
        Mettre dans une fonction
        # a) récupérer numpy non-normalisé
        imgs_np = imgs.cpu().permute(0,2,3,1).numpy()
        imgs_np = (imgs_np * std + mean) * 255.0  # repasser en 0–255

        # b) deux augmentations indépendantes
        xs1 = torch.stack([strong_transform(image=img.astype(np.uint8))['image']
                            for img in imgs_np]).to(DEVICE)
        xs2 = torch.stack([strong_transform(image=img.astype(np.uint8))['image']
                            for img in imgs_np]).to(DEVICE)
        """

        # Losses
        loss_sup = sup_loss(logits_w, masks.long())

        loss_fp = ce(logits_fp, pseudo)
        loss_s = 0.5*(ce(logits_s1, pseudo) + ce(logits_s2, pseudo))

        loss_u = LAMBDA*loss_fp + MU*loss_s
        loss = 0.5*(loss_sup + loss_u)
        opt2.zero_grad(); loss.backward(); opt2.step()
        train_loss += loss.item()*imgs.size(0)
    train_loss /= len(joint_ds)

    # Validation phase2
    model.eval(); val_loss=0
    with torch.no_grad():
        for imgs, masks in val_loader:
            imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
            val_loss += sup_loss(model(imgs), masks.long()).item()*imgs.size(0)
    val_loss /= len(val_ds)
    scheduler.step(val_loss)
    print(f"Phase2 E{epoch}/{NUM_EPOCHS_PHASE2} — train:{train_loss:.4f} val:{val_loss:.4f}")
    if epoch%5==0:
        torch.save(model.state_dict(), os.path.join(MODEL_OUT, f'phase2_e{epoch}_unimatch_drp05_th0.9_3006.pth'))
    if val_loss<best_val:
        best_val=val_loss
        torch.save(model.state_dict(), os.path.join(MODEL_OUT,'best_phase2_unimatch_drp05_th0.9_3006.pth'))




100%|██████████| 231/231 [01:32<00:00,  2.51it/s]


Phase2 E1/25 — train:0.2645 val:0.2534


100%|██████████| 231/231 [01:31<00:00,  2.51it/s]


Phase2 E2/25 — train:0.2568 val:0.2597


100%|██████████| 231/231 [01:32<00:00,  2.51it/s]


Phase2 E3/25 — train:0.2563 val:0.2542


100%|██████████| 231/231 [01:31<00:00,  2.51it/s]


Phase2 E4/25 — train:0.2555 val:0.2565


100%|██████████| 231/231 [01:32<00:00,  2.51it/s]


Phase2 E5/25 — train:0.2529 val:0.2612


100%|██████████| 231/231 [01:31<00:00,  2.53it/s]


Phase2 E6/25 — train:0.2498 val:0.2628


100%|██████████| 231/231 [01:32<00:00,  2.50it/s]


Phase2 E7/25 — train:0.2497 val:0.2627


  3%|▎         | 8/231 [00:03<01:48,  2.05it/s]


KeyboardInterrupt: 

In [None]:
torch.save(model.state_dict(), os.path.join(MODEL_OUT,'best_phase2_final_unimatch_3006.pth'))

## Prediction phase

In [None]:
test_ds = CTTestDataset(image_dir=os.path.join(PATH, "test-images"), transform=base_transform)
test_loader = DataLoader(test_ds, batch_size=8, shuffle=False, num_workers=4)

In [35]:
model.load_state_dict(torch.load(PATH +'best_phase2_unimatch_drp05_th0.9_3006.pth', map_location=DEVICE,weights_only=True))

<All keys matched successfully>

In [36]:
# Load the best model
model.eval()

In [None]:
#output_filename = 'best_phase2_unimatch_drp05_th0.9_3006.csv'
#pred_and_save(test_loader, model,  MASK_CSV ,output_filename)

# Se débarasser de çela après

In [37]:
labels_train = pd.read_csv(MASK_CSV, index_col=0, header=0).T

In [38]:
# Inference and save CSV
all_preds = []
filenames = []

with torch.no_grad():
    for imgs, names in tqdm(test_loader):
        imgs = imgs.to(DEVICE)
        logits = model(imgs)
        preds = torch.argmax(logits, dim=1).cpu().numpy()  # (B,H,W)
        for p, n in zip(preds, names):
            all_preds.append(p.flatten())
            filenames.append(n)



100%|██████████| 63/63 [00:04<00:00, 13.72it/s]


In [39]:
# Create the submission DataFrame
df = pd.DataFrame(np.stack(all_preds, axis=0), columns=labels_train.columns) #
df = df.T
df.columns = filenames

# Save CSV
output_csv = os.path.join(PATH,  'best_phase2_unimatch_drp05_th0.9_3006.csv')
df.to_csv(output_csv, index=True)
print(f"Test predictions saved to {output_csv}")

Test predictions saved to /content/drive/MyDrive/FewCTSeg/data/best_phase2_unimatch_drp05_th0.9_3006.csv
