In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms # Keep transforms
# No longer need: from torchvision import models
import torch.nn.functional as F
import cv2
import os
from tqdm.notebook import tqdm
import time
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
import lpips # Still need this for the loss function object

# <<<--- Import the model class from model.py --->>>
# Ensure model.py contains the ResNetVAE_V2 definition with increased decoder capacity
from models.resnet import ResNetVAE_V2

In [2]:
# --- Use the same CustomDataset definition as before ---
# --- Includes __init__ checks and __getitem__ checks/error handling ---
# --- Ensure output images are 320x320 and normalized to [-1, 1] ---

class CustomDataset(Dataset):
    def __init__(self, image_folder, target_size=(320, 320)):
        self.image_paths = []
        self.target_size = target_size
        print(f"Looking for images in: {image_folder}")
        try:
            if not os.path.isdir(image_folder):
                 raise FileNotFoundError(f"Dataset folder not found: {image_folder}")
            for img_name in os.listdir(image_folder):
                if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
                    self.image_paths.append(os.path.join(image_folder, img_name))
            if not self.image_paths: print(f"WARNING: No valid image files found in {image_folder}")
            else: print(f"Found {len(self.image_paths)} image files.")
        except Exception as e: print(f"ERROR initializing CustomDataset: {e}")

        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(self.target_size),
            transforms.ToTensor(), # Converts to [0, 1] range
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Maps to [-1, 1] range
        ])

    def __len__(self): return len(self.image_paths)

    def __getitem__(self, idx):
        if idx >= len(self.image_paths): raise IndexError("Index out of bounds")
        img_path = self.image_paths[idx]
        try:
            image = cv2.imread(img_path)
            if image is None: print(f"ERROR: cv2.imread failed for {img_path}. Returning None."); return None
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image = self.transform(image)
            return image
        except Exception as e: print(f"ERROR processing image {img_path} in __getitem__: {e}"); return None

def collate_fn_skip_none(batch):
    batch = [item for item in batch if item is not None]
    if not batch: return None
    try: return torch.utils.data.dataloader.default_collate(batch)
    except Exception as e: print(f"Error during collate: {e}. Batch items: {len(batch)}"); return None

In [None]:
# --- Configuration ---
LATENT_DIM = 512 # Keep at 512 due to VRAM limits observed
# <<<--- Single Learning Rate for training from scratch --->>>
LEARNING_RATE = 0.00015 # Start with a slightly lower rate for end-to-end training
WEIGHT_DECAY = 1e-5
BATCH_SIZE = 16
NUM_WORKERS = 0
KLD_WEIGHT = 0.0001 # Keep low
LPIPS_WEIGHT = 0.75 # Keep higher weight

# --- Dataset / DataLoader ---
dataset_path = r"C:\Users\Legion\Desktop\venv_wokplace\test_gpu_conda\frames_extracted"
train_dataset_resnet_v2 = CustomDataset(dataset_path)

if len(train_dataset_resnet_v2) > 0:
    train_loader_resnet_v2 = DataLoader(train_dataset_resnet_v2, batch_size=BATCH_SIZE, shuffle=True,
                                        num_workers=NUM_WORKERS, pin_memory=True, collate_fn=collate_fn_skip_none)
    print(f"DataLoader created. Batches per epoch: {len(train_loader_resnet_v2)}")
else:
    print("Dataset is empty. Cannot create DataLoader.")
    train_loader_resnet_v2 = None

