In [1]:
import os
import time
import random
import collections
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import torch
import torchvision
from torchvision.transforms import ToPILImage
from torchvision.transforms import functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

In [2]:
def fix_all_seeds(seed):
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
fix_all_seeds(2025)

In [3]:
TRAIN_CSV = "/kaggle/input/kaust-vs-kku-tournament-round-3/cells_segmentation/train.csv"
TRAIN_PATH = "/kaggle/input/kaust-vs-kku-tournament-round-3/cells_segmentation/train"
TEST_PATH = "/kaggle/input/kaust-vs-kku-tournament-round-3/cells_segmentation/test"
UNLABELED_PATH = "/kaggle/input/kaust-vs-kku-tournament-round-3/cells_segmentation/unlabeled_additional_data"

DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

WIDTH = 704
HEIGHT = 520
PCT_IMAGES_VALIDATION = 0.1

BATCH_SIZE = 2  # Reduced from 4
GRADIENT_ACCUMULATION_STEPS = 2  # Effective batch size = 2 * 2 = 4
NUM_EPOCHS = 30
LEARNING_RATE = 5e-4
WEIGHT_DECAY = 1e-4
BOX_DETECTIONS_PER_IMG = 100  # Reduced from 150
WARMUP_EPOCHS = 3

# Enable mixed precision training
USE_AMP = True

BOXES_CONF = 0.1
MASK_THRESHOLD = 0.5

In [4]:
def get_transform(train=True, height=HEIGHT, width=WIDTH):
    if train:
        return A.Compose([
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.RandomRotate90(p=0.5),
            A.Transpose(p=0.5),
            
            A.OneOf([
                A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=1.0),
                A.RandomGamma(gamma_limit=(80, 120), p=1.0),
                A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), p=1.0),
            ], p=0.7),
            
            A.OneOf([
                A.GaussNoise(var_limit=(10.0, 50.0), p=1.0),
                A.GaussianBlur(blur_limit=(3, 7), p=1.0),
                A.MedianBlur(blur_limit=5, p=1.0),
            ], p=0.3),
            
            A.OneOf([
                A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=1.0),
                A.GridDistortion(num_steps=5, distort_limit=0.3, p=1.0),
                A.OpticalDistortion(distort_limit=0.5, shift_limit=0.5, p=1.0),
            ], p=0.3),
            
            A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=15, 
                             border_mode=cv2.BORDER_CONSTANT, value=0, p=0.5),
            
            A.Resize(height, width, always_apply=True),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], always_apply=True),
            ToTensorV2(),
        ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['labels']))
    else:
        return A.Compose([
            A.Resize(height, width, always_apply=True),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], always_apply=True),
            ToTensorV2(),
        ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['labels']))

In [5]:
def rle_decode(mask_rle, shape, color=1):
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0] * shape[1], dtype=np.float32)
    for lo, hi in zip(starts, ends):
        img[lo : hi] = color
    return img.reshape(shape)

def rle_encoding(x):
    dots = np.where(x.flatten() == 1)[0]
    run_lengths = []
    prev = -2
    for b in dots:
        if (b > prev + 1): 
            run_lengths.extend((b + 1, 0))
        run_lengths[-1] += 1
        prev = b
    return ' '.join(map(str, run_lengths))

def remove_overlapping_pixels(mask, other_masks):
    for other_mask in other_masks:
        if np.sum(np.logical_and(mask, other_mask)) > 0:
            mask[np.logical_and(mask, other_mask)] = 0
    return mask

