In [None]:
!pip install -q timm torchmetrics albumentations

In [None]:
import os, glob, random, cv2
import multiprocessing
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from copy import deepcopy
from sklearn.metrics import precision_recall_curve, auc
from sklearn.model_selection import train_test_split

# PyTorch imports
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam, RAdam
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision.transforms as transforms
import timm
from torchmetrics.classification import MultilabelAUROC

# Other utilities
from albumentations import Compose, VerticalFlip, HorizontalFlip, Rotate, GridDistortion
from albumentations.pytorch import ToTensorV2 # Important for PyTorch
from tqdm.notebook import tqdm

# Set seeds for reproducibility
def set_seed(seed=10):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

set_seed(10)

%matplotlib inline

In [None]:
test_imgs_dir = '../input/understanding_cloud_organization/test_images/'
train_imgs_dir = '../input/understanding_cloud_organization/train_images/'
num_cores = multiprocessing.cpu_count()

In [None]:
train_df = pd.read_csv('../input/understanding_cloud_organization/train.csv')
train_df = train_df[~train_df['EncodedPixels'].isnull()]
train_df['Image'] = train_df['Image_Label'].map(lambda x: x.split('_')[0])
train_df['Class'] = train_df['Image_Label'].map(lambda x: x.split('_')[1])
classes = train_df['Class'].unique()
train_df = train_df.groupby('Image')['Class'].agg(set).reset_index()
for class_name in classes:
    train_df[class_name] = train_df['Class'].map(lambda x: 1 if class_name in x else 0)

# Dictionary for fast access
img_2_ohe_vector = {img:vec for img, vec in zip(train_df['Image'], train_df.iloc[:, 2:].values)}

train_df.head()

In [None]:
train_imgs, val_imgs = train_test_split(train_df['Image'].values,
                                        test_size = 0.2,
                                        stratify = train_df['Class'].map(lambda x: str(sorted(list(x)))),
                                        random_state = 2019)

In [None]:
class CloudDataset(Dataset):
    def __init__(self, images_list, folder_imgs=train_imgs_dir,
                 transform=None, resized_height=260, resized_width=260, is_test=False):
        self.images_list = deepcopy(images_list)
        self.folder_imgs = folder_imgs
        self.transform = transform
        self.resized_height = resized_height
        self.resized_width = resized_width
        self.is_test = is_test
        self.labels = [img_2_ohe_vector[img] for img in self.images_list] if not is_test else None

    def __len__(self):
        return len(self.images_list)

    def __getitem__(self, idx):
        image_name = self.images_list[idx]
        path = os.path.join(self.folder_imgs, image_name)
        img = cv2.imread(path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (self.resized_width, self.resized_height))

        if self.transform:
            augmented = self.transform(image=img)
            img = augmented['image']

        # PyTorch expects C, H, W format
        img = img.astype(np.float32) / 255.0
        img = img.transpose(2, 0, 1)
        
        img_tensor = torch.tensor(img, dtype=torch.float32)

        if self.is_test:
            return img_tensor
        else:
            label = torch.tensor(self.labels[idx], dtype=torch.float32)
            return img_tensor, label

In [None]:
albumentations_train = Compose([
    VerticalFlip(p=0.5),
    HorizontalFlip(p=0.5),
    Rotate(limit=20, p=0.5),
    GridDistortion(p=0.5)
], p=1)

# Create Datasets
train_dataset = CloudDataset(train_imgs, transform=albumentations_train)
val_dataset = CloudDataset(val_imgs) # No augmentation for validation

# Create DataLoaders
BATCH_SIZE = 32
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_cores)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_cores)

In [None]:
def dice_loss(y_pred, y_true, smooth=1.):
    y_pred = y_pred.sigmoid() # Apply sigmoid since we use BCEWithLogitsLoss
    y_true_f = y_true.flatten(1)
    y_pred_f = y_pred.flatten(1)
    intersection = (y_true_f * y_pred_f).sum(1)
    score = (2. * intersection + smooth) / (y_true_f.sum(1) + y_pred_f.sum(1) + smooth)
    return 1. - score.mean()

def bce_dice_loss(y_pred, y_true):
    # We use BCEWithLogitsLoss for better numerical stability
    bce = nn.BCEWithLogitsLoss()
    return bce(y_pred, y_true) + dice_loss(y_pred, y_true)

In [None]:
def get_model(model_name='efficientnet_b0', num_classes=4, pretrained=True):
    model = timm.create_model(model_name, pretrained=pretrained, num_classes=num_classes)
    return model

# Instantiate the model and select device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = get_model()
model.to(device)
print(f"Model loaded on {device}")

In [None]:
# Training configuration
EPOCHS = 20
LEARNING_RATE = 1e-3
CHECKPOINTS_PATH = 'checkpoints/'
if not os.path.exists(CHECKPOINTS_PATH):
    os.makedirs(CHECKPOINTS_PATH)

# Setup optimizer, loss, and scheduler
optimizer = RAdam(model.parameters(), lr=LEARNING_RATE)
criterion = bce_dice_loss
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3)

# For metrics
pr_auc_metric = MultilabelAUROC(num_labels=4, average="macro", thresholds=None)

# For plotting and tracking
history = {
    'train_loss': [], 'val_loss': [],
    'train_pr_auc': [], 'val_pr_auc': []
}
best_pr_auc = -float('inf')
early_stopping_patience = 5
patience_counter = 0

