# Dataset

In [1]:
# dataset.py
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import glob
import torch


class ImageFolderDataset(Dataset):
    """Dataset for loading images from a folder."""

    def __init__(self, img_dir, img_size=256, transform=None):
        self.img_dir = img_dir
        self.img_paths = (
            glob.glob(os.path.join(img_dir, "*.jpg"))
            + glob.glob(os.path.join(img_dir, "*.png"))
            + glob.glob(os.path.join(img_dir, "*.jpeg"))
        )
        self.img_size = img_size

        if transform is None:
            self.transform = transforms.Compose(
                [
                    transforms.Resize((img_size, img_size)),
                    transforms.ToTensor(),
                    # Keep images in [0, 1] range for general use, VGG preprocessing applied later
                ]
            )
        else:
            self.transform = transform

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        try:
            img = Image.open(img_path).convert("RGB")
            img = self.transform(img)
        except Exception as e:
            print(
                f"Warning: Could not load image {img_path}. Error: {e}. Returning None."
            )
            # You might want to return a placeholder or skip this index in the DataLoader collate_fn
            return None  # Handle this in the DataLoader's collate_fn
        return img

    def insert_additional_folder(self, additional_folder):
        """Insert additional folder to the dataset."""
        additional_paths = (
            glob.glob(os.path.join(additional_folder, "*.jpg"))
            + glob.glob(os.path.join(additional_folder, "*.png"))
            + glob.glob(os.path.join(additional_folder, "*.jpeg"))
        )
        self.img_paths.extend(additional_paths)

    def shuffle(self):
        """Shuffle the dataset."""
        import random

        random.shuffle(self.img_paths)


def custom_collate_fn(batch):
    """Collate function that filters out None values."""
    batch = list(filter(lambda x: x is not None, batch))
    if not batch:
        return None  # Return None if the whole batch failed
    return torch.utils.data.dataloader.default_collate(batch)

# VGG

In [2]:
# vgg.py
import torch
import torch.nn as nn
from torchvision.models import vgg19, VGG19_Weights

VGG19_LAYER_MAP = {
    "relu1_1": 1,
    "relu2_1": 6,
    "relu3_1": 11,  # Layer used
    "relu4_1": 20,
    "relu5_1": 29,
}


class Vgg19FeatureExtractor(nn.Module):
    def __init__(self, layers_to_extract, weights=VGG19_Weights.DEFAULT):
        """
        Initializes the VGG-19 feature extractor.

        Args:
            layers_to_extract (list): List of layer names (e.g., ['relu3_1'])
                                     from which to extract features.
            weights (VGG19_Weights): Pretrained weights to use.
        """
        super().__init__()
        self.layers_to_extract = sorted(
            [VGG19_LAYER_MAP[name] for name in layers_to_extract]
        )
        self.last_layer_index = self.layers_to_extract[-1]

        vgg = vgg19(weights=weights).features
        self.model = nn.Sequential(*[vgg[i] for i in range(self.last_layer_index + 1)])

        # Freeze VGG parameters
        for param in self.model.parameters():
            param.requires_grad = False

    def forward(self, x):
        """
        Extracts features from the specified layers.

        Args:
            x (torch.Tensor): Input image tensor (B x C x H x W).

        Returns:
            dict: A dictionary where keys are layer indices and values
                  are the corresponding feature maps.
        """
        features = {}
        current_layer_idx = 0
        for i, layer in enumerate(self.model):
            x = layer(x)
            if i in self.layers_to_extract:
                features[i] = x
                current_layer_idx += 1
                if current_layer_idx >= len(self.layers_to_extract):
                    break  # No need to compute further
        return features

# Inverse net

In [3]:
# inverse_net.py
import torch
import torch.nn as nn


