In [1]:
#!pip install -U segmentation-models-pytorch

In [2]:
import os, re, glob, json
import numpy as np, pandas as pd, cv2
import torch, random
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset, ConcatDataset
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
from tqdm import tqdm

from src.dataset import LabeledCTScanDataset, UnlabeledCTScanDataset, PseudoCTScanDataset, TestCTScanDataset
from src.custom_augment import cutmix
from utils_functions.seed import set_seed
from utils_functions.submit import pred_and_save


  check_for_updates()


In [3]:
#from google.colab import drive
#drive.mount('/content/drive')

In [4]:

# ------------------ Configuration ------------------
SEED = 26
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#PATH = "/content/drive/MyDrive/FewCTSeg/data/"
PATH ="./data/"
IMG_DIR = os.path.join(PATH, 'train-images/')
MASK_CSV = os.path.join(PATH, 'y_train.csv')
MODEL_OUT = PATH
BATCH = 8
NUM_CLASSES = 55
NUM_EPOCHS_PHASE1 = 10
NUM_EPOCHS_PHASE2 = 2
PSEUDO_LABEL_THRESHOLD = 0.6
LAMBDA = 0.5  # weight feature-perturb loss
MU = 0.5      # weight image-level loss
IGNORE_INDEX = 255  # for unlabeled pixels



In [5]:


# ------------------ Transforms ------------------
base_transform = A.Compose([
    A.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
    ToTensorV2()])

weak_transform = A.Compose([
    A.RandomResizedCrop((256,256), scale=(0.2,1.0), p=1.0),
    #A.RandomRotate90(p=0.5),
    A.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
    ToTensorV2()])

strong_transform = A.Compose([
    A.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
    ToTensorV2()])  # CutMix applied in batch

In [6]:

# ------------------ DataLoaders ------------------
# Phase1 : Split Train/val (reprodcuible)

full_lab = LabeledCTScanDataset(IMG_DIR, MASK_CSV, base_transform)
idxs = torch.randperm(len(full_lab), generator=torch.Generator().manual_seed(SEED)).tolist()
#idxs = np.random.default_rng(seed=SEED).permutation(len(full_lab))
split = int(0.8*len(full_lab))
train_ds = Subset(full_lab, idxs[:split])
val_ds   = Subset(full_lab, idxs[split:])
gen = torch.Generator(); gen.manual_seed(SEED)
train_loader = DataLoader(train_ds, batch_size=BATCH, shuffle=True, num_workers=0, generator=gen)
val_loader   = DataLoader(val_ds,   batch_size=BATCH, shuffle=False, num_workers=0)

unlab_ds = UnlabeledCTScanDataset(IMG_DIR, MASK_CSV, weak_transform)
unlab_loader = DataLoader(unlab_ds, batch_size=BATCH, shuffle=False, num_workers=0)


In [7]:
model = smp.Segformer(
    encoder_name='timm-efficientnet-b7', encoder_weights='imagenet',
    in_channels=3, classes=NUM_CLASSES
).to(DEVICE)
sup_loss = smp.losses.DiceLoss(mode='multiclass', from_logits=True)
opt1 = Adam(model.parameters(), lr=1e-3)

In [8]:
# ------------------ Phase1: Warmup ------------------
best_val=1e6
for epoch in range(NUM_EPOCHS_PHASE1):
    model.train(); train_loss=0
    for imgs, masks in tqdm(train_loader):
        imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
        opt1.zero_grad()
        logits = model(imgs)
        loss = sup_loss(logits, masks.long())
        loss.backward(); opt1.step()
        train_loss += loss.item()*imgs.size(0)
    val_loss=0; model.eval()
    with torch.no_grad():
        for imgs, masks in val_loader:
            imgs,masks=imgs.to(DEVICE),masks.to(DEVICE)
            val_loss+=sup_loss(model(imgs), masks.long()).item()*imgs.size(0)
    val_loss /= len(val_ds)
    print(f"Epoch1 {epoch+1}/{NUM_EPOCHS_PHASE1} — val: {val_loss:.4f}")
    if val_loss<best_val:
        best_val=val_loss
        torch.save(model.state_dict(), os.path.join(MODEL_OUT,'best_phase1_unimatch_3006.pth'))



100%|██████████████████████████████████████████████████████████████████████████████████| 76/76 [02:35<00:00,  2.05s/it]


Epoch1 1/10 — val: 0.3506


