In [1]:
import time
from typing import Tuple, Dict

import numpy as np
import pandas as pd
import torch
from torch import optim, nn
from tqdm import tqdm

from src.data_utils import (
    DataConfig,
    build_train_loader, build_test_loader, build_unlabeled_loader
)
from src.train_utils import (
    TrainConfig, make_optimizer,
    train_one_epoch, eval_one_epoch, eval_on_gold
)
from src.model import (
    build_resnet_binary, unfreeze_for_finetune, load_checkpoint
)

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
def fit(
    train_loader,
    val_loader,
    cfg: TrainConfig,
    device: torch.device | None = None,
    model_name : str = "resnet50"
) -> Tuple[nn.Module, Dict[str, float]]:
    device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
    save_path = f"best_{model_name}_binary.pt"

    model = build_resnet_binary(model_name, freeze_backbone=True)
    model.to(device)

    # Loss (pos_weight полезен при дисбалансе)
    if cfg.pos_weight is not None:
        pos_w = torch.tensor([cfg.pos_weight], dtype=torch.float32, device=device)
        criterion = nn.BCEWithLogitsLoss(pos_weight=pos_w)
    else:
        criterion = nn.BCEWithLogitsLoss()

    scaler = torch.cuda.amp.GradScaler() if (cfg.use_amp and device.type == "cuda") else None

    best = {"val_f1": -1.0, "val_loss": float("inf")}
    best_state = None

    # --- Stage 1: train head only
    optimizer = make_optimizer(model, lr=cfg.lr_head, weight_decay=cfg.weight_decay)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max(cfg.epochs_head, 1))

    print(f"Device: {device}. Stage1(head): epochs={cfg.epochs_head}, lr={cfg.lr_head}")
    for epoch in range(cfg.epochs_head):
        t0 = time.time()
        tr = train_one_epoch(model, train_loader, optimizer, criterion, device, scaler, cfg)
        va = eval_one_epoch(model, val_loader, criterion, device, cfg)
        scheduler.step()

        print(f"[head {epoch+1:02d}/{cfg.epochs_head}] "
              f"train loss={tr['loss']:.4f} f1={tr['f1']:.3f} | "
              f"val loss={va['loss']:.4f} f1={va['f1']:.3f} "
              f"({time.time()-t0:.1f}s)")

        if va["f1"] > best["val_f1"]:
            best["val_f1"] = va["f1"]
            best["val_loss"] = va["loss"]
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}

    # --- Stage 2: finetune верхние слои
    unfreeze_for_finetune(model, cfg)
    optimizer = make_optimizer(model, lr=cfg.lr_finetune, weight_decay=cfg.weight_decay)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max(cfg.epochs_finetune, 1))

    print(f"\nStage2(finetune): epochs={cfg.epochs_finetune}, lr={cfg.lr_finetune}, unfreeze={cfg.unfreeze}")
    for epoch in range(cfg.epochs_finetune):
        t0 = time.time()
        tr = train_one_epoch(model, train_loader, optimizer, criterion, device, scaler, cfg)
        va = eval_one_epoch(model, val_loader, criterion, device, cfg)
        scheduler.step()

        print(f"[ft   {epoch+1:02d}/{cfg.epochs_finetune}] "
              f"train loss={tr['loss']:.4f} f1={tr['f1']:.3f} | "
              f"val loss={va['loss']:.4f} f1={va['f1']:.3f} "
              f"({time.time()-t0:.1f}s)")

        if va["f1"] > best["val_f1"]:
            best["val_f1"] = va["f1"]
            best["val_loss"] = va["loss"]
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}

    # restore + save best
    if best_state is not None:
        model.load_state_dict(best_state)
        torch.save(model.state_dict(), save_path)
        print(f"\nSaved best checkpoint to: {save_path} (best val_f1={best['val_f1']:.3f})")
    else:
        print("\nWarning: best_state is None (unexpected)")

    return model, best

In [4]:
data_cfg = DataConfig(
    image_size=224,
    batch_size=384,
    num_workers=12,
)

train_loader, val_loader, train_df, val_df, info = build_train_loader(
    "dataset/metadata_processed_clean.csv",
    data_cfg,
    val_ratio=0.1,
    seed=42,
    use_weighted_sampler=True
)

