In [None]:
!pip install -U typing_extensions --cache-dir /workspace/pip_cache

!pip show typing_extensions

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as TF
from PIL import Image
import os
import random
import glob
from torchvision import transforms
from compressai.models import CompressionModel
from compressai.entropy_models import EntropyBottleneck, GaussianConditional
from compressai.layers import GDN

import json
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision.models.optical_flow import raft_small, Raft_Small_Weights
from tqdm import tqdm


import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchvision.utils import flow_to_image

In [None]:
# UTILS and DATALOADER

class VimeoSeptupletDataset(Dataset):
    def __init__(self, root_dirs, split_files, transform=None):
        self.transform = transform
        self.sequences = []
        self.sequence_to_root = {}
        self.frames_per_sample = 5
        self.samples_per_seq = 7 - self.frames_per_sample + 1
        
        if isinstance(root_dirs, str):
            root_dirs = [root_dirs]
        if isinstance(split_files, str):
            split_files = [split_files]
        
        for root_dir, split_file in zip(root_dirs, split_files):
            if not os.path.isabs(split_file):
                split_file = os.path.join(root_dir, split_file)
            
            print(f"Loading split from: {split_file}")
            
            if os.path.exists(split_file):
                with open(split_file, 'r') as f:
                    count = 0
                    for line in f:
                        name = line.strip()
                        if name:
                            self.sequences.append(name)
                            self.sequence_to_root[name] = root_dir
                            count += 1
                print(f" Loaded {count} sequences")
            else:
                print(f" File not found {split_file}")
        
        print(f"Total Dataset: {len(self.sequences)} sequences")
    
    def __len__(self):
        return len(self.sequences) * self.samples_per_seq
    
    def _get_sequence_path(self, seq_name, root_dir):
        possible_paths = []
        
        if 'vimeo_settuplet_1' in root_dir or root_dir == "":
            possible_paths.append(os.path.join(root_dir, "vimeo_settuplet_1", "sequences", seq_name))
        elif 'vimeo_part2' in root_dir:
            possible_paths.append(os.path.join(root_dir, "vimeo_settuplet_2", "sequence", seq_name))
        else:
            possible_paths.append(os.path.join(root_dir, "sequence", seq_name))
        
        possible_paths.extend([
            os.path.join(root_dir, "sequences", seq_name),
            os.path.join(root_dir, "sequence", seq_name),
            os.path.join(root_dir, "vimeo_settuplet_1", "sequences", seq_name),
            os.path.join(root_dir, "vimeo_settuplet_2", "sequence", seq_name),
        ])
        
        for path in possible_paths:
            if os.path.exists(path):
                return path
        
        return possible_paths[0] if possible_paths else os.path.join(root_dir, "sequence", seq_name)
    
    def __getitem__(self, idx):
        seq_idx = idx // self.samples_per_seq
        inner_idx = idx % self.samples_per_seq
        seq_name = self.sequences[seq_idx]
        root_dir = self.sequence_to_root[seq_name]
        
        seq_path = self._get_sequence_path(seq_name, root_dir)
        
        if not os.path.exists(seq_path):
            raise FileNotFoundError(f"Sequence path not found: {seq_path}")
        
        start_frame = inner_idx + 1
        frames = []
        
        for i in range(self.frames_per_sample):
            frame_path = os.path.join(seq_path, f"im{start_frame+i}.png")
            
            if not os.path.exists(frame_path):
                raise FileNotFoundError(f"Frame not found: {frame_path}")
            
            img = Image.open(frame_path).convert("RGB")
            
            if self.transform:
                img = self.transform(img)
            else:
                img = transforms.ToTensor()(img)
            
            frames.append(img)
        
        frame_curr = frames[-1]
        frame_prev = frames[-2]
        history = [frames[-2], frames[-3], frames[-4], frames[-5]]
        history_frames = torch.cat(history, dim=0)
        
        return frame_curr, frame_prev, history_frames


def conv(in_channels, out_channels, kernel_size=5, stride=2):
    return nn.Conv2d(
        in_channels,
        out_channels,
        kernel_size=kernel_size,
        stride=stride,
        padding=kernel_size // 2,
    )

def deconv(in_channels, out_channels, kernel_size=5, stride=2):
    return nn.ConvTranspose2d(
        in_channels,
        out_channels,
        kernel_size=kernel_size,
        stride=stride,
        output_padding=stride - 1,
        padding=kernel_size // 2,
    )