class InverseNetwork(nn.Module):
    def __init__(self, input_channels=256):  # Channels for relu3_1 of VGG19
        """
        Initializes the Inverse Network based on Appendix Table A2.
        Assumes input is from VGG19's relu3_1 (256 channels).
        """
        super().__init__()

        # Input: 1/4 H x 1/4 W x 256 (relu3_1 for 256x256 image)
        self.layers = nn.Sequential(
            # Conv-InstanceNorm-ReLU Block 1
            nn.ConvTranspose2d(input_channels, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),  # Changed to BatchNorm2d for consistency
            nn.PReLU(),
            # Upsampling + Conv-InstanceNorm-ReLU Block 2
            nn.Upsample(scale_factor=2, mode="nearest"),  # To 1/2 H x 1/2 W
            nn.ConvTranspose2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),  # Changed to BatchNorm2d for consistency
            nn.PReLU(),
            # Conv-InstanceNorm-ReLU Block 3
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),  # Changed to BatchNorm2d for consistency
            nn.PReLU(),
            # Upsampling + Conv-InstanceNorm-ReLU Block 4
            nn.Upsample(scale_factor=2, mode="nearest"),  # To H x W
            nn.ConvTranspose2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),  # Changed to BatchNorm2d for consistency
            nn.PReLU(),
            # Output Convolution
            nn.ConvTranspose2d(64, 3, kernel_size=3, stride=1, padding=1),
        )

    def forward(self, x):
        return self.layers(x)

    def init_weights(self):
        """
        Initialize weights of the network.
        """
        for m in self.modules():
            if isinstance(m, nn.ConvTranspose2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

# Style swap

In [4]:
# style_swap.py
import torch
import torch.nn as nn
import torch.nn.functional as F


def extract_patches(feature_map, patch_size=3, stride=1):
    """Extracts patches from a feature map.

    Args:
        feature_map (torch.Tensor): Input feature map (B x C x H x W).
        patch_size (int): Size of the square patches.
        stride (int): Stride for patch extraction.

    Returns:
        torch.Tensor: Extracted patches (B * n_patches_h * n_patches_w, C, patch_size, patch_size).
    """
    B, C, H, W = feature_map.shape
    # Use unfold to extract patches
    # unfold(dimension, size, step)
    patches = F.unfold(feature_map, kernel_size=patch_size, stride=stride)
    # patches shape: (B, C * patch_size * patch_size, n_patches_h * n_patches_w)
    patches = patches.permute(0, 2, 1).contiguous()
    # patches shape: (B, n_patches_h * n_patches_w, C * patch_size * patch_size)
    n_patches_total = patches.shape[1]
    patches = patches.view(B * n_patches_total, C, patch_size, patch_size)
    return patches


def style_swap_op(
    content_features, style_features, patch_size=3, stride=1, eps=1e-8, verbose=False
):
    """Performs the Style Swap operation.

    Args:
        content_features (torch.Tensor): Content feature map (B x C x H x W).
        style_features (torch.Tensor): Style feature map (B x C x H' x W').
        patch_size (int): Size of the patches.
        stride (int): Stride for patch matching and reconstruction.
        eps (float): Epsilon for numerical stability (e.g., in normalization).

    Returns:
        torch.Tensor: The resulting feature map after style swap (B x C x H x W).
    """
    device = content_features.device
    B_c, C_c, H_c, W_c = content_features.shape
    B_s, C_s, H_s, W_s = style_features.shape

    assert (
        C_c == C_s
    ), "Content and Style features must have the same number of channels."
    C = C_c

    # 1. Extract patches from style features
    # style_patches shape: (B_s * N_s, C, patch_size, patch_size) where N_s is num style patches
    if verbose:
        print(f"Extracting patches from style features...")
        print(f"Style features shape: {style_features.shape}")
        print(f"Content features: \n{content_features}")
        print(f"Style features shape: {style_features.shape}")
        print(f"Style features: \n{style_features}")

    style_patches = extract_patches(style_features, patch_size, stride)
    if verbose:
        print(f"Extracted style patches shape: {style_patches.shape}")
        print(f"Style patches: \n{style_patches}")

    # 2. Normalize style patches (for correlation calculation)
    # norm shape: (B_s * N_s, 1, 1, 1)
    style_patches_norm = torch.sqrt(
        torch.sum(style_patches**2, dim=(1, 2, 3), keepdim=True)
    )
    # Avoid division by zero for zero-patches
    style_patches_normalized = style_patches / (style_patches_norm + eps)

    # 3. Compute correlation using convolution
    # Use style patches as convolution filters
    # conv_filters shape: (n_style_patches, C, patch_size, patch_size)
    conv_filters = style_patches_normalized.to(device)
    # content_features shape: (B_c, C, H_c, W_c)
    # correlation_maps shape: (B_c, n_style_patches, H_out, W_out)
    correlation_maps = F.conv2d(
        content_features, conv_filters, stride=stride, padding=0
    )  # Using 0 padding

    if verbose:
        print(f"Correlation maps shape: {correlation_maps.shape}")
        print(f"Correlation maps: \n{correlation_maps}")

    # 4. Find the best matching style patch for each content patch (Channel-wise Argmax)
    # best_match_indices shape: (B_c, H_out, W_out)
    best_match_indices = torch.argmax(correlation_maps, dim=1)

    if verbose:
        print(f"Best match indices shape: {best_match_indices.shape}")
        print(f"Best match indices: \n{best_match_indices}")

    # 5. Create one-hot selection map
    # one_hot_map shape: (B_c, n_style_patches, H_out, W_out)
    H_out, W_out = best_match_indices.shape[1], best_match_indices.shape[2]
    one_hot_map = torch.zeros_like(correlation_maps, device=device)

    if verbose:
        print(f"One-hot map shape before scatter: {one_hot_map.shape}")
        print(f"One-hot map: \n{one_hot_map}")
    # Use scatter_ to place 1s at the argmax indices
    # scatter_(dimension, index_tensor, value)
    # index needs to be same shape as output after indexing dim -> add channel dim
    one_hot_map.scatter_(1, best_match_indices.unsqueeze(1), 1.0)
    if verbose:
        print(f"One-hot map shape after scatter: {one_hot_map.shape}")
        print(f"One-hot map: \n{one_hot_map}")

    # 6. Reconstruct using transposed convolution
    # Use original (unnormalized) style patches as filters
    # recon_filters shape: (n_style_patches, C, patch_size, patch_size)
    recon_filters = style_patches.to(device)
    # output_padding adjusts output size, often needed if stride > 1
    # calculate required output padding if needed, or ensure input sizes work well
    # For stride=1, output_padding is typically 0
    # swapped_features_sum shape: (B_c, C, H_rec, W_rec) -> should approximate H_c, W_c
    swapped_features_sum = F.conv_transpose2d(
        one_hot_map, recon_filters, stride=stride, padding=0
    )

    if verbose:
        print(f"Swapped features sum shape: {swapped_features_sum.shape}")
        print(f"Swapped features sum: \n{swapped_features_sum}")

    # 7. Normalize for overlapping patches
    # Create filters of ones for counting overlaps
    # count_filters shape: (1, 1, patch_size, patch_size)
    count_filters = torch.ones(1, 1, patch_size, patch_size, device=device)
    # Count contributions per pixel
    # one_hot_map shape: (B_c, n_style_patches, H_out, W_out)
    # Reduce one_hot_map along the patch dimension before counting
    # reduced_one_hot shape: (B_c, 1, H_out, W_out)
    reduced_one_hot = torch.sum(one_hot_map, dim=1, keepdim=True)

    if verbose:
        print(f"Reduced one-hot map shape: {reduced_one_hot.shape}")
        print(f"Reduced one-hot map: \n{reduced_one_hot}")
    # overlap_count shape: (B_c, 1, H_rec, W_rec)
    overlap_count = F.conv_transpose2d(
        reduced_one_hot, count_filters, stride=stride, padding=0
    )

    if verbose:
        print(f"Overlap count shape: {overlap_count.shape}")
        print(f"Overlap count: \n{overlap_count}")

    # Average the contributions
    # Add eps to avoid division by zero where there's no patch contribution
    swapped_features_avg = swapped_features_sum / (overlap_count + eps)

    if verbose:
        print(f"Swapped features average shape: {swapped_features_avg.shape}")
        print(f"Swapped features average: \n{swapped_features_avg}")

    return swapped_features_avg

# Loss

In [5]:
# loss.py
import torch
import torch.nn as nn
import torch.nn.functional as F


class TVLoss(nn.Module):
    """Total Variation Loss"""

    def __init__(self, weight=1.0):
        super().__init__()
        self.weight = weight

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): Input image tensor (B x C x H x W).
        Returns:
            torch.Tensor: Scalar TV loss value.
        """
        batch_size = x.size(0)
        h_x = x.size(2)
        w_x = x.size(3)
        count_h = (h_x - 1) * w_x
        count_w = h_x * (w_x - 1)
        h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :-1, :]), 2).sum()
        w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :-1]), 2).sum()
        loss = self.weight * (h_tv / count_h + w_tv / count_w) / batch_size
        return loss

# Utils

In [6]:
# utils.py
import torch
from torchvision import transforms
from PIL import Image
import os

# VGG preprocessing values
VGG_MEAN = [0.485, 0.456, 0.406]
VGG_STD = [0.229, 0.224, 0.225]


def get_device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")


def load_image(image_path, img_size=None):
    """Loads an image and optionally resizes it."""
    img = Image.open(image_path).convert("RGB")
    if img_size is not None:
        img = img.resize((img_size, img_size), Image.LANCZOS)
    return img


def preprocess_image(img, device):
    """Preprocesses a PIL image for VGG input."""
    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(mean=VGG_MEAN, std=VGG_STD),
        ]
    )
    return transform(img).unsqueeze(0).to(device)


def postprocess_image(tensor):
    """Postprocesses a tensor back to a PIL image."""
    # Ensure tensor is on CPU
    tensor = tensor.squeeze(0).cpu().detach().clone()
    # Denormalize
    mean = torch.tensor(VGG_MEAN).view(3, 1, 1)
    std = torch.tensor(VGG_STD).view(3, 1, 1)
    tensor = tensor * std + mean
    # Clamp values to [0, 1]
    tensor = torch.clamp(tensor, 0, 1)
    # Convert to PIL Image
    img = transforms.ToPILImage()(tensor)
    return img


def save_image(pil_img, save_path):
    """Saves a PIL image."""
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    pil_img.save(save_path)

# Training Loop

In [7]:
CONTENT_DIR = "/kaggle/input/coco-wikiart-nst-dataset-512-100000/content"
STYLE_DIR = "/kaggle/input/coco-wikiart-nst-dataset-512-100000/style"
CHECKPOINT_DIR = "/kaggle/working/checkpoints/"
LAYER = "relu3_1"
EPOCHS = 5
BATCH_SIZE = 16
IMG_SIZE = 512
LR = 0.001
LR_DECAY = 0.0001
LAMBDA_TV = 1e-6
LAMBDA_CONTENT = 1.0
LAMBDA_PIXEL = 1.4
LOSS_TYPE = "mse"
NUM_WORKERS = 4
SAVE_INTERVAL = 5000
LOG_INTERVAL = 100
MAX_CHECKPOINTS_NUM = 10
RESUME = None

In [8]:
# pretrain_invnet.py
import torch
import torch.nn as nn
from PIL import Image
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms  # Needed for VGG preprocessing
from tqdm import tqdm
import os

# --- Setup ---
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
device = get_device()
print(f"Using device: {device}")

# VGG Preprocessing Transformation (Applied during training loop)
vgg_preprocess = transforms.Normalize(mean=VGG_MEAN, std=VGG_STD)

# Target VGG layer index
try:
    target_layer_idx = VGG19_LAYER_MAP[LAYER]
except KeyError:
    print(
        f"Error: Invalid VGG layer name '{LAYER}'. Choose from {list(VGG19_LAYER_MAP.keys())}"
    )
    exit(1)

# Models
print("Loading models...")
vgg_extractor = Vgg19FeatureExtractor([LAYER]).to(device).eval()
# Freeze VGG (already done in class definition)

# Determine input channels based on chosen VGG layer
if LAYER == "relu3_1":
    inv_net_channels = 256
elif LAYER == "relu4_1":
    inv_net_channels = 512
else:
    inv_net_channels = 256  # Default guess

base_inverse_net = InverseNetwork(input_channels=inv_net_channels).to(device)
base_inverse_net.init_weights()
base_inverse_net.to(device)

if torch.cuda.is_available() and torch.cuda.device_count() > 1:
    print(f"Let's use {torch.cuda.device_count()} GPUs!")
    inverse_net = nn.DataParallel(base_inverse_net)
else:
    inverse_net = base_inverse_net # Use the base model directly if not using DataParallel

# Optimizer
optimizer = optim.Adam(
    inverse_net.parameters(), lr=LR, weight_decay=LR_DECAY
)

# Loss Functions
if LOSS_TYPE == "mse":
    pixel_loss_fn = nn.MSELoss().to(device)
else:  # l1
    pixel_loss_fn = nn.L1Loss().to(device)

tv_loss_fn = TVLoss(weight=LAMBDA_TV).to(device)

# Dataset and Dataloader
print("Loading dataset...")
# Dataset transform expects images in [0, 1] range
dataset_transform = transforms.Compose(
    [
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
    ]
)
train_dataset = ImageFolderDataset(
    CONTENT_DIR, img_size=IMG_SIZE, transform=dataset_transform
)

train_dataset.insert_additional_folder(STYLE_DIR)
train_dataset.shuffle()  # Shuffle dataset

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    collate_fn=custom_collate_fn,
    drop_last=True,
)  # drop_last helps if filtering causes variable batch sizes

# --- Resume training if checkpoint provided ---
start_epoch = 0
global_step = 0
if RESUME:
    if os.path.isfile(RESUME):
        print(f"=> Loading checkpoint '{RESUME}'")
        try:
            # Use weights_only=False as we saved optimizer state and args
            checkpoint = torch.load(
                RESUME, map_location=device, weights_only=False
            )
            start_epoch = checkpoint.get(
                "epoch", 0
            )  # Get epoch, default to 0 if missing
            global_step = checkpoint.get("global_step", 0)  # Get step, default to 0
            inverse_net.load_state_dict(checkpoint["model_state_dict"])
            optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
            # Restore learning rate from optimizer state if needed, or keep args.lr
            # Example: optimizer.param_groups[0]['lr'] = checkpoint.get('learning_rate', args.lr)
            print(
                f"=> Loaded checkpoint '{RESUME}' (epoch {start_epoch}, step {global_step})"
            )
        except TypeError as e:
            if "unexpected keyword argument 'weights_only'" in str(e):
                print(
                    "Warning: PyTorch version might be too old for 'weights_only'. Loading without it."
                )
                checkpoint = torch.load(RESUME, map_location=device)
                start_epoch = checkpoint.get("epoch", 0)
                global_step = checkpoint.get("global_step", 0)
                inverse_net.load_state_dict(checkpoint["model_state_dict"])
                optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
                print(
                    f"=> Loaded checkpoint '{RESUME}' (epoch {start_epoch}, step {global_step})"
                )
            else:
                print(f"Error loading checkpoint: {e}")
                exit(1)  # Exit if it's another TypeError
        except Exception as e:
            print(f"Error loading checkpoint: {e}")
            # Decide if you want to exit or continue training from scratch
            print("Could not load checkpoint, starting training from scratch.")
            start_epoch = 0
            global_step = 0
    else:
        print(
            f"=> No checkpoint found at '{RESUME}', starting training from scratch."
        )

        inverse_net.init_weights()  # Initialize weights if starting from scratch

# --- Pre-training Loop ---
print("Starting Pre-training...")
# We resume from the epoch *after* the saved one
for epoch in range(start_epoch, EPOCHS):
    inverse_net.train()  # Set model to training mode
    progress_bar = tqdm(
        enumerate(train_loader),
        total=len(train_loader),
        desc=f"Epoch {epoch+1}/{EPOCHS}",
    )

    for batch_idx, original_images in progress_bar:
        # Handle potential None batches from loader/collate_fn
        if original_images is None:
            # print(f"Warning: Skipping empty batch at index {batch_idx}.") # Can be noisy
            continue

        original_images = original_images.to(device)

        # --- Forward Pass ---
        images_vgg = vgg_preprocess(original_images)
        with torch.no_grad():
            features = vgg_extractor(images_vgg)[target_layer_idx]
        reconstructed_images = inverse_net(features)

        # --- Loss Calculation ---

        clamped_reconstructed_images = torch.clamp(reconstructed_images, 0, 1)

        reconstructed_images_vgg = vgg_preprocess(clamped_reconstructed_images)

        loss_content = LAMBDA_CONTENT * pixel_loss_fn(
            reconstructed_images_vgg, images_vgg
        )
        loss_tv = tv_loss_fn(reconstructed_images)
        loss_pixel = LAMBDA_PIXEL * pixel_loss_fn(
            clamped_reconstructed_images, original_images
        )
        total_loss = loss_pixel + loss_tv + loss_content

        # --- Backward Pass & Optimization ---
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        # --- Logging ---
        if global_step % LOG_INTERVAL == 0:
            progress_bar.set_postfix(
                {
                    "Loss": f"{total_loss.item():.4f}",
                    "Pixel": f"{loss_pixel.item():.4f}",
                    "Content": f"{loss_content.item():.4f}",
                    "TV": f"{loss_tv.item():.4f}",
                }
            )

        # --- Checkpointing ---
        # Also save at the end of each epoch? Optional.
        if global_step > 0 and global_step % SAVE_INTERVAL == 0:
            ckpt_path = os.path.join(
                CHECKPOINT_DIR, f"invnet_pretrain_step_{global_step}.pth"
            )

            list_of_checkpoints = sorted(
                glob.glob(os.path.join(CHECKPOINT_DIR, "invnet_pretrain_step_*.pth"))
            )
            if len(list_of_checkpoints) > MAX_CHECKPOINTS_NUM:
                # Remove the oldest checkpoint
                os.remove(list_of_checkpoints[0])
                print(f"Removed oldest checkpoint: {list_of_checkpoints[0]}")
            
            torch.save(
                {
                    "epoch": epoch,  # Save the *current* epoch number
                    "global_step": global_step,
                    "model_state_dict": inverse_net.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "args": {
                        "layer": LAYER,
                        "epochs": EPOCHS,
                        "batch_size": BATCH_SIZE,
                        "img_size": IMG_SIZE,
                        "lr": LR,
                        "lr_decay": LR_DECAY,
                        "lambda_tv": LAMBDA_TV,
                        "lambda_content": LAMBDA_CONTENT,
                        "lambda_pixel": LAMBDA_PIXEL,
                        "loss_type": LOSS_TYPE,
                    },
                },
                ckpt_path,
            )

            global_step += 1

            INSPECT_DIR = "/kaggle/working/inspect_pretrain/"
            VALIDATE_IMG = "/kaggle/input/coco-wikiart-nst-dataset-512-100000/content/000000000291.jpg"

            img = Image.open(VALIDATE_IMG).convert("RGB")

            transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize(mean=VGG_MEAN, std=VGG_STD),
                ]
            )

            img = transform(img).unsqueeze(0).to(device)  # Add batch dimension

            with torch.no_grad():
                features = vgg_extractor(img)[target_layer_idx]
                reconstructed_images = inverse_net(features)
                clamped_reconstructed_images = torch.clamp(reconstructed_images, 0, 1)

            if clamped_reconstructed_images.ndim == 4:
                last_batch_image = clamped_reconstructed_images[
                    0
                ]  # Get the first image of the last batch
            elif clamped_reconstructed_images.ndim == 3:
                last_batch_image = clamped_reconstructed_images

            os.makedirs(INSPECT_DIR, exist_ok=True)
            save_path = os.path.join(INSPECT_DIR, f"invnet_last_batch_epoch_{epoch+1}.png")
            # Convert to PIL and save
            to_pil = transforms.ToPILImage()
            last_img_pil = to_pil(last_batch_image.cpu())
            last_img_pil.save(save_path)
            print(f"Last batch of generated images saved to {save_path}")

# --- Save Final Model ---
final_path = os.path.join(CHECKPOINT_DIR, "invnet_pretrain_final.pth")
torch.save(
    {
        "epoch": EPOCHS,  # Save final epoch count
        "global_step": global_step,
        "model_state_dict": inverse_net.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "args": {
            "layer": LAYER,
            "epochs": EPOCHS,
            "batch_size": BATCH_SIZE,
            "img_size": IMG_SIZE,
            "lr": LR,
            "lr_decay": LR_DECAY,
            "lambda_tv": LAMBDA_TV,
            "lambda_content": LAMBDA_CONTENT,
            "lambda_pixel": LAMBDA_PIXEL,
            "loss_type": LOSS_TYPE,
        },
    },
    final_path,
)
print(f"\nPre-training finished. Final model saved to {final_path}")


Using device: cuda
Loading models...


Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:02<00:00, 230MB/s]


Let's use 2 GPUs!
Loading dataset...
Starting Pre-training...


Epoch 1/5: 100%|██████████| 6247/6247 [1:07:20<00:00,  1.55it/s, Loss=0.0694, Pixel=0.0046, Content=0.0647, TV=0.0000]
Epoch 2/5: 100%|██████████| 6247/6247 [1:07:28<00:00,  1.54it/s, Loss=0.0902, Pixel=0.0060, Content=0.0842, TV=0.0000]
Epoch 3/5: 100%|██████████| 6247/6247 [1:07:27<00:00,  1.54it/s, Loss=0.0428, Pixel=0.0029, Content=0.0399, TV=0.0000]
Epoch 4/5: 100%|██████████| 6247/6247 [1:07:27<00:00,  1.54it/s, Loss=0.0678, Pixel=0.0045, Content=0.0633, TV=0.0000]
Epoch 5/5: 100%|██████████| 6247/6247 [1:07:26<00:00,  1.54it/s, Loss=0.0509, Pixel=0.0034, Content=0.0475, TV=0.0000]


Pre-training finished. Final model saved to /kaggle/working/checkpoints/invnet_pretrain_final.pth



