In [None]:
## Import required libraries
import os
from google.colab import drive
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torchvision.transforms as transforms
from PIL import Image
from sklearn.model_selection import train_test_split

# Mount Google Drive
drive.mount('/content/drive')

# Install Kaggle API
!pip install kaggle

# Create Kaggle directory
!mkdir -p ~/.kaggle

# Create kaggle.json (REPLACE WITH YOUR CREDENTIALS)
!echo '{"username":"<your_kaggle_username>","key":"<your_31_letter_kaggel_key>"}' > ~/.kaggle/kaggle.json

# Set permissions
!chmod 600 ~/.kaggle/kaggle.json

print(os.path.expanduser("~/.kaggle/kaggle.json"))

# Download Anime Face Dataset
!kaggle datasets download -d splcher/animefacedataset -p /content/drive/<your_google_drive_name>/

# Unzip dataset
!unzip /content/drive/<your_google_drive_name>/animefacedataset.zip -d /content/drive/<your_google_drive_name>/anime_dataset

# Used for cleanup if needed, good practice
import shutil



In [None]:
# Define the target size required by the ViT model
VIT_INPUT_SIZE = 224

class AnimeDatasetPreprocessor:
    """
    Prepares an anime image dataset by selecting a sample, resizing images
    to the ViT's required input size (224x224), and splitting into
    train, validation, and test sets.
    """
    
    #--------------------------------------------------------------------------------------------------------------------|
    #                                                                                                                    |
    # Change the sample_Size, Start with Smaller no. like 500 then gradually increase after every successfull training.  |
    #                                                                                                                    |
    #--------------------------------------------------------------------------------------------------------------------|

    def __init__(self, input_dir, output_dir, sample_size=1000, target_size=VIT_INPUT_SIZE):
        self.input_dir = input_dir
        self.output_dir = output_dir
        self.sample_size = sample_size
        self.target_size = target_size # Store the target size (224, 224)

        if not os.path.isdir(input_dir):
             raise FileNotFoundError(f"Input directory not found: {input_dir}")
        if not os.path.isdir(os.path.join(input_dir, 'images')):
             raise FileNotFoundError(f"Subdirectory 'images' not found in: {input_dir}")

        # Create output directories safely
        self.train_dir = os.path.join(output_dir, 'train')
        self.test_dir = os.path.join(output_dir, 'test')
        self.val_dir = os.path.join(output_dir, 'validation')
        os.makedirs(self.train_dir, exist_ok=True)
        os.makedirs(self.test_dir, exist_ok=True)
        os.makedirs(self.val_dir, exist_ok=True)

    def prepare_dataset(self):
        """
        Loads images, selects a sample, splits, resizes, and saves them.
        """
        images_subdir = os.path.join(self.input_dir, 'images')
        try:
            image_files = [
                f for f in os.listdir(images_subdir)
                if f.lower().endswith(('.png', '.jpg', '.jpeg'))
            ]
        except FileNotFoundError:
            print(f"Error: 'images' subdirectory not found at {images_subdir}")
            return

        if not image_files:
            print(f"No images found in {images_subdir}")
            return

        print(f"Total images found: {len(image_files)}")

        # Ensure sample size is not larger than available images
        actual_sample_size = min(self.sample_size, len(image_files))
        if actual_sample_size < self.sample_size:
            print(f"Warning: Requested sample size {self.sample_size} > images available {len(image_files)}. Using {actual_sample_size}.")

        # Randomly select images
        selected_images = np.random.choice(
            image_files,
            size=actual_sample_size,
            replace=False
        )
        print(f"Selected {len(selected_images)} images for processing.")

        if len(selected_images) < 3: # Need at least 3 images for train/val/test split
             print("Error: Not enough selected images to create train/validation/test splits.")
             return

        # Split dataset (e.g., 70% train, 15% validation, 15% test)
        train_images, temp_images = train_test_split(
            selected_images,
            test_size=0.3, # 30% left for val+test
            random_state=42
        )

        # Split the remainder into validation and test (50% of temp = 15% of total)
        val_images, test_images = train_test_split(
            temp_images,
            test_size=0.5, # 50% of temp is test
            random_state=42
        )

        # --- Preprocessing function ---
        def process_and_save_images(image_list, split_folder_path):
            count = 0
            for img_name in image_list:
                input_path = os.path.join(images_subdir, img_name)
                output_filename = os.path.splitext(img_name)[0] + ".png" # Consider saving as PNG
                output_path = os.path.join(split_folder_path, output_filename)

                try:
                    # Open, convert to RGB (important!), resize, and save
                    img = Image.open(input_path).convert('RGB')
                    # --- KEY CHANGE: Resize to the target size (224x224) ---
                    img_resized = img.resize((self.target_size, self.target_size), Image.LANCZOS)
                    # --- Save (consider PNG for consistency) ---
                    img_resized.save(output_path, format='PNG')
                    count += 1
                except FileNotFoundError:
                    print(f"Warning: Image file not found during processing: {input_path}")
                except Exception as e:
                    print(f"Error processing {img_name}: {e}")
            return count
        # --- End of Preprocessing function ---

        # Process and save images for each split
        print(f"\nProcessing Training images (saving to {self.train_dir})...")
        train_count = process_and_save_images(train_images, self.train_dir)
        print(f"Processed {train_count} training images.")

        print(f"\nProcessing Validation images (saving to {self.val_dir})...")
        val_count = process_and_save_images(val_images, self.val_dir)
        print(f"Processed {val_count} validation images.")

        print(f"\nProcessing Test images (saving to {self.test_dir})...")
        test_count = process_and_save_images(test_images, self.test_dir)
        print(f"Processed {test_count} test images.")

        # Final dataset stats
        print("\n--- Dataset Preparation Summary ---")
        print(f"Target image size: {self.target_size}x{self.target_size}")
        print(f"Total images selected: {len(selected_images)}")
        print(f"Training images saved: {train_count} (in {self.train_dir})")
        print(f"Validation images saved: {val_count} (in {self.val_dir})")
        print(f"Testing images saved: {test_count} (in {self.test_dir})")
        print("-----------------------------------\n")