In [6]:
class CellDataset(Dataset):
    def __init__(self, image_dir, df, transforms=None):
        self.transforms = transforms
        self.image_dir = image_dir
        self.df = df
        self.height = HEIGHT
        self.width = WIDTH
        
        self.image_info = collections.defaultdict(dict)
        temp_df = self.df.groupby('id')['annotation'].agg(lambda x: list(x)).reset_index()
        for index, row in temp_df.iterrows():
            self.image_info[index] = {
                'image_id': row['id'],
                'image_path': os.path.join(self.image_dir, row['id'] + '.png'),
                'annotations': row["annotation"]
            }
    
    def get_box(self, a_mask):
        pos = np.where(a_mask)
        if len(pos[0]) == 0:
            return [0, 0, 1, 1]
        xmin = np.min(pos[1])
        xmax = np.max(pos[1])
        ymin = np.min(pos[0])
        ymax = np.max(pos[0])
        return [xmin, ymin, xmax, ymax]

    def __getitem__(self, idx):
        info = self.image_info[idx]
        img = Image.open(info['image_path']).convert("RGB")
        img_np = np.array(img)

        masks = []
        boxes = []
        for rle in info['annotations']:
            mask = rle_decode(rle, (self.height, self.width)).astype('uint8')
            if mask.sum() == 0:
                continue
            masks.append(mask)
            boxes.append(self.get_box(mask))

        if len(masks) == 0:
            masks = [np.zeros((self.height, self.width), dtype=np.uint8)]
            boxes = [[0, 0, 1, 1]]

        labels = [1] * len(masks)

        if self.transforms:
            try:
                augmented = self.transforms(
                    image=img_np,
                    masks=masks,
                    bboxes=boxes,
                    labels=labels
                )
                img = augmented['image']
                masks = [torch.as_tensor(m, dtype=torch.uint8) for m in augmented['masks']]
                boxes = torch.as_tensor(augmented['bboxes'], dtype=torch.float32)
                labels = torch.as_tensor(augmented['labels'], dtype=torch.int64)
            except:
                img = torchvision.transforms.ToTensor()(img_np)
                masks = [torch.as_tensor(m, dtype=torch.uint8) for m in masks]
                boxes = torch.as_tensor(boxes, dtype=torch.float32)
                labels = torch.as_tensor(labels, dtype=torch.int64)
        else:
            img = torchvision.transforms.ToTensor()(img_np)
            masks = [torch.as_tensor(m, dtype=torch.uint8) for m in masks]
            boxes = torch.as_tensor(boxes, dtype=torch.float32)
            labels = torch.as_tensor(labels, dtype=torch.int64)

        if len(masks) == 0:
            masks = [torch.zeros((self.height, self.width), dtype=torch.uint8)]
            boxes = torch.tensor([[0, 0, 1, 1]], dtype=torch.float32)
            labels = torch.tensor([1], dtype=torch.int64)

        masks = torch.stack(masks)
        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        iscrowd = torch.zeros((labels.shape[0],), dtype=torch.int64)

        target = {
            'boxes': boxes,
            'labels': labels,
            'masks': masks,
            'image_id': image_id,
            'area': area,
            'iscrowd': iscrowd
        }

        return img, target

    def __len__(self):
        return len(self.image_info)

In [7]:
def collate_fn(batch):
    return tuple(zip(*batch))

df_base = pd.read_csv(TRAIN_CSV)
unique_ids = df_base['id'].unique()
train_ids, val_ids = train_test_split(
    unique_ids,
    test_size=PCT_IMAGES_VALIDATION,
    random_state=42,
    shuffle=True
)

df_train = df_base[df_base['id'].isin(train_ids)]
df_val = df_base[df_base['id'].isin(val_ids)]

print(f"Train images: {len(train_ids)}, Val images: {len(val_ids)}")
print(f"Train annotations: {len(df_train)}, Val annotations: {len(df_val)}")

ds_train = CellDataset(TRAIN_PATH, df_train, transforms=get_transform(train=True))
dl_train = DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True, 
                      num_workers=2, collate_fn=collate_fn, pin_memory=True)

ds_val = CellDataset(TRAIN_PATH, df_val, transforms=get_transform(train=False))
dl_val = DataLoader(ds_val, batch_size=BATCH_SIZE, shuffle=False, 
                    num_workers=2, collate_fn=collate_fn, pin_memory=True)

Train images: 381, Val images: 43
Train annotations: 48918, Val annotations: 4147


In [8]:
def get_model():
    NUM_CLASSES = 2
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(
        pretrained=True,
        box_detections_per_img=BOX_DETECTIONS_PER_IMG,
        box_score_thresh=0.05,
        box_nms_thresh=0.5,  # Increased for better filtering
        rpn_score_thresh=0.05,
        rpn_nms_thresh=0.7,
        rpn_pre_nms_top_n_train=1000,  # Reduced from 2000
        rpn_pre_nms_top_n_test=500,   # Reduced from 1000
        rpn_post_nms_top_n_train=1000, # Reduced from 2000
        rpn_post_nms_top_n_test=500    # Reduced from 1000
    )
    
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, NUM_CLASSES)

    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, NUM_CLASSES)
    
    # Enable gradient checkpointing to save memory
    model.backbone.body.requires_grad_(True)
    
    return model

