# Feature Diffusion Replay

### Setup

In [None]:
import os
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
import wandb
from diffusers import UNet1DModel, DDPMScheduler
from copy import deepcopy
from models.unet.unet_1d_condition import UNet1DConditionModel
from diffusers import DPMSolverMultistepScheduler

import math
import pandas as pd
import matplotlib.pyplot as plt


SEED = 0
random.seed(SEED)
torch.manual_seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

In [9]:
import sys
sys.path.insert(0,'/kaggle/working/diffusers-unet-1d-condition')

In [None]:
wandb.login(key="1e41cd1b20b6ece9ffa899875a14d6311963debe")

### Configuration

In [12]:
DATA_ROOT    = "/kaggle/input/officehome/OfficeHome"
DOMAINS      = ["Product", "Art", "Clipart", "RealWorld"]
IMG_SIZE     = 224
BATCH_SIZE   = 64
LR           = 1e-4
NUM_EPOCHS   = {
    "BASE_SOURCE_TRAIN": 50,
}
PATIENCE     = 5


TH_START = 0.6
TH_END = 0.9
ENT_WEIGHT = 0.1
EMA_DECAY = 0.995

FEATURE_DIM      = 512     
DIFFUSION_STEPS  = 1000
DIFFUSION_EPOCHS = 5

SOURCE_DOM = "RealWorld"
DOWN_DOM = ["Product", "Clipart", "Art"]
DOMAINS = ["RealWorld", "Product", "Clipart", "Art"]

### Data Loading

In [13]:
## Defining the loaders and transforms

def get_transforms(img_size=IMG_SIZE):
    train_t = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5],
                             std =[0.5, 0.5, 0.5])
    ])
    test_t = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5],
                             std =[0.5, 0.5, 0.5])
    ])
    return train_t, test_t

def load_domain(domain, val_split=0.1, test_split=0.1):
    train_t, test_t = get_transforms()
    path = os.path.join(DATA_ROOT, domain)
    ds = datasets.ImageFolder(path, transform=train_t)
    n = len(ds)
    v = int(val_split * n)
    t = int(test_split * n)
    train_ds, val_ds, test_ds = random_split(
        ds,
        [n - v - t, v, t],
        generator=torch.Generator().manual_seed(SEED)
    )
    val_ds.dataset.transform  = test_t
    test_ds.dataset.transform = test_t
    return {
        "train": DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True),
        "val":   DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False),
        "test":  DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False),
    } 

In [None]:
## Checking loaders

rw_loaders = load_domain("RealWorld")
print("Found classes:", rw_loaders["train"].dataset.dataset.classes)
imgs, labels = next(iter(rw_loaders["train"]))
print("Batch shape:", imgs.shape, "— labels example:", labels[:5])

### Preliminary EDA

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# Build a table of (domain, class, num_images)
records = []
for domain in DOMAINS:
    domain_path = os.path.join(DATA_ROOT, domain)
    for cls in sorted(os.listdir(domain_path)):
        cls_path = os.path.join(domain_path, cls)
        if not os.path.isdir(cls_path):
            continue
        num_imgs = sum(
            1 for fname in os.listdir(cls_path)
            if fname.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif'))
        )
        records.append({"domain": domain, "class": cls, "num_images": num_imgs})

eda_df = pd.DataFrame(records)

# 1. Check class name consistency
class_sets = {d: set(eda_df[eda_df['domain']==d]['class']) for d in DOMAINS}
common_classes = set.intersection(*class_sets.values())
extra = {d: class_sets[d] - common_classes for d in DOMAINS}
missing = {d: common_classes - class_sets[d] for d in DOMAINS}

print(f"Number of common classes across all domains: {len(common_classes)}")
for d in DOMAINS:
    print(f"- {d}: {len(class_sets[d])} classes, {len(extra[d])} extra, {len(missing[d])} missing")

print("\nExtra classes by domain:")
for d, ex in extra.items():
    print(f"  {d}: {sorted(ex)}")
print("\nMissing classes by domain:")
for d, ms in missing.items():
    print(f"  {d}: {sorted(ms)}")

# 2. Check per-class image count consistency
pivot = eda_df.pivot(index='class', columns='domain', values='num_images').fillna(0).astype(int)
display(pivot.head())

# Identify classes with differing counts
inconsistent = pivot[pivot.nunique(axis=1) > 1]
print(f"\nClasses with differing image counts across domains ({len(inconsistent)}):")
print(inconsistent.index.tolist())

# 3. (Optional) Visualize count distribution for inconsistent classes
if not inconsistent.empty:
    inconsistent.plot.bar(figsize=(10,4))
    plt.title("Image Count per Class for Inconsistent Classes")
    plt.ylabel("Count")
    plt.xlabel("Class")
    plt.xticks(rotation=90)
    plt.tight_layout()
    plt.show()


### Base Classifier

In [15]:
### Model Definitions

def build_resnet(num_classes):
    m = models.resnet50(pretrained=True)
    m.fc = nn.Linear(m.fc.in_features, num_classes)
    return m.to(device)

