In [None]:
import os
import re
import copy
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt

from torchvision import models, transforms
from torch.utils.data import DataLoader, ConcatDataset, Dataset
from torchvision.datasets import VOCSegmentation, VOCDetection
from torchvision.models.segmentation import deeplabv3_resnet50
from torchvision.models.detection import fasterrcnn_resnet50_fpn

import albumentations as A
from albumentations.pytorch import ToTensorV2

from torchmetrics import JaccardIndex
from torchvision.ops import box_iou

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ROOT = "VOCdevkit"
CHECKPOINT = "checkpoints/resnet50_byol_imagenet2012.pth.tar"

VOC_CLASSES = [
    "aeroplane", "bicycle", "bird", "boat", "bottle",
    "bus", "car", "cat", "chair", "cow",
    "diningtable", "dog", "horse", "motorbike", "person",
    "pottedplant", "sheep", "sofa", "train", "tvmonitor"
]
CLASS_TO_IDX = {c: i for i, c in enumerate(VOC_CLASSES)}

def smart_load_byol_checkpoint_to_resnet50(ckpt_path: str, device='cpu'):
    ckpt_path = str(ckpt_path)
    assert os.path.exists(ckpt_path), f"Checkpoint not found: {ckpt_path}"
    ckpt = torch.load(
        ckpt_path,
        map_location='cpu',
        weights_only=False
    )['online_backbone']

    if isinstance(ckpt, dict):
        for k in ('state_dict', 'model', 'net', 'online_encoder', 'online_network', 'encoder', 'backbone'):
            if k in ckpt:
                candidate = ckpt[k]
                break
        else:
            candidate = ckpt
    else:
        candidate = ckpt

    if isinstance(candidate, dict) and any(isinstance(v, torch.Tensor) for v in candidate.values()):
        sd = copy.deepcopy(candidate)
    else:
        raise RuntimeError("Не удалось извлечь state_dict из checkpoint (структура неожиданна). "
                           "Посмотрите yaox12/utils/load_and_convert.py в репо для инструкции. "
                           "Ссылка: https://github.com/yaox12/BYOL-PyTorch (см. utils).")

    def strip_prefix(key):
        prefixes = ['module.', 'online_encoder.', 'online_network.', 'net.', 'encoder.']
        for p in prefixes:
            if key.startswith(p):
                return key[len(p):]
        return key

    new_sd = {}
    for k, v in sd.items():
        newk = strip_prefix(k)
        if newk.startswith('predictor') or newk.startswith('projector') or 'num_batches_tracked' in newk:
            continue
        new_sd[newk] = v

    model = models.resnet50(weights=None)
    model.fc = nn.Identity()

    try:
        model.load_state_dict(new_sd, strict=False)
        print("Loaded checkpoint into torchvision resnet50 (strict=False).")
    except Exception as e:
        print("Не удалось загрузить чекпоинт. Ошибка:", e)
    return model


In [None]:
class VOCSegDataset(Dataset):
    def __init__(self, root, year='2007', image_set='trainval', transforms=None):
        self.ds = VOCSegmentation(root=root, year=year, image_set=image_set, download=False)
        self.transforms = transforms
    def __len__(self):
        return len(self.ds)
    def __getitem__(self, idx):
        img, mask = self.ds[idx]
        img = np.array(img)
        mask = np.array(mask)
        if self.transforms:
            augmented = self.transforms(image=img, mask=mask)
            img = augmented['image']
            mask = augmented['mask'].long()
        return img, mask

