In [None]:
import sys
sys.path.append("../")
#from utils_functions.sort_files import alphanumeric_sort #Function which sort alphanumerically files
#from utils_functions.submit import pred_and_save, count_label_preds
import glob
import os
import numpy as np
from PIL import Image
import torch
import cv2
import albumentations as A
import segmentation_models_pytorch as smp
import pandas as pd
from torch.utils.data import DataLoader, random_split, Subset, Dataset
import matplotlib.pyplot as plt
from tqdm import tqdm
from src.dataset import  CTTestDataset, CTScanDataset
#cv2.setNumThreads(0)  - To avoid slower computation

In [None]:
#Ensure reproduicibilty

np.random.seed(26)
torch.manual_seed(26)

<torch._C.Generator at 0x793b3c64f430>

In [None]:
import re
def alphanumeric_sort(name):
    parts = re.split('(\d+)', name)
    return [int(part) if part.isdigit() else part for part in parts]


In [None]:

# ---------------------------
PATH = "/content/drive/MyDrive/FewCTSeg/data/"
train_path = os.path.join(PATH, "train-images/")
test_path = os.path.join(PATH, "test-images/")
labels_path = os.path.join(PATH, "y_train.csv")

In [None]:

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

full_dataset = CTScanDataset(
    image_dir=train_path,
    mask_csv=labels_path,
    transform=None
)

In [None]:
# 1.bis Define transforms
train_transform = A.Compose([
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        A.ToTensorV2(),
    ],
)

val_transform = A.Compose([
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    A.ToTensorV2(),
], )



In [None]:
# 2. Split into train/val
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_indices, val_indices = random_split(
    range(len(full_dataset)),
    [train_size, val_size],
    generator=torch.Generator().manual_seed(26)
)

# 2.bis Create train and val datasets
train_ds = CTScanDataset(
    image_dir=train_path,
    mask_csv=labels_path,
    transform=train_transform
)
val_ds = CTScanDataset(
    image_dir=train_path,
    mask_csv=labels_path,
    transform=val_transform
)


train_ds = torch.utils.data.Subset(train_ds, train_indices.indices)
val_ds = torch.utils.data.Subset(val_ds, val_indices.indices)

train_loader = DataLoader(train_ds, batch_size=8, shuffle=True, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=8, shuffle=False, num_workers=4)





In [None]:

# 3. Modèle Segformer
# -----------------------------------------------------------------------------
model = smp.Segformer(
    encoder_name = "timm-efficientnet-b7",
    encoder_weights="imagenet",    
    in_channels=3,                   
    classes=55,                      # 55 possibles classes
    activation=None                  # logits
).to(DEVICE)

# 4. Define loss and optimizer
# -----------------------------------------------------------------------------
loss_fn = smp.losses.DiceLoss(mode='multiclass', from_logits=True)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# 5. Training
# -----------------------------------------------------------------------------
best_val_loss = float('inf')
num_epochs = 60



config.json:   0%|          | 0.00/94.0 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/267M [00:00<?, ?B/s]

In [None]:


for epoch in range(1, num_epochs + 1):
    # --- Training ---
    model.train()
    train_loss = 0.0
    for imgs, masks in tqdm(train_loader):
        imgs = imgs.to(DEVICE)
        masks = masks.to(DEVICE)

        optimizer.zero_grad()
        preds = model(imgs)                   # (B,55,H,W)
        loss = loss_fn(preds, masks.long())
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * imgs.size(0)
    train_loss /= train_size

    # --- Validation ---
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for imgs, masks in tqdm(val_loader):
            imgs = imgs.to(DEVICE)
            masks = masks.to(DEVICE)
            preds = model(imgs)
            loss = loss_fn(preds, masks.long())
            val_loss += loss.item() * imgs.size(0)
    val_loss /= val_size

    print(f"Epoch {epoch}/{num_epochs} - Train Loss: {train_loss:.4f} - Val Loss: {val_loss:.4f}")

    # Save the best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        #torch.save(model.state_dict(), "/content/drive/MyDrive/FewCTSeg/data/"+ "best_first800_efficientnet-b7_0105.pth")
        print(f"--> New best model saved (val_loss={best_val_loss:.4f})")

print("Finished training")




100%|██████████| 80/80 [00:44<00:00,  1.80it/s]
100%|██████████| 20/20 [00:03<00:00,  5.15it/s]


Epoch 1/60 - Train Loss: 0.4109 - Val Loss: 0.3901
--> Nouveau meilleur modèle enregistré (val_loss=0.3901)


100%|██████████| 80/80 [00:43<00:00,  1.85it/s]
100%|██████████| 20/20 [00:03<00:00,  5.51it/s]