def validate(model, loader):
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for x,y in loader:
            x,y = x.to(device), y.to(device)
            pred = model(x).argmax(1)
            correct += (pred==y).sum().item(); total += y.size(0)
    return 100*correct/total


In [None]:
wandb.init(
    project="FDR_UCDA",
    name="base_classifier_real_world",
    config={
        "domain": "Real_World",
        "model": "resnet50",
        "batch_size": BATCH_SIZE,
        "lr": LR,
        "epochs": NUM_EPOCHS["BASE_SOURCE_TRAIN"],
        "patience": PATIENCE
    }
)

In [34]:
def train_base(model, loaders, epochs, patience, device):
    # Logging
    config = wandb.config
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config.lr)
    wandb.watch(model, criterion, log="all", log_freq=50)

    best_val_acc = 0.0
    patience_counter = 0

    for epoch in range(1, epochs + 1):
        model.train()
        running_loss = 0.0
        running_correct = 0
        running_total = 0

        for imgs, labels in tqdm(loaders["train"], desc=f"Train Epoch {epoch}/{epochs}"):
            imgs, labels = imgs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * imgs.size(0)
            preds = outputs.argmax(dim=1)
            running_correct += (preds == labels).sum().item()
            running_total += labels.size(0)

        train_loss = running_loss / running_total
        train_acc = 100 * running_correct / running_total

        # Validation
        model.eval()
        val_correct, val_total = 0, 0
        val_loss = 0.0
        with torch.no_grad():
            for imgs, labels in loaders["val"]:
                imgs, labels = imgs.to(device), labels.to(device)
                outputs = model(imgs)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * imgs.size(0)
                preds = outputs.argmax(dim=1)
                val_correct += (preds == labels).sum().item()
                val_total += labels.size(0)
        val_loss /= val_total
        val_acc = 100 * val_correct / val_total

        # Log to W&B
        wandb.log({
            "epoch": epoch,
            "train_loss": train_loss,
            "train_acc": train_acc,
            "val_loss": val_loss,
            "val_acc": val_acc
        })

        print(f"Epoch {epoch}: Train loss={train_loss:.4f}, acc={train_acc:.2f}% | "
              f"Val loss={val_loss:.4f}, acc={val_acc:.2f}%")

        # Early stopping & checkpoint
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
            torch.save(model.state_dict(), "best_base_model.pth")
            print(f"New best model saved (val_acc={val_acc:.2f}%)")
            wandb.save("best_base_model.pth")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Stopping early at epoch {epoch} (no improvement in {patience} epochs)")
                break

    # Load best weights
    model.load_state_dict(torch.load("best_base_model.pth"))
    final_test_acc = validate(model, loaders["test"])
    print(f"Final Test Accuracy: {final_test_acc:.2f}%")
    wandb.log({"test_acc": final_test_acc})
    wandb.finish()
    return model


In [None]:
# Usage:
rw_loaders = load_domain(SOURCE_DOM)
model = build_resnet(len(rw_loaders["train"].dataset.dataset.classes))
model = train_base(model, rw_loaders, NUM_EPOCHS["BASE_SOURCE_TRAIN"], patience=3, device=device)

### Zero Shot Results

In [None]:
pr_loader   = load_domain("Product")
art_loader  = load_domain("Art")
cart_loader = load_domain("Clipart")

base_model = build_resnet(len(rw_loaders["train"].dataset.dataset.classes))
base_model.load_state_dict(torch.load('best_base_model.pth', weights_only=True))

# for loader in [rw_loaders, pr_loader, art_loader, cart_loader]:
#     base_model.eval()
#     correct = total = 0
#     with torch.no_grad():
#         for x,y in loader['val']:
#             x,y = x.to(device), y.to(device)
#             pred = base_model(x).argmax(1)
#             correct += (pred==y).sum().item(); total += y.size(0)
#     print(100*correct/total)



|Domain | Accuracy (%) |
|----------------|--------------|
| Real World        | **84.60**    |
| Product            | **72.46**    |
| Art       | **58.68**    |
| Clip Art         | **52.98**    |

In [None]:
torch.manual_seed(SEED)

### Downstream Training without Handling Catastrophic Forgetting

In [None]:
wandb.init(
    project="FDR_UCDA",
    name="pseudo_only_adapt",
    config={
        "domains": DOMAINS,
        "lr": 1e-4,
        "epochs_per_domain": 20,
        "patience": 3,
        "pseudo_thresh": 0.9
    }
)

In [17]:
loaders = {i: load_domain(i) for i in DOMAINS}

In [18]:

def get_highconf_pseudolabels(model, imgs, threshold=0.8):
    with torch.no_grad():
        logits = model(imgs)
        probs = nn.Softmax(dim=1)(logits)
    conf, y_hat = probs.max(dim=1)
    mask = conf > threshold
    return mask, y_hat