class VOCDatasetForDet(Dataset):
    def __init__(self, root, year='2007', image_set='trainval', transforms=None):
        self.ds = VOCDetection(root=root, year=year, image_set=image_set, download=False)
        self.transforms = transforms
    def __len__(self):
        return len(self.ds)
    def __getitem__(self, idx):
        img, ann = self.ds[idx]
        objs = ann['annotation'].get('object', [])
        if isinstance(objs, dict):
            objs = [objs]
        boxes, labels, areas, iscrowd = [], [], [], []
        for o in objs:
            bb = o['bndbox']
            xmin, ymin, xmax, ymax = float(bb['xmin']), float(bb['ymin']), float(bb['xmax']), float(bb['ymax'])
            boxes.append([xmin, ymin, xmax, ymax])
            labels.append(CLASS_TO_IDX[o['name']] + 1)
            areas.append((xmax - xmin) * (ymax - ymin))
            iscrowd.append(0)
        boxes = torch.tensor(boxes, dtype=torch.float32) if boxes else torch.zeros((0,4), dtype=torch.float32)
        labels = torch.tensor(labels, dtype=torch.int64) if labels else torch.zeros((0,), dtype=torch.int64)
        target = {
            'boxes': boxes,
            'labels': labels,
            'image_id': torch.tensor([idx]),
            'area': torch.tensor(areas, dtype=torch.float32) if areas else torch.zeros((0,), dtype=torch.float32),
            'iscrowd': torch.tensor(iscrowd, dtype=torch.int64) if iscrowd else torch.zeros((0,), dtype=torch.int64)
        }
        if self.transforms:
            img = self.transforms(img)
        return img, target

def det_collate_fn(batch):
    imgs = [b[0] for b in batch]
    tgts = [b[1] for b in batch]
    return imgs, tgts

def seg_collate_fn(batch):
    imgs, masks = zip(*batch)
    max_height = max(img.shape[1] for img in imgs)
    max_width  = max(img.shape[2] for img in imgs)
    padded_imgs, padded_masks = [], []
    for img, mask in zip(imgs, masks):
        _, h, w = img.shape
        pad_h = max_height - h
        pad_w = max_width - w
        img = F.pad(img, (0, pad_w, 0, pad_h), mode='constant', value=0)
        mask = F.pad(mask, (0, pad_w, 0, pad_h), mode='constant', value=255)
        padded_imgs.append(img)
        padded_masks.append(mask)
    imgs = torch.stack(padded_imgs, 0)
    masks = torch.stack(padded_masks, 0)
    return imgs, masks

In [None]:
SEG_TRAIN_TRANSFORM = A.Compose([
    A.RandomHorizontalFlip(),
    A.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
    ToTensorV2(),
])
SEG_TEST_TRANSFORM = A.Compose([
    A.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
    ToTensorV2(),
])
DET_TRAIN_TRANSFORM = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])
DET_TEST_TRANSFORM = transforms.Compose([
    transforms.ToTensor(),
])

In [None]:
def build_deeplab_with_backbone(backbone: nn.Module, num_classes: int = 21):
    model = deeplabv3_resnet50(weights=None, weights_backbone=None, num_classes=num_classes)
    src_sd = backbone.state_dict()
    tgt_sd = model.backbone.state_dict()
    for k, v in src_sd.items():
        if k in tgt_sd and tgt_sd[k].shape == v.shape:
            tgt_sd[k] = v
    model.backbone.load_state_dict(tgt_sd, strict=False)
    with torch.no_grad():
        ok = torch.allclose(
            backbone.conv1.weight,
            model.backbone.state_dict()["conv1.weight"].to(backbone.conv1.weight.device),
        )
    if not ok:
        raise RuntimeError("Backbone weights were NOT loaded correctly")
    return model

def build_fasterrcnn_with_backbone(backbone: nn.Module, num_classes: int = 21):
    model = fasterrcnn_resnet50_fpn(pretrained=False, pretrained_backbone=False, num_classes=num_classes)
    try:
        model.backbone.body.load_state_dict(backbone.state_dict(), strict=False)
    except Exception:
        pass
    return model

In [None]:
def get_lr(step):
    if step < LR_WARMUP_ITERS:
        alpha = step / LR_WARMUP_ITERS
        return BASE_LR * (LR_WARMUP_FACTOR * (1 - alpha) + alpha)
    elif step < STEP_LR_MILESTONES[0]:
        return BASE_LR
    elif step < STEP_LR_MILESTONES[1]:
        return BASE_LR * STEP_LR_GAMMA
    else:
        return BASE_LR * STEP_LR_GAMMA * STEP_LR_GAMMA

