In [1]:
import torch
import torch.cuda.nccl

print(torch.cuda.nccl.version())

AttributeError: module 'torch._C' has no attribute '_nccl_version'

In [2]:
import torch

capi_vitl14_lvd = torch.hub.load('facebookresearch/capi:main', 'capi_vitl14_lvd')

# Simply call the models to encode an image
img = torch.zeros(1, 3, 224, 224)  # example img, replace with your stuff
global_repr, registers, feature_map = capi_vitl14_lvd(img)

Using cache found in C:\Users\axeld/.cache\torch\hub\facebookresearch_capi_main
Can't import sklearnex. If installed, that speeds up scikit-learn 10-100x


In [3]:
import gc
import itertools
import logging
import time
import datetime
import json
from collections import defaultdict, deque
from collections.abc import Callable, Iterable, Sequence
from functools import partial, reduce
from pathlib import Path
from typing import Any, Literal

import numpy as np
import sklearn.metrics, sklearn.linear_model
import torch
import torch.amp
import torch.distributed
from jaxtyping import Float, Int, Num
from torch import Tensor, nn

logger = logging.getLogger(__name__)

def accuracy(y_true, y_pred, ignore_labels: Sequence[int]):
    gt = y_true.flatten().cpu().numpy()
    pred = y_pred.flatten().cpu().numpy()
    mask = ~np.isin(gt, ignore_labels)
    return float(np.mean((gt[mask] == pred[mask]).astype(float)))


def mIoU(y_true, y_pred, ignore_labels: Sequence[int]):
    gt = y_true.flatten().cpu().numpy()
    pred = y_pred.flatten().cpu().numpy()
    mask = ~np.isin(gt, ignore_labels)
    return float(sklearn.metrics.jaccard_score(gt[mask], pred[mask], average="macro"))


metrics_dict = {
    "mIoU": mIoU,
    "acc": accuracy,
}

class MetricLogger:
    def __init__(self, delimiter: str = "  ", output_file: str | Path | None = None):
        self.meters = defaultdict(SmoothedValue)
        self.delimiter = delimiter
        if isinstance(output_file, str):
            output_file = Path(output_file)
        self.output_file = output_file

    def update(self, **kwargs):
        for k, v in kwargs.items():
            self.meters[k].update(v)

    def __getattr__(self, attr):
        if attr in self.meters:
            return self.meters[attr]
        if attr in self.__dict__:
            return self.__dict__[attr]
        raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")

    def __str__(self):
        loss_str = []
        for name, meter in self.meters.items():
            loss_str.append(f"{name}: {meter!s}")
        return self.delimiter.join(loss_str)

    def synchronize_between_processes(self):
        for meter in self.meters.values():
            meter.synchronize_between_processes()

    def add_meter(self, name, meter):
        self.meters[name] = meter

    def dump_in_output_file(self, iteration, iter_time, data_time):
        if self.output_file is None or torch.distributed.get_rank() != 0:
            return
        dict_to_dump = {
            "iteration": iteration,
            "iter_time": iter_time,
            "data_time": data_time,
        }
        dict_to_dump.update({k: v.median for k, v in self.meters.items()})
        with self.output_file.open("a") as f:
            f.write(json.dumps(dict_to_dump) + "\n")

    def log_every(self, iterable, print_freq, header=None, n_iterations=None, start_iteration=0):
        i = start_iteration
        if not header:
            header = ""
        start_time = time.time()
        end = time.time()
        iter_time = SmoothedValue(fmt="{avg:.6f}")
        data_time = SmoothedValue(fmt="{avg:.6f}")

        if n_iterations is None:
            n_iterations = len(iterable)

        space_fmt = ":" + str(len(str(n_iterations))) + "d"

        log_list = [
            header,
            "[{0" + space_fmt + "}/{1}]",
            "eta: {eta}",
            "{meters}",
            "time: {time}",
            "data: {data}",
        ]
        if torch.cuda.is_available():
            log_list += ["max mem: {memory:.0f}MB"]

        log_msg = self.delimiter.join(log_list)
        if i < n_iterations:
            for obj in iterable:
                data_time.update(time.time() - end)
                yield obj
                iter_time.update(time.time() - end)
                if i % print_freq == 0 or i == n_iterations - 1:
                    self.synchronize_between_processes()
                    self.dump_in_output_file(iteration=i, iter_time=iter_time.avg, data_time=data_time.avg)
                    eta_seconds = iter_time.global_avg * (n_iterations - i)
                    eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
                    if torch.cuda.is_available():
                        logger.info(
                            log_msg.format(
                                i,
                                n_iterations,
                                eta=eta_string,
                                meters=str(self),
                                time=str(iter_time),
                                data=str(data_time),
                                memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
                            ),
                        )
                    else:
                        logger.info(
                            log_msg.format(
                                i,
                                n_iterations,
                                eta=eta_string,
                                meters=str(self),
                                time=str(iter_time),
                                data=str(data_time),
                            ),
                        )
                i += 1
                end = time.time()
                if i >= n_iterations:
                    break
        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        logger.info(f"{header} Total time: {total_time_str} ({total_time / n_iterations:.6f} s / it)")