class ResBlock(nn.Module):
    def __init__(self, c, dilation=1, use_gn=True):
        super().__init__()
        self.use_gn = use_gn
        self.conv1 = nn.Conv2d(c, c, 3, padding=dilation, dilation=dilation)
        self.conv2 = nn.Conv2d(c, c, 3, padding=1)
        if use_gn:
            self.gn1 = nn.GroupNorm(8, c)
            self.gn2 = nn.GroupNorm(8, c)
        self.act = nn.ReLU(inplace=True)

    def forward(self, x):
        y = self.conv1(x)
        if self.use_gn: y = self.gn1(y)
        y = self.act(y)
        y = self.conv2(y)
        if self.use_gn: y = self.gn2(y)
        return x + y

class SimpleResBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, 3, 1, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(channels, channels, 3, 1, 1)
        )
        
    def forward(self, x):
        return x + self.block(x)


# VAEs architectures:

class ScaleHyperprior(CompressionModel):
    def __init__(self, N, M, in_channels=2, out_channels=2, **kwargs):
        super().__init__(**kwargs)

        self.entropy_bottleneck = EntropyBottleneck(N)

        self.g_a = nn.Sequential(
            conv(in_channels, N),
            GDN(N),
            conv(N, N),
            GDN(N),
            conv(N, N),
            GDN(N),
            conv(N, M),
        )

        self.g_s = nn.Sequential(
            deconv(M, N),
            GDN(N, inverse=True),
            deconv(N, N),
            GDN(N, inverse=True),
            deconv(N, N),
            GDN(N, inverse=True),
            deconv(N, out_channels),
        )

        self.h_a = nn.Sequential(
            conv(M, N, stride=1, kernel_size=3),
            nn.ReLU(inplace=True),
            conv(N, N),
            nn.ReLU(inplace=True),
            conv(N, N),
        )
        
        self.h_s = nn.Sequential(
            deconv(N, N),
            nn.ReLU(inplace=True),
            deconv(N, N),
            nn.ReLU(inplace=True),
            conv(N, M, stride=1, kernel_size=3),
            nn.ReLU(inplace=True),
        )

        self.gaussian_conditional = GaussianConditional(None)
        self.N = int(N)
        self.M = int(M)

    @property
    def downsampling_factor(self) -> int:
        return 2 ** 4

    def forward(self, x):
        y = self.g_a(x)
        z = self.h_a(torch.abs(y))
        z_hat, z_likelihoods = self.entropy_bottleneck(z)
        scales_hat = self.h_s(z_hat)
        y_hat, y_likelihoods = self.gaussian_conditional(y, scales_hat)
        x_hat = self.g_s(y_hat)

        return {
            "x_hat": x_hat,
            "likelihoods": {"y": y_likelihoods, "z": z_likelihoods},
        }

    @classmethod
    def from_state_dict(cls, state_dict):
        """Return a new model instance from `state_dict`."""
        N = state_dict["g_a.0.weight"].size(0)
        M = state_dict["g_a.6.weight"].size(0)
        in_channels = state_dict["g_a.0.weight"].size(1) 
        out_channels = state_dict["g_s.6.weight"].size(0)
        
        net = cls(N, M, in_channels=in_channels, out_channels=out_channels)
        net.load_state_dict(state_dict)
        return net

    def compress(self, x):
        y = self.g_a(x)
        z = self.h_a(torch.abs(y))

        z_strings = self.entropy_bottleneck.compress(z)
        z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:])

        scales_hat = self.h_s(z_hat)
        indexes = self.gaussian_conditional.build_indexes(scales_hat)
        y_strings = self.gaussian_conditional.compress(y, indexes)
        return {"strings": [y_strings, z_strings], "shape": z.size()[-2:]}

    def decompress(self, strings, shape):
        assert isinstance(strings, list) and len(strings) == 2
        z_hat = self.entropy_bottleneck.decompress(strings[1], shape)
        scales_hat = self.h_s(z_hat)
        indexes = self.gaussian_conditional.build_indexes(scales_hat)
        y_hat = self.gaussian_conditional.decompress(strings[0], indexes, z_hat.dtype)
        
        x_hat = self.g_s(y_hat) 
        
        return {"x_hat": x_hat}