In [None]:
def evaluate_segmentation_miou(model, test_loader, device, num_classes=21, ignore_index=255):
    model.eval()
    metric = JaccardIndex(task='multiclass', num_classes=num_classes, ignore_index=ignore_index).to(device)
    with torch.no_grad():
        for imgs, masks in tqdm(test_loader, desc="Seg Eval"):
            imgs = imgs.to(device)
            masks = masks.to(device)
            logits = model(imgs)['out']
            preds = logits.argmax(1)
            metric.update(preds, masks)
    miou = metric.compute().item()
    return miou

In [None]:
def train_segmentation_iters(
    model, train_loader, val_loader, device,
    max_iters=24000, accumulation_steps=4,  
    get_lr=None, val_every=2000,
):
    model = model.to(device)
    model.train()
    criterion = nn.CrossEntropyLoss(ignore_index=255)
    optimizer = torch.optim.SGD(model.parameters(), lr=BASE_LR, momentum=MOMENTUM, weight_decay=1e-4)
    scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None

    train_iter = iter(train_loader)
    train_losses = []
    val_mious = []
    step = 0
    with tqdm(total=max_iters, desc='Segmentation train') as pbar:
        while step < max_iters:
            optimizer.zero_grad()
            loss_accum = 0.0
            for inner in range(accumulation_steps):
                try:
                    imgs, masks = next(train_iter)
                except StopIteration:
                    train_iter = iter(train_loader)
                    imgs, masks = next(train_iter)
                imgs = imgs.to(device)
                masks = masks.to(device)
                lr_now = get_lr(step)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_now
                with torch.cuda.amp.autocast(enabled=(scaler is not None)):
                    output = model(imgs)['out']
                    loss = criterion(output, masks) / accumulation_steps
                if scaler is not None:
                    scaler.scale(loss).backward()
                else:
                    loss.backward()
                loss_accum += loss.item()
            if scaler is not None:
                scaler.step(optimizer)
                scaler.update()
            else:
                optimizer.step()
            pbar.update(1)
            train_losses.append(loss_accum)
            step += 1
            pbar.set_postfix(loss=loss_accum)
            if (step % val_every == 0 or step == max_iters):
                miou = evaluate_segmentation_miou(model, val_loader, device, num_classes=21)
                model.train()
                val_mious.append((step, miou))
    return model, (train_losses, val_mious)