def adapt_wo_replay(model, loaders, epochs, threshold_start, threshold_end,
    lr=1e-5, weight_decay=1e-4, alpha=0.5, ema_decay=0.999, device='cuda', patience = 3):
    # Logging
    teacher = deepcopy(model)
    for p in teacher.parameters(): p.requires_grad = False

    optimizer = optim.Adam(
        model.parameters(), lr=lr, weight_decay=weight_decay
    )
    criterion = nn.KLDivLoss(reduction="batchmean")  # for soft labels

    total_steps = epochs * len(loaders["train"])
    step = 0

    counter = 0
    max_val_acc = 0
    best_model_name = ""

    for epoch in range(1, epochs+1):
        model.train()
        teacher.train()  # teacher stays in train mode for batchnorm stats
        for imgs, _ in loaders["train"]:
            imgs = imgs.to(device)
            step += 1

            # 1) Teacher predictions → soft pseudo-labels
            with torch.no_grad():
                t_logits = teacher(imgs)
                t_probs  = F.softmax(t_logits, dim=1).float()

            # 2) Student predictions
            s_logits = model(imgs)
            s_logprobs = F.log_softmax(s_logits, dim=1).float()

            # 3) Compute dynamic threshold
            thresh = threshold_start + (threshold_end - threshold_start) * (step / total_steps)

            # 4) Mask & pseudo-label loss on high‑conf samples
            conf, _ = t_probs.max(dim=1)
            mask = conf > thresh
            if mask.any():
                loss_pseudo = criterion(
                    s_logprobs[mask], t_probs[mask]
                )
            else:
                loss_pseudo = torch.tensor(0.0, device=device)

            # 5) Entropy minimization on all samples
            loss_ent = -(F.softmax(s_logits, dim=1) * s_logprobs).sum(dim=1).mean()

            loss = alpha * loss_pseudo + (1-alpha) * loss_ent

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # 6) EMA update of teacher
            with torch.no_grad():
                for tp, sp in zip(teacher.parameters(), model.parameters()):
                    tp.data.mul_(ema_decay).add_(sp.data, alpha=1-ema_decay)

        # Validation
        val_acc = validate(model, loaders["val"])
        val_loss = 0.0  # compute on val set as usual
        print(f"Epoch {epoch}: Val Acc={val_acc:.2f}%")

        if (val_acc > max_val_acc):
            torch.save(model.state_dict(), f"best_model_{epoch}.pth")
            best_model_name = f"best_model_{epoch}.pth"
            counter = 0
            max_val_acc = val_acc
        else:
            counter+=1

        if (counter > patience):
            model.load_state_dict(torch.load(best_model_name, weights_only=True))
            return model
            
    return model


In [None]:

# Run adaptation to Product
base_model = adapt_wo_replay(
    base_model, 
    loaders["Product"], 
    epochs=5,
    threshold_start=0.7,
    threshold_end=0.9,
    device=device,
)


In [None]:
for loader in [rw_loaders, pr_loader, art_loader, cart_loader]:
    base_model.eval()
    correct = total = 0
    with torch.no_grad():
        for x,y in loader['test']:
            x,y = x.to(device), y.to(device)
            pred = base_model(x).argmax(1)
            correct += (pred==y).sum().item(); total += y.size(0)
    print(100*correct/total)

Clear increase in acc here but decrease due to catastrophic forgetting

In [None]:

# Run adaptation to Product
base_model = adapt_wo_replay(
    base_model, 
    loaders["Art"], 
    epochs=5,
    threshold_start=0.7,
    threshold_end=0.9,
    device=device,
)


In [None]:
for loader in [rw_loaders, pr_loader, art_loader, cart_loader]:
    base_model.eval()
    correct = total = 0
    with torch.no_grad():
        for x,y in loader['test']:
            x,y = x.to(device), y.to(device)
            pred = base_model(x).argmax(1)
            correct += (pred==y).sum().item(); total += y.size(0)
    print(100*correct/total)

### Feature-space Diffusion Model

In [None]:
class ConditionalFeatureDDPM(nn.Module):
    def __init__(self, feat_dim, num_classes, num_domains):
        super().__init__()
        self.class_embed  = nn.Embedding(num_classes, feat_dim)
        self.domain_embed = nn.Embedding(num_domains, feat_dim)
        self.unet = UNet1DModel(
            sample_size=feat_dim, in_channels=1, out_channels=1,
            block_out_channels=(64,128,128,64), layers_per_block=2,
            down_block_types=("DownBlock1D","DownBlock1D","AttnDownBlock1D","DownBlock1D"),
            up_block_types  =("UpBlock1D","AttnUpBlock1D","UpBlock1D","UpBlock1D")
        )

    def forward(self, x, t, class_ids, domain_ids):
        # x: [B,1,feat], t:[B], class_ids/domain_ids:[B]
        c = self.class_embed(class_ids).unsqueeze(1)   # [B,1,feat]
        d = self.domain_embed(domain_ids).unsqueeze(1) # [B,1,feat]
        return self.unet(x + c + d, t).sample         # predicted noise

def sample_replay_feats(ddpm, scheduler, class_ids, domain_ids, device):
    B = len(class_ids)
    x = torch.randn(B,1,ddpm.unet.config.sample_size, device=device)
    scheduler.set_timesteps(scheduler.num_train_timesteps)
    for t in scheduler.timesteps:
        eps = ddpm(x, t.expand(B), class_ids, domain_ids)
        x   = scheduler.step(eps, t, x).prev_sample
    return x.squeeze(1)  # [B, feat_dim]


