In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torch.optim import AdamW
from torchvision import datasets, transforms
import torchvision.utils as vutils
from torch.utils.data import ConcatDataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
from diffusers import UNet2DModel, DDPMScheduler, DDPMPipeline
from diffusers.optimization import get_cosine_schedule_with_warmup
from accelerate import Accelerator
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
import random 
import timeit
import os

In [None]:
# Configuration
RANDOM_SEED = 42
IMG_SIZE = 64 
BATCH_SIZE = 4
LEARNING_RATE = 1e-4
NUM_EPOCHS = 1
NUM_GENERATE_IMAGES = 9
NUM_TIMESTEPS = 1000
MIXED_PRECISION = "fp16"
GRADIENT_ACCUMULATION_STEPS = 1

# Set seeds for reproducibility
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed_all(RANDOM_SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Device configuration
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.cuda.empty_cache()
PATH = '../data-students/TRAIN'

In [None]:
# Define Albumentations transforms
albumentations_transform = A.Compose([
    A.Rotate(limit=15, border_mode=cv2.BORDER_REFLECT, p=1.0),
    A.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1, p=1.0),
    A.Perspective(scale=(0.05, 0.1), keep_size=True, pad_mode=cv2.BORDER_REFLECT, p=1.0),
    ToTensorV2()
])

# Define torchvision transforms
torchvision_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Combined transform class
class CombinedTransform:
    def __init__(self, albumentations_transform, torchvision_transform):
        self.albumentations_transform = albumentations_transform
        self.torchvision_transform = torchvision_transform

    def __call__(self, image):
        # Apply Albumentations transforms
        image = np.array(image)
        augmented = self.albumentations_transform(image=image)
        image = augmented['image']

        # Convert tensor to PIL Image for torchvision transforms
        image = transforms.ToPILImage()(image)

        # Apply torchvision transforms
        image = self.torchvision_transform(image)
        return image

# Use the combined transformation
combined_transform = CombinedTransform(albumentations_transform, torchvision_transform)

# Initial dataset for class identification
init_dataset = datasets.ImageFolder(root=PATH, transform=torchvision_transform)
classes = init_dataset.classes
print("Classes in dataset:", classes)

# Define how many times you want to enlarge the dataset
enlarge_factor = 3

# Create a list to hold the datasets
combined_datasets = [init_dataset]

# Add the transformed dataset to the list multiple times
for _ in range(enlarge_factor):
    transformed_dataset = datasets.ImageFolder(root=PATH, transform=combined_transform)
    combined_datasets.append(transformed_dataset)

# Concatenate the datasets into a single dataset
enlarged_dataset = ConcatDataset(combined_datasets)

# Create a dictionary to store DataLoaders for each class
class_data_loaders = {}
num_classes = len(classes)  # Ensure num_classes matches the actual number of classes

# Filter dataset and create DataLoaders for each class
for class_idx in range(num_classes):
    class_indices = [i for i, (_, label) in enumerate(enlarged_dataset) if label == class_idx]
    class_subset = Subset(enlarged_dataset, class_indices)
    class_data_loaders[class_idx] = DataLoader(class_subset, batch_size=BATCH_SIZE, shuffle=True)

print("Data loaders created for each class.")
train_dataloader = DataLoader(init_dataset, batch_size=BATCH_SIZE, shuffle=True)

torch.cuda.empty_cache()

In [None]:
# Plot some training images
real_batch = next(iter(class_data_loaders[1]))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
plt.show()

In [None]:
def model_init():
    # Model initialization
    model = UNet2DModel(
        sample_size=IMG_SIZE,  # Set to 64
        in_channels=3,
        out_channels=3,
        layers_per_block=2,
        block_out_channels=(64, 64, 128, 128, 256, 256),
        down_block_types=(
            "DownBlock2D",
            "DownBlock2D",
            "DownBlock2D",
            "DownBlock2D",
            "AttnDownBlock2D",
            "DownBlock2D"
        ),
        up_block_types=(
            "UpBlock2D",
            "AttnUpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
            "UpBlock2D"
        )
    )
    return model

In [None]:
def sample_image_generation(model, noise_scheduler, num_generate_images, random_seed, num_timesteps):
    pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
    
    images = pipeline(
        batch_size=num_generate_images,
        generator=torch.manual_seed(random_seed),
        num_inference_steps=num_timesteps
    ).images
    
    fig = plt.figure()
    for i in range(1, num_generate_images + 1):
        fig.add_subplot(3, 3, i)
        plt.imshow(images[i-1])
    plt.show()

In [None]:
models = []
model_path = 'saved_models'
os.makedirs(model_path, exist_ok=True)
torch.cuda.empty_cache()

for class_idx in range(num_classes):#num_classes
   
    model = model_init().to(device)
    # Noise scheduler
    noise_scheduler = DDPMScheduler(num_train_timesteps=NUM_TIMESTEPS)
    timesteps = torch.LongTensor([50]).to(device)
    
    optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
    lr_scheduler = get_cosine_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=500,
        num_training_steps=len(class_data_loaders[class_idx]) * NUM_EPOCHS
    )

    # Accelerator setup
    accelerator = Accelerator(
        mixed_precision=MIXED_PRECISION,
        gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS
    )
    model, optimizer, class_data_loaders[class_idx], lr_scheduler = accelerator.prepare(model, optimizer, class_data_loaders[class_idx], lr_scheduler)
    
    # Training loop
    start = timeit.default_timer()
    for epoch in tqdm(range(NUM_EPOCHS), position=0, leave=True):
        model.train()
        train_running_loss = 0
        for idx, batch in enumerate(tqdm(class_data_loaders[class_idx], position=0, leave=True)):
            clean_images = batch[0].to(device)  # Assuming class_data_loaders[class_idx] returns (images, labels) tuples
            noise = torch.randn(clean_images.shape).to(device)
            last_batch_size = len(clean_images)
            
            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (last_batch_size,)).to(device)
            noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
            
            with accelerator.accumulate(model):
                noise_pred = model(noisy_images, timesteps, return_dict=False)[0]
                loss = F.mse_loss(noise_pred, noise)
                accelerator.backward(loss)
                
                accelerator.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
                
            train_running_loss += loss.item()
        train_loss = train_running_loss / (idx + 1)
        
        train_learning_rate = lr_scheduler.get_last_lr()[0]
        print("-" * 30)
        print(f"\rTrain Loss EPOCH: {epoch + 1}: {train_loss:.4f}", end="")
        print(f"\rTrain Learning Rate EPOCH: {epoch + 1}: {train_learning_rate}", end="")
        if epoch % 250 == 0:
            sample_image_generation(model, noise_scheduler, NUM_GENERATE_IMAGES, RANDOM_SEED, NUM_TIMESTEPS)
        print("-" * 30)

    stop = timeit.default_timer()
    print(f"Training Time: {stop - start:.2f}s")
    
    model_path_ = os.path.join(model_path, f"model_DDPM_{class_idx}.pth")
    torch.save(model.state_dict(), model_path_)