def to_tensor(x: Tensor | float | int) -> Tensor:
    if isinstance(x, Tensor):
        return x
    return torch.tensor(x)

class SmoothedValue:
    """Track a series of values and provide access to smoothed values over a
    window or the global series average.
    """

    def __init__(self, window_size=20, fmt=None):
        if fmt is None:
            fmt = "{median:.4f} ({global_avg:.4f})"
        self.window_size = window_size
        self.deque: deque[Tensor | float | int] = deque(maxlen=window_size)
        self.total: Tensor | float | int = 0.0
        self.count: int = 0
        self.fmt = fmt

    def update(self, value: Tensor | float | int):
        self.deque.append(value)
        self.count += 1
        self.total += value

    def synchronize_between_processes(self):
        """Distributed synchronization of the metric"""
        if not torch.distributed.is_initialized():
            return
        logger.debug("Synchronizing values")
        count = to_tensor(self.count).to(dtype=torch.float64, device="cuda").reshape(1)
        total = to_tensor(self.total).to(dtype=torch.float64, device="cuda").reshape(1)
        tensor_deque = torch.tensor(list(self.deque), dtype=torch.float64, device="cuda")
        t = torch.cat([count, total, tensor_deque], dim=0)
        torch.distributed.barrier()
        torch.distributed.all_reduce(t, op=torch.distributed.ReduceOp.AVG)
        self.count = int(t[0].cpu().item())
        self.total = t[1]
        self.deque = deque(list(t[2:]), maxlen=self.window_size)

    @property
    def median(self) -> float | int:
        d = torch.tensor(list(self.deque))
        return d.median().cpu().item()

    @property
    def avg(self) -> float | int:
        d = torch.tensor(list(self.deque), dtype=torch.float32)
        return d.mean().cpu().item()

    @property
    def global_avg(self) -> float | int:
        return to_tensor(self.total).cpu().item() / self.count

    @property
    def max(self) -> float | int:
        return torch.tensor(self.deque).max().cpu().item()

    @property
    def value(self) -> float | int:
        v = self.deque[-1]
        return to_tensor(v).cpu().item()

    def __str__(self):
        return self.fmt.format(
            median=self.median,
            avg=self.avg,
            global_avg=self.global_avg,
            max=self.max,
            value=self.value,
        )