# Example Usage (adjust paths as needed)
print("Running Preprocessor...")
preprocessor = AnimeDatasetPreprocessor(
    input_dir='/content/drive/<your_google_drive_name>/anime_dataset',                # Directory containing the 'images' subfolder
    output_dir='/content/drive/<your_google_drive_name>/processed_anime_dataset_224', # New output directory for 224x224 images
    sample_size=1000                                                 # Number of images to sample, change to 500 and start again (caution: don't run locally unless you have minimum RTX 4090)
)                                                                    # with RAM more than 16GB and C-Drive space(minimum) 1TB. 
print(os.path.expanduser("~/.content/drive/<your_google_drive_name>/processed_anime_dataset_224"))
preprocessor.prepare_dataset()
print("Preprocessor finished.")


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision.models import vit_b_16, ViT_B_16_Weights # Example ViT


In [None]:
from torch.utils.data import DataLoader
from torchvision import transforms, datasets # Make sure datasets is imported
import matplotlib.pyplot as plt
import torchvision.utils as vutils

In [None]:
# ==============================================================================
#                            COMPLETE COLAB SCRIPT
# ==============================================================================

# ------------------------------------------
# 1. IMPORTS
# ------------------------------------------
import glob
import time
from PIL import Image

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from torchvision import transforms
import torchvision.utils as vutils

from sklearn.model_selection import train_test_split # For preprocessor split

print(f"PyTorch Version: {torch.__version__}")
print(f"Torchvision Version: {torchvision.__version__}")


In [None]:
# ------------------------------------------
# 2. MOUNT GOOGLE DRIVE
# ------------------------------------------
print("Mounting Google Drive...")
try:
    drive.mount('/content/drive')
    print("Google Drive mounted successfully.")
except Exception as e:
    print(f"Error mounting Google Drive: {e}")
    print("Please ensure you have authorized Google Drive access.")
    # Depending on the workflow, you might want to exit here
    # raise SystemExit("Drive mounting failed.")


In [None]:
# ------------------------------------------
# 3. CONFIGURATION
# ------------------------------------------
# --- Paths ---
# !! ADJUST THESE PATHS !!
DRIVE_SAVE_DIR = '/content/drive/<your_google_drive_name>/Colab_Outputs/AnimeVAE' # Base directory on Drive for saving models/images
RAW_DATASET_DIR = '/content/drive/<your_google_drive_name>/anime_dataset/images' # Path on Drive to the folder containing the 'images' subfolder
# ------------------
PROCESSED_DATA_DIR_NAME = '/drive/<your_google_drive_name>/processed_anime_dataset_224' # Name for the folder where resized images will be stored (locally in Colab runtime)
PROCESSED_DATA_ROOT = f'./{PROCESSED_DATA_DIR_NAME}' # Local path for processed data

# --- Model & Data Parameters ---
VIT_INPUT_SIZE = 224        # Input size required by ViT-B/16
LATENT_DIM = 256            # Dimensionality of the VAE latent space
BETA = 1.0                  # Weight for the KL divergence term (Beta-VAE)

# --- Training Parameters ---
LEARNING_RATE = 0.0001      # Learning rate for the Adam optimizer (VAE might need smaller LR)
BATCH_SIZE = 32              # Adjust based on GPU memory (start lower if memory errors occur)
NUM_EPOCHS = 400             # Number of training epochs (as requested)
SAVE_EVERY_EPOCHS = 10       # How often to save the model weights to Drive
GENERATE_EVERY_EPOCHS = 10   # How often to generate and save sample images during training
# --- Preprocessing Parameters ---
SAMPLE_SIZE_PREPROCESSOR = 2000 # Max number of images to process from the raw dataset

# --- Generation Parameters ---
NUM_IMAGES_TO_GENERATE = 16 # Number of images to generate after training