In [None]:
def voc2007_evaluate(model, data_loader, device, iou_threshs=[0.5, 0.75, 1.0], score_thresh=0.05):
    model.eval()
    gt = {}
    preds_by_class = {c: [] for c in range(len(VOC_CLASSES))}
    with torch.no_grad():
        for imgs, targets in tqdm(data_loader, desc="Gather predictions"):
            imgs_device = [im.to(device) for im in imgs]
            outputs = model(imgs_device)
            for out, t in zip(outputs, targets):
                img_id = int(t['image_id'].item())
                gt_boxes = t['boxes'].cpu()
                gt_labels = t['labels'].cpu()
                gt[img_id] = {'boxes': gt_boxes, 'labels': gt_labels, 'detected': [False]*len(gt_labels)}
                boxes = out.get('boxes', torch.empty((0,4))).cpu()
                labels = out.get('labels', torch.empty((0,), dtype=torch.int64)).cpu()
                scores = out.get('scores', torch.empty((0,))).cpu()
                for b, lab, s in zip(boxes, labels, scores):
                    cls = int(lab.item()) - 1
                    if cls < 0 or cls >= len(VOC_CLASSES): continue
                    if s < score_thresh: continue
                    preds_by_class[cls].append({'img_id': img_id, 'score': float(s.item()), 'box': b.numpy()})
    mAPs = []
    for iou_thresh in iou_threshs:
        ap_per_class = []
        for cls in range(len(VOC_CLASSES)):
            preds = sorted(preds_by_class[cls], key=lambda x: x['score'], reverse=True)
            npos = sum((gt[i]['labels'] == (cls+1)).sum().item() for i in gt)
            if npos == 0:
                ap_per_class.append(0.0)
                continue
            tp = np.zeros(len(preds))
            fp = np.zeros(len(preds))
            for i, p in enumerate(preds):
                img_id = p['img_id']
                pred_box = p['box']
                gt_entry = gt.get(img_id, {'boxes': torch.zeros((0,4)), 'labels': torch.zeros((0,)), 'detected':[]})
                gt_boxes = gt_entry['boxes'].numpy() if len(gt_entry['boxes'])>0 else np.zeros((0,4))
                gt_labels = gt_entry['labels'].numpy() if len(gt_entry['labels'])>0 else np.zeros((0,))
                same_idx = np.where(gt_labels == (cls+1))[0]
                if same_idx.size == 0:
                    fp[i] = 1
                    continue
                ious = []
                for gi in same_idx:
                    ious.append(box_iou(torch.tensor(pred_box).unsqueeze(0), gt_entry['boxes'][gi].unsqueeze(0)).item())
                ious = np.array(ious) if ious else np.array([])
                if ious.size == 0 or ious.max() < iou_thresh:
                    fp[i] = 1
                else:
                    sorted_idx = np.argmax(ious)
                    global_gt_idx = same_idx[sorted_idx]
                    if not gt_entry['detected'][global_gt_idx]:
                        tp[i] = 1
                        gt_entry['detected'][global_gt_idx] = True
                    else:
                        fp[i] = 1
            tp_c = np.cumsum(tp)
            fp_c = np.cumsum(fp)
            rec = tp_c / float(npos)
            prec = tp_c / np.maximum(tp_c + fp_c, np.finfo(np.float64).eps)
            ap = 0.0
            for t in np.linspace(0, 1, 11):
                p_vals = prec[rec >= t]
                p_val = np.max(p_vals) if p_vals.size > 0 else 0.0
                ap += p_val / 11.0
            ap_per_class.append(ap)
        mAP = float(np.mean(ap_per_class))
        mAPs.append(mAP)
    return mAPs

def train_detection_iters(
    model, train_loader, val_loader, device,
    max_iters=24000, accumulation_steps=4,
    get_lr=None, val_every=2000
):
    model.to(device)
    model.train()
    optimizer = torch.optim.SGD([p for p in model.parameters() if p.requires_grad],
                                lr=BASE_LR, momentum=MOMENTUM, weight_decay=1e-4)
    scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None

    losses_per_step = []
    val_maps = []
    step = 0
    train_iter = iter(train_loader)
    optimizer.zero_grad()
    with tqdm(total=max_iters, desc='Detection train') as pbar:
        while step < max_iters:
            optimizer.zero_grad()
            loss_accum = 0.0
            for _ in range(accumulation_steps):
                try:
                    imgs, targets = next(train_iter)
                except StopIteration:
                    train_iter = iter(train_loader)
                    imgs, targets = next(train_iter)
                imgs = [im.to(device) for im in imgs]
                targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
                lr_now = get_lr(step)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_now
                with torch.cuda.amp.autocast(enabled=(scaler is not None)):
                    loss_dict = model(imgs, targets)
                    loss = sum(loss_dict.values()) / accumulation_steps
                if scaler is not None:
                    scaler.scale(loss).backward()
                else:
                    loss.backward()
                loss_accum += loss.item()
            if scaler is not None:
                scaler.step(optimizer)
                scaler.update()
            else:
                optimizer.step()
            optimizer.zero_grad()
            losses_per_step.append(loss_accum)
            pbar.update(1)
            step += 1
            pbar.set_postfix(loss=loss_accum)
            if (step % val_every == 0):
                model.eval()
                with torch.no_grad():
                    maps = voc2007_evaluate(model, val_loader, device)
                val_maps.append((step, maps))
                model.train()
    return model, losses_per_step, val_maps