class MeanTeacher:
    def __init__(self, student, ema_decay=0.99):
        self.student = student
        self.teacher = deepcopy(student).eval()
        for p in self.teacher.parameters(): p.requires_grad = False
        self.ema_decay = ema_decay

    @torch.no_grad()
    def update_teacher(self):
        for t_p, s_p in zip(self.teacher.parameters(), self.student.parameters()):
            t_p.data.mul_(self.ema_decay).add_(s_p.data, alpha=1-self.ema_decay)

# ---------------------
# 3) BASE CLASSIFIER
# ---------------------
def build_classifier(num_classes):
    m = models.resnet18(pretrained=True)
    m.fc = nn.Linear(m.fc.in_features, num_classes)
    return m

# ---------------------
# 4) TRAINING LOOP
# ---------------------
def train_with_diffusion_replay(
    student, ddpm, scheduler,
    source_feats, source_labels,    # tensors on CPU
    domain_order,                   # e.g. ["Real_World","Art","Clipart","Product"]
    data_loaders,                   # dict domain->{"train","val","test"} loaders
    device,
    epochs_per_domain=3,
    pseudo_thresh=0.8,
    lr=1e-4
):
    # prepare
    mt = MeanTeacher(student.to(device), ema_decay=0.995)
    optimizer = optim.Adam(mt.student.parameters(), lr=lr)
    ce = nn.CrossEntropyLoss()
    mse = nn.MSELoss()

    # train ddpm initially on source
    ddpm = ddpm.to(device)
    optD = optim.Adam(ddpm.parameters(), lr=lr)
    for epoch in range(epochs_per_domain):
        idx = torch.randperm(len(source_feats))[:512]
        h0 = source_feats[idx].to(device)
        y0 = source_labels[idx].to(device)
        d0 = torch.zeros_like(y0, device=device)  # source domain index=0

        h0 = h0.unsqueeze(1)
        noise = torch.randn_like(h0)
        t = torch.randint(0, scheduler.num_train_timesteps, (h0.size(0),), device=device)

        noised = scheduler.add_noise(h0, noise, t)
        pred  = ddpm(noised, t, y0, d0)
        lossD = mse(pred, noise)
        optD.zero_grad(); lossD.backward(); optD.step()

    # now continual adaptation
    for domain_id, domain in enumerate(domain_order[1:], start=1):
        for epoch in range(epochs_per_domain):
            mt.student.train()
            for imgs, _ in tqdm(data_loaders[domain]["train"], desc=f"[{domain}] Ep{epoch+1}"):
                imgs = imgs.to(device)
                B = imgs.size(0)

                # -- (1) Teacher pseudo-labels --
                with torch.no_grad():
                    t_logits = mt.teacher(imgs)
                    t_probs  = F.softmax(t_logits, dim=1)
                    conf, y_pseudo = t_probs.max(1)
                    mask = conf > pseudo_thresh

                # -- (2) Student supervised step on pseudo-labels --
                if mask.sum() > 0:
                    out_s = mt.student(imgs[mask])
                    loss_pl = ce(out_s, y_pseudo[mask].to(device))
                else:
                    loss_pl = torch.tensor(0., device=device)

                # -- (3) Diffusion replay --
                # sample balanced classes and this domain
                cls_ids = torch.randint(0, student.fc.out_features, (B,), device=device)
                dom_ids = torch.full((B,), domain_id-1, dtype=torch.long, device=device)
                feats_replay = sample_replay_feats(ddpm, scheduler, cls_ids, dom_ids, device)
                out_r = mt.student.fc(feats_replay)
                loss_re = ce(out_r, cls_ids)

                # -- (4) Entropy regularization on all imgs --
                out_all = mt.student(imgs)
                logp_all = F.log_softmax(out_all, dim=1)
                loss_ent = -(F.softmax(out_all, dim=1) * logp_all).sum(1).mean()

                # total loss
                loss = loss_pl + loss_re + 0.1*loss_ent
                optimizer.zero_grad(); loss.backward(); optimizer.step()

                # EMA update
                mt.update_teacher()

        # end of domain epoch

    return mt.student, mt.teacher


### Adaptation Model and Loop

### Evaluation

### Final Adaptive ModeL

In [None]:
### Configuration

In [None]:
### Creating Feature Extractir
from torchvision.models.feature_extraction import create_feature_extractor
return_nodes = {"avgpool": "features"}
feature_extractor = create_feature_extractor(base_model, return_nodes)

In [None]:
wandb.init(
    project="FDR_UCDA",
    name="mt_diffusion_replay_1d",
    config=dict(
        lr=LR, batch=BATCH_SIZE, epochs=6,
        diff_steps=DIFFUSION_STEPS, ema=EMA_DECAY,
        th_start=TH_START, th_end=TH_END,
        ent_weight=ENT_WEIGHT
    )
)

In [None]:
### DIFFUSION SETUP
scheduler = DPMSolverMultistepScheduler(
    num_train_timesteps=DIFF_STEPS, prediction_type="epsilon"
)

