In [6]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from generator import GENERATOR
from discriminator import EfficientNetDiscriminator
from lossFunctions import discriminator_gan_loss, generator_gan_loss, l2_loss, yolo_class_loss
from train import get_dataloader, preprocess , denormalize
from ultralytics import YOLO

In [None]:
#download the dataset if it is not already downloaded
!curl -L "https://universe.roboflow.com/ds/glquRaJDf9?key=1kWinBXVTQ" > roboflow.zip; unzip roboflow.zip; rm roboflow.zip

In [7]:
data_yaml_path = "/home/salma/graduation_project/YOLO/yolov8/Self-Driving Cars.v1i.yolov8/data.yaml"
train_loader = get_dataloader(data_yaml_path, split='train', batch_size=16)

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_epochs = 10
generator = GENERATOR(in_channels=3, out_channels=3).to(device)
discriminator = EfficientNetDiscriminator().to(device)
yolo = YOLO("../yolov8n_TrafficSigns.pt").to(device).eval()  # Freeze YOLO

opt_g = torch.optim.Adam(generator.parameters(), lr=1e-4)
opt_d = torch.optim.Adam(discriminator.parameters(), lr=1e-4)

Loaded pretrained weights for efficientnet-b0


In [10]:
for epoch in range(num_epochs):
    generator.train()
    discriminator.train()
    
    for batch_idx, (clean_images, _) in enumerate(train_loader):
        clean_images = clean_images.to(device)
        
        # --- Discriminator Training ---
        opt_d.zero_grad()
        
        # Generate adversarial images
        adv_images = generator(clean_images)
        
        # Preprocess for discriminator
        preprocessed_real = preprocess(clean_images)
        preprocessed_fake = preprocess(adv_images.detach())
        
        # Discriminator forward pass
        d_real = discriminator(preprocessed_real)
        d_fake = discriminator(preprocessed_fake)
        
        # Compute and backprop discriminator loss
        d_loss = discriminator_gan_loss(d_real, d_fake)
        d_loss.backward()
        opt_d.step()
        
        # --- Generator Training ---
        opt_g.zero_grad()
        
        # Generator forward pass
        adv_images = generator(clean_images)
        preprocessed_fake = preprocess(adv_images)
        
        # Get discriminator predictions
        d_fake = discriminator(preprocessed_fake)
        
        # Calculate losses
        g_loss_gan = generator_gan_loss(d_fake)
        g_loss_l2 = l2_loss(clean_images, adv_images)
        
        # YOLO evaluation
        with torch.no_grad():
            yolo_input = denormalize(adv_images)  # Scale to [0, 1]
            yolo_results = yolo(yolo_input)
        
        g_loss_yolo = yolo_class_loss(yolo_results)
        
        # Total generator loss
        g_loss_total = g_loss_gan + 0.05*g_loss_l2 + 1.0*g_loss_yolo
        g_loss_total.backward()
        opt_g.step()

RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/salma/.local/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 351, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
  File "/home/salma/.local/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
    return self.collate_fn(data)
  File "/home/salma/.local/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 398, in default_collate
    return collate(batch, collate_fn_map=default_collate_fn_map)
  File "/home/salma/.local/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 211, in collate
    return [
  File "/home/salma/.local/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 212, in <listcomp>
    collate(samples, collate_fn_map=collate_fn_map)
  File "/home/salma/.local/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 171, in collate
    {
  File "/home/salma/.local/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 172, in <dictcomp>
    key: collate(
  File "/home/salma/.local/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 155, in collate
    return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
  File "/home/salma/.local/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 272, in collate_tensor_fn
    return torch.stack(batch, 0, out=out)
RuntimeError: stack expects each tensor to be equal size, but got [1, 4] at entry 0 and [2, 4] at entry 1