class Classifier:
    hparam_grids: dict[str, Iterable[Any]]
    n_pixels_per_sample: int
    label_dtype: torch.dtype
    inference_bs: int
    train_set_subsampling: int
    ignore_labels: Sequence[int]

    def fit(self, features: Float[Tensor, "n d"], labels: Int[Tensor, "n l"]) -> None:
        self.unfit()
        if self.train_set_subsampling > 1:
            labels = labels[:: self.train_set_subsampling]
            features = features[:: self.train_set_subsampling]
        mask = torch.isin(labels.mode(dim=-1).values, torch.tensor(self.ignore_labels))
        self._fit(features[~mask], labels[~mask])

    def unfit(self) -> None:
        pass

    def upscale(self, labels: Int[Tensor, "n"]) -> Int[Tensor, "n l"]:  # noqa: F821
        """Convert from patch-level to pixel-level"""
        return labels[:, None].expand(-1, self.n_pixels_per_sample)

    def select_hparams(
        self,
        features_train: Float[Tensor, "nt d"],
        labels_train: Int[Tensor, "nt l"],
        features_val: Float[Tensor, "nv d"],
        labels_val: Int[Tensor, "nv l"],
        ignore_labels: Sequence[int],
        metric_name: str = "mIoU",
    ) -> dict[str, float]:
        hparam_names, grids = zip(*self.hparam_grids.items(), strict=False)
        hparam_grid = list(itertools.product(*grids))
        metrics = {}
        r = torch.distributed.get_rank()
        n = torch.distributed.get_world_size()
        rank_results: dict[int, float] = {}
        if len(hparam_grid) > 1:
            # split the grid among ranks
            for hparam_idx, hparam_set in list(enumerate(hparam_grid))[r::n]:
                logger.info(f"Rank {r} testing hparam set {hparam_idx}/{len(hparam_grid)}")
                gc.collect()
                torch.cuda.empty_cache()
                for k, v in zip(hparam_names, hparam_set, strict=False):
                    logger.info(f"Setting {k}={v}")
                    setattr(self, k, v)
                self.fit(features_train, labels_train)
                preds = self.predict(features_val)
                score = metrics_dict[metric_name](labels_val, preds, ignore_labels)
                rank_results[hparam_idx] = score
                logger.info(
                    "Tested "
                    + "_".join(f"{k}={v}" for k, v in zip(hparam_names, hparam_set, strict=False))
                    + f", {metric_name}={score:.3f}",
                )
                self.unfit()
                gc.collect()
                torch.cuda.empty_cache()
            torch.distributed.barrier()
            gathered_results = [{} for _ in range(torch.distributed.get_world_size())]
            torch.distributed.all_gather_object(gathered_results, rank_results)
            all_results = reduce(lambda x, y: {**x, **y}, gathered_results)
            best_hparam_idx, best_score = max(all_results.items(), key=lambda x: x[1])
            best_hparam_set = hparam_grid[best_hparam_idx]
            # log a bit
            for idx, score in all_results.items():
                hruid = f"{metric_name}_" + "_".join(
                    f"{k}={v}" for k, v in zip(hparam_names, hparam_grid[idx], strict=True)
                )
                metrics[hruid] = score
            logger.info(
                "Best hparam set: "
                + "_".join(f"{k}={v}" for k, v in zip(hparam_names, best_hparam_set, strict=True))
                + f" with {metric_name} {best_score:.3f}",
            )
        else:
            logger.info("Grid is length 1, skipping hparam search")
            best_hparam_set = hparam_grid[0]
        # set the new hparams
        for k, v in zip(hparam_names, best_hparam_set, strict=False):
            setattr(self, k, v)
        return metrics

    @torch.no_grad()
    def predict(self, features: Float[Tensor, "n d"]) -> Int[Tensor, "n l"]:
        predictions = torch.zeros(features.shape[0], self.n_pixels_per_sample, dtype=self.label_dtype, device="cpu")
        for i in MetricLogger().log_every(
            range(0, features.shape[0], self.inference_bs),
            print_freq=10,
            header=f"{type(self).__name__} inference",
        ):
            predictions[i : i + self.inference_bs] = self._predict_batch(features[i : i + self.inference_bs])
            gc.collect()
        return predictions

    def _fit(self, features: Float[Tensor, "n d"], labels: Int[Tensor, "n l"]) -> None:
        raise NotImplementedError

    def _predict_batch(self, features: Float[Tensor, "n d"]) -> Int[Tensor, "n l"]:
        raise NotImplementedError