Epoch 2/60 - Train Loss: 0.3253 - Val Loss: 0.3345
--> Nouveau meilleur modèle enregistré (val_loss=0.3345)


100%|██████████| 80/80 [00:43<00:00,  1.84it/s]
100%|██████████| 20/20 [00:04<00:00,  4.83it/s]


Epoch 3/60 - Train Loss: 0.2773 - Val Loss: 0.3172
--> Nouveau meilleur modèle enregistré (val_loss=0.3172)


100%|██████████| 80/80 [00:43<00:00,  1.85it/s]
100%|██████████| 20/20 [00:03<00:00,  5.45it/s]


Epoch 4/60 - Train Loss: 0.2569 - Val Loss: 0.3010
--> Nouveau meilleur modèle enregistré (val_loss=0.3010)


100%|██████████| 80/80 [00:43<00:00,  1.85it/s]
100%|██████████| 20/20 [00:04<00:00,  4.91it/s]


Epoch 5/60 - Train Loss: 0.2401 - Val Loss: 0.2889
--> Nouveau meilleur modèle enregistré (val_loss=0.2889)


100%|██████████| 80/80 [00:43<00:00,  1.85it/s]
100%|██████████| 20/20 [00:03<00:00,  5.40it/s]


Epoch 6/60 - Train Loss: 0.2259 - Val Loss: 0.2815
--> Nouveau meilleur modèle enregistré (val_loss=0.2815)


100%|██████████| 80/80 [00:43<00:00,  1.84it/s]
100%|██████████| 20/20 [00:04<00:00,  4.95it/s]


Epoch 7/60 - Train Loss: 0.2128 - Val Loss: 0.2879


100%|██████████| 80/80 [00:43<00:00,  1.85it/s]
100%|██████████| 20/20 [00:03<00:00,  5.48it/s]


Epoch 8/60 - Train Loss: 0.2028 - Val Loss: 0.2818


100%|██████████| 80/80 [00:43<00:00,  1.85it/s]
100%|██████████| 20/20 [00:03<00:00,  5.30it/s]


Epoch 9/60 - Train Loss: 0.1971 - Val Loss: 0.2775
--> Nouveau meilleur modèle enregistré (val_loss=0.2775)


100%|██████████| 80/80 [00:43<00:00,  1.84it/s]
100%|██████████| 20/20 [00:03<00:00,  5.41it/s]


Epoch 10/60 - Train Loss: 0.1887 - Val Loss: 0.2793


 79%|███████▉  | 63/80 [00:34<00:09,  1.81it/s]


KeyboardInterrupt: 

In [None]:
#Save best_simple_segformer_0105.pth
#torch.save(model.state_dict(), "/content/drive/MyDrive/FewCTSeg/data/"+ "best_simple_segformer_0105.pth")

In [None]:
torch.save(model.state_dict(), "/content/drive/MyDrive/FewCTSeg/data/"+ "best_first800_efficientnet-b7_final_0105.pth") 

In [None]:
model.load_state_dict(torch.load("/content/drive/MyDrive/FewCTSeg/data/"+ "best_first800_efficientnet-b7_0105.pth",  map_location=DEVICE, weights_only=True))

<All keys matched successfully>

In [None]:
# 5. Inference on test set
PATH = "/content/drive/MyDrive/FewCTSeg/data/"
# -----------------------------------------------------------------------------
model.eval()

test_transform = A.Compose([
    A.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
    A.ToTensorV2(),
])

test_ds = CTTestDataset(image_dir=os.path.join(PATH, "test-images"), transform=test_transform)
test_loader = DataLoader(test_ds, batch_size=8, shuffle=False, num_workers=4)

# Save to csv
all_preds = []
filenames = []


In [None]:
labels_train = pd.read_csv(labels_path, index_col=0, header=0).T

In [None]:
with torch.no_grad():
    for imgs, names in tqdm(test_loader):
        imgs = imgs.to(DEVICE)
        logits = model(imgs)
        preds = torch.argmax(logits, dim=1).cpu().numpy()  # (B,H,W)
        for p, n in zip(preds, names):
            all_preds.append(p.flatten())
            filenames.append(n)


df = pd.DataFrame(np.stack(all_preds, axis=0), columns=labels_train.columns) #
df = df.T

df.columns = filenames




100%|██████████| 63/63 [00:11<00:00,  5.66it/s]


In [None]:
# Save CSV
output_csv = os.path.join(PATH, "best_first800_efficientnet-b7_0105.csv")
df.to_csv(output_csv, index=True)
print(f"Test predictions saved to {output_csv}")

Test predictions saved to /content/drive/MyDrive/FewCTSeg/data/best_first800_efficientnet-b7_dataaug_0105.csv


In [None]:
#pred_and_save(test_loader, model,  labels_path,output_filename)