In [1]:
import os, time, numpy as np
from typing import List, Tuple, Dict

import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
from PIL import Image
import open_clip

# Repro & device
def set_seed(seed=42):
	torch.manual_seed(seed); torch.cuda.manual_seed_all(seed); np.random.seed(seed)

def device_auto():
	if torch.cuda.is_available(): return "cuda"
	if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): return "mps"
	return "cpu"

device = device_auto()
set_seed(42)
print("Device:", device)

# Collate: keep PILs as a list (avoid default_collate error)
def collate_src(batch):
	imgs, y = zip(*batch)               # imgs: tuple of PIL
	return list(imgs), torch.tensor(y, dtype=torch.long)

def collate_imgs(batch):
	return list(batch)                  # list of PIL


  from .autonotebook import tqdm as notebook_tqdm


Device: mps


## 1) Config

In [2]:
# --- Paths ---
CIFAR10C_DIR = "./data/cifar/CIFAR-10-C"   # folder with *.npy + labels.npy (download from the official release)

# --- Training hparams ---
EPOCHS = 10          # try 10–30
BATCH_SIZE = 128
LR = 1e-6            # small LR (vision encoder only)
WEIGHT_DECAY = 1e-4
TAU = 0.07           # CLIP temperature (used on image–text sims)
PL_CONF_THRESH = 0.4 # pseudo-label confidence threshold
LAMBDA_DCP = 0.5     # weight on target debiased consistency
L2_DELTA = 0.0       # (optional) L2 on correction—unused here

# --- DataLoader settings ---
NUM_WORKERS = 0 if device == "mps" else 2
PIN_MEMORY  = (device == "cuda")

# --- Domain-aware prompt words ---
SRC_DOMAIN = "natural"
TGT_DOMAIN = "corrupted"

# CIFAR-10 classnames (order matches torchvision targets)
classnames = ["airplane","automobile","bird","cat","deer","dog","frog","horse","ship","truck"]

# CIFAR-10-C corruption list
CORRUPTIONS = [
	"gaussian_noise","shot_noise","impulse_noise","defocus_blur","glass_blur",
	"motion_blur","zoom_blur","snow","frost","fog","brightness","contrast",
	"elastic_transform","pixelate","jpeg_compression"
]


## 2) Model and prompt encoder

In [3]:
# Load CLIP (ViT-B/32). Freeze text encoder.
clip_model, _, clip_preprocess = open_clip.create_model_and_transforms("ViT-B-32", pretrained="openai")
clip_model = clip_model.to(device)
for p in clip_model.transformer.parameters():  # freeze text encoder
	p.requires_grad_(False)

tokenizer = open_clip.get_tokenizer("ViT-B-32")

@torch.no_grad()
def build_text_matrix(classnames: List[str], domain_word: str):
	prompts = [f"a {domain_word} photo of a {c}" for c in classnames]
	tok = tokenizer(prompts).to(device)
	t = clip_model.encode_text(tok)
	return t / t.norm(dim=-1, keepdim=True)   # [C,d]

E_text_src = build_text_matrix(classnames, SRC_DOMAIN)
E_text_tgt = build_text_matrix(classnames, TGT_DOMAIN)

# We’ll reuse CLIP’s mean/std for normalization after our augmentations
MEAN = clip_preprocess.transforms[-1].mean
STD  = clip_preprocess.transforms[-1].std
to_clip = transforms.Compose([transforms.Resize(224), transforms.ToTensor(),
							  transforms.Normalize(mean=MEAN, std=STD)])





## 3) Datasets and Loaders

In [4]:
# CIFAR-10 source (train, labeled); CIFAR-10 test (clean eval)
src_train_base = datasets.CIFAR10(root="./data", train=True,  download=True)
src_test_base  = datasets.CIFAR10(root="./data", train=False, download=True)

src_train_loader = DataLoader(src_train_base, batch_size=BATCH_SIZE, shuffle=True,
							  num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, collate_fn=collate_src)