# Single GPU setup (no DataParallel)
model = get_model()
model = model.to(DEVICE)
print(f"→ Using single GPU with batch size {BATCH_SIZE} and gradient accumulation {GRADIENT_ACCUMULATION_STEPS}")

Downloading: "https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth" to /root/.cache/torch/hub/checkpoints/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth
100%|██████████| 170M/170M [00:00<00:00, 196MB/s]


→ Using single GPU with batch size 2 and gradient accumulation 2


In [9]:
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(params, lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

def warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor):
    def f(x):
        if x >= warmup_iters:
            return 1
        alpha = float(x) / warmup_iters
        return warmup_factor * (1 - alpha) + alpha
    return torch.optim.lr_scheduler.LambdaLR(optimizer, f)

def cosine_lr_scheduler(optimizer, total_epochs, warmup_epochs):
    warmup_iters = warmup_epochs * len(dl_train)
    warmup_scheduler = warmup_lr_scheduler(optimizer, warmup_iters, 0.1)
    main_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=total_epochs - warmup_epochs, eta_min=LEARNING_RATE * 0.01
    )
    return warmup_scheduler, main_scheduler

warmup_scheduler, main_scheduler = cosine_lr_scheduler(optimizer, NUM_EPOCHS, WARMUP_EPOCHS)