# --- Create Save Directory on Drive ---
os.makedirs(DRIVE_SAVE_DIR, exist_ok=True)
print(f"Save directory on Drive: {DRIVE_SAVE_DIR}")

In [None]:


# --- 4.2 Custom Dataset Class ---
class ImageDatasetNoLabels(Dataset):
    """Loads images from a single folder (no class subdirectories)."""
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        patterns = ['*.png', '*.PNG', '*.jpg', '*.JPG', '*.jpeg', '*.JPEG']
        self.image_files = []
        for pattern in patterns:
             self.image_files.extend(glob.glob(os.path.join(root_dir, pattern)))
        self.image_files = sorted(self.image_files)

        if not self.image_files:
            # Raise error if the folder is empty after preprocessing claims success
            raise FileNotFoundError(f"CRITICAL: No image files found in the processed directory {root_dir}. Check preprocessing output.")
        print(f"Initialized Dataset: Found {len(self.image_files)} images in {root_dir}")

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Warning: Error loading image {img_path}: {e}. Returning dummy tensor.")
            # Return a dummy tensor of the correct size if loading fails
            return torch.zeros((3, VIT_INPUT_SIZE, VIT_INPUT_SIZE))

        if self.transform:
            image = self.transform(image)

        # Return only the image tensor (DataLoader handles batching)
        return image


In [None]:
def visualize_and_save_samples(samples_tensor, epoch_num, save_dir, prefix="generated"):
    """Denormalizes, visualizes, and saves a grid of generated images."""
    if samples_tensor is None or samples_tensor.shape[0] == 0:
        print(f"[{prefix} Epoch {epoch_num}] No samples provided for visualization.")
        return

    print(f"[{prefix} Epoch {epoch_num}] Visualizing and saving {samples_tensor.shape[0]} samples...")
    try:
        # Ensure samples are on CPU
        samples_tensor = samples_tensor.cpu()

        # Denormalize images (using ImageNet stats assumed in DataLoader)
        inv_normalize = transforms.Normalize(
           mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
           std=[1/0.229, 1/0.224, 1/0.225]
        )
        samples_denorm = torch.stack([inv_normalize(img) for img in samples_tensor])
        samples_denorm = torch.clamp(samples_denorm, 0, 1) # Ensure valid range [0, 1]

        # Create grid
        grid = vutils.make_grid(samples_denorm,
                                padding=2,
                                normalize=False, # Already in [0,1]
                                nrow=int(samples_denorm.shape[0]**0.5)) # Make grid square-ish

        # Plotting
        plt.figure(figsize=(8, 8))
        plt.imshow(grid.permute(1, 2, 0).numpy()) # CHW -> HWC for matplotlib
        plt.title(f'{prefix.capitalize()} Anime VAE Images (Epoch {epoch_num})')
        plt.axis('off')
        plt.show()

        # Saving to Drive
        os.makedirs(save_dir, exist_ok=True) # Ensure directory exists
        save_image_path = os.path.join(save_dir, f'{prefix}_anime_vae_samples_epoch_{epoch_num}.png')
        vutils.save_image(samples_denorm, save_image_path, normalize=False)
        print(f"[{prefix} Epoch {epoch_num}] Image grid saved to: {save_image_path}")

    except Exception as e:
        print(f"\nError during visualization or saving of {prefix} images at epoch {epoch_num}: {e}")
        import traceback
        traceback.print_exc()

In [None]:
# Define the target size required by the ViT model
VIT_INPUT_SIZE = 224

# --- Dummy Classes (Replace with your actual implementations) ---
class VisionTransformerVAE(nn.Module):
    def __init__(self, latent_dim, vit_input_size):
        super().__init__()
        self.latent_dim = latent_dim
        self.input_size = vit_input_size
        # Dummy layers - replace with actual ViT Encoder + Decoder + Latent mapping
        self.encoder_dummy = nn.Linear(3 * vit_input_size * vit_input_size, 512)
        self.fc_mu = nn.Linear(512, latent_dim)
        self.fc_logvar = nn.Linear(512, latent_dim)
        self.decoder_input = nn.Linear(latent_dim, 512)
        self.decoder_dummy = nn.Linear(512, 3 * vit_input_size * vit_input_size)
        print(f"Placeholder VisionTransformerVAE initialized (Latent: {latent_dim}, Input: {vit_input_size})")

    def encode(self, x):
        # Flatten input for dummy linear layer
        x = x.view(x.size(0), -1)
        h = F.relu(self.encoder_dummy(x))
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = F.relu(self.decoder_input(z))
        recon = torch.sigmoid(self.decoder_dummy(h)) # Use sigmoid for [0,1] range
        # Reshape back to image format
        return recon.view(recon.size(0), 3, self.input_size, self.input_size)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