# Flow postprocessing net

class MotionRefineNET(nn.Module):
    def __init__(self, base=64, blocks=8, use_gn=True, use_gate=True):
        super().__init__()
        in_ch = 14 
        
        self.stem = nn.Sequential(nn.Conv2d(in_ch, base, 3, padding=1), nn.ReLU(inplace=True))
        
        body = []
        for i in range(blocks):
            dil = 1 if i < blocks - 2 else 2
            body.append(ResBlock(base, dilation=dil, use_gn=use_gn))
        self.body = nn.Sequential(*body)

        self.delta_head = nn.Conv2d(base, 2, 3, padding=1)
        nn.init.zeros_(self.delta_head.weight)
        nn.init.zeros_(self.delta_head.bias)

        self.use_gate = use_gate
        if use_gate:
            self.gate_head = nn.Conv2d(base, 1, 3, padding=1)
            nn.init.zeros_(self.gate_head.weight)
            nn.init.zeros_(self.gate_head.bias)

    def forward(self, flow_hat, history_4f):
        x = torch.cat([flow_hat, history_4f], dim=1)
        
        f = self.stem(x)
        f = self.body(f)
        
        delta = self.delta_head(f)
        
        if self.use_gate:
            gate = torch.sigmoid(self.gate_head(f))
            delta = gate * delta
            
        return flow_hat + delta

# Residual postprocessing NET

class ResRefiNET(nn.Module):
    def __init__(self, in_channels=3, mid_channels=64, num_blocks=6):
        super().__init__()
        self.head = nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1)
        self.body = nn.Sequential(*[ResBlock(mid_channels, use_gn=True) for _ in range(num_blocks)])
        self.tail = nn.Conv2d(mid_channels, in_channels, kernel_size=3, padding=1)
        nn.init.zeros_(self.tail.weight)
        nn.init.zeros_(self.tail.bias)

    def forward(self, x):
        identity = x
        out = self.head(x)
        out = self.body(out)
        correction = self.tail(out)
        return identity + correction

# Frame recostruction NET

