In [None]:
# Setup
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
from diffusers import ScoreSdeVePipeline, ScoreSdeVeScheduler, UNet2DModel
from datasets import load_dataset, Dataset, DatasetDict
from torchvision.models import mobilenet_v3_small, MobileNet_V3_Small_Weights
from matplotlib import pyplot as plt
from tqdm import tqdm

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'     # Update this line if you want to use a different device such as TPU or Macbook's MPS
PRETRAINED = "google/ncsnpp-celebahq-256"
DATASET_SOURCE = "Ryan-sjtu/celebahq-caption"

In [None]:
# Load
unconditional_pipeline = ScoreSdeVePipeline.from_pretrained(PRETRAINED)
scheduler = ScoreSdeVeScheduler.from_pretrained(PRETRAINED)
unet = UNet2DModel.from_pretrained(PRETRAINED)
dataset: DatasetDict = load_dataset(DATASET_SOURCE)

In [None]:
# Train a time dependent gender classifier

BATCH_SIZE = 128
N_EPOCHS = 10

# EfficientNet
weights = MobileNet_V3_Small_Weights.DEFAULT
model = mobilenet_v3_small(weights=weights)
model = nn.Sequential(
    model.features,
    model.avgpool,
    nn.Flatten(),
    nn.Dropout(0.2, inplace=True),
    nn.Linear(in_features=576, res=1),
    nn.Sigmoid()
)
model = model.to(device=DEVICE)
preprocess = weights.transforms()

# Dataset
train_ds: Dataset = dataset['train'].with_format('torch', device=DEVICE)
train_ds = train_ds.map(lambda x: {
    'input': preprocess(torch.transpose(x['image'], -1,-3)),
    'label': torch.ones(1, device=DEVICE) if 'woman' in x['text'] else torch.zeros(1, device=DEVICE)
})
dataloader = DataLoader(train_ds, 64)

# TODO: Noise timescale

In [None]:
optim = torch.optim.Adam(model.parameters())
loss_fn = nn.BCELoss()
sig = nn.Sigmoid().to(device=DEVICE)

with tqdm(total=N_EPOCHS) as pbar:
    for epoch in range(N_EPOCHS):
        running_loss = 0
        for i, data in enumerate(dataloader):
            inputs, labels = data['input'], data['label']
            optim.zero_grad()
            predicted = model(inputs)
            loss = loss_fn(predicted, labels)
            running_loss += loss.item()
            loss.backward()
            optim.step()
        
        # Logging
        pbar.set_description(f"loss {running_loss}")
        pbar.update(1)
        pbar.refresh()

In [None]:
id = 12
test_ds: Dataset = dataset['train'].with_format('torch', device=DEVICE)
img = test_ds[id]['image']
img = preprocess(torch.transpose(img,-3,-1))
print(img)
print(f"Predicted {model(torch.unsqueeze(img, 0))}")
plt.imshow(test_ds[id]['image'].cpu())