In [2]:
import torch
import multiprocessing
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn as nn

from utils.get_img_paths import get_img_paths
from utils.tile_dataset import TileDataset
from utils.generator import UNetGenerator
from utils.discriminator import PatchGANDiscriminator
from utils.training import train_conditional_gan
from utils.weights_init import weights_init

In [3]:
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
    n_gpu = float(torch.cuda.device_count())
    device_name = torch.cuda.get_device_name(DEVICE)
elif torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
    device_name = "Apple Silicon"
    n_gpu = 0.0
else:
    DEVICE = torch.device("cpu")
    device_name = "CPU"
    n_gpu = 0.0
    
torch.manual_seed(0)

n_cores = multiprocessing.cpu_count()
print(f"Number of GPUs: {n_gpu} / Number of CPU Cores: {n_cores}")
print(f"Training on {device_name} ({DEVICE})")

Number of GPUs: 1.0 / Number of CPU Cores: 24
Training on NVIDIA GeForce RTX 4090 (cuda)


In [4]:
MASK_DIR = "./tile/ground_truth"
ANOMALY_TYPES_TO_TRAIN = ["crack", "glue_strip"]
IMG_SIZE = 256 
NUM_CHANNELS = 3     # Image channels (RGB)
NUM_CLASSES = len(ANOMALY_TYPES_TO_TRAIN) # Number of defect conditions
EMBED_SIZE = 16      # Size of the condition embedding vector (in G and D)
NGF = 64             # Base number of features for Generator
NDF = 64             # Base number of features for Discriminator
BATCH_SIZE = 32      # Adjust based on VRAM
EPOCHS = 300         # Number of training epochs
LR_G = 0.0002        # Learning rate for Generator
LR_D = 0.0002        # Learning rate for Discriminator
BETA1 = 0.5          # Adam optimizer beta1
BETA2 = 0.999        # Adam optimizer beta2
LAMBDA_L1 = 50.0
NUM_WORKERS = round(n_cores*0.7)

CHECKPOINT_DIR = "./checkpoints_cgan_tile" # Directory to save model checkpoints
SAMPLE_DIR = "./samples_cgan_tile"         # Directory to save generated image samples
SAVE_CHECKPOINT_FREQ = 30             # Save checkpoints every N epochs
SAVE_SAMPLES_FREQ = 10

In [5]:
img_paths = {
    "crack": [], "glue_strip": [], "gray_stroke": [],
    "oil": [], "rough": [], "good": []
}

img_paths = get_img_paths(paths=img_paths, subfolder="train")
img_paths = get_img_paths(paths=img_paths, subfolder="test")

In [6]:
conditional_dataset = TileDataset(
    image_paths_dict=img_paths,
    mask_base_dir=MASK_DIR,
    anomaly_types_for_training=ANOMALY_TYPES_TO_TRAIN,
    image_size=IMG_SIZE
)

conditional_dataloader = DataLoader(
    conditional_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    drop_last=True
)

Processing category: crack
  -> Added 17 samples for anomaly type 'crack'.
Processing category: glue_strip
  -> Added 18 samples for anomaly type 'glue_strip'.
Processing category: good
  -> Added 263 'good' samples.
-> Total 298 samples prepared for the dataset.


In [7]:
generator = UNetGenerator(num_classes=NUM_CLASSES, embed_size=16).to(DEVICE)
discriminator = PatchGANDiscriminator(num_classes=NUM_CLASSES, embed_size=16).to(DEVICE)

generator.apply(weights_init)
discriminator.apply(weights_init)

optimizer_G = optim.Adam(generator.parameters(), lr=LR_G, betas=(BETA1, BETA2))
optimizer_D = optim.Adam(discriminator.parameters(), lr=LR_D, betas=(BETA1, BETA2))

adversarial_loss = nn.BCEWithLogitsLoss()
reconstruction_loss = nn.L1Loss()

In [8]:
fixed_batch_for_vis = None
try:
    fixed_batch_for_vis = next(iter(conditional_dataloader))
    print(f"Using fixed batch of size {fixed_batch_for_vis[0].size(0)} for visualization.")
except Exception as e:
    print(f"Could not get fixed batch for visualization: {e}")

Using fixed batch of size 32 for visualization.


In [9]:
train_conditional_gan(
    generator=generator,
    discriminator=discriminator,
    dataloader=conditional_dataloader,
    optimizer_G=optimizer_G,
    optimizer_D=optimizer_D,
    adversarial_loss=adversarial_loss,
    reconstruction_loss=reconstruction_loss,
    lambda_l1=LAMBDA_L1,
    epochs=EPOCHS,
    device=DEVICE,
    fixed_batch_for_vis=fixed_batch_for_vis,
    checkpoint_dir=CHECKPOINT_DIR,
    sample_dir=SAMPLE_DIR,
    save_checkpoint_freq=SAVE_CHECKPOINT_FREQ,
    save_samples_freq=SAVE_SAMPLES_FREQ
)


Starting Conditional GAN Training for 300 epochs on cuda...
Using fixed batch of size 32 for visualization.


Total Training Progress:   0%|          | 0/2700 [00:00<?, ?it/s]

Training Finished after 300 epochs.
