# Imports

In [None]:
from cgan import * 
import wandb
from datetime import datetime

# Set up paths

In [None]:
# Directories for your data
image_dir = "C:\\Users\\kamen\\Dev\\School\\H25\\IFT3710\\IFT3710-Advanced-Project-in-ML-AI\\data\\preprocessing_outputs\\unified_set\\images"
mask_dir = "C:\\Users\\kamen\\Dev\\School\\H25\\IFT3710\\IFT3710-Advanced-Project-in-ML-AI\\data\\preprocessing_outputs\\unified_set\\labels"

# Output directories
sample_dir = "big_unet"
checkpoint_dir = "checkpoints"
output_dir = "C:\\Users\\kamen\\Dev\\School\\H25\\IFT3710\\IFT3710-Advanced-Project-in-ML-AI\\data\\dataset_pix2pix\\new_samples_big_unet"

# Test mask for progress visualization during training
test_mask_path = "C:\\Users\\kamen\\Dev\\School\\H25\\IFT3710\\IFT3710-Advanced-Project-in-ML-AI\\src\\data_augmentation\\gans\\base_gan\\generated_samples\\sample_1_epoch_86.png"

# Launch training to train a cGAN that takes masks on input, and generates images

In [None]:
# Parameters
batch_size = 8      # This is fine for most GPUs
epochs = 1        # Increase this since your model is still improving
lr = 0.0001         # This lower learning rate is good
beta1 = 0.5         # Standard for GANs
beta2 = 0.999       # Standard value
lambda_L1 = 150 

# Initialize wandb before any training happens
wandb.login()  # You'll need to enter your API key on first run
wandb.init(
    project="cell-gan",  # Choose an appropriate project name
    name=f"gan-training-{datetime.now().strftime('%Y%m%d-%H%M%S')}",
    config={
        "architecture": "big_unet",
        "batch_size": batch_size,
        "epochs": epochs,
        "learning_rate": lr,
        "beta1": beta1,
        "beta2": beta2,
        "lambda_L1": lambda_L1,
        "input_nc": 1,
        "output_nc": 3,
        "ngf": 256,
        "use_dropout": True,
        "n_blocks": 1
    }
)

# The rest of your code stays the same until the train_gan function

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Data transformations
# For images: scale to [-1, 1]
image_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

# For masks: grayscale and scale to [-1, 1]
mask_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

# Create dataset
dataset = CellGANDataset(
    image_dir, mask_dir, 
    transform=image_transform, 
    mask_transform=mask_transform
)

# Split into train and validation sets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(
    dataset, [train_size, val_size]
)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Initialize ResNet generator
generator = get_generator(
    arch_type='small_unet',
    input_nc=1,          # mask channels
    output_nc=3,         # cell image
    ngf=256,             
    norm_layer=nn.InstanceNorm2d, 
    use_dropout=True,
    n_blocks=9           # Increase number of ResNet blocks for more parameters
)

# Initialize PatchGAN discriminator
discriminator = PatchGANDiscriminator(
    input_nc=4,          # 1 for mask + 3 for image
    ndf=64,
    n_layers=3,
    norm_layer=nn.BatchNorm2d
)

# Print model sizes
print(f"Generator Architecture: Big Unet")
print(f"Generator Parameters: {count_parameters(generator):,}")
print(f"Discriminator Parameters: {count_parameters(discriminator):,}")

# Create directories if they don't exist
os.makedirs(sample_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(output_dir, exist_ok=True)

# Train models
trained_generator, trained_discriminator, history = train_gan(
    generator, discriminator, train_loader, val_loader, device,
    epochs=epochs, lr=lr, beta1=beta1, beta2=beta2, lambda_L1=lambda_L1,
    sample_dir=sample_dir, checkpoint_dir=checkpoint_dir,
    test_mask_path=test_mask_path
)

# Plot training history
plot_training_history(history)

# Test inference on masks in a directory and save generated images
test_mask_dir = "C:\\Users\\kamen\\Dev\\School\\H25\\IFT3710\\IFT3710-Advanced-Project-in-ML-AI\\src\\data_augmentation\\gans\\base_gan\\generated_samples"

# Process each mask and save the generated image
process_mask_directory(
    trained_generator, 
    test_mask_dir, 
    output_dir, 
    device, 
    make_comparison=True
)