src_test_loader  = DataLoader(src_test_base,  batch_size=BATCH_SIZE, shuffle=False,
							  num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, collate_fn=collate_src)

# CIFAR-10-C unlabeled target (all corruptions, severities 1–5) — return ONLY PIL images
class CIFAR10C_Unlabeled(Dataset):
	def __init__(self, root, corruptions, severities=(1,2,3,4,5)):
		self.root = root
		self.corruptions = list(corruptions)
		self.severities = list(severities)
		# build (corr, sev, idx) index
		self.index = []
		# length per severity per corruption is 10k (based on standard release)
		for c in self.corruptions:
			x = np.load(os.path.join(root, f"{c}.npy"))  # shape [50000, 32, 32, 3]
			for s in self.severities:
				start = (s-1)*10000; end = s*10000
				length = end - start
				self.index.extend([(c, s, i) for i in range(start, end)])
			# cache array handle for quick access
		self._cache = {}  # corruption -> numpy array
	def __len__(self): return len(self.index)
	def __getitem__(self, i):
		c, s, idx = self.index[i]
		if c not in self._cache:
			self._cache[c] = np.load(os.path.join(self.root, f"{c}.npy"))
		arr = self._cache[c][idx]  # [32,32,3] uint8
		return Image.fromarray(arr)

tgt_train_base = CIFAR10C_Unlabeled(CIFAR10C_DIR, CORRUPTIONS, severities=(1,2,3,4,5))
tgt_train_loader = DataLoader(tgt_train_base, batch_size=BATCH_SIZE, shuffle=True,
							  num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, collate_fn=collate_imgs)

# For evaluation on CIFAR-10-C with labels
class CIFAR10C_Labeled(Dataset):
	def __init__(self, root, corruption, severity):
		self.X = np.load(os.path.join(root, f"{corruption}.npy"))     # [50000,32,32,3]
		self.Y = np.load(os.path.join(root, "labels.npy")).astype(int) # [50000]
		start = (severity-1)*10000; end = severity*10000
		self.X = self.X[start:end]; self.Y = self.Y[start:end]
	def __len__(self): return len(self.Y)
	def __getitem__(self, i):
		return Image.fromarray(self.X[i]), int(self.Y[i])

def make_c10c_eval_loader(corr, sev, bs=BATCH_SIZE):
	ds = CIFAR10C_Labeled(CIFAR10C_DIR, corr, sev)
	return DataLoader(ds, batch_size=bs, shuffle=False,
					  num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, collate_fn=collate_src)


## 4) Augmentations (weak/strong)

In [5]:
# Weak/Strong views (then CLIP-normalize)
aug_weak = transforms.Compose([
	transforms.RandomResizedCrop(224, scale=(0.8,1.0)),
	transforms.RandomHorizontalFlip(p=0.5),
])
aug_strong = transforms.Compose([
	transforms.RandomResizedCrop(224, scale=(0.5,1.0)),
	transforms.RandAugment(num_ops=2, magnitude=9),
	transforms.RandomHorizontalFlip(p=0.5),
])

@torch.no_grad()
def encode_images(px):
	z = clip_model.encode_image(px)
	return z / z.norm(dim=-1, keepdim=True)

def logits_from_image_and_text(z_img, E_text):
	logit_scale = clip_model.logit_scale.exp()
	return logit_scale * (z_img @ E_text.t())

def symmetric_clip_loss(z_img, z_txt):
	logit_scale = clip_model.logit_scale.exp()
	sim_i2t = logit_scale * (z_img @ z_txt.t())
	sim_t2i = sim_i2t.t()
	targets = torch.arange(z_img.size(0), device=z_img.device)
	return 0.5*(F.cross_entropy(sim_i2t, targets) + F.cross_entropy(sim_t2i, targets))


## 5) CFM/DCM (source & target) and PAD debiasor

