# Chlorella Classification Pipeline - Cloud Notebook

Ez a notebook egy teljes k√∂r≈± oszt√°lyoz√°si pipeline holografikus mikroszk√≥pos k√©pekhez.

## Tartalom:
1. Setup √©s k√∂rnyezet konfigur√°ci√≥
2. Adatok bet√∂lt√©se √©s feldolgoz√°sa
3. Model defin√≠ci√≥ (ResNet18/ResNeXt-50/VGG11-BN)
4. Training with K-Fold Cross-Validation
5. Evaluation √©s metrik√°k
6. Inference √©s submission gener√°l√°s

## Kaggle/Colab √∫tmutat√≥:
- **Kaggle**: Az adatok a `/kaggle/input/your-dataset-name/` mapp√°ban v√°rhat√≥ak
- **Colab**: T√∂ltsd fel az adatokat vagy csatold Google Drive-ot
- A modell checkpointok a `/kaggle/working/outputs/` vagy `/content/outputs/` mapp√°ba ker√ºlnek

## ‚ö†Ô∏è FONTOS - √Åll√≠tsd be az adatok el√©r√©si √∫tj√°t:
A notebook 3. cell√°j√°ban (Konfigur√°ci√≥) m√≥dos√≠tsd a `data_root` √©rt√©k√©t a saj√°t adataid hely√©re!

## 1. Setup √©s Dependencies

In [None]:
# Import alapvet≈ë library-k
import os
import random
import re
import ssl
from pathlib import Path
from typing import Dict, List, Tuple, Any, Optional, Callable

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from torchvision import models
from PIL import Image
import albumentations as A

from sklearn.model_selection import StratifiedGroupKFold
from sklearn.metrics import (
    fbeta_score, 
    precision_recall_fscore_support, 
    confusion_matrix,
    precision_recall_curve
)

from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