ddpm = UNet1DConditionModel(
    sample_size=student.fc.in_features,  # 512 for ResNet‑18
    in_channels=1, out_channels=1,
    layers_per_block=2,
    block_out_channels=(64,64,128,128,256),
    down_block_types=("DownBlock1D","AttnDownBlock1D","ResnetDownsampleBlock1D","DownBlock1D","DownBlock1D"),
    up_block_types=("UpBlock1D","ResnetUpsampleBlock1D","AttnUpBlock1D","UpBlock1D","UpBlock1D"),
    num_class_embeds=len(loaders["train"].dataset.dataset.classes),
    class_embeddings_concat=True,
    encoder_hid_dim=64
).to(device)
opt_ddpm = optim.Adam(ddpm.parameters(), lr=LR)


### Custom 1d diffusion Implementation

In [19]:
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):  # x: (batch,)
        device = x.device
        half = self.dim // 2
        emb = math.log(10000) / (half - 1)
        emb = torch.exp(torch.arange(half, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb  # (batch, dim)

class ConditionalUNet1D(nn.Module):
    def __init__(self,
                 input_channels=1,
                 base_channels=64,
                 channel_mults=(1, 2, 4),
                 time_emb_dim=128,
                 domain_vocab_size=10,
                 class_vocab_size=10,
                 cond_emb_dim=32):
        super().__init__()
        # time embedding
        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim * 2),
            nn.SiLU(),
            nn.Linear(time_emb_dim * 2, time_emb_dim)
        )
        # domain & class embedding
        self.domain_embed = nn.Embedding(domain_vocab_size, cond_emb_dim)
        self.class_embed = nn.Embedding(class_vocab_size, cond_emb_dim)
        # initial conv
        chs = base_channels
        self.initial_conv = nn.Conv1d(input_channels, chs, kernel_size=3, padding=1)

        # downsample
        self.downs = nn.ModuleList()
        in_ch = chs
        for mult in channel_mults:
            out_ch = base_channels * mult
            block = nn.Sequential(
                nn.SiLU(),
                nn.Conv1d(in_ch, out_ch, 3, padding=1),
                nn.SiLU(),
                nn.Conv1d(out_ch, out_ch, 3, padding=1)
            )
            down = nn.Module()
            down.block = block
            down.pool = nn.AvgPool1d(2)
            self.downs.append(down)
            in_ch = out_ch

        # bottleneck
        self.mid_block1 = nn.Sequential(
            nn.SiLU(),
            nn.Conv1d(in_ch, in_ch, 3, padding=1)
        )
        self.mid_block2 = nn.Sequential(
            nn.SiLU(),
            nn.Conv1d(in_ch, in_ch, 3, padding=1)
        )

        # upsample
        self.ups = nn.ModuleList()
        for mult in reversed(channel_mults):
            out_ch = base_channels * mult
            # block expects concatenated skip(=out_ch) and upsampled h(=out_ch)
            block = nn.Sequential(
                nn.SiLU(),
                nn.Conv1d(out_ch * 2, out_ch, 3, padding=1),
                nn.SiLU(),
                nn.Conv1d(out_ch, out_ch, 3, padding=1)
            )
            up = nn.Module()
            up.block = block
            # transpose conv upsamples h from in_ch to out_ch channels
            up.upsample = nn.ConvTranspose1d(in_ch, out_ch, kernel_size=2, stride=2)
            self.ups.append(up)
            in_ch = out_ch

        self.final_conv = nn.Sequential(
            nn.SiLU(),
            nn.Conv1d(base_channels, input_channels, 1)
        )

        # embed to modulate via FiLM
        self.time_cond = nn.Linear(time_emb_dim, in_ch)
        self.domain_cond = nn.Linear(cond_emb_dim, in_ch)
        self.class_cond = nn.Linear(cond_emb_dim, in_ch)

    def forward(self, x, t, domain_id, class_id):
        # x: (batch, channels, length)
        # t: (batch,) timesteps
        # domain_id, class_id: (batch,)

        t_emb = self.time_mlp(t)
        d_emb = self.domain_embed(domain_id)
        c_emb = self.class_embed(class_id)

        h = self.initial_conv(x)
        hs = [h]
        for down in self.downs:
            h = down.block(h)
            hs.append(h)
            h = down.pool(h)

        h = self.mid_block1(h)
        h = self.mid_block2(h)

        # Upsample before concatenation to match skip dimensions
        for up in self.ups:
            skip = hs.pop()
            h = up.upsample(h)                        # <-- upsample h
            h = torch.cat([h, skip], dim=1)           # <-- then concat skip
            h = up.block(h)

        cond = self.time_cond(t_emb) + self.domain_cond(d_emb) + self.class_cond(c_emb)
        cond = cond.unsqueeze(-1)
        h = h * (1 + cond)
        return self.final_conv(h)