# --- Main Training Loop ---
for epoch in range(EPOCHS):
    print(f"--- Epoch {epoch+1}/{EPOCHS} ---")
    
    # Training phase
    model.train()
    train_loss = 0.0
    for images, labels in tqdm(train_loader, desc="Training"):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()

    avg_train_loss = train_loss / len(train_loader)
    history['train_loss'].append(avg_train_loss)
    
    # Validation phase
    model.eval()
    val_loss = 0.0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc="Validation"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            
            # Store predictions and labels for metrics
            all_preds.append(outputs.sigmoid().cpu())
            all_labels.append(labels.cpu())

    avg_val_loss = val_loss / len(val_loader)
    history['val_loss'].append(avg_val_loss)

    # Calculate PR-AUC
    val_preds = torch.cat(all_preds)
    val_labels = torch.cat(all_labels).int() # AUROC metric expects integer labels
    
    pr_auc_val = pr_auc_metric(val_preds, val_labels).item()
    history['val_pr_auc'].append(pr_auc_val)
    
    print(f"Epoch {epoch+1}: Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Val PR-AUC: {pr_auc_val:.4f}")
    
    # LR Scheduler
    scheduler.step(pr_auc_val)

    # Model Checkpointing
    if pr_auc_val > best_pr_auc:
        print(f"Validation PR-AUC improved from {best_pr_auc:.4f} to {pr_auc_val:.4f}. Saving model...")
        best_pr_auc = pr_auc_val
        patience_counter = 0 # Reset patience
        torch.save(model.state_dict(), os.path.join(CHECKPOINTS_PATH, 'best_model.pth'))
    else:
        patience_counter += 1

    # Early Stopping
    if patience_counter >= early_stopping_patience:
        print(f"Early stopping triggered after {early_stopping_patience} epochs with no improvement.")
        break
        
# Load the best model for inference
model.load_state_dict(torch.load(os.path.join(CHECKPOINTS_PATH, 'best_model.pth')))

In [None]:
fs = 10
fnt1 = 12
fnt2 = 15

plt.figure(figsize=(fs, fs))
plt.plot(history['val_pr_auc'])
plt.xlabel('Epoch', fontsize=fnt1)
plt.ylabel('Mean PR-AUC', fontsize=fnt1)
plt.legend(['Validation'])
plt.title('Validation PR-AUC', fontsize=fnt2)
plt.savefig('pr_auc_hist.png')
plt.show()

plt.figure(figsize=(fs, fs))
plt.plot(history['train_loss'])
plt.plot(history['val_loss'])
plt.xlabel('Epoch', fontsize=fnt1)
plt.ylabel('Loss (BCE + Dice)', fontsize=fnt1)
plt.legend(['Train', 'Validation'])
plt.title('Training & Validation Loss', fontsize=fnt2)
plt.savefig('loss_hist.png')
plt.show()

In [None]:
# Prediction function
def predict(model, loader, device):
    model.eval()
    all_predictions = []
    with torch.no_grad():
        for images in tqdm(loader, desc="Predicting"):
            images = images.to(device)
            outputs = model(images)
            all_predictions.append(outputs.sigmoid().cpu().numpy())
    return np.vstack(all_predictions)

# Create test dataloader
test_image_list = os.listdir(test_imgs_dir)
test_dataset = CloudDataset(images_list=test_image_list, folder_imgs=test_imgs_dir, is_test=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_cores)

# Get predictions
y_pred_test = predict(model, test_loader, device)

# Load validation predictions to find thresholds
val_pred_for_threshold = predict(model, val_loader, device)
y_true = np.array([label.numpy() for _, label in val_dataset])


# Thresholding logic (unchanged, just uses new predictions)
class_names = ['Fish', 'Flower', 'Sugar', 'Gravel']
recall_thresholds = {}
precision_thresholds = {}

for i, class_name in enumerate(class_names):
    precision, recall, thresholds = precision_recall_curve(y_true[:, i], val_pred_for_threshold[:, i])
    # A simple thresholding strategy: find threshold for a given recall
    target_recall = 0.94
    valid_indices = np.where(recall >= target_recall)[0]
    if len(valid_indices) > 0:
        # threshold is not returned for the last value, so we need to handle that
        if valid_indices[-1] < len(thresholds):
            best_threshold = thresholds[valid_indices[-1]]
        else: # if the best recall is the last one
             best_threshold = 0.99
    else:
        best_threshold = 0.5 # Default fallback
    recall_thresholds[class_name] = best_threshold
    print(f"Threshold for {class_name} at recall > {target_recall}: {recall_thresholds[class_name]:.3f}")

# SUBMISSION
image_labels_empty = set()
for i, (img, predictions) in enumerate(zip(test_image_list, y_pred_test)):
    for class_i, class_name in enumerate(class_names):
        if predictions[class_i] < recall_thresholds[class_name]:
            image_labels_empty.add(f'{img}_{class_name}')

# Assuming segmentation results are from another model/source
# If you don't have this file, you'll need to create a placeholder or generate it
try:
    submission = pd.read_csv('../input/densenet201cloudy/densenet201.csv')
    predictions_nonempty = set(submission.loc[~submission['EncodedPixels'].isnull(), 'Image_Label'].values)
    print(f'{len(image_labels_empty.intersection(predictions_nonempty))} masks would be removed')
    submission.loc[submission['Image_Label'].isin(image_labels_empty), 'EncodedPixels'] = np.nan
    submission.to_csv('submission.csv', index=None)
    print("Submission file created.")
except FileNotFoundError:
    print("Warning: '../input/densenet201cloudy/densenet201.csv' not found.")
    print("Classifier predictions have been made, but submission file cannot be post-processed.")