print("‚úì Imports sikeres!")
print(f"PyTorch verzi√≥: {torch.__version__}")
print(f"CUDA el√©rhet≈ë: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 3. Utility Funkci√≥k

In [None]:
# =============== KONFIGUR√ÅCI√ì ===============
# ‚ö†Ô∏è FONTOS: M√≥dos√≠tsd a 'data_root' √©rt√©k√©t a saj√°t adataid hely√©re!

CONFIG = {
    'data': {
        'data_root': '/kaggle/input/itk-nn',  # üëà M√ìDOS√çTSD ezt!
        # Kaggle p√©lda: '/kaggle/input/your-dataset-name'
        # Colab p√©lda: '/content/your-dataset-folder'
        'output_dir': '/kaggle/working/outputs',  # Colab: '/content/outputs'
        'img_size': 224,
        'num_workers': 2  # Cloud k√∂rnyezetben kevesebb worker
    },
    'model': {
        'architecture': 'resnet18',  # 'resnet18', 'resnext50_32x4d', 'vgg11_bn'
        'num_classes': 5,
        'input_channels': 4,
        'pretrained': True  # ‚ö†Ô∏è Kaggle-en kapcsold be az Internet-et a Settings-ben!
    },
    'training': {
        'num_folds': 5,
        'epochs': 20,  # Cloud k√∂rnyezetben kevesebb epoch
        'batch_size': 16,
        'lr_head': 0.001,
        'lr_backbone': 0.0001,
        'weight_decay': 0.0001,
        'patience': 5,
        'unfreeze_epoch': 5
    },
    'augmentation': {
        'rotation_degrees': 10,
        'horizontal_flip_prob': 0.5,
        'vertical_flip_prob': 0.5,
        'brightness': 0.2,
        'contrast': 0.2,
        'blur_prob': 0.3,
        'blur_sigma_min': 0.1,
        'blur_sigma_max': 2.0
    },
    'reproducibility': {
        'seed': 42
    }
}

# Oszt√°ly defin√≠ci√≥k
CLASS_LABELS = [
    {'label_id': 0, 'label_name': 'chlorella', 'folder_name': 'class_chlorella', 'is_priority': True},
    {'label_id': 1, 'label_name': 'debris', 'folder_name': 'class_debris', 'is_priority': False},
    {'label_id': 2, 'label_name': 'haematococcus', 'folder_name': 'class_haematococcus', 'is_priority': False},
    {'label_id': 3, 'label_name': 'small_haematococcus', 'folder_name': 'class_small_haemato', 'is_priority': False},
    {'label_id': 4, 'label_name': 'small_particle', 'folder_name': 'class_small_particle', 'is_priority': False},
]

CLASS_ID_TO_NAME = {cls['label_id']: cls['label_name'] for cls in CLASS_LABELS}
FOLDER_TO_CLASS_ID = {cls['folder_name']: cls['label_id'] for cls in CLASS_LABELS}

IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

# SSL fix macOS-hez (ha sz√ºks√©ges)
ssl._create_default_https_context = ssl._create_unverified_context

print("‚úì Konfigur√°ci√≥ bet√∂ltve!")
print(f"üìÇ Adatok helye: {CONFIG['data']['data_root']}")
print(f"üèóÔ∏è  Architekt√∫ra: {CONFIG['model']['architecture']}")
print(f"üì¶ Batch size: {CONFIG['training']['batch_size']}")
print(f"üîÑ Epochs: {CONFIG['training']['epochs']}")

## 2. Konfigur√°ci√≥

In [None]:
# =============== UTILITY FUNKCI√ìK ===============

def set_seed(seed: int = 42):
    """Random seed be√°ll√≠t√°sa reproduk√°lhat√≥s√°ghoz"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def ensure_dir(directory: str) -> Path:
    """K√∂nyvt√°r l√©trehoz√°sa, ha nem l√©tezik"""
    dir_path = Path(directory)
    dir_path.mkdir(parents=True, exist_ok=True)
    return dir_path

def parse_subject_id(filename: str) -> str:
    """Subject ID kinyer√©se f√°jln√©vb≈ël"""
    basename = Path(filename).stem
    pattern = r'^(.+?)_(?:amp|phase|mask)$'
    match = re.match(pattern, basename)
    return match.group(1) if match else basename

def discover_subjects(data_root: str, split: str = 'train') -> Dict[str, Dict]:
    """K√©pek felfedez√©se √©s csoportos√≠t√°sa subject ID √©s modalit√°s szerint"""
    data_root = Path(data_root)
    split_dir = data_root / split
    
    if not split_dir.exists():
        raise FileNotFoundError(f"Split k√∂nyvt√°r nem tal√°lhat√≥: {split_dir}")
    
    subjects = {}
    
    if split == 'train':
        for class_folder in sorted(split_dir.iterdir()):
            if not class_folder.is_dir():
                continue
            
            folder_name = class_folder.name
            if folder_name not in FOLDER_TO_CLASS_ID:
                continue
            
            class_id = FOLDER_TO_CLASS_ID[folder_name]
            class_name = CLASS_ID_TO_NAME[class_id]
            
            for img_path in sorted(class_folder.glob('*.png')):
                subject_id = parse_subject_id(img_path.name)
                
                modality = None
                if '_amp' in img_path.stem:
                    modality = 'amp'
                elif '_phase' in img_path.stem:
                    modality = 'phase'
                elif '_mask' in img_path.stem:
                    modality = 'mask'
                else:
                    continue
                
                if subject_id not in subjects:
                    subjects[subject_id] = {
                        'subject_id': subject_id,
                        'class_label': class_id,
                        'class_name': class_name,
                        'modalities': {},
                        'split': split
                    }
                
                subjects[subject_id]['modalities'][modality] = img_path
    
    else:  # test
        # Test eset√©n egy f√°jl = egy subject (nincs modalit√°s szepar√°ci√≥)
        png_files = list(split_dir.glob('*.png'))
        
        if not png_files:
            png_files = list(split_dir.glob('**/*.png'))
        
        print(f"‚úì {len(png_files)} test k√©p tal√°lhat√≥")
        
        for img_path in sorted(png_files):
            # Subject ID = f√°jln√©v .png n√©lk√ºl
            subject_id = img_path.stem
            
            # Test eset√©n egy k√©p tartalmazza az √∂sszes modalit√°st
            # Mindh√°rom modalit√°snak ugyanazt a k√©pet haszn√°ljuk
            subjects[subject_id] = {
                'subject_id': subject_id,
                'class_label': None,
                'class_name': None,
                'modalities': {
                    'amp': img_path,
                    'phase': img_path,
                    'mask': img_path
                },
                'split': split
            }
    
    return subjects

def create_subject_folds(subject_ids: List[str], class_labels: List[int], 
                         n_splits: int = 5, seed: int = 42):
    """K-Fold split k√©sz√≠t√©se StratifiedGroupKFold-dal"""
    sgkf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=seed)
    
    X = np.arange(len(subject_ids))
    y = np.array(class_labels)
    groups = np.array(subject_ids)
    
    folds = []
    for train_idx, val_idx in sgkf.split(X, y, groups):
        train_subjects = [subject_ids[i] for i in train_idx]
        val_subjects = [subject_ids[i] for i in val_idx]
        folds.append((train_subjects, val_subjects))
    
    return folds

# Seed be√°ll√≠t√°sa
set_seed(CONFIG['reproducibility']['seed'])
print("‚úì Utility funkci√≥k bet√∂ltve √©s seed be√°ll√≠tva!")

In [None]:
# =============== MODEL √âP√çT√âS ===============

def build_backbone(architecture: str = 'resnet18', pretrained: bool = True):
    """Pre-trained backbone bet√∂lt√©se"""
    if architecture == 'resnet18':
        model = models.resnet18(pretrained=pretrained)
        feature_dim = model.fc.in_features
        model = nn.Sequential(*list(model.children())[:-1])
        return model, feature_dim
    
    elif architecture == 'resnext50_32x4d':
        model = models.resnext50_32x4d(pretrained=pretrained)
        feature_dim = model.fc.in_features
        model = nn.Sequential(*list(model.children())[:-1])
        return model, feature_dim
    
    elif architecture == 'vgg11_bn':
        model = models.vgg11_bn(pretrained=pretrained)
        feature_extractor = model.features
        feature_dim = 512
        return feature_extractor, feature_dim
    
    else:
        raise ValueError(f"Nem t√°mogatott architekt√∫ra: {architecture}")

def adapt_first_conv_for_4ch(model: nn.Module, architecture: str = 'resnet18'):
    """Els≈ë conv r√©teg adapt√°l√°sa 4 csatorn√°s inputra"""
    if architecture in ['resnet18', 'resnext50_32x4d']:
        old_conv = model[0]
        
        new_conv = nn.Conv2d(
            in_channels=4,
            out_channels=old_conv.out_channels,
            kernel_size=old_conv.kernel_size,
            stride=old_conv.stride,
            padding=old_conv.padding,
            bias=old_conv.bias is not None
        )
        
        with torch.no_grad():
            new_conv.weight[:, :3, :, :] = old_conv.weight
            new_conv.weight[:, 3:, :, :] = torch.randn_like(new_conv.weight[:, 3:, :, :]) * 0.01
            if old_conv.bias is not None:
                new_conv.bias.copy_(old_conv.bias)
        
        model[0] = new_conv
    
    elif architecture == 'vgg11_bn':
        old_conv = model[0]
        
        new_conv = nn.Conv2d(
            in_channels=4,
            out_channels=old_conv.out_channels,
            kernel_size=old_conv.kernel_size,
            stride=old_conv.stride,
            padding=old_conv.padding,
            bias=old_conv.bias is not None
        )
        
        with torch.no_grad():
            new_conv.weight[:, :3, :, :] = old_conv.weight
            new_conv.weight[:, 3:, :, :] = torch.randn_like(new_conv.weight[:, 3:, :, :]) * 0.01
            if old_conv.bias is not None:
                new_conv.bias.copy_(old_conv.bias)
        
        model[0] = new_conv
    
    return model

def replace_classifier_head(model: nn.Module, architecture: str, feature_dim: int, num_classes: int = 5):
    """Classifier fej cser√©je"""
    if architecture in ['resnet18', 'resnext50_32x4d']:
        classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(feature_dim, num_classes)
        )
        model = nn.Sequential(model, classifier)
    
    elif architecture == 'vgg11_bn':
        classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((7, 7)),
            nn.Flatten(),
            nn.Linear(feature_dim * 7 * 7, 4096),
            nn.ReLU(True),
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(0.5),
            nn.Linear(4096, num_classes)
        )
        model = nn.Sequential(model, classifier)
    
    return model

class ChlorellaClassifier(nn.Module):
    """Teljes classifier modell"""
    
    def __init__(self, architecture: str = 'resnet18', num_classes: int = 5, 
                 input_channels: int = 4, pretrained: bool = True):
        super().__init__()
        self.architecture = architecture
        self.num_classes = num_classes
        self.input_channels = input_channels
        
        # Model √©p√≠t√©se
        backbone, feature_dim = build_backbone(architecture, pretrained)
        backbone = adapt_first_conv_for_4ch(backbone, architecture)
        self.model = replace_classifier_head(backbone, architecture, feature_dim, num_classes)
    
    def forward(self, x):
        return self.model(x)
    
    def get_backbone_params(self):
        """Backbone param√©terek (discriminative fine-tuning-hoz)"""
        return self.model[0].parameters()
    
    def get_classifier_params(self):
        """Classifier param√©terek"""
        return self.model[1].parameters()

print("‚úì Model architekt√∫ra defini√°lva!")

In [None]:
# =============== TRAINING LOOP ===============

def train_one_epoch(model, dataloader, criterion, optimizer, device):
    """Egy epoch training"""
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0
    
    for inputs, labels in tqdm(dataloader, desc='Training'):
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        # Skip invalid labels
        valid_mask = labels >= 0
        if not valid_mask.any():
            continue
        
        inputs = inputs[valid_mask]
        labels = labels[valid_mask]
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    
    avg_loss = total_loss / total if total > 0 else 0.0
    accuracy = correct / total if total > 0 else 0.0
    
    return avg_loss, accuracy

def validate(model, dataloader, criterion, device):
    """Valid√°ci√≥"""
    model.eval()
    total_loss = 0.0
    all_labels = []
    all_preds = []
    all_probs = []
    
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc='Validation'):
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            valid_mask = labels >= 0
            if not valid_mask.any():
                continue
            
            inputs = inputs[valid_mask]
            labels = labels[valid_mask]
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            probs = torch.softmax(outputs, dim=1)
            _, predicted = outputs.max(1)
            
            total_loss += loss.item() * inputs.size(0)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(predicted.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
    
    y_true = np.array(all_labels)
    y_pred = np.array(all_preds)
    y_probs = np.array(all_probs)
    
    avg_loss = total_loss / len(y_true) if len(y_true) > 0 else 0.0
    accuracy = (y_true == y_pred).mean() if len(y_true) > 0 else 0.0
    chlorella_f0_5 = compute_fbeta_score(y_true, y_pred, beta=0.5, class_id=0)
    
    return avg_loss, accuracy, chlorella_f0_5, y_true, y_pred, y_probs

print("‚úì Training loop funkci√≥k defini√°lva!")

In [None]:
# =============== TRAINING UTILITIES ===============

class EarlyStopping:
    """Early stopping az F0.5 metrika alapj√°n"""
    
    def __init__(self, patience: int = 5, mode: str = 'max', delta: float = 0.0):
        self.patience = patience
        self.mode = mode
        self.delta = delta
        self.counter = 0
        self.best_score = None
        self.should_stop = False
    
    def __call__(self, score: float) -> bool:
        if self.best_score is None:
            self.best_score = score
            return False
        
        if self.mode == 'max':
            improved = score > self.best_score + self.delta
        else:
            improved = score < self.best_score - self.delta
        
        if improved:
            self.best_score = score
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.should_stop = True
        
        return self.should_stop

def freeze_backbone(model: nn.Module):
    """Backbone befagyaszt√°sa"""
    if hasattr(model, 'get_backbone_params'):
        for param in model.get_backbone_params():
            param.requires_grad = False

def unfreeze_backbone(model: nn.Module):
    """Backbone felolvaszt√°sa"""
    if hasattr(model, 'get_backbone_params'):
        for param in model.get_backbone_params():
            param.requires_grad = True

def get_discriminative_optimizer(model: nn.Module, lr_head: float = 1e-3, 
                                 lr_backbone: float = 1e-4, weight_decay: float = 1e-4):
    """Optimizer k√ºl√∂nb√∂z≈ë learning rate-ekkel"""
    if hasattr(model, 'get_backbone_params') and hasattr(model, 'get_classifier_params'):
        optimizer = optim.Adam([
            {'params': model.get_backbone_params(), 'lr': lr_backbone},
            {'params': model.get_classifier_params(), 'lr': lr_head}
        ], weight_decay=weight_decay)
    else:
        optimizer = optim.Adam(model.parameters(), lr=lr_head, weight_decay=weight_decay)
    return optimizer

def compute_fbeta_score(y_true: np.ndarray, y_pred: np.ndarray, beta: float = 0.5, class_id: int = 0):
    """F-beta score sz√°m√≠t√°sa egy oszt√°lyra"""
    y_true_binary = (y_true == class_id).astype(int)
    y_pred_binary = (y_pred == class_id).astype(int)
    score = fbeta_score(y_true_binary, y_pred_binary, beta=beta, zero_division=0.0)
    return float(score)

print("‚úì Training utility funkci√≥k defini√°lva!")

## 5. Training Utilities

In [None]:
# =============== DATASET OSZT√ÅLY ===============

class SubjectDataset(Dataset):
    """Multi-modal holografikus mikroszk√≥pos k√©pek dataset-je"""
    
    def __init__(self, subjects: Dict[str, Dict], transform=None, img_size: int = 224):
        self.subjects = list(subjects.values())
        self.transform = transform if transform else get_val_transforms(img_size)
        self.img_size = img_size
    
    def __len__(self) -> int:
        return len(self.subjects)
    
    def __getitem__(self, idx: int):
        subject = self.subjects[idx]
        modalities = subject['modalities']
        
        # Modalit√°sok bet√∂lt√©se vagy zero-fill
        amp_img, amp_present = self._load_modality(modalities, 'amp')
        phase_img, phase_present = self._load_modality(modalities, 'phase')
        mask_img, mask_present = self._load_modality(modalities, 'mask')
        
        # Augment√°ci√≥k alkalmaz√°sa
        if self.transform:
            transformed = self.transform(image=amp_img, phase=phase_img, mask=mask_img)
            amp_img = transformed['image']
            phase_img = transformed['phase']
            mask_img = transformed['mask']
        
        # NumPy array-ekk√© konvert√°l√°s
        if isinstance(amp_img, Image.Image):
            amp_img = np.array(amp_img)
        if isinstance(phase_img, Image.Image):
            phase_img = np.array(phase_img)
        if isinstance(mask_img, Image.Image):
            mask_img = np.array(mask_img)
        
        # Sz√ºrke√°rnyalatosra biztos√≠t√°s
        amp_img = self._ensure_grayscale(amp_img)
        phase_img = self._ensure_grayscale(phase_img)
        mask_img = self._ensure_grayscale(mask_img)
        
        # Normaliz√°l√°s [0, 1]-re
        amp_img = amp_img.astype(np.float32) / 255.0
        phase_img = phase_img.astype(np.float32) / 255.0
        mask_img = mask_img.astype(np.float32) / 255.0
        
        # 3 csatorn√°ba stackel√©s
        img_3ch = np.stack([amp_img, phase_img, mask_img], axis=0)
        
        # ImageNet normaliz√°l√°s az els≈ë 3 csatorn√°ra
        for i in range(3):
            img_3ch[i] = (img_3ch[i] - IMAGENET_MEAN[i]) / IMAGENET_STD[i]
        
        # Mask indicator csatorna (4. csatorna)
        mask_indicator = np.array([
            float(amp_present), 
            float(phase_present), 
            float(mask_present)
        ], dtype=np.float32).mean()
        
        mask_indicator_ch = np.full(
            (1, img_3ch.shape[1], img_3ch.shape[2]), 
            mask_indicator, 
            dtype=np.float32
        )
        
        # 4 csatorn√°s tensor √∂ssze√°ll√≠t√°sa
        img_4ch = np.concatenate([img_3ch, mask_indicator_ch], axis=0)
        tensor = torch.from_numpy(img_4ch).float()
        
        # Label
        label = subject['class_label'] if subject['class_label'] is not None else -1
        
        return tensor, label
    
    def _load_modality(self, modalities: Dict[str, Path], modality_type: str):
        """Modalit√°s bet√∂lt√©se vagy zeros visszaad√°sa"""
        if modality_type in modalities:
            img_path = modalities[modality_type]
            try:
                img = Image.open(img_path).convert('L')
                return np.array(img), True
            except Exception as e:
                print(f"Figyelmeztet√©s: {img_path} bet√∂lt√©se sikertelen: {e}")
                return np.zeros((self.img_size, self.img_size), dtype=np.uint8), False
        else:
            return np.zeros((self.img_size, self.img_size), dtype=np.uint8), False
    
    def _ensure_grayscale(self, img: np.ndarray) -> np.ndarray:
        """Sz√ºrke√°rnyalatosra biztos√≠t√°s"""
        if img.ndim == 3:
            img = img.mean(axis=2)
        elif img.ndim != 2:
            raise ValueError(f"V√°ratlan k√©p m√©ret: {img.shape}")
        return img

print("‚úì Dataset oszt√°ly defini√°lva!")

In [None]:
# =============== AUGMENT√ÅCI√ìK ===============

def get_train_transforms(img_size: int = 224):
    """Training augment√°ci√≥s pipeline"""
    cfg = CONFIG['augmentation']
    return A.Compose([
        A.Resize(height=img_size, width=img_size),
        A.Rotate(limit=cfg['rotation_degrees'], p=0.5, border_mode=0),
        A.HorizontalFlip(p=cfg['horizontal_flip_prob']),
        A.VerticalFlip(p=cfg['vertical_flip_prob']),
        A.ColorJitter(brightness=cfg['brightness'], contrast=cfg['contrast'], p=0.3),
        A.GaussianBlur(blur_limit=(3, 7), 
                       sigma_limit=(cfg['blur_sigma_min'], cfg['blur_sigma_max']), 
                       p=cfg['blur_prob']),
    ], additional_targets={'phase': 'image', 'mask': 'mask'})

def get_val_transforms(img_size: int = 224):
    """Validation/test transform (csak resize)"""
    return A.Compose([
        A.Resize(height=img_size, width=img_size),
    ], additional_targets={'phase': 'image', 'mask': 'mask'})

print("‚úì Augment√°ci√≥k defini√°lva!")

## 4. Dataset √©s Augment√°ci√≥k

In [None]:
# =============== FOLD TRAINING ===============

def train_one_fold(model, train_loader, val_loader, fold_id, config, device, output_dir):
    """Egy fold teljes training-je k√©t szakaszban"""
    
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    epochs = config['training']['epochs']
    unfreeze_epoch = config['training']['unfreeze_epoch']
    patience = config['training']['patience']
    lr_head = config['training']['lr_head']
    lr_backbone = config['training']['lr_backbone']
    weight_decay = config['training']['weight_decay']
    
    criterion = nn.CrossEntropyLoss()
    early_stopping = EarlyStopping(patience=patience, mode='max')
    
    best_f0_5 = 0.0
    best_epoch = 0
    best_val_predictions = None
    
    # Stage 1: Head-only training
    print(f"\\n[Fold {fold_id}] Stage 1: Classifier head training (backbone frozen)")
    freeze_backbone(model)
    optimizer = optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()), 
        lr=lr_head, 
        weight_decay=weight_decay
    )
    
    for epoch in range(min(unfreeze_epoch, epochs)):
        train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc, chlorella_f0_5, y_true, y_pred, y_probs = validate(
            model, val_loader, criterion, device
        )
        
        print(f"[Fold {fold_id}] Epoch {epoch+1}/{epochs} | "
              f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}, F0.5: {chlorella_f0_5:.4f}")
        
        if chlorella_f0_5 > best_f0_5:
            best_f0_5 = chlorella_f0_5
            best_epoch = epoch + 1
            
            checkpoint_path = output_dir / f'fold_{fold_id}_best.pth'
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'metric_value': chlorella_f0_5,
                'config': config
            }, checkpoint_path)
            
            best_val_predictions = {'y_true': y_true, 'y_pred': y_pred, 'y_probs': y_probs}
            print(f"[Fold {fold_id}] √öj legjobb F0.5: {chlorella_f0_5:.4f} ‚Üí Checkpoint mentve")
        
        if early_stopping(chlorella_f0_5):
            print(f"[Fold {fold_id}] Early stopping triggered at epoch {epoch+1}")
            break
    
    # Stage 2: Full network fine-tuning
    if epoch + 1 >= unfreeze_epoch and not early_stopping.should_stop:
        print(f"\\n[Fold {fold_id}] Stage 2: Full network fine-tuning (backbone unfrozen)")
        unfreeze_backbone(model)
        optimizer = get_discriminative_optimizer(
            model, lr_head=lr_head, lr_backbone=lr_backbone, weight_decay=weight_decay
        )
        early_stopping = EarlyStopping(patience=patience, mode='max')
        
        for epoch in range(unfreeze_epoch, epochs):
            train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
            val_loss, val_acc, chlorella_f0_5, y_true, y_pred, y_probs = validate(
                model, val_loader, criterion, device
            )
            
            print(f"[Fold {fold_id}] Epoch {epoch+1}/{epochs} | "
                  f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | "
                  f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}, F0.5: {chlorella_f0_5:.4f}")
            
            if chlorella_f0_5 > best_f0_5:
                best_f0_5 = chlorella_f0_5
                best_epoch = epoch + 1
                
                checkpoint_path = output_dir / f'fold_{fold_id}_best.pth'
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'metric_value': chlorella_f0_5,
                    'config': config
                }, checkpoint_path)
                
                best_val_predictions = {'y_true': y_true, 'y_pred': y_pred, 'y_probs': y_probs}
                print(f"[Fold {fold_id}] √öj legjobb F0.5: {chlorella_f0_5:.4f} ‚Üí Checkpoint mentve")
            
            if early_stopping(chlorella_f0_5):
                print(f"[Fold {fold_id}] Early stopping triggered at epoch {epoch+1}")
                break
    
    print(f"\\n[Fold {fold_id}] Training befejezve. Legjobb F0.5: {best_f0_5:.4f} at epoch {best_epoch}")
    
    return {
        'best_f0_5': best_f0_5,
        'best_epoch': best_epoch,
        'val_predictions': best_val_predictions
    }

print("‚úì Fold training funkci√≥ defini√°lva!")

## 7. Training Loop (√ñsszes Fold)

## 6. Adatok Bet√∂lt√©se √©s K-Fold Splits

In [None]:
# =============== K-FOLD SPLITS L√âTREHOZ√ÅSA ===============

# Subject-ek √©s labelek kigy≈±jt√©se
subject_ids = list(all_subjects.keys())
class_labels = [all_subjects[sid]['class_label'] for sid in subject_ids]

# K-Fold splits
num_folds = CONFIG['training']['num_folds']
print(f"\nüîÄ {num_folds}-fold cross-validation splits l√©trehoz√°sa...")

folds = create_subject_folds(
    subject_ids, 
    class_labels, 
    n_splits=num_folds, 
    seed=CONFIG['reproducibility']['seed']
)

print(f"‚úì {len(folds)} fold elk√©sz√ºlt!")

# Fold-ok statisztik√°ja
print("\nüìà Fold-ok m√©rete:")
for fold_id, (train_sids, val_sids) in enumerate(folds):
    train_count = len(train_sids)
    val_count = len(val_sids)
    split_ratio = (val_count / (train_count + val_count)) * 100
    print(f"  Fold {fold_id}: {train_count} train, {val_count} val ({split_ratio:.1f}% val)")

In [None]:
# =============== SUBJECTS FELFEDEZ√âSE ===============

print("\nüîç Training subjects felfedez√©se...")
try:
    all_subjects = discover_subjects(data_root, split='train')
    print(f"‚úì {len(all_subjects)} subject tal√°lva!")
    
    # Oszt√°ly eloszl√°s
    class_counts = {}
    for subject in all_subjects.values():
        class_name = subject['class_name']
        class_counts[class_name] = class_counts.get(class_name, 0) + 1
    
    print("\nüìä Oszt√°ly eloszl√°s:")
    for class_name, count in sorted(class_counts.items()):
        print(f"  {class_name}: {count}")
    
    # Modalit√°s statisztik√°k
    modality_counts = {'amp': 0, 'phase': 0, 'mask': 0}
    for subject in all_subjects.values():
        for modality in subject['modalities'].keys():
            modality_counts[modality] += 1
    
    print("\nüé≠ Modalit√°s lefedetts√©g:")
    for modality, count in modality_counts.items():
        percentage = (count / len(all_subjects)) * 100
        print(f"  {modality}: {count}/{len(all_subjects)} ({percentage:.1f}%)")

except Exception as e:
    print(f"\n‚ùå Hiba az adatok bet√∂lt√©sekor: {e}")
    print("\nEllen≈ërizd az adatok strukt√∫r√°j√°t!")
    print("V√°rt strukt√∫ra:")
    print("  data_root/")
    print("    train/")
    print("      class_chlorella/")
    print("        123_amp.png")
    print("        123_phase.png")
    print("        123_mask.png")
    raise

In [None]:
# =============== ADATOK BET√ñLT√âSE ===============

data_root = CONFIG['data']['data_root']
print(f"üìÇ Adatok keres√©se: {data_root}")

# El√©rhet≈ë k√∂nyvt√°rak ellen≈ërz√©se
if os.path.exists(data_root):
    print("\nüìÅ El√©rhet≈ë k√∂nyvt√°rak:")
    for item in sorted(os.listdir(data_root)):
        item_path = os.path.join(data_root, item)
        if os.path.isdir(item_path):
            print(f"  ‚îú‚îÄ {item}")
else:
    print(f"\n‚ùå FIGYELEM: A '{data_root}' k√∂nyvt√°r nem tal√°lhat√≥!")
    print("\nüí° √Åll√≠tsd be a helyes el√©r√©si utat:")
    print("  - Kaggle: /kaggle/input/your-dataset-name/")
    print("  - Colab: /content/your-dataset-folder/")
    print("\nM√≥dos√≠tsd a 6. cell√°ban (Konfigur√°ci√≥) a CONFIG['data']['data_root'] √©rt√©k√©t!")
    raise FileNotFoundError(f"Az adatok k√∂nyvt√°ra nem tal√°lhat√≥: {data_root}")

In [None]:
# =============== TRAINING √ñSSZES FOLD-RA ===============

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\\nEszk√∂z: {device}")

output_dir = ensure_dir(CONFIG['data']['output_dir'])
checkpoints_dir = ensure_dir(output_dir / 'checkpoints')

# Training minden fold-ra
fold_results = []

for fold_id, (train_sids, val_sids) in enumerate(folds):
    print(f"\\n{'='*60}")
    print(f"FOLD {fold_id} TRAINING")
    print(f"{'='*60}")
    
    # Subjects sz≈±r√©se
    train_subjects = {sid: all_subjects[sid] for sid in train_sids}
    val_subjects = {sid: all_subjects[sid] for sid in val_sids}
    
    # Datasets l√©trehoz√°sa
    train_dataset = SubjectDataset(
        train_subjects,
        transform=get_train_transforms(CONFIG['data']['img_size']),
        img_size=CONFIG['data']['img_size']
    )
    
    val_dataset = SubjectDataset(
        val_subjects,
        transform=get_val_transforms(CONFIG['data']['img_size']),
        img_size=CONFIG['data']['img_size']
    )
    
    # DataLoaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=CONFIG['training']['batch_size'],
        shuffle=True,
        num_workers=CONFIG['data']['num_workers'],
        pin_memory=True if torch.cuda.is_available() else False
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=CONFIG['training']['batch_size'],
        shuffle=False,
        num_workers=CONFIG['data']['num_workers'],
        pin_memory=True if torch.cuda.is_available() else False
    )
    
    # Model l√©trehoz√°sa
    model = ChlorellaClassifier(
        architecture=CONFIG['model']['architecture'],
        num_classes=CONFIG['model']['num_classes'],
        input_channels=CONFIG['model']['input_channels'],
        pretrained=CONFIG['model']['pretrained']
    ).to(device)
    
    # Training
    fold_result = train_one_fold(
        model, 
        train_loader, 
        val_loader, 
        fold_id, 
        CONFIG, 
        device, 
        checkpoints_dir
    )
    
    fold_results.append(fold_result)
    
    # Mem√≥ria felszabad√≠t√°s
    del model, train_loader, val_loader, train_dataset, val_dataset
    torch.cuda.empty_cache() if torch.cuda.is_available() else None

print(f"\\n{'='*60}")
print("√ñSSZES FOLD TRAINING BEFEJEZVE")
print(f"{'='*60}")

## üéâ K√©sz!

A pipeline sikeresen lefutott. Az eredm√©nyek:

### üìÅ Mentett f√°jlok:
- **Checkpointok**: `/kaggle/working/outputs/checkpoints/fold_X_best.pth`
- **Training summary**: `/kaggle/working/outputs/reports/training_summary.json`
- **Submission**: `/kaggle/working/outputs/submissions/submission.csv`

### üìä K√∂vetkez≈ë l√©p√©sek:
1. T√∂ltsd le a `submission.csv` f√°jlt
2. Ellen≈ërizd a training metrik√°kat
3. K√≠s√©rletezz a hyperparam√©terekkel (epochs, learning rate, augment√°ci√≥k)
4. Pr√≥b√°lj m√°s architekt√∫r√°t (resnext50_32x4d, vgg11_bn)

In [None]:
# =============== TEST SET INFERENCE ===============

print("\nüîÆ Test set inference elkezd√©se...")

# Test subjects bet√∂lt√©se
try:
    test_root = Path(CONFIG['data']['data_root']) / 'test'
    print(f"üîç Test mapp√°t keresem: {test_root}")
    
    if not test_root.exists():
        raise FileNotFoundError(f"Test mappa nem tal√°lhat√≥: {test_root}")
    
    # Ellen≈ërizz√ºk hogy vannak-e PNG f√°jlok
    png_count = len(list(test_root.glob('*.png')))
    if png_count == 0:
        png_count = len(list(test_root.glob('**/*.png')))
    
    print(f"üìä {png_count} PNG f√°jl tal√°lhat√≥ a test mapp√°ban")
    
    test_subjects = discover_subjects(CONFIG['data']['data_root'], split='test')
    print(f"‚úì {len(test_subjects)} test subject tal√°lva!")
except Exception as e:
    print(f"‚ö†Ô∏è Test set nem tal√°lhat√≥: {e}")
    print("Kihagyom a test inference-t.")
    test_subjects = {}

if test_subjects:
    # Test dataset l√©trehoz√°sa
    test_dataset = SubjectDataset(
        test_subjects,
        transform=get_val_transforms(CONFIG['data']['img_size']),
        img_size=CONFIG['data']['img_size']
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=CONFIG['training']['batch_size'],
        shuffle=False,
        num_workers=CONFIG['data']['num_workers'],
        pin_memory=True if torch.cuda.is_available() else False
    )
    
    # Ensemble prediction - haszn√°ljuk az √∂sszes fold legjobb modellj√©t
    all_predictions = []
    all_subject_ids = list(test_subjects.keys())
    
    print(f"\nü§ñ Ensemble prediction {len(fold_results)} modellel...")
    
    for fold_id in range(len(fold_results)):
        checkpoint_path = checkpoints_dir / f'fold_{fold_id}_best.pth'
        
        if not checkpoint_path.exists():
            print(f"‚ö†Ô∏è Fold {fold_id} checkpoint nem tal√°lhat√≥: {checkpoint_path}")
            continue
        
        # Model bet√∂lt√©se
        model = ChlorellaClassifier(
            architecture=CONFIG['model']['architecture'],
            num_classes=CONFIG['model']['num_classes'],
            input_channels=CONFIG['model']['input_channels'],
            pretrained=False  # Nem kell pretrained, mert bet√∂ltj√ºk a s√∫lyokat
        ).to(device)
        
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval()
        
        # Predikci√≥
        fold_probs = []
        with torch.no_grad():
            for inputs, _ in tqdm(test_loader, desc=f'Fold {fold_id} inference'):
                inputs = inputs.to(device)
                outputs = model(inputs)
                probs = torch.softmax(outputs, dim=1)
                fold_probs.extend(probs.cpu().numpy())
        
        all_predictions.append(np.array(fold_probs))
        
        del model
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    # Ensemble: √°tlagoljuk a val√≥sz√≠n≈±s√©geket
    ensemble_probs = np.mean(all_predictions, axis=0)
    ensemble_preds = np.argmax(ensemble_probs, axis=1)
    
    # Submission DataFrame l√©trehoz√°sa
    submission_df = pd.DataFrame({
        'subject_id': all_subject_ids,
        'predicted_class': ensemble_preds
    })
    
    # Submission ment√©se
    submissions_dir = ensure_dir(output_dir / 'submissions')
    submission_path = submissions_dir / 'submission.csv'
    submission_df.to_csv(submission_path, index=False)
    
    print(f"\n‚úÖ Submission elk√©sz√ºlt!")
    print(f"üíæ Mentve: {submission_path}")
    print(f"üìä {len(submission_df)} predikci√≥")
    
    # Predikci√≥ eloszl√°s
    pred_counts = submission_df['predicted_class'].value_counts().sort_index()
    print("\nüìà Predikci√≥ eloszl√°s:")
    for class_id, count in pred_counts.items():
        class_name = CLASS_ID_TO_NAME.get(class_id, 'unknown')
        percentage = (count / len(submission_df)) * 100
        print(f"   ‚Ä¢ {class_name}: {count} ({percentage:.1f}%)")
    
    # El≈ën√©zet
    print("\nüëÄ Submission el≈ën√©zet (els≈ë 10 sor):")
    print(submission_df.head(10).to_string(index=False))
else:
    print("\n‚ö†Ô∏è Test set nem el√©rhet≈ë - submission f√°jl nem k√©sz√ºlt.")

## 9. Test Set Inference √©s Submission

In [None]:
# =============== EREDM√âNYEK √ñSSZEGZ√âSE ===============

import json

# Metrik√°k aggreg√°l√°sa
f0_5_scores = [result['best_f0_5'] for result in fold_results]
best_epochs = [result['best_epoch'] for result in fold_results]

summary = {
    'num_folds': len(fold_results),
    'architecture': CONFIG['model']['architecture'],
    'avg_f0_5': float(np.mean(f0_5_scores)),
    'std_f0_5': float(np.std(f0_5_scores)),
    'min_f0_5': float(np.min(f0_5_scores)),
    'max_f0_5': float(np.max(f0_5_scores)),
    'fold_scores': [float(score) for score in f0_5_scores],
    'fold_best_epochs': [int(epoch) for epoch in best_epochs]
}

# Ment√©s JSON-ba
reports_dir = ensure_dir(output_dir / 'reports')
summary_path = reports_dir / 'training_summary.json'

with open(summary_path, 'w') as f:
    json.dump(summary, f, indent=2)

# Eredm√©nyek ki√≠r√°sa
print("\n" + "="*60)
print("üìä TRAINING EREDM√âNYEK √ñSSZEFOGLAL√ÅSA")
print("="*60)
print(f"\nüèóÔ∏è  Architekt√∫ra: {summary['architecture']}")
print(f"üìÅ Fold-ok sz√°ma: {summary['num_folds']}")
print(f"\nüéØ Chlorella F0.5 Score:")
print(f"   ‚Ä¢ √Åtlag: {summary['avg_f0_5']:.4f} ¬± {summary['std_f0_5']:.4f}")
print(f"   ‚Ä¢ Min:   {summary['min_f0_5']:.4f}")
print(f"   ‚Ä¢ Max:   {summary['max_f0_5']:.4f}")
print(f"\nüìà Fold-onk√©nti eredm√©nyek:")
for fold_id, (score, epoch) in enumerate(zip(f0_5_scores, best_epochs)):
    print(f"   ‚Ä¢ Fold {fold_id}: F0.5 = {score:.4f} (epoch {epoch})")

print(f"\nüíæ Eredm√©nyek mentve: {summary_path}")
print("="*60)

## 8. Eredm√©nyek √ñsszegz√©se