In [None]:
!curl -L -o dataset.zip\
  https://www.kaggle.com/api/v1/datasets/download/ludehsar/apple-disease-dataset
!unzip -q dataset.zip -d datasets

In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from datasets import load_dataset
from diffusers import DDPMScheduler
from model import UNet2DModel, CNNModel
from diffusers.optimization import get_cosine_schedule_with_warmup
from accelerate import Accelerator
from PIL import Image
from tqdm import tqdm
from torch.utils.data import Dataset
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import os

# Configuration
image_size = 128
batch_size = 8
num_epochs = 1000
lr_unet = 1e-4
lr_cls = 1e-3
output_dir = "./ddpm_unet"
  

# Dataset Loading
dataset_train = load_dataset("datasets/apple_disease", split="train")
dataset_test = load_dataset("datasets/apple_disease", split="test")

class AppleDiseaseDataset(Dataset):
    def __init__(self, hf_dataset, transform=None):
        self.dataset = hf_dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image = self.dataset[idx]['image'].convert('RGB')
        label = self.dataset[idx]['label']
        if self.transform:
            image = self.transform(image)
        return {"pixel_values": image, 'labels': label}

transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
])

dataset_train = AppleDiseaseDataset(dataset_train, transform=transform)
dataset_test = AppleDiseaseDataset(dataset_test, transform=transform)

dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=4)
dataloader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=4)

# Models and Setup
model = UNet2DModel(
    sample_size=image_size,
    in_channels=3,
    out_channels=3,
    layers_per_block=2,
    block_out_channels=(64, 32, 16, 8),
    down_block_types=("DownBlock2D", "DownBlock2D", "DownBlock2D", "AttnDownBlock2D"),
    up_block_types=("AttnUpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D"),
    norm_num_groups=8
)
classifier = CNNModel(input_channels=8, num_cls=196)

noise_scheduler = DDPMScheduler(num_train_timesteps=1000)

# Training Functions
def train_one_epoch(model, classifier, dataloader, optimizer, lr_scheduler, epoch, noise_scheduler, train_classifier, freeze_unet):
    model.train()
    if freeze_unet:
        for param in model.parameters():
            param.requires_grad = False
    else:
        for param in model.parameters():
            param.requires_grad = True

    classifier.train() if train_classifier else classifier.eval()

    losses = []
    for step, batch in enumerate(tqdm(dataloader, desc=f"Epoch {epoch}")):
        clean_images = batch["pixel_values"]
        labels = batch["labels"]

        noise = torch.randn_like(clean_images)
        timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (clean_images.shape[0],), device=clean_images.device).long()
        noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)

        noise_pred = model(noisy_images, timesteps).sample

        loss = loss_recon(noise_pred, noise)

        if train_classifier:
            with torch.no_grad():
                max_timesteps = torch.tensor([0]*clean_images.shape[0]).to(clean_images.device)
                mid_block = model(clean_images, max_timesteps).mid_block
            cls_outputs = classifier(mid_block.detach() if freeze_unet else mid_block)
            loss += loss_cls(cls_outputs, labels)

        optimizer.zero_grad()
        accelerator.backward(loss)
        optimizer.step()
        lr_scheduler.step()
        losses.append(loss.item())

    print(f"Epoch {epoch} - Loss: {sum(losses)/len(losses):.4f}")

def evaluate(model, classifier, dataloader):
    model.eval()
    classifier.eval()

    y_true, y_pred = [], []
    with torch.no_grad():
        for batch in dataloader:
            clean_images = batch["pixel_values"]
            labels = batch["labels"]

            max_timesteps = torch.tensor([1000]*clean_images.shape[0]).to(clean_images.device)
            mid_block = model(clean_images, max_timesteps).mid_block
            outputs = classifier(mid_block)

            preds = torch.argmax(outputs, dim=1)
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())

    acc = accuracy_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred, average='weighted')
    rec = recall_score(y_true, y_pred, average='weighted')
    f1 = f1_score(y_true, y_pred, average='weighted')

    print(f"Evaluation - Accuracy: {acc:.4f}, Precision: {prec:.4f}, Recall: {rec:.4f}, F1-score: {f1:.4f}")


train_classifier = False  
freeze_unet = False  

params = [{"params": model.parameters(), "lr": lr_unet}]
if train_classifier:
    params.append({"params": classifier.parameters(), "lr": lr_cls})

optimizer = torch.optim.AdamW(params)
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=500,
    num_training_steps=len(dataloader_train) * num_epochs,
)

loss_recon = torch.nn.MSELoss()
loss_cls = torch.nn.CrossEntropyLoss()

accelerator = Accelerator()
model, classifier, optimizer, dataloader_train, dataloader_test = accelerator.prepare(
    model, classifier, optimizer, dataloader_train, dataloader_test
)

# Training Loop
for epoch in range(num_epochs):
    train_one_epoch(model, classifier, dataloader_train, optimizer, lr_scheduler, epoch, noise_scheduler, False, False)
    if train_classifier:
        evaluate(model, classifier, dataloader_test)

In [None]:
from diffusers import DDPMScheduler, DDPMPipeline

if accelerator.is_main_process:
    model.eval()
    pipeline = DDPMPipeline(unet=model, scheduler=noise_scheduler)
    pipeline.to("cuda")
    images = pipeline(num_inference_steps=1000).images
images[0]

### Train Classifier

In [None]:
train_classifier = True  
freeze_unet = True  

params = []
if train_classifier:
    params.append({"params": classifier.parameters(), "lr": lr_cls})

num_epochs = 100
optimizer = torch.optim.AdamW(params)
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=500,
    num_training_steps=len(dataloader_train) * num_epochs,
)


accelerator = Accelerator()
model, classifier, optimizer, dataloader_train, dataloader_test = accelerator.prepare(
    model, classifier, optimizer, dataloader_train, dataloader_test
)

# Training Loop
for epoch in range(num_epochs):
    train_one_epoch(model, classifier, dataloader_train, optimizer, lr_scheduler, epoch, noise_scheduler, True, True)
    if train_classifier:
        evaluate(model, classifier, dataloader_test)