<a href="https://colab.research.google.com/github/Mariannly/EPIC_4/blob/main/Stage2/Challenge1/training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Training diffusion model

## Reading training data

In [14]:
#Mounting drive
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [15]:
!ls "/content/drive/MyDrive/EPIC_V/Stage2/training_data"

BC_Ch2_1.tif  data_preprocessing.ipynb	processed_target
BC_Ch2_2.tif  processed_source


In [16]:
source_path = "/content/drive/MyDrive/EPIC_V/Stage2/training_data/processed_source"
target_path = "/content/drive/MyDrive/EPIC_V/Stage2/training_data/processed_target"

In [17]:
#importing libreries
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import torch
from torch import nn
from diffusers import UNet2DModel, DDPMScheduler

In [18]:
#config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_epochs = 10
batch_size = 16
lr = 1e-4
image_size = 128
num_timesteps = 1000

In [19]:
#loading training data
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1), # Add this line to convert to grayscale
    transforms.ToTensor(),  # TIFF to torch tensor
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Dataset setup
xz_dataset = datasets.ImageFolder(root=source_path, transform=transform)
xy_dataset = datasets.ImageFolder(root=target_path, transform=transform)

xz_dataloader = DataLoader(xz_dataset, batch_size=batch_size, shuffle=True)
xy_dataloader = DataLoader(xy_dataset, batch_size=batch_size, shuffle=True)

## Set up Model

In [20]:
model = UNet2DModel(
    sample_size=image_size,
    in_channels=1,
    out_channels=1,
    layers_per_block=2,
    block_out_channels=(64, 128, 256),
    down_block_types=("DownBlock2D", "DownBlock2D", "DownBlock2D"),
    up_block_types=("UpBlock2D", "UpBlock2D", "UpBlock2D")
)

In [21]:
scheduler = DDPMScheduler(num_train_timesteps=num_timesteps)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
criterion = nn.MSELoss()

## Training

In [22]:
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0

    for step, (xz_batch, _) in enumerate(xz_dataloader):
        xz_batch = xz_batch.to(device)

        # 1️⃣ Sample a random timestep for each image in the batch
        timesteps = torch.randint(0, scheduler.num_train_timesteps, (xz_batch.size(0),), device=device).long()

        # 2️⃣ Add noise to the clean XZ images
        noise = torch.randn_like(xz_batch)
        noisy_imgs = scheduler.add_noise(xz_batch, noise, timesteps)

        # 3️⃣ Predict the noise using the model
        pred_noise = model(noisy_imgs, timesteps).sample  # .sample = output tensor

        # 4️⃣ Compute loss between predicted and true noise
        loss = criterion(pred_noise, noise)

        # 5️⃣ Backpropagation
        optimizer.zero_grad()  # 🔁 must go before loss.backward()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        if step % 50 == 0:
            print(f"[Epoch {epoch+1}/{num_epochs}] Step {step}, Loss: {loss.item():.4f}")

    avg_loss = total_loss / len(xz_dataloader)
    print(f"✅ Epoch {epoch+1} finished. Average Loss: {avg_loss:.4f}")




  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)


RuntimeError: Given groups=1, weight of size [64, 1, 3, 3], expected input[16, 3, 128, 128] to have 1 channels, but got 3 channels instead