class GaussianDiffusion1D(nn.Module):
    def __init__(self, model, seq_length, timesteps=1000, beta_start=1e-4, beta_end=0.02, device='cpu'):
        super().__init__()
        self.model = model
        self.seq_length = seq_length
        self.timesteps = timesteps
        self.device = device
        # linear schedule
        self.register_buffer('betas', torch.linspace(beta_start, beta_end, timesteps))
        alphas = 1.0 - self.betas
        alphas_cum = torch.cumprod(alphas, dim=0)
        self.register_buffer('alphas_cum', alphas_cum)
        self.register_buffer('sqrt_alphas_cum', torch.sqrt(alphas_cum))
        self.register_buffer('sqrt_one_minus_ac', torch.sqrt(1 - alphas_cum))

    def q_sample(self, x_start, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x_start)
        sqrt_ac = self.sqrt_alphas_cum[t].view(-1, 1, 1)
        sqrt_om = self.sqrt_one_minus_ac[t].view(-1, 1, 1)
        return sqrt_ac * x_start + sqrt_om * noise

    def p_losses(self, x_start, t, domain_id, class_id, noise=None):
        noise = noise or torch.randn_like(x_start)
        x_noisy = self.q_sample(x_start, t, noise)
        predicted = self.model(x_noisy, t, domain_id, class_id)
        return F.mse_loss(predicted, noise)

    @torch.no_grad()
    def sample(self, batch_size, domain_id, class_id):
        x = torch.randn(batch_size, 1, self.seq_length, device=self.device)
        for i in reversed(range(self.timesteps)):
            t = torch.full((batch_size,), i, device=self.device, dtype=torch.long)
            noise_pred = self.model(x, t, domain_id, class_id)
            beta = self.betas[t].view(-1, 1, 1)
            alpha = 1 - beta
            x = (x - beta.sqrt() * noise_pred) / alpha.sqrt()
            if i > 0:
                x = x + torch.sqrt(beta) * torch.randn_like(x)
        return x



In [None]:
wandb.init(project="1d_conditional_diffusion")

In [20]:

def train_diffusion(
    diffusion: GaussianDiffusion1D,
    model: nn.Module,
    extractor: nn.Module,
    loaders: dict,
    domain_label: int,
    epochs: int = 10,
    lr: float = 1e-4,
    device: str = 'cuda'
):
    optimizer = torch.optim.Adam(diffusion.model.parameters(), lr=lr)
    diffusion.to(device)

    for epoch in range(1, epochs+1):
        # Training phase
        diffusion.model.train()
        train_loss = 0.0
        for imgs, class_ids in loaders['train']:
            imgs = imgs.to(device)
            class_ids = class_ids.to(device)
            with torch.no_grad():
                feats = extractor(imgs)["features"]  # (B, 1, L)
                feats = feats.squeeze().squeeze().unsqueeze(1)

            bsz = feats.size(0)
            t = torch.randint(0, diffusion.timesteps, (bsz,), device=device).long()
            domain_ids = torch.full((bsz,), domain_label, device=device, dtype=torch.long)
            loss = diffusion.p_losses(feats, t, domain_ids, class_ids)
                
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * bsz

        train_loss /= len(loaders['train'].dataset)

        # Validation phase
        diffusion.model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for imgs, class_ids in loaders['val']:
                imgs = imgs.to(device)
                class_ids = class_ids.to(device)
                feats = extractor(imgs)["features"]
                feats = feats.squeeze().squeeze().unsqueeze(1)
                bsz = feats.size(0)
                t = torch.randint(0, diffusion.timesteps, (bsz,), device=device).long()
                domain_ids = torch.full((bsz,), domain_label, device=device, dtype=torch.long)
                loss = diffusion.p_losses(feats, t, domain_ids, class_ids)

                val_loss += loss.item() * bsz
        val_loss /= len(loaders['val'].dataset)
                # Log metrics to W&B
        wandb.log({"epoch": epoch, "train_loss": train_loss, "val_loss": val_loss})
        # print(f"Epoch {epoch}: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}")
        print(f"Epoch {epoch}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")

    # Testing (optional)
    diffusion.model.eval()
    test_loss = 0.0
    with torch.no_grad():
        for imgs, class_ids in loaders['test']:
            imgs = imgs.to(device)
            class_ids = class_ids.to(device)
            feats = extractor(imgs)["features"]
            feats = feats.squeeze().squeeze().unsqueeze(1)
            bsz = feats.size(0)
            t = torch.randint(0, diffusion.timesteps, (bsz,), device=device).long()
            domain_ids = torch.full((bsz,), domain_label, device=device, dtype=torch.long)
            loss = diffusion.p_losses(feats, t, domain_ids, class_ids)
            test_loss += loss.item() * bsz
    test_loss /= len(loaders['test'].dataset)
    wandb.log({"test_loss": test_loss})
    print(f"Test Loss = {test_loss:.4f}")

In [21]:
domain_vocab_size = 4
class_vocab_size = 65
seq_length = 2048

In [None]:
model = ConditionalUNet1D(
    input_channels=1,
    base_channels=64,
    channel_mults=(1, 2, 4),
    time_emb_dim=128,
    domain_vocab_size=domain_vocab_size,
    class_vocab_size=class_vocab_size,
    cond_emb_dim=32
).to(device)

diffusion = GaussianDiffusion1D(
    model,
    seq_length=seq_length,
    timesteps=1000,
    device=device
)

# Choose which domain to train on (0 through 3)
domain_label = 0  # e.g. first domain

