This is th official PyTorch implementation for "Sampling Control for Imbalanced Calibration in Semi-Supervised Learning" at AAAI 2026.
Senmao Tian, Xiang Wei, Shunli Zhang.
Beijing Jiaotong University (BJTU)
Class imbalance remains a critical challenge in semi-supervised learning (SSL), especially when distributional mismatches between labeled and unlabeled data lead to biased classification. Although existing methods address this issue by adjusting logits based on the estimated class distribution of unlabeled data, they often handle model imbalance in a coarse-grained manner, conflating data imbalance with bias arising from varying class-specific learning difficulties. To address this issue, we propose a unified framework, SC-SSL, which suppresses model bias through decoupled sampling control. During training, we identify the key variables for sampling control under ideal conditions. By introducing a classifier with explicit expansion capability and adaptively adjusting sampling probabilities across different data distributions, SC-SSL mitigates feature-level imbalance for minority classes. In the inference phase, we further analyze the weight imbalance of the linear classifier and apply post-hoc sampling control with an optimization bias vector to directly calibrate the logits. Extensive experiments across various benchmark datasets and distribution settings validate the consistency and state-of-the-art performance of SC-SSL.
This repo is based on the public and widely-used codebase USB.
The core code implementation is located at semilearn/imb_algorithms/scssl.
I've also made some modifications to semilearn/nets/. Some implementations of backbone models may be different from USB, and this is to better align with some works in LTSSL.
To avoid redundancy, we have removed most of the config files. If you need to reproduce any results, you can customize the config according to the corresponding paper or find one in the USB codebase.
To install the required packages, you can create a conda environment:
conda create --name SCSSL python=3.10For pytorch, we highly recommend you to install through pip, for example:
pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118You can also select different version from: Pytorch Previous
For other packages, you can install them by:
pip install -r requirements.txtThe directory structure for datasets looks like:
./data
├── cifar-10
├── cifar-100
├── stl-10
├── imagenet32
└── imagenet64
All datasets except ImageNet can be automatically prepared. You can also see this for more details. We provide the preprocess scripts for ImageNet-127 at preprocess/imagenet127.
For example, if you want to use '0' GPU to run on CIFAR-10-LT with
CUDA_VISIBLE_DEVICES=0 python train.py --c config/classic_cv_imb/fixmatch_scssl/fixmatch_scssl_cifar10_lb500_100_ulb4000_-100_0.yaml
- Note: Although USB claims to support multi-GPU parallelism, we noticed that their implementation is rather rudimentary and has not been updated to align with PyTorch best practices. Moreover, our experiments show that USB’s multi-GPU setup does not improve training speed and leads to unbalanced memory usage. Therefore, we have modified parts of the original code to achieve more efficient single-GPU training. Except for training on the ImageNet-127 dataset, which may require 40GB or more of GPU memory, the models generally have low memory usage on other datasets.
The model will be automatically evaluated every 1024 iterations during training. After training, the last two lines in saved_models/classic_cv_imb/fixmatch_scssl_cifar10_lb500_100_ulb4000_-100_0/log.txt will tell you the best accuracy.
For example,
[2024-11-10 10:45:39,670 INFO] model saved: ./saved_models/classic_cv_imb/SCSSL/fixmatch_scssl_cifar10_lb500_100_ulb4000_-100_0/latest_model.pth
[2024-11-10 10:45:39,674 INFO] Model result - eval/best_acc : 0.8655
[2024-11-10 10:45:39,674 INFO] Model result - eval/best_it : 228351
Our code is based on the implementation of USB, CPE, CDMAD. We thank the authors for making their code available to the public.
