In [9]:

import os, re, glob, json
import numpy as np, pandas as pd, cv2
import torch, random, copy, math
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset, ConcatDataset
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
from tqdm import tqdm
import matplotlib.pyplot as plt



In [13]:
from src.dataset import LabeledCTScanDataset, UnlabeledPathsDataset, TestCTScanDataset
from src.custom_augment import cutmix
from src.unimatch import pseudo_targets_from_logits, inf_loop, strong_perturbation, weak_perturbation, feature_perturbation
from utils.seed import set_seed
from utils.sort_files import alphanumeric_sort
from utils.submit import pred_and_save
from utils.visualization import denormalize, visualize_test_prediction
from utils.metrics import compute_dice_score, compute_per_class_dice

In [15]:

# ------------------ Configuration ------------------
SEED = 26
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#PATH = "/content/drive/MyDrive/FewCTSeg/data/"
PATH = "./data/"
MODEL_OUT = "./ckpts/"
IMG_DIR = os.path.join(PATH, 'train-images/')
MASK_CSV = os.path.join(PATH, 'y_train.csv')
BATCH = 8
NUM_CLASSES = 55
NUM_EPOCHS_PHASE1 = 1
NUM_EPOCHS_PHASE2 = 1
LAMBDA = 0.5  # weight feature-perturb loss
MU = 0.5      # weight image-level loss
IGNORE_INDEX = 0  # pour pixels non-sélectionnés
MEAN = np.array([0.485, 0.456, 0.406])
STD  = np.array([0.229, 0.224, 0.225])

MODEL_PHASE1_FILENAME= 'best_phase1_unimatch.pth'
MODEL_FILENAME = 'best_phase2_unimatch.pth'
CSV_FILENAME = 'submission_unimatch.csv'


set_seed(SEED)


In [17]:
# ------------------ Transforms ------------------
base_transform = A.Compose([
    A.Normalize(mean=tuple(MEAN.tolist()), std=tuple(STD.tolist())),
    ToTensorV2()])

weak_transform = A.Compose([
    A.RandomResizedCrop((256,256), scale=(0.2,0.8), p=1.0),
    A.Normalize(mean=tuple(MEAN.tolist()), std=tuple(STD.tolist())),
    ToTensorV2()])

# strong pool used in UniMatch (applied on denormalized images or from weak view)
strong_transform = A.Compose([
    A.CoarseDropout(num_holes_range=(3,8),hole_height_range = (0.1,0.3), hole_width_range=(0.1,0.3),p=1.0),
    #A.GridDropout(p=1.0),
    A.Normalize(mean=tuple(MEAN.tolist()), std=tuple(STD.tolist())),
    ToTensorV2()])

In [19]:
# ------------------ Prepare splits / dataloaders ------------------
full_lab = LabeledCTScanDataset(IMG_DIR, MASK_CSV, transform=None)
idxs = torch.randperm(len(full_lab), generator=torch.Generator().manual_seed(SEED)).tolist()
split = int(0.8 * len(full_lab))
train_idxs = idxs[:split]
val_idxs   = idxs[split:]

# Datasets for phase1 warmup and labeled stream
train_ds = LabeledCTScanDataset(IMG_DIR, MASK_CSV, transform=weak_transform, indices=train_idxs)
val_ds   = LabeledCTScanDataset(IMG_DIR, MASK_CSV, transform=base_transform, indices=val_idxs)

gen = torch.Generator(); gen.manual_seed(SEED)
train_loader = DataLoader(train_ds, batch_size=BATCH, shuffle=True, num_workers=0, generator=gen)
val_loader   = DataLoader(val_ds,   batch_size=BATCH, shuffle=False, num_workers=0)

# Build unlabeled dataset from files with empty masks (online processing)
paths_all = sorted(glob.glob(os.path.join(IMG_DIR, '*.png')), key=alphanumeric_sort)
masks_all = pd.read_csv(MASK_CSV, index_col=0).T.values.reshape(-1,256,256).astype(np.uint8)
valid_unlab = [m.sum()==0 for m in masks_all]
unlab_paths = [p for p,v in zip(paths_all, valid_unlab) if v]
unlabeled_ds = UnlabeledPathsDataset(unlab_paths)