In [6]:
def cfm_and_dcm_for_batch(imgs_PIL):
	"""
	Returns:
	  lambda_cfm (float in [0,1]), momentum_dcm (float), z_w, z_s, z_o
	"""
	# Build three views
	px_w = torch.stack([to_clip(aug_weak(img))  for img in imgs_PIL]).to(device)
	px_s = torch.stack([to_clip(aug_strong(img)) for img in imgs_PIL]).to(device)
	px_o = torch.stack([to_clip(transforms.Resize(224)(img)) for img in imgs_PIL]).to(device)

	with torch.no_grad():
		z_o = encode_images(px_o)
	z_w = encode_images(px_w)
	z_s = encode_images(px_s)

	d_wo = (z_w - z_o).pow(2).sum(dim=-1)
	d_so = (z_s - z_o).pow(2).sum(dim=-1)
	d_ws = (z_w - z_s).pow(2).sum(dim=-1)

	# Simple bounded proxy for "forgetting"
	lam = 1.0 - ((d_wo + d_so + d_ws).mean() / 6.0) * 2.0
	lam = float(torch.clamp(lam, 0.0, 1.0))

	m = float(F.cosine_similarity(z_w, z_s, dim=-1).mean().clamp(-1,1))
	return lam, m, z_w, z_s, z_o

class PseudoLabelDebiasor:
	def __init__(self, num_classes):
		self.C = num_classes
		self.p_prime = torch.full((self.C,), 1.0/self.C, device=device)
	def step(self, probs_batch, m_t):
		batch_mean = probs_batch.mean(dim=0).detach()
		self.p_prime = (m_t * self.p_prime + (1.0 - m_t) * batch_mean).clamp_(1e-6, 1.0)
		self.p_prime = self.p_prime / self.p_prime.sum()
	def debias(self, q, lambda_t):
		q_prime = q - lambda_t * (self.p_prime + 1e-6).log().unsqueeze(0)
		return F.softmax(q_prime, dim=-1)  # renormalize


## 6) Training loop

In [7]:
# Train only the vision encoder (keep text frozen)
params = [p for n,p in clip_model.named_parameters() if p.requires_grad and "transformer" not in n]
opt = torch.optim.AdamW(params, lr=LR, weight_decay=WEIGHT_DECAY)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=EPOCHS*max(1, len(src_train_loader)))

lambda_s_run = 0.0
lambda_t_run = 0.0
pad = PseudoLabelDebiasor(num_classes=10)

@torch.no_grad()
def zero_shot_probs(z_img, E_text):
	logits = logits_from_image_and_text(z_img, E_text)
	return F.softmax(logits, dim=-1)