100%|██████████████████████████████████████████████████████████████████████████████████| 76/76 [02:28<00:00,  1.95s/it]


Epoch1 2/10 — val: 0.3046


100%|██████████████████████████████████████████████████████████████████████████████████| 76/76 [02:34<00:00,  2.04s/it]


Epoch1 3/10 — val: 0.2766


100%|██████████████████████████████████████████████████████████████████████████████████| 76/76 [02:37<00:00,  2.07s/it]


Epoch1 4/10 — val: 0.2768


100%|██████████████████████████████████████████████████████████████████████████████████| 76/76 [02:40<00:00,  2.11s/it]


Epoch1 5/10 — val: 0.2642


100%|██████████████████████████████████████████████████████████████████████████████████| 76/76 [02:34<00:00,  2.03s/it]


Epoch1 6/10 — val: 0.2599


100%|██████████████████████████████████████████████████████████████████████████████████| 76/76 [02:33<00:00,  2.02s/it]


Epoch1 7/10 — val: 0.2647


100%|██████████████████████████████████████████████████████████████████████████████████| 76/76 [02:34<00:00,  2.03s/it]


Epoch1 8/10 — val: 0.2594


100%|██████████████████████████████████████████████████████████████████████████████████| 76/76 [02:32<00:00,  2.01s/it]


Epoch1 9/10 — val: 0.2624


100%|██████████████████████████████████████████████████████████████████████████████████| 76/76 [02:42<00:00,  2.14s/it]


Epoch1 10/10 — val: 0.2780


In [8]:

model.load_state_dict(torch.load(os.path.join(MODEL_OUT,'best_phase1_unimatch_3006.pth')))
model.eval()
pseudo_paths, pseudo_masks = [], []

  model.load_state_dict(torch.load(os.path.join(MODEL_OUT,'best_phase1_unimatch_3006.pth')))


In [9]:
# ------------------ Pseudo-labeling ------------------
with torch.no_grad():
    for imgs, paths in tqdm(unlab_loader):
        imgs = imgs.to(DEVICE)
        pw = torch.softmax(model(imgs), dim=1)
        conf, pred = pw.max(1)
        for i,pth in enumerate(paths):
            conf_i, pred_i = conf[i].cpu().numpy(), pred[i].cpu().numpy().astype(np.uint8)
            mask_i = np.where(conf_i>=PSEUDO_LABEL_THRESHOLD, pred_i, 0)
            if (conf_i>=PSEUDO_LABEL_THRESHOLD).mean()>=LAMBDA:
                pseudo_paths.append(pth); pseudo_masks.append(mask_i)


100%|████████████████████████████████████████████████████████████████████████████████| 156/156 [00:28<00:00,  5.53it/s]


In [10]:

if len(pseudo_masks)==0:
    raise ValueError("No pseudo-labels found, adjust thresholds.")
    
pseudo_masks = np.stack(pseudo_masks,0)
pseudo_ds = PseudoCTScanDataset(pseudo_paths, pseudo_masks, transform=strong_transform)


In [11]:

# ------------------ Phase2: UniMatch ------------------
joint_ds = ConcatDataset([train_ds, pseudo_ds])
joint_loader = DataLoader(joint_ds, batch_size=BATCH, shuffle=True, num_workers=0)

In [12]:
# ------------------ Hyperparameters of UniMatch ------------------
opt2 = Adam(model.parameters(), lr=1e-4)
scheduler = ReduceLROnPlateau(opt2, 'min', factor=0.5, patience=3)
drop2d = nn.Dropout2d(p=0.5)
dice = smp.losses.DiceLoss(mode='multiclass', from_logits=True)
ce = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX)
std, mean = np.array([0.485, 0.456, 0.406]), np.array([0.229, 0.224, 0.225])

In [13]:
# ------------------ Phase2: UniMatch ------------------