class KNNClassifier(Classifier):
    num_neighbors: int
    distance: str

    def __init__(
        self,
        inference_bs: int = 1024,
        train_set_chunk_size: int | None = 262144,
        train_set_subsampling: int = 1,
        device: str = "cuda",
        dtype: str = "float32",
        num_neighbors: Sequence[int] = (1, 3, 10, 30),
        distance: Sequence[str] = ("cosine", "L2"),
        ignore_labels: Sequence[int] = (255,),
    ):
        super().__init__()
        self.device = torch.device(device)
        self.dtype = getattr(torch, dtype)
        self.inference_bs = inference_bs
        self.train_set_chunk_size = train_set_chunk_size
        self.train_set_subsampling = train_set_subsampling
        self.ignore_labels = ignore_labels
        # try to make this grid divisible by 8 if possible
        self.hparam_grids = {
            "num_neighbors": num_neighbors,
            "distance": distance,
        }

    def unfit(self):
        if hasattr(self, "train_X"):
            del self.train_X
        if hasattr(self, "train_y"):
            del self.train_y

    def _fit(self, features: Float[Tensor, "n d"], labels: Int[Tensor, "n l"]) -> None:
        self.train_X = features.to(self.dtype).to(self.device, non_blocking=True)
        self.train_y = labels.to(self.device, non_blocking=True)
        self.n_pixels_per_sample = labels.shape[-1]
        self.label_dtype = labels.dtype

    # @torch.compile
    def _cdist(self, a: Float[Tensor, "n d"], b: Float[Tensor, "m d"]) -> Float[Tensor, "n m"]:
        # WARN L1 and Linf are horribly slow
        if self.distance == "L2":
            return torch.cdist(a, b, p=2)
        if self.distance == "cosine":
            a = a / torch.norm(a, dim=-1)[:, None]
            b = b / torch.norm(b, dim=-1)[:, None]
            return 1 - a @ b.T
        if self.distance == "L1":
            return torch.cdist(a, b, p=1)
        if self.distance == "Linf":
            return torch.cdist(a, b, p=float("inf"))
        if self.distance == "inner_product":
            return -a @ b.T
        raise NotImplementedError

    @torch.compile(dynamic=True)
    def _find_closest_chunk(
        self,
        queries: Float[Tensor, "n d"],
        keys: Float[Tensor, "m d"],
        values: Num[Tensor, "m l"],
    ) -> tuple[Float[Tensor, "n k"], Num[Tensor, "n k l"]]:
        dists = self._cdist(queries, keys)
        # get top k closest neighbors
        k = min(self.num_neighbors, dists.shape[-1])
        distances, indices = torch.topk(dists, k, dim=-1, largest=False)
        closest_values = torch.gather(
            values,
            0,
            indices.flatten()[..., None].expand(indices.numel(), values.shape[1]),
        ).reshape(*indices.shape, *values.shape[1:])
        return distances, closest_values

    def _predict_batch(self, features: Float[Tensor, "n d"]) -> Int[Tensor, "n l"]:
        features = features.to(self.dtype).to(self.device, non_blocking=True)
        chunk_size = self.train_set_chunk_size
        if chunk_size is None:
            chunk_size = self.train_X.shape[0]
        aggregated_distances = []
        aggregated_labels = []
        for i in range(0, self.train_X.shape[0], chunk_size):
            distances, closest_values = self._find_closest_chunk(
                features,
                self.train_X[i : i + chunk_size].to(self.dtype).to(self.device),
                self.train_y[i : i + chunk_size].to(self.device),
            )
            aggregated_distances.append(distances.cpu())
            aggregated_labels.append(closest_values.cpu())
            del distances
            del closest_values
        _, aggregation_indices = torch.topk(
            torch.cat(aggregated_distances, dim=1),
            self.num_neighbors,
            dim=1,
            largest=False,
        )
        del aggregated_distances
        neighbor_labels = torch.gather(
            torch.cat(aggregated_labels, dim=1),
            1,
            aggregation_indices[..., None].expand(*aggregation_indices.shape, self.train_y.shape[1]),
        )
        del aggregated_labels
        # get the most common label
        return neighbor_labels.mode(dim=1).values


In [17]:
import einops
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import VisionDataset

def dump_metrics(results_dict: dict, results_path: Path, cfg: Any) -> None:
    if torch.distributed.get_rank() == 0:
        logger.info(f"Saving results to {results_path}")
        results_path.write_text(json.dumps({"results": results_dict, "config": cfg}))

def all_gather_and_flatten(tensor_rank: Tensor):
    tensor_all_ranks = torch.empty(
        torch.distributed.get_world_size(),
        *tensor_rank.shape,
        dtype=tensor_rank.dtype,
        device=tensor_rank.device,
    )
    torch.distributed.all_gather(list(tensor_all_ranks.unbind(0)), tensor_rank.contiguous())
    return tensor_all_ranks.flatten(end_dim=1)