# Start training
# train_diffusion(
#     diffusion=diffusion,
#     model=model,
#     extractor=feature_extractor,
#     loaders=rw_loaders,
#     domain_label=domain_label,
#     epochs=20,
#     lr=2e-4,
#     device=device
# )

diffusion.load_state_dict(torch.load("diffusion.pth", weights_only=True))

In [22]:
torch.save(diffusion.state_dict(), "diffusion.pth")

In [None]:
domain_ids = torch.full((64,), 0, device=device, dtype=torch.long)
class_ids = torch.full((64,), 1, device=device, dtype=torch.long)
print(torch_ids.shape, class_ids.shape)
diffusion.sample(64, domain_ids, class_ids)

### Final Combination

We have base trained model and the diffusion model with us

Preliminary results on different domains

In [23]:
loader_dict = {"RealWorld":rw_loaders, "Product": pr_loader, "Art": art_loader, "ClipArt": cart_loader}

In [42]:
def test_accuracy(model, loader_dict):
    for key, loader in loader_dict.items():
        model.eval()
        correct = total = 0
        with torch.no_grad():
            for x,y in loader['test']:
                x,y = x.to(device), y.to(device)
                pred = model(x).argmax(1)
                correct += (pred==y).sum().item(); total += y.size(0)
        print(f"Accuracy on {key}:", 100*correct/total)
    return

In [None]:
test_accuracy(base_model, loader_dict)

### Adapting without Diffusion Replay

In [95]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
noreplay = deepcopy(base_model)

# Adapting on Product
noreplay = adapt_wo_replay(noreplay, loader_dict["Product"], epochs=20, threshold_start=0.5, threshold_end=0.8,
    lr=1e-5, weight_decay=1e-4, alpha=0.7, ema_decay=0.99, device='cuda', patience = 5)

In [None]:
# noreplay.load_state_dict(torch.load("best_model_7.pth", weights_only=True))
test_accuracy(noreplay, loader_dict)

In [None]:
# Adapting on Art
noreplay = adapt_wo_replay(noreplay, loader_dict["Art"], epochs=20, threshold_start=0.6, threshold_end=0.8,
    lr=1e-5, weight_decay=1e-4, alpha=0.7, ema_decay=0.99, device='cuda', patience = 5)

In [None]:
noreplay.load_state_dict(torch.load("best_model_6.pth", weights_only=True))
test_accuracy(noreplay, loader_dict)

In [None]:
# Adapting on Art
noreplay = adapt_wo_replay(noreplay, loader_dict["ClipArt"], epochs=20, threshold_start=0.6, threshold_end=0.8,
    lr=1e-5, weight_decay=1e-4, alpha=0.7, ema_decay=0.99, device='cuda', patience = 5)

In [None]:
noreplay.load_state_dict(torch.load("best_model_4.pth", weights_only=True))
test_accuracy(noreplay, loader_dict)

In [None]:
noreplay = deepcopy(base_model)

# Adapting on Product
noreplay = adapt_wo_replay(noreplay, loader_dict["ClipArt"], epochs=30, threshold_start=0.2, threshold_end=0.9,
    lr=1e-5, weight_decay=1e-4, alpha=0.5, ema_decay=0.995, device='cuda', patience = 5)

In [None]:
test_accuracy(noreplay, loader_dict)

### With Replay