In [20]:
# ------------------ Model / Loss / Opt ------------------
model = smp.Segformer(
    encoder_name='timm-efficientnet-b7', encoder_weights='imagenet',
    in_channels=3, classes=NUM_CLASSES
).to(DEVICE)

In [21]:
# ------------------ Losses ------------------
sup_loss = smp.losses.DiceLoss(mode='multiclass', from_logits=True)
ce = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX, reduction='mean')

# ------------------ Phase1: Warmup (supervised) ------------------
opt1 = Adam(model.parameters(), lr=1e-3)
best_val = 1e6
train_history = []



In [None]:

print("=== Phase1: Supervised phase ===")
for epoch in range(NUM_EPOCHS_PHASE1):
    model.train()
    train_loss = 0.0
    for imgs, masks in tqdm(train_loader, desc=f"Phase1 E{epoch+1}", leave=False):
        imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
        opt1.zero_grad()
        logits = model(imgs)
        loss = sup_loss(logits, masks.long())
        loss.backward(); opt1.step()
        train_loss += loss.item() * imgs.size(0)
    train_loss /= len(train_ds)

    # validation
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for imgs, masks in val_loader:
            imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
            val_loss += sup_loss(model(imgs), masks.long()).item() * imgs.size(0)
    val_loss /= len(val_ds)
    train_history.append((train_loss, val_loss))
    print(f"Epoch1 {epoch+1}/{NUM_EPOCHS_PHASE1} — train:{train_loss:.4f} val:{val_loss:.4f}")
    if val_loss < best_val:
        best_val = val_loss
        torch.save(model.state_dict(), os.path.join(MODEL_OUT, MODEL_PHASE1_FILENAME))


=== Phase1: Supervised phase ===


Phase1 E1:  12%|████████▌                                                               | 9/76 [03:42<26:38, 23.86s/it]

# Phase 2 : Training with UniMatch (semi-supervised phase)

In [22]:
# ------------------ Parameters ------------------
TAU = 0.95       # confidence threshold for pixel-wise masking
FP_DROP_P = 0.5  # Dropout2d probability on last feature map

drop2d = nn.Dropout2d(p=FP_DROP_P)
opt2 = Adam(model.parameters(), lr=1e-4)
scheduler = ReduceLROnPlateau(opt2, 'min', factor=0.5, patience=3)


In [23]:
# ------------------ Configuration of dataloaders (iterators) for mixing labeled/unlabeled in each step ------------------
labeled_loader_phase2 = DataLoader(train_ds, batch_size=BATCH, shuffle=True, num_workers=0,drop_last=True)
unlabeled_loader_phase2 = DataLoader(unlabeled_ds, batch_size=BATCH, shuffle=True, num_workers=0,drop_last=True)

lab_iter = inf_loop(labeled_loader_phase2)
unlab_iter = inf_loop(unlabeled_loader_phase2)


In [24]:
# ------------------ Reload best phase1 weights ------------------
 
model.load_state_dict(torch.load(os.path.join(MODEL_OUT, MODEL_PHASE1_FILENAME)))
model.eval()
print("Phase1  complete — best val:", best_val)


Phase1  complete — best val: 1000000.0


In [None]:
best_val_phase2 = 1e6

#V1 - 1311