def train_epoch():
	global lambda_s_run, lambda_t_run
	clip_model.train()
	it_src = iter(src_train_loader)
	it_tgt = iter(tgt_train_loader)
	steps = min(len(src_train_loader), len(tgt_train_loader))
	sup_meter, unsup_meter = 0.0, 0.0
	print()
	for st in range(steps):
		print(f"Step {st+1}/{steps}", end="\r")
		# ===== Source (labeled CIFAR-10) =====
		try:
			imgs_s, y = next(it_src)
		except StopIteration:
			it_src = iter(src_train_loader); imgs_s, y = next(it_src)
		y = y.to(device)

		lam_s, m_s, z_w_s, z_s_s, z_o_s = cfm_and_dcm_for_batch(imgs_s)
		lambda_s_run = m_s * lam_s + (1 - m_s) * lambda_s_run
		lam_sup = float(lambda_s_run)

		# Per-sample text embeddings for gold labels (domain-aware)
		with torch.no_grad():
			txt_prompts = [f"a {SRC_DOMAIN} photo of a {classnames[int(yi)]}" for yi in y]
			tok = tokenizer(txt_prompts).to(device)
			z_txt = clip_model.encode_text(tok); z_txt = z_txt / z_txt.norm(dim=-1, keepdim=True)
		# supervised uses weak-view images
		px_sup = torch.stack([to_clip(aug_weak(img)) for img in imgs_s]).to(device)
		z_img_sup = encode_images(px_sup)
		L_sup = symmetric_clip_loss(z_img_sup, z_txt)

		# ===== Target (unlabeled CIFAR-10-C) =====
		try:
			imgs_t = next(it_tgt)
		except StopIteration:
			it_tgt = iter(tgt_train_loader); imgs_t = next(it_tgt)

		lam_t, m_t, z_w_t, z_s_t, z_o_t = cfm_and_dcm_for_batch(imgs_t)
		lambda_t_run = m_t * lam_t + (1 - m_t) * lambda_t_run
		lam_debias = float(lambda_t_run)

		q_weak = zero_shot_probs(z_w_t, E_text_tgt)  # [B,10]
		pad.step(q_weak, m_t)                        # update running class prior
		q_deb = pad.debias(q_weak, lam_debias)       # debiased soft labels

		# consistency on strong-view predictions (masked by confidence)
		logits_strong = logits_from_image_and_text(z_s_t, E_text_tgt)
		logp_strong = F.log_softmax(logits_strong, dim=-1)
		conf = q_deb.max(dim=1).values
		mask = (conf >= PL_CONF_THRESH).float()
		if mask.sum() > 0:
			loss_vec = -(q_deb * logp_strong).sum(dim=1)
			L_dcp = (loss_vec * mask).sum() / (mask.sum() + 1e-6)
		else:
			L_dcp = torch.tensor(0.0, device=device)

		# ===== Combine & step =====
		loss = lam_sup * L_sup + LAMBDA_DCP * L_dcp
		opt.zero_grad(set_to_none=True)
		loss.backward()
		opt.step(); sched.step()

		sup_meter  += L_sup.item()
		unsup_meter += L_dcp.item()

	print()

	return sup_meter/steps, unsup_meter/steps

# (Optional) zero-shot baselines before training
@torch.no_grad()
def eval_clean(loader):
	clip_model.eval(); correct=total=0
	for imgs, y in loader:
		y = y.to(device)
		px = torch.stack([to_clip(img) for img in imgs]).to(device)
		z  = encode_images(px)
		logits = logits_from_image_and_text(z, E_text_src)  # clean domain prompts
		pred = logits.argmax(1)
		correct += (pred==y).sum().item(); total += y.size(0)
	return correct/max(total,1)

@torch.no_grad()
def eval_c10c_mean():
	clip_model.eval(); accs=[]
	for corr in CORRUPTIONS:
		print(f"Evaluating {corr}...")
		corr_acc=[]
		for sev in [1,2,3,4,5]:
			print(f"Evaluating {corr} at severity {sev}...")
			loader = make_c10c_eval_loader(corr, sev)
			correct=total=0
			for imgs, y in loader:
				y = y.to(device)
				px = torch.stack([to_clip(img) for img in imgs]).to(device)
				z  = encode_images(px)
				logits = logits_from_image_and_text(z, E_text_tgt)  # target domain prompts
				pred = logits.argmax(1)
				correct += (pred==y).sum().item(); total += y.size(0)
			corr_acc.append(correct/max(total,1))
		accs.append(np.mean(corr_acc))
	return float(np.mean(accs))


In [None]:
#print("Zero-shot (pre-train) — CIFAR-10 clean acc:",  eval_clean(src_test_loader))
#print("Zero-shot (pre-train) — CIFAR-10-C mean acc:", eval_c10c_mean())

for ep in range(1, EPOCHS+1):
	Ls, Lt = train_epoch()
	print(f"[Epoch {ep:02d}] L_sup={Ls:.4f}  L_dcp={Lt:.4f}")

print("Adapted — CIFAR-10 clean acc:",  eval_clean(src_test_loader))
print("Adapted — CIFAR-10-C mean acc:", eval_c10c_mean())



Step 88/391