In [10]:
class AverageMeter:
    def __init__(self):
        self.reset()
    
    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
    
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def train_one_epoch(model, optimizer, data_loader, device, epoch, warmup_scheduler=None):
    model.train()
    loss_meter = AverageMeter()
    mask_loss_meter = AverageMeter()
    
    scaler = torch.cuda.amp.GradScaler() if USE_AMP else None
    optimizer.zero_grad()
    
    for batch_idx, (images, targets) in enumerate(tqdm(data_loader, desc=f"Epoch {epoch}")):
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        if USE_AMP:
            with torch.cuda.amp.autocast():
                loss_dict = model(images, targets)
                loss = sum(loss for loss in loss_dict.values())
                loss = loss / GRADIENT_ACCUMULATION_STEPS
        else:
            loss_dict = model(images, targets)
            loss = sum(loss for loss in loss_dict.values())
            loss = loss / GRADIENT_ACCUMULATION_STEPS
        
        if USE_AMP:
            scaler.scale(loss).backward()
        else:
            loss.backward()
        
        if (batch_idx + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
            if USE_AMP:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
                scaler.step(optimizer)
                scaler.update()
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
                optimizer.step()
            
            optimizer.zero_grad()
            
            if warmup_scheduler is not None and epoch <= WARMUP_EPOCHS:
                warmup_scheduler.step()
        
        loss_meter.update(loss.item() * GRADIENT_ACCUMULATION_STEPS, len(images))
        if 'loss_mask' in loss_dict:
            mask_loss_meter.update(loss_dict['loss_mask'].item(), len(images))
        
        # Clear cache every 10 batches
        if batch_idx % 10 == 0:
            torch.cuda.empty_cache()
    
    return loss_meter.avg, mask_loss_meter.avg
            

def validate_one_epoch(model, data_loader, device):
    model.train()  # Keep in train mode to get losses
    loss_meter = AverageMeter()
    mask_loss_meter = AverageMeter()
    
    with torch.no_grad():
        for batch_idx, (images, targets) in enumerate(tqdm(data_loader, desc="Validation")):
            images = list(image.to(device) for image in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            
            if USE_AMP:
                with torch.cuda.amp.autocast():
                    loss_dict = model(images, targets)
                    loss = sum(loss for loss in loss_dict.values())
            else:
                loss_dict = model(images, targets)
                loss = sum(loss for loss in loss_dict.values())
            
            loss_meter.update(loss.item(), len(images))
            if 'loss_mask' in loss_dict:
                mask_loss_meter.update(loss_dict['loss_mask'].item(), len(images))
            
            # Clear cache every 5 batches during validation
            if batch_idx % 5 == 0:
                torch.cuda.empty_cache()
    
    return loss_meter.avg, mask_loss_meter.avg

In [11]:
# Clear GPU cache before training
torch.cuda.empty_cache()
import gc
gc.collect()

print(f"GPU Memory before training: {torch.cuda.memory_allocated()/1024**3:.2f} GB")

GPU Memory before training: 0.17 GB


In [12]:
best_val_loss = float('inf')
patience = 8
patience_counter = 0
train_losses = []
val_losses = []

for epoch in range(1, NUM_EPOCHS + 1):
    train_loss, train_mask_loss = train_one_epoch(
        model, optimizer, dl_train, DEVICE, epoch, 
        warmup_scheduler if epoch <= WARMUP_EPOCHS else None
    )
    
    val_loss, val_mask_loss = validate_one_epoch(model, dl_val, DEVICE)
    
    if epoch > WARMUP_EPOCHS:
        main_scheduler.step()
    
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    
    print(f"Epoch {epoch:2d}/{NUM_EPOCHS} - Train Loss: {train_loss:.4f} - Val Loss: {val_loss:.4f}")
    print(f"                 Train Mask: {train_mask_loss:.4f} - Val Mask: {val_mask_loss:.4f}")
    print(f"                 LR: {optimizer.param_groups[0]['lr']:.6f}")
    
    # In training loop, replace the save lines with:
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        model_to_save = model.module if hasattr(model, 'module') else model
        torch.save({
            'epoch': epoch,
            'model_state_dict': model_to_save.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
        }, 'best_model.pth')
        print(f"                 New best model saved!")
    
    # And for regular checkpoints:
    model_to_save = model.module if hasattr(model, 'module') else model
    torch.save({
        'epoch': epoch,
        'model_state_dict': model_to_save.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'val_loss': val_loss,
    }, f'model_epoch_{epoch}.pth')
    
    print("-" * 50)

Epoch 1: 100%|██████████| 191/191 [05:22<00:00,  1.69s/it]
Validation: 100%|██████████| 22/22 [00:11<00:00,  1.89it/s]


Epoch  1/30 - Train Loss: 2.3805 - Val Loss: 1.6677
                 Train Mask: 0.7614 - Val Mask: 0.4154
                 LR: 0.000125
                 New best model saved!
--------------------------------------------------


Epoch 2: 100%|██████████| 191/191 [05:25<00:00,  1.70s/it]
Validation: 100%|██████████| 22/22 [00:11<00:00,  1.95it/s]


Epoch  2/30 - Train Loss: 1.9291 - Val Loss: 1.6101
                 Train Mask: 0.5365 - Val Mask: 0.4116
                 LR: 0.000199
                 New best model saved!
--------------------------------------------------


Epoch 3: 100%|██████████| 191/191 [05:04<00:00,  1.60s/it]
Validation: 100%|██████████| 22/22 [00:11<00:00,  1.92it/s]


Epoch  3/30 - Train Loss: 1.8890 - Val Loss: 1.5868
                 Train Mask: 0.5177 - Val Mask: 0.4117
                 LR: 0.000274
                 New best model saved!
--------------------------------------------------


Epoch 4: 100%|██████████| 191/191 [04:41<00:00,  1.47s/it]
Validation: 100%|██████████| 22/22 [00:11<00:00,  1.88it/s]


Epoch  4/30 - Train Loss: 1.8956 - Val Loss: 1.5829
                 Train Mask: 0.5242 - Val Mask: 0.3924
                 LR: 0.000273
                 New best model saved!
--------------------------------------------------


Epoch 5: 100%|██████████| 191/191 [04:59<00:00,  1.57s/it]
Validation: 100%|██████████| 22/22 [00:11<00:00,  1.94it/s]


Epoch  5/30 - Train Loss: 1.8505 - Val Loss: 1.5296
                 Train Mask: 0.5106 - Val Mask: 0.3898
                 LR: 0.000270
                 New best model saved!
--------------------------------------------------


Epoch 6: 100%|██████████| 191/191 [05:28<00:00,  1.72s/it]
Validation: 100%|██████████| 22/22 [00:11<00:00,  1.91it/s]


Epoch  6/30 - Train Loss: 1.8140 - Val Loss: 1.5658
                 Train Mask: 0.5120 - Val Mask: 0.3959
                 LR: 0.000266
--------------------------------------------------


Epoch 7: 100%|██████████| 191/191 [05:26<00:00,  1.71s/it]
Validation: 100%|██████████| 22/22 [00:11<00:00,  1.94it/s]


Epoch  7/30 - Train Loss: 1.8013 - Val Loss: 1.5328
                 Train Mask: 0.5133 - Val Mask: 0.3912
                 LR: 0.000260
--------------------------------------------------


Epoch 8: 100%|██████████| 191/191 [05:05<00:00,  1.60s/it]
Validation: 100%|██████████| 22/22 [00:11<00:00,  1.90it/s]


Epoch  8/30 - Train Loss: 1.7749 - Val Loss: 1.5627
                 Train Mask: 0.5163 - Val Mask: 0.4348
                 LR: 0.000252
--------------------------------------------------


Epoch 9: 100%|██████████| 191/191 [05:11<00:00,  1.63s/it]
Validation: 100%|██████████| 22/22 [00:11<00:00,  1.93it/s]


Epoch  9/30 - Train Loss: 1.7718 - Val Loss: 1.5381
                 Train Mask: 0.5051 - Val Mask: 0.4019
                 LR: 0.000242
--------------------------------------------------


Epoch 10: 100%|██████████| 191/191 [04:40<00:00,  1.47s/it]
Validation: 100%|██████████| 22/22 [00:11<00:00,  1.95it/s]


Epoch 10/30 - Train Loss: 1.7521 - Val Loss: 1.4662
                 Train Mask: 0.5011 - Val Mask: 0.3970
                 LR: 0.000232
                 New best model saved!
--------------------------------------------------


Epoch 11: 100%|██████████| 191/191 [05:06<00:00,  1.60s/it]
Validation: 100%|██████████| 22/22 [00:11<00:00,  1.94it/s]


Epoch 11/30 - Train Loss: 1.7284 - Val Loss: 1.4840
                 Train Mask: 0.4942 - Val Mask: 0.3997
                 LR: 0.000220
--------------------------------------------------


Epoch 12: 100%|██████████| 191/191 [05:06<00:00,  1.60s/it]
Validation: 100%|██████████| 22/22 [00:11<00:00,  1.95it/s]


Epoch 12/30 - Train Loss: 1.7383 - Val Loss: 1.4673
                 Train Mask: 0.4989 - Val Mask: 0.3955
                 LR: 0.000207
--------------------------------------------------


Epoch 13: 100%|██████████| 191/191 [05:36<00:00,  1.76s/it]
Validation: 100%|██████████| 22/22 [00:11<00:00,  1.91it/s]


Epoch 13/30 - Train Loss: 1.7224 - Val Loss: 1.4826
                 Train Mask: 0.4939 - Val Mask: 0.3794
                 LR: 0.000193
--------------------------------------------------


Epoch 14: 100%|██████████| 191/191 [05:10<00:00,  1.63s/it]
Validation: 100%|██████████| 22/22 [00:11<00:00,  1.92it/s]


Epoch 14/30 - Train Loss: 1.7167 - Val Loss: 1.5578
                 Train Mask: 0.4921 - Val Mask: 0.4165
                 LR: 0.000178
--------------------------------------------------


Epoch 15: 100%|██████████| 191/191 [05:15<00:00,  1.65s/it]
Validation: 100%|██████████| 22/22 [00:11<00:00,  1.89it/s]


Epoch 15/30 - Train Loss: 1.6888 - Val Loss: 1.4610
                 Train Mask: 0.4992 - Val Mask: 0.3881
                 LR: 0.000163
                 New best model saved!
--------------------------------------------------


Epoch 16: 100%|██████████| 191/191 [05:30<00:00,  1.73s/it]
Validation: 100%|██████████| 22/22 [00:11<00:00,  1.94it/s]


Epoch 16/30 - Train Loss: 1.6766 - Val Loss: 1.4798
                 Train Mask: 0.4838 - Val Mask: 0.3828
                 LR: 0.000147
--------------------------------------------------


Epoch 17: 100%|██████████| 191/191 [05:14<00:00,  1.65s/it]
Validation: 100%|██████████| 22/22 [00:11<00:00,  1.92it/s]


Epoch 17/30 - Train Loss: 1.6784 - Val Loss: 1.3993
                 Train Mask: 0.4795 - Val Mask: 0.3521
                 LR: 0.000132
                 New best model saved!
--------------------------------------------------


Epoch 18: 100%|██████████| 191/191 [04:54<00:00,  1.54s/it]
Validation: 100%|██████████| 22/22 [00:11<00:00,  1.88it/s]


Epoch 18/30 - Train Loss: 1.6566 - Val Loss: 1.4244
                 Train Mask: 0.4715 - Val Mask: 0.3481
                 LR: 0.000116
--------------------------------------------------


Epoch 19: 100%|██████████| 191/191 [05:01<00:00,  1.58s/it]
Validation: 100%|██████████| 22/22 [00:11<00:00,  1.92it/s]


Epoch 19/30 - Train Loss: 1.6618 - Val Loss: 1.4524
                 Train Mask: 0.4843 - Val Mask: 0.3808
                 LR: 0.000101
--------------------------------------------------


Epoch 20: 100%|██████████| 191/191 [04:57<00:00,  1.56s/it]
Validation: 100%|██████████| 22/22 [00:11<00:00,  1.84it/s]


Epoch 20/30 - Train Loss: 1.6299 - Val Loss: 1.4239
                 Train Mask: 0.4772 - Val Mask: 0.3480
                 LR: 0.000086
--------------------------------------------------


Epoch 21: 100%|██████████| 191/191 [05:42<00:00,  1.79s/it]
Validation: 100%|██████████| 22/22 [00:11<00:00,  1.83it/s]


Epoch 21/30 - Train Loss: 1.6253 - Val Loss: 1.4031
                 Train Mask: 0.4690 - Val Mask: 0.3409
                 LR: 0.000072
--------------------------------------------------


Epoch 22: 100%|██████████| 191/191 [05:20<00:00,  1.68s/it]
Validation: 100%|██████████| 22/22 [00:11<00:00,  1.87it/s]


Epoch 22/30 - Train Loss: 1.5950 - Val Loss: 1.4365
                 Train Mask: 0.4714 - Val Mask: 0.3364
                 LR: 0.000059
--------------------------------------------------


Epoch 23: 100%|██████████| 191/191 [05:35<00:00,  1.75s/it]
Validation: 100%|██████████| 22/22 [00:11<00:00,  1.90it/s]


Epoch 23/30 - Train Loss: 1.6141 - Val Loss: 1.4235
                 Train Mask: 0.4752 - Val Mask: 0.3576
                 LR: 0.000047
--------------------------------------------------


Epoch 24: 100%|██████████| 191/191 [05:10<00:00,  1.63s/it]
Validation: 100%|██████████| 22/22 [00:11<00:00,  1.89it/s]


Epoch 24/30 - Train Loss: 1.6000 - Val Loss: 1.3906
                 Train Mask: 0.4676 - Val Mask: 0.3507
                 LR: 0.000036
                 New best model saved!
--------------------------------------------------


Epoch 25: 100%|██████████| 191/191 [05:24<00:00,  1.70s/it]
Validation: 100%|██████████| 22/22 [00:11<00:00,  1.90it/s]


Epoch 25/30 - Train Loss: 1.5828 - Val Loss: 1.4358
                 Train Mask: 0.4719 - Val Mask: 0.3494
                 LR: 0.000027
--------------------------------------------------


Epoch 26: 100%|██████████| 191/191 [04:56<00:00,  1.55s/it]
Validation: 100%|██████████| 22/22 [00:11<00:00,  1.90it/s]


Epoch 26/30 - Train Loss: 1.5707 - Val Loss: 1.4109
                 Train Mask: 0.4557 - Val Mask: 0.3519
                 LR: 0.000019
--------------------------------------------------


Epoch 27: 100%|██████████| 191/191 [05:25<00:00,  1.70s/it]
Validation: 100%|██████████| 22/22 [00:11<00:00,  1.92it/s]


Epoch 27/30 - Train Loss: 1.5691 - Val Loss: 1.4077
                 Train Mask: 0.4542 - Val Mask: 0.3537
                 LR: 0.000013
--------------------------------------------------


Epoch 28: 100%|██████████| 191/191 [04:57<00:00,  1.56s/it]
Validation: 100%|██████████| 22/22 [00:11<00:00,  1.92it/s]


Epoch 28/30 - Train Loss: 1.5722 - Val Loss: 1.4147
                 Train Mask: 0.4601 - Val Mask: 0.3629
                 LR: 0.000009
--------------------------------------------------


Epoch 29: 100%|██████████| 191/191 [05:13<00:00,  1.64s/it]
Validation: 100%|██████████| 22/22 [00:11<00:00,  1.93it/s]


Epoch 29/30 - Train Loss: 1.5777 - Val Loss: 1.4294
                 Train Mask: 0.4634 - Val Mask: 0.3752
                 LR: 0.000006
--------------------------------------------------


Epoch 30: 100%|██████████| 191/191 [05:15<00:00,  1.65s/it]
Validation: 100%|██████████| 22/22 [00:11<00:00,  1.95it/s]


Epoch 30/30 - Train Loss: 1.5727 - Val Loss: 1.4342
                 Train Mask: 0.4619 - Val Mask: 0.3769
                 LR: 0.000005
--------------------------------------------------


In [13]:
class CellTestDataset(Dataset):
    def __init__(self, image_dir, transforms=None):
        self.transforms = transforms
        self.image_dir = image_dir
        self.image_ids = [fname[:-4] for fname in os.listdir(self.image_dir) 
                          if fname.endswith('.png')]
    
    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        image_path = os.path.join(self.image_dir, image_id + '.png')
        img = Image.open(image_path).convert("RGB")
        img_np = np.array(img)

        if self.transforms:
            augmented = self.transforms(image=img_np, bboxes=[], labels=[])
            img_tensor = augmented['image']
        else:
            img_tensor = torchvision.transforms.ToTensor()(img)

        return {'image': img_tensor, 'image_id': image_id}

    def __len__(self):
        return len(self.image_ids)

ds_test = CellTestDataset(TEST_PATH, transforms=get_transform(train=False))
test_loader = DataLoader(ds_test, batch_size=1, shuffle=False, num_workers=2, collate_fn=lambda x: x)

In [14]:
checkpoint = torch.load('best_model.pth')
model = get_model()
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(DEVICE)

# No DataParallel for inference (single batch)
model.eval()

print(f"Loaded best model from epoch {checkpoint['epoch']} with val_loss: {checkpoint['val_loss']:.4f}")

Loaded best model from epoch 24 with val_loss: 1.3906


In [15]:
def apply_nms_to_masks(boxes, scores, masks, iou_threshold=0.3):
    keep = torchvision.ops.nms(boxes, scores, iou_threshold)
    return boxes[keep], scores[keep], masks[keep]

def post_process_masks(masks, min_area=50):
    processed_masks = []
    for mask in masks:
        mask_np = mask.cpu().numpy()
        if mask_np.sum() < min_area:
            continue
        processed_masks.append(mask_np)
    return processed_masks

submission = []

model.eval()
with torch.no_grad():
    for batch in tqdm(test_loader, desc="Inference"):
        sample = batch[0]
        img = sample['image'].to(DEVICE)
        image_id = sample['image_id']
        
        outputs = model([img])[0]
        
        if len(outputs['masks']) == 0:
            submission.append((image_id, "-1"))
            continue
        
        boxes = outputs['boxes']
        scores = outputs['scores']
        masks = outputs['masks']
        
        mask_above_threshold = scores > BOXES_CONF
        if not mask_above_threshold.any():
            submission.append((image_id, "-1"))
            continue
            
        boxes = boxes[mask_above_threshold]
        scores = scores[mask_above_threshold]
        masks = masks[mask_above_threshold]
        
        boxes, scores, masks = apply_nms_to_masks(boxes, scores, masks)
        
        any_mask = False
        prev_masks = []
        
        for mask, score in zip(masks, scores):
            mask_np = mask.cpu().numpy()
            bin_mask = mask_np[0] > MASK_THRESHOLD
            
            if bin_mask.sum() < 20:
                continue
                
            bin_mask = remove_overlapping_pixels(bin_mask, prev_masks)
            
            if bin_mask.sum() < 10:
                continue
                
            prev_masks.append(bin_mask)
            rle = rle_encoding(bin_mask.astype(np.uint8))
            
            if rle:
                submission.append((image_id, rle))
                any_mask = True
        
        if not any_mask:
            submission.append((image_id, "-1"))

df_sub = pd.DataFrame(submission, columns=['id', 'annotation'])
df_sub["idx"] = range(len(df_sub))
df_sub = df_sub[["idx", "id", "annotation"]].replace({"": "-1"})
df_sub.to_csv("submission.csv", index=False)

print(f"Submission created with {len(df_sub)} entries")
print(f"Images with predictions: {len(df_sub[df_sub['annotation'] != '-1'])}")
print(f"Images without predictions: {len(df_sub[df_sub['annotation'] == '-1'])}")

Inference: 100%|██████████| 182/182 [03:13<00:00,  1.07s/it]

Submission created with 10391 entries
Images with predictions: 10391
Images without predictions: 0