In [None]:
def show_segmentation_sample(
    img,
    gt_mask,
    pred_mask=None,
    class_colors=None,
    class_names=None,
    norm_mean=[0.485,0.456,0.406],
    norm_std=[0.229,0.224,0.225]
):
    import numpy as np
    if isinstance(img, torch.Tensor):
        img = img.detach().cpu()
        if img.ndim == 3:
            img = img.numpy()
        if img.shape[0] == 3:
            img = img.transpose(1,2,0)
    img = img * norm_std + norm_mean
    img = np.clip(img, 0, 1)
    if isinstance(gt_mask, torch.Tensor):
        gt_mask = gt_mask.detach().cpu().numpy()
    if pred_mask is not None and isinstance(pred_mask, torch.Tensor):
        pred_mask = pred_mask.detach().cpu().numpy()
    ncols = 3 if pred_mask is not None else 2
    plt.figure(figsize=(4*ncols,4))
    plt.subplot(1, ncols, 1)
    plt.imshow(img)
    plt.title('Image')
    plt.axis('off')
    plt.subplot(1, ncols, 2)
    plt.imshow(gt_mask, cmap='tab20', vmin=0, vmax=20)
    plt.title('GT')
    plt.axis('off')
    if pred_mask is not None:
        plt.subplot(1, ncols, 3)
        plt.imshow(pred_mask, cmap='tab20', vmin=0, vmax=20)
        plt.title('Prediction')
        plt.axis('off')
    plt.show()

def plot_pairwise_curves(results, metric_name='val_metric', title='Comparison', ylabel='mAP'):
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    for mode, vals in results.items():
        plt.plot(range(len(vals['train_loss'])), vals['train_loss'], label=f'{mode} train loss')
    plt.title(f"{title}\nTrain loss")
    plt.xlabel('Iteration')
    plt.ylabel('Loss')
    plt.grid()
    plt.legend()
    plt.subplot(1, 2, 2)
    for mode, vals in results.items():
        xs = vals.get(f'{metric_name}_x', range(len(vals[metric_name])))
        plt.plot(xs, vals[metric_name], marker='o', label=f'{mode}')
    plt.title(f"{title}\nTest {ylabel}")
    plt.xlabel('Step')
    plt.ylabel(ylabel)
    plt.legend()
    plt.tight_layout()
    plt.grid()
    plt.show()

def save_results_json(filename, results):
    with open(filename, 'w', encoding='utf-8') as f:
        json.dump(results, f, ensure_ascii=False, indent=4)

In [None]:
BASE_LR = 0.01
MOMENTUM = 0.9
LR_WARMUP_FACTOR = 0.333
LR_WARMUP_ITERS = 1000
STEP_LR_MILESTONES = [8000, 10000]
STEP_LR_GAMMA = 0.1

MAX_ITERS = 12000
EFFECTIVE_BATCH = 16
REAL_BATCH = 8
accumulation_steps = EFFECTIVE_BATCH // REAL_BATCH
VAL_EVERY = 1000

In [None]:
TASK = "segmentation"   # или "detection"

MODE = "supervised"   # "supervised" или "byol"

results = {}