gold_loader, gold_df = build_test_loader(
    "golden/metadata_processed_clean.csv",
    data_cfg,
    return_meta=True,
)

In [5]:
train_cfg = TrainConfig(
    epochs_head=20,
    epochs_finetune=20,
    lr_head=1e-3,
    lr_finetune=3e-4,
    weight_decay=1e-2,
    use_amp=True,
    pos_weight=None,
    unfreeze=True,
    unfreeze_layer4=True
)

#model, best = fit(train_loader, val_loader, train_cfg, model_name="resnet50")

In [6]:
#gold_metrics = eval_on_gold(model, gold_loader, device, threshold=0.5)
#gold_metrics

In [7]:
@torch.no_grad()
def predict_probs(model: nn.Module, loader, device: torch.device):
    model.eval()
    all_probs = []
    all_ids = []
    all_paths = []
    all_urls = []

    for x, item_id, path, url in tqdm(loader, desc="infer", leave=False):
        x = x.to(device, non_blocking=True)
        logits = model(x).view(-1)
        probs = torch.sigmoid(logits).cpu().numpy()

        all_probs.append(probs)
        all_ids.extend(list(item_id))
        all_paths.extend(list(path))
        all_urls.extend(list(url))

    probs = np.concatenate(all_probs, axis=0)
    return probs, all_ids, all_paths, all_urls


def make_pseudo_labels(
    unlabeled_meta_csv: str,
    out_csv: str,
    checkpoints: dict,
    batch_size: int = 128,
    num_workers: int = 8
):
    """
    checkpoints: dict вида
      {
        "resnet18": "path/to/best_resnet18_binary.pt",
        "resnet34": "path/to/best_resnet34_binary.pt",
        "resnet50": "path/to/best_resnet50_binary.pt",
      }
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    loader, _ = build_unlabeled_loader(
        unlabeled_meta_csv,
        data_cfg,
        batch_size=batch_size,
        num_workers=num_workers,
    )

    result_df = None

    for arch, ckpt in checkpoints.items():
        model = build_resnet_binary(arch)
        model = load_checkpoint(model, ckpt, device)
        model.to(device)

        probs, ids, paths, urls = predict_probs(model, loader, device)

        col = f"prob_{arch}"
        cur = pd.DataFrame({
            "id": ids,
            "processed_path": paths,
            "url": urls,
            col: probs.astype(np.float32),
        })

        if result_df is None:
            result_df = cur
        else:
            # смёржим по processed_path (стабильнее), можно и по id
            result_df = result_df.merge(cur[["processed_path", col]], on="processed_path", how="left")

    result_df.to_csv(out_csv, index=False)
    print(f"Saved: {out_csv} (rows={len(result_df)})")
    return result_df

In [8]:
checkpoints = {
    "resnet18": "best_resnet18_binary.pt",
    "resnet34": "best_resnet34_binary.pt",
    "resnet50": "best_resnet50_binary.pt",
}

pseudo_df = make_pseudo_labels(
    unlabeled_meta_csv="unlabeled/metadata_processed_clean.csv",
    out_csv="unlabeled/pseudo_labels.csv",
    checkpoints=checkpoints,
    batch_size=128,
    num_workers=8
)

pseudo_df.head()

                                                        

Saved: unlabeled/pseudo_labels.csv (rows=33044)


Unnamed: 0,id,processed_path,url,prob_resnet18,prob_resnet34,prob_resnet50
0,tensor(0),unlabeled/processed/id_0_9e38212d63.jpg,https://new-projects-team-public.s3.yandex.net...,0.933317,0.982012,0.999828
1,tensor(1),unlabeled/processed/id_1_1fb97f0d7d.jpg,https://new-projects-team-public.s3.yandex.net...,0.649001,0.745515,0.973895
2,tensor(2),unlabeled/processed/id_2_187cb66874.jpg,https://new-projects-team-public.s3.yandex.net...,0.151606,0.459432,0.22005
3,tensor(3),unlabeled/processed/id_3_0d7763cddd.jpg,https://new-projects-team-public.s3.yandex.net...,0.874225,0.870741,0.99023
4,tensor(4),unlabeled/processed/id_4_d696a6cfe9.jpg,https://new-projects-team-public.s3.yandex.net...,9.7e-05,1.3e-05,0.005488
