# FixMatch and ABC Training Comparison

This notebook runs and compares the training of the original FixMatch algorithm and the Auxiliary Balanced Classifier (ABC) enhanced version. We'll train both models on the CIFAR-10 & CIFAR-100 datasets and compare their performance.

---

## Table of Contents

1. [Setup](#Setup)
2. [Imports and Seed Initialization](#Imports-and-Seed-Initialization)
3. [Dataset Preparation](#Dataset-Preparation)
4. [Model Definitions](#Model-Definitions)
5. [Training Functions](#Training-Functions)
6. [Training Original FixMatch Model](#Training-Original-FixMatch-Model)
7. [Training ABC Model](#Training-ABC-Model)
8. [Results Comparison](#Results-Comparison)
9. [Conclusion](#Conclusion)

---



In [None]:
import os
import sys
import random
import numpy as np
import matplotlib.pyplot as plt
import torch
project_root = os.path.abspath('..')  # Parent directory
sys.path.append(project_root)

# Import your modules
from main import set_seed
from experiments import train_fixmatch, train_abc

ModuleNotFoundError: No module named 'torch'

In [None]:
set_seed(42)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

In [None]:



original_argv = sys.argv.copy()

sys.argv = [
    'main.py',
    '--experiment', 'fixmatch',
    '--dataset', 'cifar10',
    '--num-labeled', '4000',
    '--batch-size', '64',
    '--mu', '7',
    '--epochs', '100',  # Adjust as needed
    '--lr', '0.03',
    '--threshold', '0.95',
    '--seed', '42',
    '--device', device
]

import main

print("Starting FixMatch Training:")
main.main()

sys.argv = original_argv



In [None]:


original_argv = sys.argv.copy()

sys.argv = [
    'main.py',
    '--experiment', 'abc',
    '--dataset', 'cifar10',
    '--num-labeled', '4000',
    '--batch-size', '64',
    '--mu', '7',
    '--epochs', '100',
    '--lr', '0.03',
    '--alpha', '1.0',
    '--lambda-u', '1.0',
    '--threshold', '0.95',
    '--seed', '42',
    '--device', device
]
import importlib
importlib.reload(main)

print("\nStarting ABC Training:")
main.main()
sys.argv = original_argv