class ImageDatasetNoLabels(torch.utils.data.Dataset):
     def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        try:
            self.image_files = [os.path.join(root_dir, f) for f in os.listdir(root_dir)
                                if os.path.isfile(os.path.join(root_dir, f))
                                and f.lower().endswith(('.png', '.jpg', '.jpeg'))]
            if not self.image_files:
                raise FileNotFoundError(f"No image files found in {root_dir}")
        except FileNotFoundError as e:
            print(f"Error initializing Dataset: {e}")
            raise
        print(f"Placeholder ImageDatasetNoLabels found {len(self.image_files)} images in {root_dir}")

     def __len__(self):
        return len(self.image_files)

     def __getitem__(self, idx):
        img_path = self.image_files[idx]
        try:
            image = Image.open(img_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            return image
        except Exception as e:
             print(f"Warning: Error loading image {img_path}: {e}. Returning None.")
             # Returning None might cause issues in DataLoader if not handled properly.
             # A better approach might be to return a dummy tensor or skip the file.
             # For simplicity here, we rely on the dataloader's default collate to potentially raise error later.
             # Or, modify your dataloader's collate_fn to filter out Nones.
             # Let's return a dummy tensor of the expected size. Needs transform first.
             dummy_tensor = torch.zeros((3, VIT_INPUT_SIZE, VIT_INPUT_SIZE)) # Assuming VIT_INPUT_SIZE is accessible
             return dummy_tensor # Or handle differently in collate_fn
             

In [None]:
# ==============================================================================
#                      IMPORTS AND SETUP (Assumed)
# ==============================================================================
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import traceback

In [None]:
class AnimeGANTrainer:
    """Handles VAE training, checkpointing, saving, and generation."""

    def __init__(self, model, learning_rate, beta, device,
                 model_save_path, checkpoint_path, generate_every_epochs,
                 start_epoch=0, optimizer_state=None):

        self.device = device
        self.model = model.to(self.device)
        self.learning_rate = learning_rate
        self.beta = beta
        self.model_save_path = model_save_path # Base path for final .pth file
        self.checkpoint_path = checkpoint_path # Path for .pth checkpoint file
        self.generate_every_epochs = generate_every_epochs
        # Ensure results_save_dir exists even if model_save_path dir doesn't initially
        base_save_dir = os.path.dirname(model_save_path)
        if not base_save_dir: # Handle case where path is just a filename
             base_save_dir = "."
        self.results_save_dir = os.path.join(base_save_dir, "training_generations")
        os.makedirs(self.results_save_dir, exist_ok=True) # Create generation dir

        self.current_epoch = start_epoch # The last *completed* epoch

        # Define Loss (using reduction='sum' as common for VAEs, average later)
        self.reconstruction_loss_fn = nn.MSELoss(reduction='sum')

        # Initialize Optimizer
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)
        print(f"Optimizer initialized: Adam, LR: {self.learning_rate}")

        # Load Optimizer State if resuming from checkpoint
        if optimizer_state:
            print("Loading optimizer state from checkpoint...")
            try:
                self.optimizer.load_state_dict(optimizer_state)
                # Move optimizer state tensors to the correct device
                for state in self.optimizer.state.values():
                    for k, v in state.items():
                        if isinstance(v, torch.Tensor):
                            state[k] = v.to(self.device)
                print("Optimizer state loaded successfully.")
                print(f"  Optimizer LR after loading state: {self.optimizer.param_groups[0]['lr']}") # Log loaded LR
            except Exception as e:
                print(f"\nWarning: Could not load optimizer state: {e}")
                print("Optimizer will start from scratch.")
        else:
             print("No optimizer state provided, starting optimizer fresh.")


        print("-" * 30)
        print(f"Trainer Initialized:")
        print(f"  Device: {self.device}")
        print(f"  Start Epoch (Next to Run): {self.current_epoch + 1}") # Training starts at next epoch
        print(f"  Learning Rate: {self.optimizer.param_groups[0]['lr']}") # Current LR in optimizer
        print(f"  Beta (KL Weight): {self.beta}")
        print(f"  Checkpoint Path: {self.checkpoint_path}")
        print(f"  Final Model Save Path: {self.model_save_path}")
        print(f"  Generate Images Every: {self.generate_every_epochs} epochs")
        print(f"  Intermediate Images Dir: {self.results_save_dir}")
        print("-" * 30)


    def save_checkpoint(self, epoch_completed):
        """Saves a checkpoint including model, optimizer, epoch, and hyperparams."""
        if not self.checkpoint_path:
            print("Warning: Checkpoint path not set. Cannot save checkpoint.")
            return

        # Ensure the directory exists
        os.makedirs(os.path.dirname(self.checkpoint_path), exist_ok=True)

        checkpoint = {
            'epoch': epoch_completed, # Record the epoch that just finished
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'learning_rate': self.optimizer.param_groups[0]['lr'], # Save current LR from optimizer
            'beta': self.beta
        }
        try:
            # Use a temporary file and rename for atomicity (safer saving)
            temp_path = self.checkpoint_path + ".tmp"
            torch.save(checkpoint, temp_path)
            os.replace(temp_path, self.checkpoint_path) # Atomic rename
            print(f"Checkpoint saved successfully for epoch {epoch_completed} to {self.checkpoint_path}")
        except Exception as e:
            print(f"\nError saving checkpoint for epoch {epoch_completed}: {e}")
            if os.path.exists(temp_path):
                 os.remove(temp_path) # Clean up temp file on error


    def save_model(self, path):
        """Saves only the model's state_dict."""
        # Ensure the directory exists
        os.makedirs(os.path.dirname(path), exist_ok=True)
        try:
            print(f"\nSaving model state_dict to: {path}")
            # Use a temporary file and rename for atomicity
            temp_path = path + ".tmp"
            torch.save(self.model.state_dict(), temp_path)
            os.replace(temp_path, path)
            print("Model state_dict saved successfully.")
        except Exception as e:
            print(f"\nError saving model state_dict to {path}: {e}")
            if os.path.exists(temp_path):
                os.remove(temp_path)


    def calculate_vae_loss(self, reconstructed_x, original_x, mu, logvar):
        """Calculates the VAE loss components."""
        # Reconstruction Loss (sum over pixels and channels for each image)
        recon_loss = self.reconstruction_loss_fn(reconstructed_x, original_x)

        # KL Divergence (sum over latent dimensions for each image)
        kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1) # Sum over latent dim

        # Average losses over the batch
        batch_size = original_x.size(0)
        if batch_size == 0:
             return torch.tensor(0.0, device=self.device), torch.tensor(0.0), torch.tensor(0.0)

        avg_recon_loss_per_sample = recon_loss / batch_size
        avg_kl_div_per_sample = torch.sum(kl_div) / batch_size

        # Total loss per sample
        total_loss = avg_recon_loss_per_sample + self.beta * avg_kl_div_per_sample

        return total_loss, avg_recon_loss_per_sample, avg_kl_div_per_sample


    def train(self, train_loader, epochs_to_run, total_target_epochs, save_every):
        """Trains the model for a specified number of epochs with checkpointing."""
        if epochs_to_run <= 0:
            print("No epochs remaining to run based on start_epoch and total_target_epochs.")
            return

        print(f"\n--- Starting Training ---")
        print(f"Running for {epochs_to_run} epochs (Epoch {self.current_epoch + 1} to {total_target_epochs})")
        start_time_total = time.time()
        self.model.train()

        initial_epoch_completed = self.current_epoch

        for epoch_idx in range(epochs_to_run):
            self.current_epoch = initial_epoch_completed + 1 + epoch_idx
            epoch_start_time = time.time()

            total_epoch_loss = 0.0
            total_epoch_recon_loss = 0.0
            total_epoch_kl_loss = 0.0
            samples_processed_this_epoch = 0

            progress_bar = tqdm(enumerate(train_loader),
                                total=len(train_loader),
                                desc=f"Epoch {self.current_epoch}/{total_target_epochs}",
                                unit="batch",
                                leave=True)

            for i, batch_data in progress_bar:
                try:
                    images = batch_data.to(self.device, non_blocking=True)
                    batch_size = images.size(0)
                    if batch_size == 0: continue
                except Exception as e:
                     print(f"\nError moving batch {i} to device in epoch {self.current_epoch}: {e}")
                     continue

                if images.ndim != 4 or images.shape[1] != 3:
                    print(f"\nWarning: Unexpected batch tensor shape {images.shape} in epoch {self.current_epoch}, batch {i}. Skipping.")
                    continue

                self.optimizer.zero_grad(set_to_none=True)
                try:
                    reconstructed_images, mu, logvar = self.model(images)
                    loss, recon_loss, kl_div = self.calculate_vae_loss(reconstructed_images, images, mu, logvar)
                except Exception as e:
                     print(f"\nError during forward pass or loss calculation in epoch {self.current_epoch}, batch {i}: {e}")
                     traceback.print_exc()
                     print("Skipping batch.")
                     continue

                if not torch.isfinite(loss):
                    print(f"\nWarning: NaN or Inf loss detected in epoch {self.current_epoch}, batch {i}. Loss: {loss.item()}. Skipping batch.")
                    continue

                try:
                    loss.backward()
                    self.optimizer.step()
                except Exception as e:
                    print(f"\nError during backward pass or optimizer step in epoch {self.current_epoch}, batch {i}: {e}")
                    traceback.print_exc()
                    print("Skipping optimizer step for this batch.")
                    continue

                batch_loss_val = loss.item()
                batch_recon_loss_val = recon_loss.item()
                batch_kl_loss_val = kl_div.item()

                total_epoch_loss += batch_loss_val * batch_size
                total_epoch_recon_loss += batch_recon_loss_val * batch_size
                total_epoch_kl_loss += batch_kl_loss_val * batch_size
                samples_processed_this_epoch += batch_size

                if samples_processed_this_epoch > 0:
                    progress_bar.set_postfix({
                        'Loss': f"{total_epoch_loss / samples_processed_this_epoch:.4f}",
                        'Recon': f"{total_epoch_recon_loss / samples_processed_this_epoch:.4f}",
                        'KL': f"{total_epoch_kl_loss / samples_processed_this_epoch:.4f}"
                    })

            epoch_duration = time.time() - epoch_start_time

            if samples_processed_this_epoch > 0:
                 avg_loss = total_epoch_loss / samples_processed_this_epoch
                 avg_recon_loss = total_epoch_recon_loss / samples_processed_this_epoch
                 avg_kl_loss = total_epoch_kl_loss / samples_processed_this_epoch

                 print(f"Epoch [{self.current_epoch}/{total_target_epochs}] Summary"
                       f" | Avg Loss: {avg_loss:.4f}"
                       f" | Recon: {avg_recon_loss:.4f}"
                       f" | KL: {avg_kl_loss:.4f}"
                       f" | Time: {epoch_duration:.2f}s"
                       f" | LR: {self.optimizer.param_groups[0]['lr']:.1e}")
            else:
                 print(f"Epoch [{self.current_epoch}/{total_target_epochs}] | No samples processed successfully. Time: {epoch_duration:.2f}s")


            self.save_checkpoint(epoch_completed=self.current_epoch)

            is_last_epoch_overall = (self.current_epoch == total_target_epochs)
            if save_every > 0 and (self.current_epoch % save_every == 0 or is_last_epoch_overall):
                if not is_last_epoch_overall:
                    epoch_save_path = f"{os.path.splitext(self.model_save_path)[0]}_epoch_{self.current_epoch}.pth"
                    self.save_model(epoch_save_path)

            # --- Periodic Image Generation --- ## <<<< LINE CORRECTED HERE >>>> ##
            if self.generate_every_epochs > 0 and (self.current_epoch % self.generate_every_epochs == 0 or is_last_epoch_overall):
                 print(f"\n--- Generating samples at end of Epoch {self.current_epoch} ---")
                 generated_samples = self.generate_images(num_images=NUM_IMAGES_TO_GENERATE)
                 if generated_samples is not None:
                     visualize_and_save_samples(generated_samples, self.current_epoch, self.results_save_dir, prefix="intermediate")
                 self.model.train() # Switch back to train mode after generation


        total_training_time = time.time() - start_time_total
        print(f"\n--- Training Loop Finished ---")
        print(f"Completed epochs {initial_epoch_completed + 1} through {self.current_epoch}")
        print(f"Total Time for this run: {total_training_time // 60:.0f}m {total_training_time % 60:.0f}s")

        if self.current_epoch == total_target_epochs:
             print("\nSaving final model state (target epoch reached)...")
             self.save_model(self.model_save_path)
        else:
             print(f"\nTarget epoch {total_target_epochs} not reached (stopped at {self.current_epoch}). Final model not saved to '{self.model_save_path}'. Use the latest checkpoint.")


    def generate_images(self, num_images=16):
        """Generates images from random noise using the decoder."""
        if num_images <= 0:
            print("Number of images to generate must be positive.")
            return None
        print(f"Generating {num_images} images...")
        self.model.eval()
        generated_images = None
        try:
            latent_dim = self.model.latent_dim
            with torch.no_grad():
                z = torch.randn(num_images, latent_dim).to(self.device)
                generated_images = self.model.decode(z)
            print("Image generation complete.")
            return generated_images.cpu()
        except AttributeError as e:
             print(f"Error generating images: Model might be missing 'latent_dim' attribute or 'decode' method. {e}")
             return None
        except Exception as e:
             print(f"An unexpected error occurred during image generation: {e}")
             traceback.print_exc()
             return None