if TASK == "segmentation":
    seg_train_12 = VOCSegDataset(root=ROOT, year='2012', image_set='train', transforms=SEG_TRAIN_TRANSFORM)
    seg_train = ConcatDataset([seg_train_12])
    seg_test = VOCSegDataset(root=ROOT, year='2012', image_set='val', transforms=SEG_TEST_TRANSFORM)
    seg_train_loader = DataLoader(seg_train, batch_size=REAL_BATCH, shuffle=True, num_workers=4, collate_fn=seg_collate_fn, drop_last=True)
    seg_test_loader = DataLoader(seg_test, batch_size=REAL_BATCH, shuffle=False, num_workers=4, collate_fn=seg_collate_fn)

    if MODE == "byol":
        backbone = smart_load_byol_checkpoint_to_resnet50(CHECKPOINT)
    elif MODE == "supervised":
        backbone = models.resnet50(weights="IMAGENET1K_V1")
        backbone.fc = nn.Identity()
    else:
        raise ValueError()
    backbone = backbone.to(DEVICE)
    seg_model = build_deeplab_with_backbone(backbone, num_classes=21)
    seg_model.to(DEVICE)
    seg_model, (seg_train_losses, seg_val_mious) = train_segmentation_iters(
        seg_model, seg_train_loader, seg_test_loader,
        DEVICE,
        max_iters=MAX_ITERS,
        accumulation_steps=accumulation_steps,
        get_lr=get_lr,
        val_every=VAL_EVERY,
    )
    mious_x = [x[0] for x in seg_val_mious]
    mious_y = [x[1] for x in seg_val_mious]
    results[MODE] = {
        'seg_train_losses': seg_train_losses,
        'seg_val_mious': mious_y,
        'seg_val_mious_x': mious_x,
    }
    save_results_json("segmentation_results.json", results)

    plot_pairwise_curves(
        {MODE: {
            'train_loss': results[MODE]['seg_train_losses'],
            'val_metric': results[MODE]['seg_val_mious'],
            'val_metric_x': results[MODE]['seg_val_mious_x'],
        }},
        metric_name='val_metric',
        title=f'Segmentation ({MODE}) loss and mIoU',
        ylabel='mIoU'
    )

    imgs, masks = next(iter(seg_test_loader))
    seg_model.eval()
    with torch.no_grad():
        out = seg_model(imgs.to(DEVICE))['out'].cpu()
        preds = out.argmax(1)
    for i in range(min(3, imgs.shape[0])):
        show_segmentation_sample(imgs[i], masks[i], preds[i])

elif TASK == "detection":
    det_train_07 = VOCDatasetForDet(root=ROOT, year='2007', image_set='trainval', transforms=DET_TRAIN_TRANSFORM)
    det_train_12 = VOCDatasetForDet(root=ROOT, year='2012', image_set='trainval', transforms=DET_TRAIN_TRANSFORM)
    det_train = ConcatDataset([det_train_07, det_train_12])
    det_test = VOCDatasetForDet(root=ROOT, year='2007', image_set='test', transforms=DET_TEST_TRANSFORM)
    det_train_loader = DataLoader(det_train, batch_size=REAL_BATCH, shuffle=True, num_workers=8, collate_fn=det_collate_fn, drop_last=True)
    det_test_loader = DataLoader(det_test, batch_size=1, shuffle=False, num_workers=4, collate_fn=det_collate_fn)
    if MODE == "byol":
        backbone_det = smart_load_byol_checkpoint_to_resnet50(CHECKPOINT)
    elif MODE == "supervised":
        backbone_det = models.resnet50(weights="IMAGENET1K_V1")
        backbone_det.fc = nn.Identity()
    else:
        raise ValueError()
    backbone_det = backbone_det.to(DEVICE)
    det_model = build_fasterrcnn_with_backbone(backbone_det, num_classes=21)
    det_model.to(DEVICE)
    det_model, det_train_losses, det_val_maps = train_detection_iters(
        det_model, det_train_loader, det_test_loader,
        device=DEVICE,
        max_iters=MAX_ITERS,
        accumulation_steps=accumulation_steps,
        get_lr=get_lr,
        val_every=VAL_EVERY
    )
    xm = [x[0] for x in det_val_maps]
    ym = [x[1][0] for x in det_val_maps]
    results[MODE] = {
        'det_train_losses': det_train_losses,
        'det_val_map': ym,
        'det_val_map_x': xm,
    }
    save_results_json("detection_results.json", results)
    plot_pairwise_curves(
        {MODE: {
            'train_loss': results[MODE]['det_train_losses'],
            'val_metric': results[MODE]['det_val_map'],
            'val_metric_x': results[MODE]['det_val_map_x'],
        }},
        metric_name='val_metric',
        title=f'Detection ({MODE}) loss and mAP@0.5',
        ylabel='mAP@0.5'
    )