for epoch in range(1, NUM_EPOCHS_PHASE2+1):
    model.train()
    running_loss = 0.0
    steps_per_epoch = len(labeled_loader_phase2)
    fractions_this_epoch = []
    for step in tqdm(range(steps_per_epoch), desc=f"Phase2 E{epoch}", leave=False):
        # 1) labeled batch (already weak_aug applied via train_ds)
        imgs_lab, masks_lab = next(lab_iter)
        imgs_lab, masks_lab = imgs_lab.to(DEVICE), masks_lab.to(DEVICE).long()


        # 2) unlabeled raw images -> apply weak and then strongs from weak
        imgs_unlab_raw, _ = next(unlab_iter)

        # apply weak augmentation on raw images, returns normalized tensor
        batch_w = weak_perturbation(imgs_unlab_raw, weak_transform, DEVICE)

        #Two differents strongs augmentations applied on the normalized weak batch
        batch_s1, batch_s2 = strong_perturbation(batch_w, strong_transform, MEAN,STD, DEVICE)
        """
        batch_w_np = batch_w.permute(0,2,3,1).cpu().numpy()
        batch_w_np = (batch_w_np * STD + MEAN) * 255.0  
        batch_s1 = torch.stack([strong_transform(image=img.astype(np.uint8))['image']
                                for img in batch_w_np]).to(DEVICE)
        batch_s2 = torch.stack([strong_transform(image=img.astype(np.uint8))['image']
                                for img in batch_w_np]).to(DEVICE)
        """



        # 3) supervised forward on labeled
        logits_lab = model(imgs_lab)  # (Bl, C, H, W)
        loss_sup = sup_loss(logits_lab, masks_lab)

        # 4) unlabeled streams
        # weak forward -> get features from encoder
        logits_fp = feature_perturbation(model, batch_w,drop2d)

        # two strong predictions (image-level strong perturbations)
        logits_s1 = model(batch_s1)
        logits_s2 = model(batch_s2)

        # 5) build pseudo targets (pixel-wise mask using TAU - confidence threshold)
        target_u, mask_conf = pseudo_targets_from_logits(logits_w, tau=TAU, ignore_index=IGNORE_INDEX)
        fraction_kept = mask_conf.float().mean().item()
        fractions_this_epoch.append(fraction_kept)

        # 6) unsupervised losses: CE with ignore_index
        logits_s_cat = torch.cat([logits_s1, logits_s2], dim=0)
        target_u_cat = torch.cat([target_u, target_u], dim=0)
        loss_s = ce(logits_s_cat, target_u_cat)
        loss_fp = ce(logits_fp, target_u)

        loss_unsup = LAMBDA * loss_fp + MU * 0.5 * loss_s

        # 7) final combined loss
        loss = 0.5 * (loss_sup + loss_unsup)

        opt2.zero_grad()
        loss.backward()
        opt2.step()

        running_loss += loss.item()

    # end epoch: validation
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for imgs, masks in val_loader:
            imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
            val_loss += sup_loss(model(imgs), masks.long()).item() * imgs.size(0)
    val_loss /= len(val_ds)
    scheduler.step(val_loss)
    # end epoch: stats & validation
    avg_loss = running_loss / steps_per_epoch
    mean_fraction = float(np.mean(fractions_this_epoch)) if len(fractions_this_epoch)>0 else 0.0
    print(f"Phase2 E{epoch}/{NUM_EPOCHS_PHASE2} — train_avg_loss:{avg_loss:.4f} | val:{val_loss:.4f} | mean_fraction_kept:{mean_fraction:.3f}")


    if val_loss < best_val_phase2:
        best_val_phase2 = val_loss
        torch.save(model.state_dict(), os.path.join(MODEL_OUT, MODEL_FILENAME))

print("Phase2 complete — best val:", best_val_phase2)

# End of script


Phase2 E1:   0%|                                                                                | 0/75 [00:00<?, ?it/s]

# Evaluation phase

In [None]:
compute_per_class_dice(model, val_loader, NUM_CLASSES)

# Prediction phase

In [None]:
#Load the weights of the model
model.load_state_dict(torch.load(os.path.join(MODEL_OUT,MODEL_FILENAME), map_location=DEVICE,weights_only=True))
model.eval()
pass

In [None]:
test_ds = TestCTScanDataset(img_dir=os.path.join(PATH, "test-images"), transform = base_transform)
test_loader = DataLoader(test_ds, batch_size=BATCH, shuffle=False, num_workers=0)

In [None]:
pred_and_save(test_loader, model,  MASK_CSV,CSV_FILENAME)