class AdaptiveRefiNET(nn.Module):    
    def __init__(self, base=64, num_blocks=10):
        super().__init__()
        
        self.encoder = nn.Sequential(
            nn.Conv2d(19, base, 3, 1, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(base, base, 3, 1, 1),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        self.res_blocks = nn.ModuleList([
            SimpleResBlock(base) for _ in range(num_blocks)
        ])
        
        self.decoder = nn.Sequential(
            nn.Conv2d(base, base, 3, 1, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(base, 3, 3, 1, 1)
        )
        
    def forward(self, recon, warped, mask, history):
        x = torch.cat([recon, warped, mask, history], dim=1)
        
        feat = self.encoder(x)
        
        for block in self.res_blocks:
            feat = block(feat)
        
        correction = self.decoder(feat)
        refined = recon + correction * mask
        
        return refined.clamp(0, 1)

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

VIMEO_PARTS = [
    "",
    "vimeo_part2",
    "vimeo_part4",
    "vimeo_part9",
]

TRAIN_LISTS = [
    "vimeo_settuplet_1/sep_trainlist.txt",
    "vimeo_settuplet_2/sep_trainlist.txt",
    "sep_trainlist.txt",
    "sep_trainlist.txt",
]

TEST_LISTS = [
    "vimeo_settuplet_1/sep_testlist.txt",
    "vimeo_settuplet_2/sep_testlist.txt",
    "sep_testlist.txt",
    "sep_testlist.txt",
]

MVAE_CHECKPOINT = "FlowVAE_finetune_ep11.pth"
MREFINE_CHECKPOINT = "RefiFlow.pth"
RVAE_CHECKPOINT = "ResidualVAE_HardMode_Ep4.pth"
RREFINE_CHECKPOINT = "ResRefiNET.pth"
ADAPTIVE_CHECKPOINT = "AdaptiveNET.pth"  

SAVE_DIR = "checkpoints_adaptive"
BATCH_SIZE = 16
NUM_EPOCHS = 10      
START_EPOCH = 1       
LR = 1e-4
MAX_FLOW = 100.0



def flow_warp(x, flow):
    B, C, H, W = x.size()
    xx = torch.arange(0, W).view(1, -1).repeat(H, 1)
    yy = torch.arange(0, H).view(-1, 1).repeat(1, W)
    xx = xx.view(1, 1, H, W).repeat(B, 1, 1, 1)
    yy = yy.view(1, 1, H, W).repeat(B, 1, 1, 1)
    grid = torch.cat((xx, yy), 1).float().to(x.device)
    vgrid = grid + flow
    vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :] / max(W - 1, 1) - 1.0
    vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :] / max(H - 1, 1) - 1.0
    vgrid = vgrid.permute(0, 2, 3, 1)
    return F.grid_sample(x, vgrid, align_corners=True, mode='bilinear', padding_mode='border')


def normalize_flow(flow, max_flow):
    return flow.clamp(-max_flow, max_flow) / max_flow


def compute_adaptive_mask(residual, lambda_param=1.5, epsilon=1e-6):
    norm_per_pixel = torch.sqrt(torch.sum(residual ** 2, dim=1, keepdim=True))
    H, W = residual.shape[2], residual.shape[3]
    mu = torch.sum(norm_per_pixel, dim=(2, 3), keepdim=True) / (H * W)
    mask = torch.tanh(lambda_param * norm_per_pixel / (mu + epsilon))
    return mask


def fix_compressai_state_dict(state_dict):
    new_dict = {}
    for key, value in state_dict.items():
        new_key = key
        if "entropy_bottleneck.matrices." in key:
            new_key = key.replace("matrices.", "_matrix")
        elif "entropy_bottleneck.biases." in key:
            new_key = key.replace("biases.", "_bias")
        elif "entropy_bottleneck.factors." in key:
            new_key = key.replace("factors.", "_factor")
        new_dict[new_key] = value
    return new_dict



def train_epoch(mvae, mrefinement, rvae, rrefinement, adaptive_net, raft, raft_transforms,
                loader, optimizer, criterion, device):
    adaptive_net.train()
    total_loss = 0.0
    total_mse_inter = 0.0
    
    pbar = tqdm(loader, desc="Train", leave=False)
    
    for frame_curr, frame_prev, history_frames in pbar:
        frame_curr = frame_curr.to(device, non_blocking=True)
        frame_prev = frame_prev.to(device, non_blocking=True)
        history_frames = history_frames.to(device, non_blocking=True)
        
        with torch.no_grad():
            # Motion Branch
            img1, img2 = raft_transforms(frame_prev, frame_curr)
            flow_gt = raft(img1, img2)[-1]
            flow_gt_norm = normalize_flow(flow_gt, MAX_FLOW)
            
            mvae_out = mvae(flow_gt_norm)
            flow_coarse_norm = mvae_out["x_hat"]
            
            flow_refined_norm = mrefinement(flow_coarse_norm, history_frames)
            flow_refined_px = flow_refined_norm * MAX_FLOW
            warped_image = flow_warp(frame_prev, flow_refined_px)
            
            # Residual Branch
            residual_gt = frame_curr - warped_image
            
            rvae_out = rvae(residual_gt)
            residual_coarse = rvae_out["x_hat"]
            
            residual_refined = rrefinement(residual_coarse)
            
            # Intermediate reconstruction
            recon_intermediate = (warped_image + residual_refined).clamp(0, 1)
            mse_inter = criterion(recon_intermediate, frame_curr)
            
            # Compute mask
            mask = compute_adaptive_mask(residual_refined)
        
        # Adaptive Refinement (TRAINABLE)
        final_recon = adaptive_net(recon_intermediate, warped_image, mask, history_frames)
        
        # Loss
        loss = criterion(final_recon, frame_curr)
        
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(adaptive_net.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
        total_mse_inter += mse_inter.item()
        pbar.set_postfix({"mse_final": f"{loss.item():.6f}", "mse_inter": f"{mse_inter.item():.6f}"})
    
    avg_loss = total_loss / len(loader)
    avg_inter = total_mse_inter / len(loader)
    return avg_loss, avg_inter


def validate(mvae, mrefinement, rvae, rrefinement, adaptive_net, raft, raft_transforms,
             loader, criterion, device):
    adaptive_net.eval()
    total_loss = 0.0
    total_mse_intermediate = 0.0
    
    pbar = tqdm(loader, desc="Val", leave=False)
    
    for frame_curr, frame_prev, history_frames in pbar:
        frame_curr = frame_curr.to(device)
        frame_prev = frame_prev.to(device)
        history_frames = history_frames.to(device)
        
        with torch.no_grad():
            # Motion Branch
            img1, img2 = raft_transforms(frame_prev, frame_curr)
            flow_gt = raft(img1, img2)[-1]
            flow_gt_norm = normalize_flow(flow_gt, MAX_FLOW)
            
            mvae_out = mvae(flow_gt_norm)
            flow_coarse_norm = mvae_out["x_hat"]
            
            flow_refined_norm = mrefinement(flow_coarse_norm, history_frames)
            flow_refined_px = flow_refined_norm * MAX_FLOW
            warped_image = flow_warp(frame_prev, flow_refined_px)
            
            # Residual Branch
            residual_gt = frame_curr - warped_image
            
            rvae_out = rvae(residual_gt)
            residual_coarse = rvae_out["x_hat"]
            
            residual_refined = rrefinement(residual_coarse)
            
            # Intermediate reconstruction
            recon_intermediate = (warped_image + residual_refined).clamp(0, 1)
            mse_inter = criterion(recon_intermediate, frame_curr)
            
            # Compute mask
            mask = compute_adaptive_mask(residual_refined)
            
            # Adaptive Refinement
            final_recon = adaptive_net(recon_intermediate, warped_image, mask, history_frames)
            
            # Loss
            loss = criterion(final_recon, frame_curr)
            
            total_loss += loss.item()
            total_mse_intermediate += mse_inter.item()
    
    avg_mse_final = total_loss / len(loader)
    avg_mse_inter = total_mse_intermediate / len(loader)
    
    return avg_mse_final, avg_mse_inter


def main():
    os.makedirs(SAVE_DIR, exist_ok=True)
    
    
    # Motion VAE
    mvae = ScaleHyperprior(N=192, M=192, in_channels=2).to(DEVICE)
    mvae.load_state_dict(torch.load(MVAE_CHECKPOINT, map_location=DEVICE, weights_only=False)["model_state_dict"])
    mvae.eval()
    for p in mvae.parameters():
        p.requires_grad = False
    
    # Motion Refinement
    mrefinement = MotionRefineNET(base=64, blocks=8).to(DEVICE)
    checkpoint = torch.load(MREFINE_CHECKPOINT, map_location=DEVICE, weights_only=False)
    if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
        mrefinement.load_state_dict(checkpoint["model_state_dict"])
    else:
        mrefinement.load_state_dict(checkpoint)
    mrefinement.eval()
    for p in mrefinement.parameters():
        p.requires_grad = False
    
    # Residual VAE
    rvae = ScaleHyperprior(N=128, M=128, in_channels=3, out_channels=3).to(DEVICE)
    checkpoint = torch.load(RVAE_CHECKPOINT, map_location=DEVICE, weights_only=False)
    state_dict = checkpoint.get("model_state_dict", checkpoint)
    state_dict = fix_compressai_state_dict(state_dict)
    rvae.load_state_dict(state_dict)
    rvae.eval()
    for p in rvae.parameters():
        p.requires_grad = False
    
    # Residual Refinement
    rrefinement = ResRefiNET().to(DEVICE)
    checkpoint = torch.load(RREFINE_CHECKPOINT, map_location=DEVICE, weights_only=False)
    if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
        rrefinement.load_state_dict(checkpoint["model_state_dict"])
    else:
        rrefinement.load_state_dict(checkpoint)
    rrefinement.eval()
    for p in rrefinement.parameters():
        p.requires_grad = False
    
    # RAFT
    from torchvision.models.optical_flow import raft_small, Raft_Small_Weights
    raft = raft_small(weights=Raft_Small_Weights.DEFAULT, progress=False).to(DEVICE).eval()
    raft_transforms = Raft_Small_Weights.DEFAULT.transforms()
    
    print(f"\nLoading AdaptiveRefiNET from checkpoint: {ADAPTIVE_CHECKPOINT}")
    adaptive_net = AdaptiveRefiNET(base=64, num_blocks=10).to(DEVICE)
    
    if os.path.exists(ADAPTIVE_CHECKPOINT):
        checkpoint = torch.load(ADAPTIVE_CHECKPOINT, map_location=DEVICE, weights_only=False)
        adaptive_net.load_state_dict(checkpoint["model_state_dict"])
        
        if "epoch" in checkpoint:
            print(f"Loaded from Epoch {checkpoint['epoch']}")
        if "val_loss_final" in checkpoint:
            print(f"Previous Val Loss: {checkpoint['val_loss_final']:.6f}")
        if "val_improvement_pct" in checkpoint or "improvement_pct" in checkpoint:
            imp = checkpoint.get("val_improvement_pct", checkpoint.get("improvement_pct", "N/A"))
            print(f" Previous Improvement: {imp:.2f}%")
    else:
        print(f"Checkpoint not found {ADAPTIVE_CHECKPOINT}")
        return
    
    print(" AdaptiveRefiNET loaded (TRAINABLE)")
    print(f"Trainable parameters: {sum(p.numel() for p in adaptive_net.parameters() if p.requires_grad):,}")
    

    train_ds = VimeoSeptupletDataset(VIMEO_PARTS, TRAIN_LISTS, transforms.ToTensor())

    val_ds = VimeoSeptupletDataset(VIMEO_PARTS, TEST_LISTS, transforms.ToTensor())
    
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                             num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
                           num_workers=4, pin_memory=True)
    
    optimizer = optim.Adam(adaptive_net.parameters(), lr=LR)
    
    if "optimizer_state_dict" in checkpoint:
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        print(" Optimizer state restored")
    
    remaining_epochs = NUM_EPOCHS - START_EPOCH + 1
    scheduler = CosineAnnealingLR(optimizer, T_max=remaining_epochs, eta_min=1e-6)
    
    criterion = nn.MSELoss().to(DEVICE)
    
    best_loss = checkpoint.get("val_loss_final", float("inf"))
    print(f" prev loss: {best_loss:.6f}")
    
    metrics_file = os.path.join(SAVE_DIR, "adaptive_metrics.json")
    if os.path.exists(metrics_file):
        with open(metrics_file, 'r') as f:
            metrics = json.load(f)
        print(f" Loaded previous metrics ({len(metrics)} epochs)")
    else:
        metrics = []
    
    
    for epoch in range(START_EPOCH, NUM_EPOCHS + 1):
        print(f"\nEpoch {epoch}/{NUM_EPOCHS} | LR: {scheduler.get_last_lr()[0]:.2e}")
        
        tr_loss, tr_inter = train_epoch(mvae, mrefinement, rvae, rrefinement, adaptive_net,
                                        raft, raft_transforms, train_loader, optimizer, criterion, DEVICE)
        
        va_loss, va_inter = validate(mvae, mrefinement, rvae, rrefinement, adaptive_net,
                                     raft, raft_transforms, val_loader, criterion, DEVICE)
        
        tr_improvement = ((tr_inter - tr_loss) / tr_inter * 100) if tr_inter > 0 else 0
        va_improvement = ((va_inter - va_loss) / va_inter * 100) if va_inter > 0 else 0
        
        print(f"Train  | MSE Inter: {tr_inter:.6f} | MSE Final: {tr_loss:.6f} | Improvement: {tr_improvement:.2f}%")
        print(f"Val    | MSE Inter: {va_inter:.6f} | MSE Final: {va_loss:.6f} | Improvement: {va_improvement:.2f}%")
        
        metrics.append({
            "epoch": epoch,
            "train_mse_intermediate": tr_inter,
            "train_mse_final": tr_loss,
            "train_improvement_pct": tr_improvement,
            "val_mse_intermediate": va_inter,
            "val_mse_final": va_loss,
            "val_improvement_pct": va_improvement
        })
        
        if va_loss < best_loss:
            best_loss = va_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': adaptive_net.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss_final': va_loss,
                'val_loss_intermediate': va_inter,
                'improvement_pct': va_improvement
            }, os.path.join(SAVE_DIR, "adaptive_best.pth"))
            print(" Best Model Saved")
        
        torch.save({
            'epoch': epoch,
            'model_state_dict': adaptive_net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss_final': va_loss,
            'val_loss_intermediate': va_inter,
        }, os.path.join(SAVE_DIR, "adaptive_latest.pth"))
        
        scheduler.step()
        
        with open(metrics_file, 'w') as f:
            json.dump(metrics, f, indent=2)
    
    print("finish")
    print(f"Best Validation Loss: {best_loss:.6f}")



if __name__ == "__main__":
    main()