In [None]:
import os

!pip uninstall -y tensorflow
!pip install -q "protobuf==3.20.3" bitsandbytes transformers diffusers accelerate sentencepiece

from huggingface_hub import login
from kaggle_secrets import UserSecretsClient
import wandb

try:
    user_secrets = UserSecretsClient()
    
    hf_token = user_secrets.get_secret("HF_TOKEN")
    login(token=hf_token)
    
    wandb_api = user_secrets.get_secret("WANDB_API_KEY")
    wandb.login(key=wandb_api)
    print("‚úÖ Logged in to HuggingFace and WandB")
except Exception as e:
    print(f"‚ö†Ô∏è Auth Warning: {e}")
    print("Ensure 'HF_TOKEN' and 'WANDB_API_KEY' are in Add-ons -> Secrets")

!git clone https://github.com/Particle1904/SingleStreamDiT_T5Gemma2
print("‚úÖ Repo Cloned")

repo_dir = "SingleStreamDiT_T5Gemma2"
output_dir = os.path.join(repo_dir, "output")

In [None]:
repo_name = "SingleStreamDiT_T5Gemma2"
if os.path.exists(repo_name):
    os.chdir(repo_name)
    print(f"üìÇ Changed directory to: {os.getcwd()}")
else:
    print("‚ùå Repo folder not found. Did Cell 1 run successfully?")

if os.path.exists(output_dir):
    shutil.rmtree(output_dir)
    print(f"üßπ Cleaned: Removed existing '{output_dir}' to start fresh.")
else:
    print("‚ÑπÔ∏è No output folder found in repo, starting with clean slate.")

expected_path = "/kaggle/input/oxfordflowers/cached_data"
if not os.path.exists(expected_path):
    print(f"‚ùå WARNING: Dataset not found at {expected_path}")
    print("Ensure your dataset slug is 'oxfordflowers' and it contains 'cached_data'.")
else:
    print("‚úÖ Dataset found.")

In [None]:
%%writefile config.py
import os
import torch
from accelerate.utils import set_seed
from accelerate import Accelerator, DistributedDataParallelKwargs

ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
_accelerator = Accelerator(log_with="wandb", kwargs_handlers=[ddp_kwargs])
set_seed(42)

class Config:
    is_kaggle = os.path.exists("/kaggle/working")
    
    if is_kaggle:
        output_dir = "/kaggle/working/output"
        cache_dir = "/kaggle/input/oxfordflowers/cached_data"
    else:
        cache_dir = "./cached_data"  
        output_dir = "./output"
    
    project_name = "flowers"
    dataset_dir = "./dataset"
    checkpoint_dir = os.path.join(output_dir, "checkpoints")
    samples_dir = os.path.join(output_dir, "samples")
    log_dir = os.path.join(output_dir, "logs")
    log_file = os.path.join(log_dir, f"{project_name}_log.csv")    
    target_file = os.path.join(cache_dir, "39.pt")        
    resume_from = None
    reset_optmizer = True
        
    text_embed_dim = 1152
    in_channels = 16    

    hidden_size = 768
    num_heads = 12
    depth = 16
    refiner_depth = 2
    max_token_length = 128
    patch_size = 2
    rope_base = 10_000
    
    vae_id = "diffusers/FLUX.1-vae"
    text_model_id = "google/t5gemma-2-1b-1b"
    
    target_resolution = 448
    bucket_alignment = 32
    vae_scaling_factor = 0.3611
    vae_downsample_factor= 8
    dataset_mean = 0.0
    dataset_std = 1.0

    learning_rate = 2e-4   
    epochs = 1200
    batch_size = 24
    accum_steps = 1
    loss_type = "mse"
    
    model_dropout = 0.05
    weight_decay = 0.05
    optimizer_warmup = 0.05
    offset_noise = 0.05
    text_dropout = 0.15
    flip_aug = False       
    
    shift_val = 1.0        
    
    use_self_eval = True
    start_self_eval_at = 0.90
    self_eval_lambda = 0.3
    
    fal_lambda = 0.05
    fcl_lambda = 0.05
    fourier_stack_depth = 2
    if is_kaggle:
        dtype = torch.float32
    else:
        dtype = torch.bfloat16
    gradient_checkpointing = True
    use_ema = True
    ema_decay = 0.999
    
    accelerator = _accelerator
    device = _accelerator.device
    load_entire_dataset = True
    num_workers = 2 if os.name != 'nt' else 0
    
    save_every = 100
    validate_every = 50
    validate_cfg = 3.00
    validate_steps = 30 
    validate_sampler = "euler"
    
    inference_steps = 50
    guidance_scale = 3.5
    sampler = "rk4"

In [None]:
import builtins
from config import Config

silencer_code = """
import builtins
from config import Config
# If this is a worker GPU, kill the print function
if not Config.accelerator.is_main_process:
    builtins.print = lambda *args, **kwargs: None
"""

if os.path.exists("train.py"):
    with open("train.py", "r") as f: content = f.read()
    if "builtins.print = lambda" not in content: # Avoid double injection
        with open("train.py", "w") as f: f.write(silencer_code + "\n" + content)
        print("‚úÖ Muted logs on Worker GPU (train.py)")

!sed -i 's/tqdm(self.files, desc="Loading Dataset")/tqdm(self.files, desc="Loading Dataset", disable=not Config.accelerator.is_main_process)/g' dataset.py
print("‚úÖ Fixed double progress bar (dataset.py)")

#!sed -i "s/if sys.platform.startswith('linux'):/if False: # Disabled for T4 stability/g" train.py
#print("‚úÖ Disabled torch.compile (Fixes T4/FFT crash)")
!sed -i 's/mode="max-autotune"/mode="reduce-overhead"/g' train.py
print("‚úÖ Set torch.compile to reduce-overhead mode")

print("‚úÖ Patched config.py: Enabled find_unused_parameters=True")

print("üöÄ Launching Accelerator...")
!accelerate launch --multi_gpu --num_processes=2 --mixed_precision=fp16 train.py