for epoch in range(1, NUM_EPOCHS_PHASE2+1):

    model.train(); train_loss=0
    for imgs, masks in tqdm(joint_loader):
        imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
        
        # CutMix per batch
        imgs, masks = cutmix(imgs, masks)
        
        # Forward weak
        feat_w = model.encoder(imgs)
        logits_w = model.segmentation_head(model.decoder(feat_w))
        # Feature perturbation
        last = feat_w[-1]; feat_fp = list(feat_w)
        feat_fp[-1] = drop2d(last)
        logits_fp = model.segmentation_head(model.decoder(feat_fp))

        # Two strong streams
        # (we reuse mix outputs as strong; for demo)
        logits_s1 = model(imgs); logits_s2 = model(imgs)
        # Pseudo mask from weak augmentation (detach)
        pseudo = logits_w.argmax(1).detach()

        #
        """
        Mettre dans une fonction
        # a) récupérer numpy non-normalisé
        imgs_np = imgs.cpu().permute(0,2,3,1).numpy()
        imgs_np = (imgs_np * std + mean) * 255.0  # repasser en 0–255

        # b) deux augmentations indépendantes
        xs1 = torch.stack([strong_transform(image=img.astype(np.uint8))['image']
                            for img in imgs_np]).to(DEVICE)
        xs2 = torch.stack([strong_transform(image=img.astype(np.uint8))['image']
                            for img in imgs_np]).to(DEVICE)
        """

        # Losses
        loss_sup = sup_loss(logits_w, masks.long())

        loss_fp = ce(logits_fp, pseudo)
        loss_s = 0.5*(ce(logits_s1, pseudo) + ce(logits_s2, pseudo))

        loss_u = LAMBDA*loss_fp + MU*loss_s
        loss = 0.5*(loss_sup + loss_u)
        opt2.zero_grad(); loss.backward(); opt2.step()
        train_loss += loss.item()*imgs.size(0)
    train_loss /= len(joint_ds)

    # Validation phase2
    model.eval(); val_loss=0
    with torch.no_grad():
        for imgs, masks in val_loader:
            imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
            val_loss += sup_loss(model(imgs), masks.long()).item()*imgs.size(0)
    val_loss /= len(val_ds)
    scheduler.step(val_loss)
    print(f"Phase2 E{epoch}/{NUM_EPOCHS_PHASE2} — train:{train_loss:.4f} val:{val_loss:.4f}")
    if epoch%5==0:
        torch.save(model.state_dict(), os.path.join(MODEL_OUT, f'phase2_e{epoch}_unimatch_drp05_th0.9_3006.pth'))
    if val_loss<best_val:
        best_val=val_loss
        torch.save(model.state_dict(), os.path.join(MODEL_OUT,'best_phase2_unimatch_drp05_th0.9_3006.pth'))




  0%|                                                                                          | 0/231 [00:12<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 220.00 MiB. GPU 0 has a total capacity of 8.00 GiB of which 0 bytes is free. Of the allocated memory 21.88 GiB is allocated by PyTorch, and 500.24 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
torch.save(model.state_dict(), os.path.join(MODEL_OUT,'best_phase2_final_unimatch_3006.pth'))

## Prediction phase

In [None]:
test_ds = TestCTScanDataset(image_dir=os.path.join(PATH, "test-images"), transform=base_transform)
test_loader = DataLoader(test_ds, batch_size=8, shuffle=False, num_workers=4)

In [35]:
model.load_state_dict(torch.load(PATH +'best_phase2_unimatch_drp05_th0.9_3006.pth', map_location=DEVICE,weights_only=True))

<All keys matched successfully>

In [36]:
# Load the best model
model.eval()

In [None]:
output_filename = 'best_phase2_unimatch_drp05_th0.9_3006.csv'
pred_and_save(test_loader, model,  MASK_CSV ,output_filename)

# Se débarasser de çela après

In [37]:
labels_train = pd.read_csv(MASK_CSV, index_col=0, header=0).T

In [38]:
# Inference and save CSV
all_preds = []
filenames = []

with torch.no_grad():
    for imgs, names in tqdm(test_loader):
        imgs = imgs.to(DEVICE)
        logits = model(imgs)
        preds = torch.argmax(logits, dim=1).cpu().numpy()  # (B,H,W)
        for p, n in zip(preds, names):
            all_preds.append(p.flatten())
            filenames.append(n)



100%|██████████| 63/63 [00:04<00:00, 13.72it/s]


In [39]:
# Create the submission DataFrame
df = pd.DataFrame(np.stack(all_preds, axis=0), columns=labels_train.columns) #
df = df.T
df.columns = filenames

# Save CSV
output_csv = os.path.join(PATH,  'best_phase2_unimatch_drp05_th0.9_3006.csv')
df.to_csv(output_csv, index=True)
print(f"Test predictions saved to {output_csv}")

Test predictions saved to /content/drive/MyDrive/FewCTSeg/data/best_phase2_unimatch_drp05_th0.9_3006.csv