# --- Model Setup ---
if train_loader_resnet_v2:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # <<<--- Instantiate the model using the imported class (ResNetVAE_V2) --->>>
    # Model weights will be initialized randomly (except for the ResNet encoder part,
    # which uses its pre-trained weights by default when loaded inside the class)
    model_v2 = ResNetVAE_V2(latent_dim=LATENT_DIM)
    model_v2.to(device)
    print("Instantiated ResNetVAE_V2 from model.py and moved to device.")
    print("Training from scratch (using pre-trained encoder weights, random decoder/latent weights).")

    # <<<--- REMOVED Weight Loading Block --->>>

    # --- LPIPS Loss Setup ---
    try:
        lpips_loss_fn = lpips.LPIPS(net='alex', verbose=False).to(device)
        for param in lpips_loss_fn.parameters(): param.requires_grad = False
        lpips_loss_fn.eval()
        print("LPIPS loss function created.")
    except Exception as e: print(f"ERROR setting up LPIPS: {e}."); lpips_loss_fn = None

    # --- Optimizer (Simplified to Single LR) ---
    # <<<--- Using single learning rate for all parameters --->>>
    optimizer = optim.AdamW(model_v2.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    print(f"Optimizer created with single LR: {LEARNING_RATE}")

    # Scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=10, verbose=True)
    print("Scheduler defined.")
else:
    print("Skipping Model/Optimizer/LPIPS setup.")

In [None]:
# Make sure variables from Cell 3 are used (e.g., train_loader_resnet_v2, model_v2)
if 'train_loader_resnet_v2' in locals() and train_loader_resnet_v2 is not None and \
   'model_v2' in locals() and 'lpips_loss_fn' in locals() and lpips_loss_fn is not None:

    # --- Training Params ---
    # <<<--- Start from epoch 0, reset best loss --->>>
    num_epochs = 200 # Train longer from scratch
    gradient_clip = 1.0
    early_stopping_patience = 15
    best_loss = float("inf") # Reset best loss for new run
    early_stop_counter = 0

    # --- Paths ---
    # <<<--- Using the same V2 directory, will overwrite previous results --->>>
    model_save_dir = r"C:\Users\Legion\Desktop\venv_wokplace\test_gpu_conda\model_saved_resnet_v2_finetune"
    os.makedirs(model_save_dir, exist_ok=True)
    best_model_path = os.path.join(model_save_dir, "resnet_v2_scratch_vae_best.pth") # New name
    final_model_path = os.path.join(model_save_dir, "resnet_v2_scratch_vae_final.pth") # New name

    print("\n--- Starting ResNetVAE_V2 Training Loop (From Scratch) ---")
    model_v2.train() # Set model to training mode

    # --- Training Loop (Logic identical to before) ---
    for epoch in range(num_epochs):
        epoch_loss, epoch_recon_loss, epoch_kld_loss, epoch_lpips_loss = 0, 0, 0, 0
        pbar = tqdm(train_loader_resnet_v2, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)

        for batch in pbar:
            if batch is None: continue
            images = batch.to(device)
            optimizer.zero_grad()

            try:
                reconstruction, original_input, mu, log_var = model_v2(images)
                loss_dict = model_v2.loss_function(reconstruction, original_input, mu, log_var,
                                                 M_N=KLD_WEIGHT,
                                                 lpips_model=lpips_loss_fn,
                                                 lpips_weight=LPIPS_WEIGHT)
                loss = loss_dict['loss']
                recon_loss_val = loss_dict['Reconstruction_Loss_L1'].item()
                kld_loss_val = loss_dict['KLD'].item()
                lpips_loss_val = loss_dict['Perceptual_Loss'].item()

                loss.backward()
                # torch.nn.utils.clip_grad_norm_(model_v2.parameters(), gradient_clip)
                optimizer.step()

                epoch_loss += loss.item(); epoch_recon_loss += recon_loss_val;
                epoch_kld_loss += kld_loss_val; epoch_lpips_loss += lpips_loss_val
                pbar.set_postfix(Loss=loss.item(), L1=recon_loss_val, KLD=kld_loss_val, LPIPS=lpips_loss_val)

            except Exception as e: print(f"\nERROR during training step: {e}"); continue

        num_batches = len(train_loader_resnet_v2)
        if num_batches > 0:
             avg_epoch_loss = epoch_loss / num_batches; avg_recon_loss = epoch_recon_loss / num_batches;
             avg_kld_loss = epoch_kld_loss / num_batches; avg_lpips_loss = epoch_lpips_loss / num_batches
             print(f"Epoch [{epoch+1}/{num_epochs}], Avg Loss: {avg_epoch_loss:.4f} | L1: {avg_recon_loss:.4f} | KLD: {avg_kld_loss:.4f} | LPIPS: {avg_lpips_loss:.4f}")

             scheduler.step(avg_epoch_loss)

             if avg_epoch_loss < best_loss:
                 best_loss = avg_epoch_loss; early_stop_counter = 0
                 torch.save(model_v2.state_dict(), best_model_path)
                 print(f"Model Improved & Saved to {best_model_path}!")
             else:
                 early_stop_counter += 1
                 print(f"No Improvement ({early_stop_counter}/{early_stopping_patience})")
                 if early_stop_counter >= early_stopping_patience:
                     print("Early Stopping Triggered! Training Stopped."); break
        else: print(f"Epoch [{epoch+1}/{num_epochs}] - DataLoader empty."); break

    print("\n--- Training Complete ---")
    # Save final model regardless
    torch.save(model_v2.state_dict(), final_model_path)
    print(f"Final Model Saved Successfully to {final_model_path}")

else: print("Skipping Training Loop.")

In [None]:
# --- Standard Visualization ---
# Needs to load the models saved by *this* training run

if 'model_v2' in locals() and 'train_dataset_resnet_v2' in locals() and len(train_dataset_resnet_v2) > 0:
    print("\n--- Visualizing ResNetVAE_V2 (Scratch Trained) Sample Reconstructions ---")

    model_class = ResNetVAE_V2 # Use imported class
    latent_dim_used = LATENT_DIM # Use LATENT_DIM from setup cell
    # <<<--- Point to the correct saved model path from this run --->>>
    model_load_dir = r"C:\Users\Legion\Desktop\venv_wokplace\test_gpu_conda\model_saved_resnet_v2_finetune"
    best_model_load_path = os.path.join(model_load_dir, "resnet_v2_scratch_vae_best.pth") # Use new name

    if 'device' not in locals(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Instantiate the SAME architecture used for training
    inference_model = model_class(latent_dim=latent_dim_used)

    try:
        # Load the weights saved from THIS training run
        print(f"Loading model state from: {best_model_load_path}")
        inference_model.load_state_dict(torch.load(best_model_load_path, map_location=device))
        inference_model.to(device)
        inference_model.eval()
        print("Inference model loaded successfully.")

        # --- Get Samples and Reconstruct (Same logic as before) ---
        vis_loader = DataLoader(train_dataset_resnet_v2, batch_size=3, shuffle=True, collate_fn=collate_fn_skip_none)
        # ... (rest of the visualization code identical to previous Cell 6) ...
        # Make sure plotting code uses 'inference_model'
        sample_batch = next(iter(vis_loader))
        if sample_batch is None: print("Could not get a valid batch for visualization.")
        else:
            # ... (copy paste rest of plotting logic from previous response) ...
            plt.show() # Make sure plt.show() is called

    except FileNotFoundError: print(f"Error: Saved model not found at {best_model_load_path}.")
    except StopIteration: print("Could not get batch from vis_loader.")
    except Exception as e: print(f"Could not visualize samples: {e}")
else: print("Skipping visualization.")

In [None]:
# --- Heatmap Visualization ---
# Needs to load the models saved by *this* training run

if 'inference_model' in locals() and inference_model is not None and \
   'train_dataset_resnet_v2' in locals() and len(train_dataset_resnet_v2) > 0:
    print("\n--- Visualizing Reconstruction Differences (Heatmap) ---")

    inference_model.eval() # Already loaded and in eval mode from previous cell ideally

    num_heatmap_samples = 3
    heatmap_loader = DataLoader(train_dataset_resnet_v2, batch_size=num_heatmap_samples, shuffle=True, collate_fn=collate_fn_skip_none)

    try:
        heatmap_batch = next(iter(heatmap_loader))
        if heatmap_batch is None: print("Could not get a valid batch for heatmap visualization.")
        else:
            # ... (rest of the heatmap visualization code identical to previous Cell 7) ...
            # Make sure plotting code uses 'inference_model'
            plt.show() # Make sure plt.show() is called

    except StopIteration: print("Could not get batch from heatmap_loader.")
    except Exception as e: print(f"Could not generate heatmap visualization: {e}")
else:
    print("Skipping Heatmap visualization as prerequisites not met (model not loaded?).")

In [None]:
# Only proceed if training ran and model exists
if 'model_v2' in locals() and 'train_dataset_resnet_v2' in locals() and len(train_dataset_resnet_v2) > 0:
    print("\n--- Visualizing ResNetVAE_V2 (Scratch Trained) Sample Reconstructions ---")

    model_class = ResNetVAE_V2 # Use imported class
    latent_dim_used = LATENT_DIM # Use LATENT_DIM from setup cell
    # <<<--- Point to the correct saved model path from this run --->>>
    model_load_dir = r"C:\Users\Legion\Desktop\venv_wokplace\test_gpu_conda\model_saved_resnet_v2_finetune"
    best_model_load_path = os.path.join(model_load_dir, "resnet_v2_scratch_vae_best.pth") # Use new name

    if 'device' not in locals(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Instantiate the SAME architecture used for training
    inference_model = model_class(latent_dim=latent_dim_used)

    try:
        print(f"Loading model state from: {best_model_path}")
        inference_model.load_state_dict(torch.load(best_model_path, map_location=device))
        inference_model.to(device)
        inference_model.eval()
        print("Inference model loaded successfully.")

        # --- Get Samples and Reconstruct (Same logic as before) ---
        vis_loader = DataLoader(train_dataset_resnet_v2, batch_size=3, shuffle=True, collate_fn=collate_fn_skip_none)
        sample_batch = next(iter(vis_loader))

        if sample_batch is None:
             print("Could not get a valid batch for visualization.")
        else:
            sample_images_cpu = sample_batch
            sample_images_gpu = sample_images_cpu.to(device)

            with torch.no_grad():
                reconstructed_gpu, _, _, _ = inference_model(sample_images_gpu)

            sample_images_np = sample_images_cpu.numpy()
            reconstructed_np = reconstructed_gpu.cpu().numpy()

            # --- Plotting (Same logic as before) ---
            num_images_to_show = sample_images_np.shape[0]
            fig, axes = plt.subplots(2, num_images_to_show, figsize=(5 * num_images_to_show, 10))
            fig.suptitle("Original vs Reconstructed (ResNetVAEV2 + LPIPS in 90 Epochs)", fontsize=16)
            if num_images_to_show == 1: axes = np.array([axes]).T

            for i in range(num_images_to_show):
                original_img = sample_images_np[i] * 0.5 + 0.5 # De-normalize
                reconstructed_img = reconstructed_np[i] * 0.5 + 0.5 # De-normalize
                axes[0, i].imshow(np.clip(np.transpose(original_img, (1, 2, 0)), 0, 1))
                axes[0, i].set_title(f"Original {i+1}")
                axes[0, i].axis('off')
                axes[1, i].imshow(np.clip(np.transpose(reconstructed_img, (1, 2, 0)), 0, 1))
                axes[1, i].set_title(f"Reconstructed {i+1}")
                axes[1, i].axis('off')

            plt.tight_layout(rect=[0, 0.03, 1, 0.95])
            plt.show()

    except FileNotFoundError: print(f"Error: Saved model not found at {best_model_path}.")
    except StopIteration: print("Could not get batch from vis_loader.")
    except Exception as e: print(f"Could not visualize samples: {e}")
else:
     print("Skipping visualization.")

In [None]:
# --- TESTING SCRIPT (Corrected) --- #

import os
import random
import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import cv2

# --- IMPORT THE MODEL --- #
# Make sure model.py has the ResNetVAE_V2 definition matching the checkpoint
from models.resnet import ResNetVAE_V2

# --- SETTINGS --- #
# <<<--- Make sure this points to the checkpoint from the V2 training run --->>>
model_path = r'C:\Users\Legion\Desktop\venv_wokplace\test_gpu_conda\model_saved_resnet_v2_finetune\resnet_v2_scratch_vae_best.pth' # Using _best usually preferred
images_dir = r'C:\Users\Legion\Desktop\venv_wokplace\test_gpu_conda\processed_frames_320'
output_dir = r'C:\Users\Legion\Desktop\venv_wokplace\test_gpu_conda\test_outputs_v2' # New output folder maybe
os.makedirs(output_dir, exist_ok=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# <<<--- Ensure this matches the latent dim used for the loaded model --->>>
latent_dim = 512
input_size = 320

# --- TRANSFORM (Corrected: Added Normalization) --- #
transform = transforms.Compose([
    transforms.Resize((input_size, input_size)),
    transforms.ToTensor(), # To [0, 1] range
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # To [-1, 1] range
])

# --- LOAD MODEL --- #
print(f"Loading model definition (latent_dim={latent_dim})...")
# Instantiate the model structure matching the checkpoint
model = ResNetVAE_V2(latent_dim=latent_dim, input_height=input_size).to(device)
print(f"Loading weights from: {model_path}")
try:
    # Use strict=True by default, it should work if model def and checkpoint match
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    print("Model loaded successfully!")
except FileNotFoundError:
    print(f"ERROR: Model checkpoint not found at {model_path}")
    exit() # Exit if model can't load
except RuntimeError as e:
    print(f"ERROR loading state_dict: {e}")
    print("This usually means the model architecture in model.py doesn't match the saved checkpoint.")
    print("Make sure you are loading the correct .pth file for the instantiated ResNetVAE_V2 architecture.")
    exit() # Exit if model can't load


# --- LOAD RANDOM IMAGES --- #
try:
    all_imgs = [f for f in os.listdir(images_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    if not all_imgs:
        print(f"ERROR: No images found in {images_dir}")
        exit()
    random_imgs = random.sample(all_imgs, min(10, len(all_imgs)))
    print(f"Processing {len(random_imgs)} random images...")
except FileNotFoundError:
    print(f"ERROR: Images directory not found: {images_dir}")
    exit()

for idx, img_name in enumerate(random_imgs):
    img_path = os.path.join(images_dir, img_name)
    try:
        # Load with PIL is fine for transforms
        orig_pil = Image.open(img_path).convert('RGB')

        # Apply the full transform (including normalization)
        input_tensor = transform(orig_pil).unsqueeze(0).to(device)

        # Run inference
        with torch.no_grad():
            recons, _, _, _ = model(input_tensor) # model expects [-1, 1], outputs [-1, 1]

        # --- Visualization Prep (Corrected De-normalization) --- #
        # Move to CPU and permute dimensions (B, C, H, W) -> (H, W, C)
        # Keep originals from input_tensor which is already [-1, 1]
        orig_np = input_tensor.squeeze().cpu().permute(1, 2, 0).numpy()
        recon_np = recons.squeeze().cpu().permute(1, 2, 0).numpy()

        # De-normalize from [-1, 1] back to [0, 1] for displaying/saving
        orig_vis = np.clip(orig_np * 0.5 + 0.5, 0, 1)
        recon_vis = np.clip(recon_np * 0.5 + 0.5, 0, 1)

        # Calculate difference on the [0, 1] scale images
        diff = np.abs(orig_vis - recon_vis)
        mean_diff_val = np.mean(diff) # Mean absolute error over all pixels/channels
        # Create heatmap from mean difference across channels
        heatmap_gray = np.mean(diff, axis=2) # Grayscale heatmap HxW
        # Normalize heatmap for better visualization if needed, or use raw values
        # heatmap_norm = (heatmap_gray - heatmap_gray.min()) / (heatmap_gray.max() - heatmap_gray.min() + 1e-6)
        heatmap_uint8 = np.uint8(255 * heatmap_gray) # Convert to uint8 for colormap
        heatmap_color = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
        # Convert heatmap color from BGR (OpenCV default) to RGB for Matplotlib
        heatmap_color_rgb = heatmap_color[..., ::-1]

        # --- Plotting --- #
        fig, axs = plt.subplots(1, 3, figsize=(15, 5))
        fig.suptitle(f"Image: {img_name}", fontsize=12)

        axs[0].imshow(orig_vis)
        axs[0].set_title('Original')

        axs[1].imshow(recon_vis)
        axs[1].set_title('Reconstruction')

        axs[2].imshow(heatmap_color_rgb)
        axs[2].set_title(f'Difference Heatmap\nMean Abs Err: {mean_diff_val:.4f}')

        for ax in axs:
            ax.axis('off')

        out_path = os.path.join(output_dir, f'result_{idx+1}_{img_name}.png')
        plt.savefig(out_path, bbox_inches='tight') # Use tight bbox
        plt.close(fig) # Close the figure to free memory
        # print(f" Saved: {out_path}") # Optional print

    except Exception as e:
        print(f"ERROR processing image {img_name}: {e}")

print(f"\n Done! Results saved in: {output_dir}")