In [None]:
if __name__ == "__main__":

    # --- Step 5.1: Run Preprocessor ---
    print("\n--- Running Preprocessor ---")
    train_output_dir = os.path.join(PROCESSED_DATA_ROOT, 'train')
    if os.path.exists(train_output_dir) and len(os.listdir(train_output_dir)) > 0:
         print(f"Processed data found at '{train_output_dir}'. Skipping preprocessing.")
    else:
        print(f"Processed data not found or empty at '{train_output_dir}'. Running preprocessing...")
        try:
            preprocessor = AnimeDatasetPreprocessor(
                input_dir=RAW_DATASET_DIR,
                output_dir=PROCESSED_DATA_ROOT,
                sample_size=SAMPLE_SIZE_PREPROCESSOR,
                target_size=VIT_INPUT_SIZE
            )
            preprocessor.prepare_dataset()
            if not os.path.exists(train_output_dir) or len(os.listdir(train_output_dir)) == 0:
                 raise RuntimeError("Preprocessing finished, but output training directory is still empty or missing!")
            print("Preprocessing completed successfully.")
        except NameError:
            print("ERROR: AnimeDatasetPreprocessor class not found. Please define or import it.")
            sys.exit("Preprocessing definition failed.")
        except FileNotFoundError as e:
            print(f"\nCritical Error during preprocessing setup: {e}")
            print(f"Please ensure the RAW_DATASET_DIR ('{RAW_DATASET_DIR}') exists.")
            sys.exit("Preprocessing failed. Cannot continue.")
        except Exception as e:
            print(f"\nAn unexpected error occurred during preprocessing: {e}")
            traceback.print_exc()
            sys.exit("Preprocessing failed. Cannot continue.")
    print("--- Preprocessor Step Done ---")


    # --- Step 5.2: Create DataLoader ---
    print("\n--- Setting up DataLoader ---")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    try:
        data_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        train_dataset = ImageDatasetNoLabels(
            root_dir=train_output_dir,
            transform=data_transform
        )
        if len(train_dataset) == 0:
             raise RuntimeError(f"Dataset created but contains 0 images. Check path: {train_output_dir}")

        num_loader_workers = 2 if device.type == 'cuda' else 0
        train_loader = DataLoader(
            train_dataset, batch_size=BATCH_SIZE, shuffle=True,
            num_workers=num_loader_workers, pin_memory=(num_loader_workers > 0),
            persistent_workers=(num_loader_workers > 0), drop_last=True
        )
        print(f"DataLoader created for {len(train_dataset)} samples.")
        print(f"Batch size: {BATCH_SIZE}, Batches/epoch: {len(train_loader)}, Workers: {num_loader_workers}")

    except NameError:
        print("ERROR: ImageDatasetNoLabels class not found. Please define or import it.")
        sys.exit("DataLoader setup failed.")
    except FileNotFoundError as e:
         print(f"Error: Failed to create Dataset - Path not found: {e}")
         sys.exit("DataLoader setup failed (FileNotFound).")
    except RuntimeError as e:
         print(f"Error: Failed to create Dataset or DataLoader: {e}")
         sys.exit("DataLoader setup failed (RuntimeError).")
    except Exception as e:
         print(f"An unexpected error occurred during DataLoader setup: {e}")
         traceback.print_exc()
         sys.exit("DataLoader setup failed (Unexpected).")
    print("--- DataLoader Setup Done ---")


    # --- Step 5.3: Initialize Model and Trainer (with Resume Logic) ---
    print("\n--- Initializing Model and Trainer ---")
    try:
        checkpoint_dir = os.path.join(DRIVE_SAVE_DIR, 'checkpoints')
        os.makedirs(checkpoint_dir, exist_ok=True)
        checkpoint_path = os.path.join(checkpoint_dir, 'anime_vae_vit_checkpoint.pth')
        final_model_save_path = os.path.join(DRIVE_SAVE_DIR, f'anime_vae_vit_final_{NUM_EPOCHS}epochs.pth')
        print(f"Final model path: {final_model_save_path}")
        print(f"Checkpoint path: {checkpoint_path}")

        vae_model = VisionTransformerVAE(
            latent_dim=LATENT_DIM, vit_input_size=VIT_INPUT_SIZE
        )
        print(f"{type(vae_model).__name__} structure initialized.")

        start_epoch = 0
        optimizer_state_dict = None
        model_loaded_from_checkpoint = False

        if os.path.exists(checkpoint_path):
            print(f"\nCheckpoint found: '{checkpoint_path}'. Loading...")
            try:
                checkpoint = torch.load(checkpoint_path, map_location='cpu')
                print("Checkpoint dictionary loaded.")

                if 'model_state_dict' in checkpoint:
                    vae_model.load_state_dict(checkpoint['model_state_dict'])
                    print(" -> Model weights loaded.")
                    model_loaded_from_checkpoint = True
                else: print(" -> WARNING: 'model_state_dict' missing.")

                if 'optimizer_state_dict' in checkpoint:
                    optimizer_state_dict = checkpoint['optimizer_state_dict']
                    print(" -> Optimizer state dict found.")
                else: print(" -> WARNING: 'optimizer_state_dict' missing.")

                start_epoch = checkpoint.get('epoch', 0)
                print(f" -> Last completed epoch: {start_epoch}.")
                if 'epoch' not in checkpoint: print("    (Warning: 'epoch' key missing, using 0)")

                loaded_lr = checkpoint.get('learning_rate')
                if loaded_lr is not None: print(f" -> Checkpoint LR: {loaded_lr}")
                loaded_beta = checkpoint.get('beta')
                if loaded_beta is not None: print(f" -> Checkpoint Beta: {loaded_beta}")

            except Exception as e:
                print(f"\nERROR loading checkpoint: {e}. Starting fresh.")
                traceback.print_exc()
                start_epoch = 0; optimizer_state_dict = None; model_loaded_from_checkpoint = False
        else:
            print(f"\nNo checkpoint found. Starting fresh.")
            start_epoch = 0; optimizer_state_dict = None; model_loaded_from_checkpoint = False

        vae_model.to(device)
        print(f"\nModel moved to: {device}")

        actual_start_epoch_for_loop = start_epoch + 1
        epochs_remaining_to_run = max(0, NUM_EPOCHS - start_epoch)
        print("-" * 40)
        print(" Pre-Trainer Init Summary ".center(40, "-"))
        print(f"  Model Loaded: {'Yes' if model_loaded_from_checkpoint else 'No'}")
        print(f"  Optimizer State Found: {'Yes' if optimizer_state_dict is not None else 'No'}")
        print(f"  Last Completed Epoch: {start_epoch}")
        print(f"  Next Epoch to Run: {actual_start_epoch_for_loop}")
        print(f"  Epochs Remaining: {epochs_remaining_to_run}")
        print(f"  LR for Trainer: {LEARNING_RATE}")
        print("-" * 40)

        trainer = AnimeGANTrainer(
            model=vae_model, learning_rate=LEARNING_RATE, beta=BETA, device=device,
            model_save_path=final_model_save_path, checkpoint_path=checkpoint_path,
            generate_every_epochs=GENERATE_EVERY_EPOCHS, start_epoch=start_epoch,
            optimizer_state=optimizer_state_dict
        )

    except NameError as e:
        print(f"ERROR: Class definition missing (e.g., VisionTransformerVAE). Details: {e}")
        sys.exit("Initialization failed - Missing Class.")
    except Exception as e:
        print(f"\nCRITICAL ERROR during Model/Trainer Initialization: {e}")
        traceback.print_exc()
        sys.exit("Initialization failed.")
    print("--- Model and Trainer Initialized ---")


    # --- Step 5.4: Train the Model ---
    print("\n--- Verifying DataLoader for Training ---")
    try:
        num_batches = len(train_loader)
        if num_batches == 0: raise ValueError("train_loader is empty.")
        print(f"DataLoader OK: {len(train_dataset)} samples, {num_batches} batches.")
    except Exception as e:
        print(f"Error verifying DataLoader: {e}")
        sys.exit("ERROR: DataLoader verification failed.")

    print("\n--- Preparing for Model Training ---")
    try:
        epochs_already_done = trainer.current_epoch
        epochs_to_run = NUM_EPOCHS - epochs_already_done
        generate_final_images = False # Default

        if epochs_to_run <= 0:
            print(f"\nModel already trained for {epochs_already_done} epochs (Target: {NUM_EPOCHS}). Skipping training.")
            generate_final_images = True
        else:
             print(f"Starting training: Epoch {epochs_already_done + 1} -> {NUM_EPOCHS} ({epochs_to_run} epochs)")
             trainer.train(
                 train_loader, epochs_to_run=epochs_to_run,
                 total_target_epochs=NUM_EPOCHS, save_every=SAVE_EVERY_EPOCHS
             )
             print(f"\nTraining run finished. Last completed epoch: {trainer.current_epoch}.")
             generate_final_images = True

    except KeyboardInterrupt:
        print("\n--- Training Interrupted (KeyboardInterrupt) ---")
        print("Attempting checkpoint save...")
        trainer.save_checkpoint(epoch_completed=trainer.current_epoch)
        print(f"Training stopped. Last completed epoch: {trainer.current_epoch}.")
        generate_final_images = True # Allow generation after interrupt
    except Exception as e:
        print(f"\n--- An unexpected error occurred during training: {e} ---")
        traceback.print_exc()
        print("Attempting checkpoint save...")
        trainer.save_checkpoint(epoch_completed=trainer.current_epoch)
        print(f"Training halted. Last completed epoch: {trainer.current_epoch}.")
        generate_final_images = False # Skip generation on error
        sys.exit("Training halted due to error.")
    print("--- Model Training Step Finished ---")


    # --- Step 5.5: Generate Final Images ---
    if generate_final_images:
        print("\n--- Generating Final Sample Images ---")
        try:
            final_generated_samples = trainer.generate_images(num_images=NUM_IMAGES_TO_GENERATE)
            if final_generated_samples is not None:
                 print(f"Generated final samples tensor shape: {final_generated_samples.shape}")
            else: print("Image generation returned None.")
        except Exception as e:
            print(f"\nError during final image generation: {e}")
            final_generated_samples = None
        print("--- Final Image Generation Done ---")
    else:
        print("\n--- Skipping Final Image Generation ---")
        final_generated_samples = None


    # --- Step 5.6: Visualize and Save Final Generated Images ---
    if final_generated_samples is not None:
        print("\n--- Visualizing and Saving Final Generated Images ---")
        try:
            last_completed_epoch = trainer.current_epoch
            visualize_and_save_samples(
                final_generated_samples, epoch=last_completed_epoch,
                save_dir=DRIVE_SAVE_DIR, prefix=f"final_generated_epoch{last_completed_epoch}"
            )
            print(f"Saved final generated images for epoch {last_completed_epoch}.")
        except NameError:
            print("ERROR: visualize_and_save_samples function not found.")
        except Exception as e:
            print(f"Error during visualization/saving: {e}")
    else:
        print("\nSkipping final visualization/saving.")

    print("\n SCRIPT COMPLETE ".center(80, "="))