In [46]:
def adapt_with_diffusion_replay(
    model,
    loaders,
    epochs,
    threshold_start,
    threshold_end,
    prev_domains,
    diffusion,
    diffusion_threshold=0.9,
    diffusion_ratio=0.3,
    lr=1e-5,
    weight_decay=1e-4,
    alpha=0.5,
    ema_decay=0.999,
    device='cuda',
    patience=3,
):
    """
    Continual adaptation for ResNet with EMA teacher-student and diffusion feature replay.

    Assumes `model` is a ResNet (e.g., torchvision.models.resnet) where:
      - backbone = all layers except the final `fc`
      - classifier = model.fc

    Args:
        model: ResNet model.
        loaders: dict with 'train' and 'val' DataLoaders.
        epochs: number of epochs.
        threshold_start: initial confidence threshold for real data.
        threshold_end: final confidence threshold for real data.
        prev_domains: list of domain ids previously seen.
        diffusion: diffusion model providing features via `.sample_features(batch_size, domain_ids, class_ids)`.
        diffusion_threshold: fixed threshold for pseudo-labeling diffusion samples.
        diffusion_ratio: fraction of each batch replaced by diffusion data.
        lr, weight_decay, alpha, ema_decay, device, patience: hyperparams.
    """
    # Separate backbone and classifier
    # Backbone: all layers except final fc
    backbone = nn.Sequential(*list(model.children())[:-1]).to(device)
    classifier = model.fc.to(device)
    feature_dim = model.fc.in_features

    # EMA teacher
    teacher_backbone = deepcopy(backbone)
    teacher_classifier = deepcopy(classifier)
    for p in teacher_backbone.parameters(): p.requires_grad = False
    for p in teacher_classifier.parameters(): p.requires_grad = False

    optimizer = optim.Adam(
        list(backbone.parameters()) + list(classifier.parameters()),
        lr=lr, weight_decay=weight_decay
    )
    criterion = nn.KLDivLoss(reduction='batchmean')
    total_steps = epochs * len(loaders['train'])
    step = 0

    best_val_acc = 0.0
    patience_counter = 0
    best_ckpt = None

    for epoch in range(1, epochs + 1):
        backbone.train(); classifier.train()
        teacher_backbone.train(); teacher_classifier.train()

        for imgs, _ in loaders['train']:
            imgs = imgs.to(device)
            batch_size = imgs.size(0)
            step += 1

            # split counts
            n_diff = int(batch_size * diffusion_ratio)
            n_real = batch_size - n_diff
            real_imgs = imgs[:n_real]

            # diffusion features
            if n_diff > 0:
                domain_ids = torch.tensor(
                    [prev_domains[i % len(prev_domains)] for i in range(n_diff)],
                    device=device, dtype=torch.long)
                class_ids = torch.randint(
                    0, 65, (n_diff,), device=device)
                # print(domain_ids, n_diff, class_ids)
                # return
                diff_feats = diffusion.sample(n_diff, domain_ids, class_ids)
            else:
                diff_feats = torch.empty((0, feature_dim), device=device)

            # Teacher predictions
            with torch.no_grad():
                t_feats_real = backbone(real_imgs).view(n_real, -1)
                t_logits_real = teacher_classifier(t_feats_real)
                t_probs_real = F.softmax(t_logits_real, dim=1)

                if n_diff > 0:
                    t_logits_diff = teacher_classifier(diff_feats)
                    t_probs_diff = F.softmax(t_logits_diff, dim=1)
                else:
                    t_probs_diff = torch.empty((0, t_probs_real.size(1)), device=device)
                # print(t_probs_real.size(), t_probs_diff.size())
                t_probs = torch.cat([t_probs_real, t_probs_diff.squeeze()], dim=0)

            # Student predictions
            s_feats_real = backbone(real_imgs).view(n_real, -1)
            s_logits_real = classifier(s_feats_real)
            s_logprobs_real = F.log_softmax(s_logits_real, dim=1)

            if n_diff > 0:
                s_logits_diff = classifier(diff_feats)
                s_logprobs_diff = F.log_softmax(s_logits_diff, dim=1)
            else:
                s_logprobs_diff = torch.empty((0, s_logprobs_real.size(1)), device=device)

            s_logprobs = torch.cat([s_logprobs_real, s_logprobs_diff.squeeze()], dim=0)

            # thresholds
            dynamic_thresh = threshold_start + (threshold_end - threshold_start) * (step / total_steps)
            conf = t_probs.max(dim=1).values
            mask_real = conf[:n_real] > dynamic_thresh
            mask_diff = conf[n_real:] > diffusion_threshold

            # losses
            loss_pseudo = torch.tensor(0.0, device=device)
            if mask_real.any():
                loss_pseudo += criterion(
                    s_logprobs_real[mask_real], t_probs_real[mask_real])
            if mask_diff.any():
                loss_pseudo += criterion(
                    s_logprobs_diff[mask_diff], t_probs_diff[mask_diff])

            loss_ent = -(s_logprobs.exp() * s_logprobs).sum(dim=1).mean()
            loss = alpha * loss_pseudo + (1 - alpha) * loss_ent

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # EMA update
            with torch.no_grad():
                for tb, b in zip(teacher_backbone.parameters(), backbone.parameters()):
                    tb.data.mul_(ema_decay).add_(b.data, alpha=1-ema_decay)
                for tc, c in zip(teacher_classifier.parameters(), classifier.parameters()):
                    tc.data.mul_(ema_decay).add_(c.data, alpha=1-ema_decay)

        # validate
        # rebuild model for validation
        model.fc = classifier
        # backbone doesn't include fc, so reattach
        # assume validate uses model directly
        val_acc = validate(model, loaders['val'])
        print(f"Epoch {epoch}: Val Acc={val_acc:.2f}%")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
            best_ckpt = {
                'backbone': deepcopy(backbone.state_dict()),
                'classifier': deepcopy(classifier.state_dict())
            }
        else:
            patience_counter += 1

        if patience_counter > patience:
            print("Patience exceeded, loading best weights.")
            backbone.load_state_dict(best_ckpt['backbone'])
            classifier.load_state_dict(best_ckpt['classifier'])
            break

    # attach final weights
    model = deepcopy(model)
    model.fc.load_state_dict(classifier.state_dict())
    return model


In [None]:
print9

In [32]:
diffusion = diffusion.to(device)

In [None]:
new_model = deepcopy(base_model)
new_model = adapt_with_diffusion_replay(
    new_model,
    loaders["Product"],
    epochs=30,
    threshold_start=0.5,
    threshold_end=0.9,
    prev_domains=[0],
    diffusion=diffusion,
    diffusion_threshold=0.9,
    diffusion_ratio=0.3,
    lr=1e-5,
    weight_decay=1e-4,
    alpha=0.5,
    ema_decay=0.995,
    device='cuda',
    patience=3
)

test_accuracy(new_model, loaders)

In [None]:
test_accuracy(new_model, loaders)