class DatasetWithEnumeratedTargets(VisionDataset):
    """If pad_dataset is set, pads based on torch's DistributedSampler implementation, which
    with drop_last=False pads the last batch to be a multiple of the world size.
    https://github.com/pytorch/pytorch/blob/main/torch/utils/data/distributed.py#L91
    """

    def __init__(self, dataset: VisionDataset, pad_dataset: bool = False, num_replicas: int | None = None):
        self._dataset = dataset
        self._size = len(self._dataset)
        self._padded_size = self._size
        self._pad_dataset = pad_dataset
        if self._pad_dataset:
            assert num_replicas is not None, "num_replicas should be set if pad_dataset is True"
            self._padded_size = num_replicas * ((len(dataset) + num_replicas - 1) // num_replicas)

    def __getitem__(self, index: int) -> tuple[Any, tuple[int, int]]:
        image, target = self._dataset[index % self._size]
        if index >= self._size:
            assert self._pad_dataset
            return image, (-1, target)
        target = index if target is None else target
        return image, (index, target)

    def __len__(self) -> int:
        return self._padded_size
    

import torch
from torch.utils.data import IterableDataset
from torchvision.datasets.vision import VisionDataset
from typing import Any, Iterator


class IterableDatasetWithEnumeratedTargets(IterableDataset):
    """Iterable version of DatasetWithEnumeratedTargets.
    Pads dataset if pad_dataset is True, ensuring the last batch is a multiple of world size.
    """

    def __init__(self, dataset: VisionDataset, pad_dataset: bool = False, num_replicas: int | None = None):
        super().__init__()
        self._dataset = dataset
        self._size = len(self._dataset)
        self._pad_dataset = pad_dataset
        self._padded_size = self._size

        if self._pad_dataset:
            assert num_replicas is not None, "num_replicas should be set if pad_dataset is True"
            self._padded_size = num_replicas * ((len(dataset) + num_replicas - 1) // num_replicas)

    def __iter__(self) -> Iterator[tuple[Any, tuple[int, int]]]:
        index = 0
        while index < self._padded_size:
            image, target = self._dataset[index % self._size]
            if index >= self._size:
                assert self._pad_dataset
                yield image, (-1, target)  # Padding case
            else:
                target = index if target is None else target
                yield image, (index, target)
            index += 1
    
    def __len__(self) -> int:
        return self._padded_size
    
    
def make_data_loader(
    *,
    dataset,
    batch_size: int=4,
    num_workers: int=1,
    shuffle: bool = True,
    seed: int = 0,
    sampler_advance: int = 0,
    drop_last: bool = True,
    persistent_workers: bool = False,
    collate_fn: Callable[[list], Any] | None = None,
) -> DataLoader:
    logger.info("Using PyTorch data loader")
    if isinstance(dataset, torch.utils.data.IterableDataset):
        logger.info("Dataset is iterable, not using a sampler")
        sampler = None
    elif False:
        logger.info("Using DistributedSampler")
        sampler = torch.utils.data.DistributedSampler(
            dataset,
            shuffle=shuffle,
            seed=seed,
            drop_last=drop_last,
        )
    data_loader = DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=drop_last,
        persistent_workers=persistent_workers,
        collate_fn=collate_fn,
    )
    logger.info(f"batch size: {batch_size}")
    try:
        logger.info(f"# of batches: {len(data_loader):,d}")
    except TypeError:  # data loader has no length
        logger.info("infinite data loader")
    return data_loader


@torch.inference_mode()
def extract_features(
    model: nn.Module,
    dts_len: int,
    dataset: VisionDataset,
    *,
    gather_on_cpu: bool = False,
) -> tuple[Float[Tensor, "len(dataset) ih iw d"], Int[Tensor, "len(dataset) ih iw ps**2"]]:
    """Featurize the dataset."""
    bs = ih = iw = dim = 0
    ps = model.patch_size
    dataset_with_enumerated_targets = DatasetWithEnumeratedTargets(dataset)
    iterabledataset_with_enumerated_targets = IterableDatasetWithEnumeratedTargets(dataset_with_enumerated_targets)
    data_loader = make_data_loader(
        dataset=iterabledataset_with_enumerated_targets,
        shuffle=False,
        drop_last=False,
    )
    gather_device = torch.device("cpu") if gather_on_cpu else torch.device("cuda")
    features, all_labels = None, None
    for samples, (index, labels_rank) in MetricLogger().log_every(data_loader, 10, header="Extracting features"):
        samples = samples.cuda(non_blocking=True)
        index = index.cuda(non_blocking=True)
        _, _, featmap = model(samples)
        bs, ih, iw, dim = featmap.shape
        features_rank = featmap.reshape(bs * ih * iw, dim).float()
        if len(labels_rank.shape) == 3:
            # segmentation
            labels_rank = einops.rearrange(
                labels_rank,
                "bs (ih ph) (iw pw) -> (bs ih iw) (ph pw)",
                ih=ih,
                iw=iw,
                ph=ps,
                pw=ps,
            )
        elif len(labels_rank.shape) == 1:
            # classification
            labels_rank = labels_rank[:, None, None].expand(bs, ih * iw, ps**2).flatten(0, 1)
        else:
            raise NotImplementedError
        labels_rank = labels_rank.cuda(non_blocking=True)
        sample_count = dts_len * ih * iw

        # init storage feature matrix
        if features is None or all_labels is None:
            features = torch.zeros(
                sample_count,
                dim,
                device=gather_device,
                dtype=features_rank.dtype,
            )
            labels_shape = list(labels_rank.shape)
            labels_shape[0] = sample_count
            all_labels = torch.full(
                labels_shape,
                fill_value=-1,
                device=gather_device,
                dtype=labels_rank.dtype,
            )
            logger.info(f"Storing features into tensor of shape {features.shape}")

        # share indexes, features and labels between processes
        pos_rank = torch.arange(ih * iw, device=features_rank.device)[None, :].expand(bs, ih * iw).flatten()
        index = index[:, None].expand(bs, ih * iw).flatten()
        index = index * ih * iw + pos_rank.to(index.device).to(index.dtype)
        index_all = all_gather_and_flatten(index).to(gather_device)
        features_all_ranks = all_gather_and_flatten(features_rank).to(gather_device)
        labels_all_ranks = all_gather_and_flatten(labels_rank).to(gather_device)

        # update storage feature matrix
        if len(index_all) > 0:
            features.index_copy_(0, index_all, features_all_ranks)
            all_labels.index_copy_(0, index_all, labels_all_ranks)

    del data_loader
    torch.cuda.empty_cache()
    gc.collect()
    assert features is not None and all_labels is not None
    return features.reshape(dts_len, ih, iw, dim), all_labels.reshape(dts_len, ih, iw, ps**2)


In [18]:
import glob
import os

train_img_dir = "./images/training_data"  # Example
train_mask_dir = "./images/training_gt"

image_paths = sorted(glob.glob(os.path.join(train_img_dir, "*.png")))
print("Found image files:", image_paths)

Found image files: []


In [19]:
import functools
import socket
import torch.distributed as dist

def _get_available_port() -> int:
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        # A "" host address means INADDR_ANY i.e. binding to all interfaces.
        # Note this is not compatible with IPv6.
        s.bind(("", 0))
        return s.getsockname()[1]

@functools.lru_cache
def enable_distributed(
    *,
    set_cuda_current_device: bool = True,
    backend: str = "nccl",
    nccl_async_error_handling: bool = False,
):
    if "MASTER_ADDR" not in os.environ:
        # Environment is not set, assume single gpu
        logger.info("Dist init for single-gpu training")
        os.environ["MASTER_ADDR"] = "127.0.0.1"
        os.environ["MASTER_PORT"] = str(_get_available_port())
        os.environ["RANK"] = "0"
        os.environ["WORLD_SIZE"] = "1"
        os.environ["LOCAL_RANK"] = "0"
        os.environ["LOCAL_WORLD_SIZE"] = "1"
        os.environ["USE_LIBUV"] = "0"
    else:
        logger.info("Dist init from preset env")
    if set_cuda_current_device:
        torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
    if nccl_async_error_handling:
        os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1"
        os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"  # "TORCH_" prefix added in PyTorch 2.2
    os.environ["USE_LIBUV"] = "0"
    dist.init_process_group(backend=backend, timeout=datetime.timedelta(seconds=30))
    logger.info(f"{os.environ['LOCAL_RANK']=}")
    dist.barrier(device_ids=[int(os.environ["LOCAL_RANK"])])

enable_distributed()

RuntimeError: Distributed package doesn't have NCCL built in

In [20]:
torch.cuda.nccl.is_available(torch.randn(1).cuda())
torch.cuda.nccl.version()

AttributeError: module 'torch._C' has no attribute '_nccl_version'

In [21]:
import os
import glob
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import v2 as T

import segmentation_models_pytorch as smp


# -------------------
# 1. Custom Dataset
# -------------------
class ISICSegmentationDataset(Dataset):
    """
    Expects:
      - image_dir: folder with images like ISIC_XXXXXXX.jpg
      - mask_dir:  folder with corresponding masks like ISIC_XXXXXXX_Segmentation.png
    """
    def __init__(self, image_dir, mask_dir, transform=None, target_size=(512, 512)):
        super().__init__()
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.target_size = target_size
        
        # Look for .jpg images
        self.image_paths = sorted(glob.glob(os.path.join(image_dir, "*.jpg")))
        if len(self.image_paths) == 0:
            print(f"[Warning] No .jpg images found in {image_dir}")
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        # Get image path and corresponding mask path
        img_path = self.image_paths[idx]
        filename = os.path.basename(img_path)  # e.g. "ISIC_0000000.jpg"
        mask_name = filename.replace(".jpg", "_Segmentation.png")
        mask_path = os.path.join(self.mask_dir, mask_name)
        
        # Open images
        image = Image.open(img_path).convert("RGB")
        mask  = Image.open(mask_path).convert("L")  # single channel for mask
        
        # Resize both image and mask to target_size
        image = image.resize(self.target_size, Image.BILINEAR)
        mask  = mask.resize(self.target_size, Image.NEAREST)  # NEAREST preserves mask values
        
        # Convert to numpy arrays and scale
        image = np.array(image, dtype=np.float32) / 255.0
        mask  = np.array(mask, dtype=np.float32) / 255.0
        
        # Convert shape from H x W x C to C x H x W
        image = np.transpose(image, (2, 0, 1))
        mask  = np.expand_dims(mask, axis=0)
        
        # Optionally apply further transforms if provided
        if self.transform:
            image, mask = self.transform(image, mask)
        
        return torch.tensor(image, dtype=torch.float32), torch.tensor(mask, dtype=torch.float32)


# ----------------------------
# 2. Create Datasets & Loaders
# ----------------------------
from torch.utils.data import random_split

def get_train_val_dataloaders(train_img_dir, train_mask_dir, batch_size=4, val_split=0.2):
    full_dataset = ISICSegmentationDataset(train_img_dir, train_mask_dir)
    total_samples = len(full_dataset)
    val_size = int(total_samples * val_split)
    train_size = total_samples - val_size
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, val_loader


# ------------------------
# 3. Define the Model
# ------------------------
def create_model(num_classes=1):
    # Example: U-Net with a ResNet34 encoder
    # For a single class segmentation, set classes=1 and activation=None (we'll handle it in the loss)
    model = smp.Unet(
        encoder_name="resnet34",
        encoder_weights="imagenet",
        in_channels=3,
        classes=num_classes
    )
    return model


# -----------------------
# 4. Training Loop
# -----------------------
def train_one_epoch(model, loader, loss_fn, optimizer, device="cuda"):
    model.train()
    epoch_loss = 0
    
    for images, masks in loader:
        images = images.to(device)
        masks  = masks.to(device)
        
        # Forward
        outputs = model(images)
        
        # Calculate loss
        loss = loss_fn(outputs, masks)
        
        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    return epoch_loss / len(loader)


def evaluate(model, loader, loss_fn, device="cuda"):
    model.eval()
    eval_loss = 0
    
    with torch.no_grad():
        for images, masks in loader:
            images = images.to(device)
            masks  = masks.to(device)
            
            outputs = model(images)
            loss = loss_fn(outputs, masks)
            eval_loss += loss.item()
    
    return eval_loss / len(loader)


# --------------------
# 5. Putting it all together
# --------------------
def main():
    # Directories (update these paths with your own)
    train_img_dir  = "images/training_data"
    train_mask_dir = "images/training_gt"
    test_img_dir   = "images/test_data"
    test_mask_dir  = "images/test_gt"
    
    # Hyperparameters
    batch_size = 4
    lr = 1e-4
    num_epochs = 10
    val_split = 0.2  # 20% of training data for validation
    
    # Device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Get data loaders: split the training dataset into train and val
    model = torch.hub.load('facebookresearch/capi:main', 'capi_vitl14_lvd')
    train_loader, val_loader = get_train_val_dataloaders(train_img_dir, train_mask_dir, batch_size, val_split)
    train_features, train_labels = extract_features(model, len(train_img_dir), train_loader)
    train_features = train_features.flatten(0, -2)
    train_labels = train_labels.flatten(0, -2)
    val_features, val_labels = extract_features(model, int(len(train_img_dir)*val_split), val_loader)
    val_features = val_features.flatten(0, -2)
    val_labels = val_labels.flatten(0, -2)
    test_dataset = ISICSegmentationDataset(test_img_dir, test_mask_dir)
    test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    test_features, test_labels = extract_features(model, len(test_img_dir), test_loader)
    test_features = test_features.flatten(0, -2)
    test_labels = test_labels.flatten(0, -2)
    model.cpu()
    for k in list(model._parameters.keys()):
            del model._parameters[k]
    for k in list(model._modules.keys()):
            del model._modules[k]
    classifier = KNNClassifier(ignore_labels=(255,))
    hparam_metrics = classifier.select_hparams(
                train_features,
                train_labels,
                val_features,
                val_labels,
                (255,),
            )
    results_dict = {}
    dump_predictions = True
    dump_classifier = True
    classifier_name = "kNN"
    output_dir = ""
    for k, v in hparam_metrics.items():
                results_dict[f"hparam_fitting.kNN.{k}"] = v
    if torch.distributed.get_rank() == 0:
                classifier.fit(
                    torch.cat([train_features, val_features]),
                    torch.cat([train_labels, val_labels]),
                )
                preds = classifier.predict(test_features)
                logger.info(f"Predictions shape: {preds.shape}")
    if dump_predictions:
                    torch.save(preds, Path(output_dir) / f"preds_{classifier_name}.pth")
    for metric_name, metric in metrics_dict.items():
                    result_name = f"labels_{classifier_name}_{metric_name}"
                    results_dict[result_name] = float(metric(test_labels, preds, (255,)))
                    logger.info(f"{result_name}: {results_dict[result_name]:.4g}")
    if dump_classifier and hasattr(classifier, "estimator"):
                    torch.save(
                        {
                            "coef_": torch.tensor(classifier.estimator.coef_),
                            "intercept_": torch.tensor(classifier.estimator.intercept_),
                        },
                        Path(output_dir) / "classifier.pth",
                    )
            # dump partial results
    metric_dumper = partial(dump_metrics, results_path="", cfg={})
    metric_dumper(results_dict)
    torch.distributed.barrier()
    del classifier
    gc.collect()
    torch.cuda.empty_cache()
    
    # Define loss & optimizer
    loss_fn = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    # Training loop
    """for epoch in range(num_epochs):
        train_loss = train_one_epoch(model, train_loader, loss_fn, optimizer, device)
        val_loss   = evaluate(model, val_loader, loss_fn, device)
        
        print(f"[Epoch {epoch+1}/{num_epochs}]",
              f"Train Loss: {train_loss:.4f} |",
              f"Val Loss: {val_loss:.4f}")
    
    # Save the fine-tuned model
    torch.save(model.state_dict(), "isic_unet_resnet34_finetuned.pth")
    test_loss = evaluate(model, test_loader, loss_fn, device)
    print(f"Test Loss: {test_loss:.4f}")
    print("Model saved.")"""


if __name__ == "__main__":
    main()

Using cache found in C:\Users\axeld/.cache\torch\hub\facebookresearch_capi_main


RuntimeError: DataLoader worker (pid(s) 17432) exited unexpectedly