In [None]:
#  Disinstalla tutto
!pip uninstall -y numpy scipy pandas compressai torch torchvision torchaudio

#  Installa PyTorch per CUDA 12.1 (per 5090)
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

!pip install numpy==1.24.3 scipy==1.10.1

!pip install compressai==1.2.4

#  Restarta il kernel

In [None]:
#  IGNORA
!pip uninstall -y numpy scipy compressai torch torchvision

#CompressAI requires specific numpy/scipy versions to function correctly.

!pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118
!pip install numpy==1.24.3
!pip install scipy==1.10.1
!pip install compressai==1.2.4

# Restart the kernel !!!!!!!!!!!!!!!!

In [None]:
# 1. Aggiorna la libreria problematica
!pip install -U typing_extensions --cache-dir /workspace/pip_cache

# 2. Controllo opzionale (dovrebbe essere >= 4.10.0)
!pip show typing_extensions

In [None]:
# UTILS and DATALOADER
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torch.cuda.amp import autocast, GradScaler
import torchvision.transforms.functional as TF
from PIL import Image
import os
import random
import glob
from torchvision import transforms
from compressai.models import CompressionModel
from compressai.entropy_models import EntropyBottleneck, GaussianConditional
from compressai.layers import GDN

import json
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision.models.optical_flow import raft_small, Raft_Small_Weights
from tqdm import tqdm


import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchvision.utils import flow_to_image


class VimeoSeptupletDataset(Dataset):
    def __init__(self, root_dirs, split_files, transform=None):
        self.transform = transform
        self.sequences = []
        self.sequence_to_root = {}
        self.frames_per_sample = 5 
        self.samples_per_seq = 7 - self.frames_per_sample + 1
        
        if isinstance(root_dirs, str): 
            root_dirs = [root_dirs]
        if isinstance(split_files, str): 
            split_files = [split_files]
        
        for root_dir, split_file in zip(root_dirs, split_files):
            if not os.path.isabs(split_file):
                split_file = os.path.join(root_dir, split_file)
            
            print(f"Loading split from: {split_file}")
            
            if os.path.exists(split_file):
                with open(split_file, 'r') as f:
                    count = 0
                    for line in f:
                        name = line.strip()
                        if name:
                            self.sequences.append(name)
                            self.sequence_to_root[name] = root_dir
                            count += 1
                print(f" Loaded {count} sequences")
            else:
                print(f"Error: File not found {split_file}")
        
        print(f" Total Dataset: {len(self.sequences)} sequences")
    
    def __len__(self):
        return len(self.sequences) * self.samples_per_seq
    
    def _get_sequence_path(self, seq_name, root_dir):
        possible_paths = []
        
        if 'vimeo_settuplet_1' in root_dir or root_dir == "":
            possible_paths.append(os.path.join(root_dir, "vimeo_settuplet_1", "sequences", seq_name))
        elif 'vimeo_part2' in root_dir:
            possible_paths.append(os.path.join(root_dir, "vimeo_settuplet_2", "sequence", seq_name))
        else:
            possible_paths.append(os.path.join(root_dir, "sequence", seq_name))
        
        possible_paths.extend([
            os.path.join(root_dir, "sequences", seq_name),
            os.path.join(root_dir, "sequence", seq_name),
            os.path.join(root_dir, "vimeo_settuplet_1", "sequences", seq_name),
            os.path.join(root_dir, "vimeo_settuplet_2", "sequence", seq_name),
        ])
        
        for path in possible_paths:
            if os.path.exists(path):
                return path
        
        return possible_paths[0] if possible_paths else os.path.join(root_dir, "sequence", seq_name)
    
    def __getitem__(self, idx):
        seq_idx = idx // self.samples_per_seq
        inner_idx = idx % self.samples_per_seq
        seq_name = self.sequences[seq_idx]
        root_dir = self.sequence_to_root[seq_name]
        
        seq_path = self._get_sequence_path(seq_name, root_dir)
        
        if not os.path.exists(seq_path):
             return torch.zeros(3, 256, 448), torch.zeros(3, 256, 448), torch.zeros(12, 256, 448)
        
        start_frame = inner_idx + 1
        frames = []
        
        for i in range(self.frames_per_sample):
            frame_path = os.path.join(seq_path, f"im{start_frame+i}.png")
            
            if not os.path.exists(frame_path):
                img = Image.new('RGB', (448, 256))
            else:
                try:
                    img = Image.open(frame_path).convert("RGB")
                except:
                    img = Image.new('RGB', (448, 256))
            
            if self.transform:
                img = self.transform(img)
            else:
                img = transforms.ToTensor()(img)
            
            frames.append(img)
        
        frame_curr = frames[-1]
        frame_prev = frames[-2]
        history = [frames[-2], frames[-3], frames[-4], frames[-5]]
        history_frames = torch.cat(history, dim=0)
        
        return frame_curr, frame_prev, history_frames


def conv(in_channels, out_channels, kernel_size=5, stride=2):
    return nn.Conv2d(
        in_channels,
        out_channels,
        kernel_size=kernel_size,
        stride=stride,
        padding=kernel_size // 2,
    )

def deconv(in_channels, out_channels, kernel_size=5, stride=2):
    return nn.ConvTranspose2d(
        in_channels,
        out_channels,
        kernel_size=kernel_size,
        stride=stride,
        output_padding=stride - 1,
        padding=kernel_size // 2,
    )

class ResBlock(nn.Module):
    def __init__(self, c, dilation=1, use_gn=True):
        super().__init__()
        self.use_gn = use_gn
        self.conv1 = nn.Conv2d(c, c, 3, padding=dilation, dilation=dilation)
        self.conv2 = nn.Conv2d(c, c, 3, padding=1)
        if use_gn:
            self.gn1 = nn.GroupNorm(8, c)
            self.gn2 = nn.GroupNorm(8, c)
        self.act = nn.ReLU(inplace=True)

    def forward(self, x):
        y = self.conv1(x)
        if self.use_gn: y = self.gn1(y)
        y = self.act(y)
        y = self.conv2(y)
        if self.use_gn: y = self.gn2(y)
        return x + y

class SimpleResBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, 3, 1, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(channels, channels, 3, 1, 1)
        )
        
    def forward(self, x):
        return x + self.block(x)


# VAEs architectures:

class ScaleHyperprior(CompressionModel):
    def __init__(self, N, M, in_channels=2, out_channels=2, **kwargs):
        super().__init__(**kwargs)

        self.entropy_bottleneck = EntropyBottleneck(N)

        self.g_a = nn.Sequential(
            conv(in_channels, N),
            GDN(N),
            conv(N, N),
            GDN(N),
            conv(N, N),
            GDN(N),
            conv(N, M),
        )

        self.g_s = nn.Sequential(
            deconv(M, N),
            GDN(N, inverse=True),
            deconv(N, N),
            GDN(N, inverse=True),
            deconv(N, N),
            GDN(N, inverse=True),
            deconv(N, out_channels),
        )

        self.h_a = nn.Sequential(
            conv(M, N, stride=1, kernel_size=3),
            nn.ReLU(inplace=True),
            conv(N, N),
            nn.ReLU(inplace=True),
            conv(N, N),
        )
        
        self.h_s = nn.Sequential(
            deconv(N, N),
            nn.ReLU(inplace=True),
            deconv(N, N),
            nn.ReLU(inplace=True),
            conv(N, M, stride=1, kernel_size=3),
            nn.ReLU(inplace=True),
        )

        self.gaussian_conditional = GaussianConditional(None)
        self.N = int(N)
        self.M = int(M)

    @property
    def downsampling_factor(self) -> int:
        return 2 ** 4

    def forward(self, x):
        y = self.g_a(x)
        z = self.h_a(torch.abs(y))
        z_hat, z_likelihoods = self.entropy_bottleneck(z)
        scales_hat = self.h_s(z_hat)
        y_hat, y_likelihoods = self.gaussian_conditional(y, scales_hat)
        x_hat = self.g_s(y_hat)

        return {
            "x_hat": x_hat,
            "likelihoods": {"y": y_likelihoods, "z": z_likelihoods},
        }

    @classmethod
    def from_state_dict(cls, state_dict):
        """Return a new model instance from `state_dict`."""
        N = state_dict["g_a.0.weight"].size(0)
        M = state_dict["g_a.6.weight"].size(0)
        in_channels = state_dict["g_a.0.weight"].size(1) 
        out_channels = state_dict["g_s.6.weight"].size(0)
        
        net = cls(N, M, in_channels=in_channels, out_channels=out_channels)
        net.load_state_dict(state_dict)
        return net

    def compress(self, x):
        y = self.g_a(x)
        z = self.h_a(torch.abs(y))

        z_strings = self.entropy_bottleneck.compress(z)
        z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:])

        scales_hat = self.h_s(z_hat)
        indexes = self.gaussian_conditional.build_indexes(scales_hat)
        y_strings = self.gaussian_conditional.compress(y, indexes)
        return {"strings": [y_strings, z_strings], "shape": z.size()[-2:]}

    def decompress(self, strings, shape):
        assert isinstance(strings, list) and len(strings) == 2
        z_hat = self.entropy_bottleneck.decompress(strings[1], shape)
        scales_hat = self.h_s(z_hat)
        indexes = self.gaussian_conditional.build_indexes(scales_hat)
        y_hat = self.gaussian_conditional.decompress(strings[0], indexes, z_hat.dtype)
        
        x_hat = self.g_s(y_hat) 
        
        return {"x_hat": x_hat}

# Flow postprocessing net

class MotionRefineNET(nn.Module):
    def __init__(self, base=64, blocks=8, use_gn=True, use_gate=True):
        super().__init__()

        in_ch = 14 
        
        self.stem = nn.Sequential(nn.Conv2d(in_ch, base, 3, padding=1), nn.ReLU(inplace=True))
        
        body = []
        for i in range(blocks):
            dil = 1 if i < blocks - 2 else 2
            body.append(ResBlock(base, dilation=dil, use_gn=use_gn))
        self.body = nn.Sequential(*body)

        self.delta_head = nn.Conv2d(base, 2, 3, padding=1)
        nn.init.zeros_(self.delta_head.weight)
        nn.init.zeros_(self.delta_head.bias)

        self.use_gate = use_gate
        if use_gate:
            self.gate_head = nn.Conv2d(base, 1, 3, padding=1)
            nn.init.zeros_(self.gate_head.weight)
            nn.init.zeros_(self.gate_head.bias)

    def forward(self, flow_hat, history_4f):

        x = torch.cat([flow_hat, history_4f], dim=1)
        
        f = self.stem(x)
        f = self.body(f)
        
        delta = self.delta_head(f)
        
        if self.use_gate:
            gate = torch.sigmoid(self.gate_head(f))
            delta = gate * delta
            
        return flow_hat + delta

# Residual postprocessing NET

class ResRefiNET(nn.Module):
    def __init__(self, in_channels=3, mid_channels=64, num_blocks=6):
        super().__init__()
        self.head = nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1)
        self.body = nn.Sequential(*[ResBlock(mid_channels, use_gn=True) for _ in range(num_blocks)])
        self.tail = nn.Conv2d(mid_channels, in_channels, kernel_size=3, padding=1)
        nn.init.zeros_(self.tail.weight)
        nn.init.zeros_(self.tail.bias)

    def forward(self, x):
        identity = x
        out = self.head(x)
        out = self.body(out)
        correction = self.tail(out)
        return identity + correction

# Frame recostruction NET

class AdaptiveRefiNET(nn.Module):    
    def __init__(self, base=64, num_blocks=10):
        super().__init__()
        
        self.encoder = nn.Sequential(
            nn.Conv2d(19, base, 3, 1, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(base, base, 3, 1, 1),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        self.res_blocks = nn.ModuleList([
            SimpleResBlock(base) for _ in range(num_blocks)
        ])
        
        self.decoder = nn.Sequential(
            nn.Conv2d(base, base, 3, 1, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(base, 3, 3, 1, 1)
        )
        
    def forward(self, recon, warped, mask, history):
        x = torch.cat([recon, warped, mask, history], dim=1)
        
        feat = self.encoder(x)
        
        for block in self.res_blocks:
            feat = block(feat)
        
        correction = self.decoder(feat)
        refined = recon + correction * mask
        
        return refined.clamp(0, 1)

In [None]:
# Qui il problema piu grande è stato far andare la 5090 senza crash di CUDA
# RunPod quando si cambia da gpu x a gpu y fa casini con la cache :(
# Senza almeno 32GB di VRAM il training sarebbe stato infattibile
# Volevamo finetunare anche il Motion Estimator (RAFT) ma sarebbe stato troppo pesante e lento.
# Avevamo provato con una A100 con 80GB di VRAM ma il settaggio era molto complicato e anche una volta runnato andava a 3 fps di inferenza.
# Il bilancio costo/qualità non ne valeva la pena.

# La vera sfida è stata la scelta degli iperparametri giusti di rete e training, poichè il tempo e le risorse per sbagliare e 
# ricominciare erano davvero limitate.



# Gestione memoria ottimizzata per 32GB VRAM (GDDR7)
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,roundup_power2_divisions:16"

# Evita riuso cache Inductor/Triton compilate per altre GPU (Runpod moment)
os.environ.setdefault("TORCHINDUCTOR_CACHE_DIR", "/tmp/torchinductor_5090")
os.environ.setdefault("TRITON_CACHE_DIR", "/tmp/triton_5090")

# TF32
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision('high')


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Valori per ottimizzare l'addestramento su RTX 5090 
# BATCH SIZE MASSIMO A 16 CON GRAD_ACC A 2 SENNO CRASHA RUNPOD!!!!!!!!!!!!!!!
# Cosi va a 3fps ma bisogna accettare il compromesso. RAFT è pesantissimo. 

BATCH_SIZE = 12
GRADIENT_ACCUMULATION = 3

NUM_WORKERS = 8
PREFETCH_FACTOR = 4
PIN_MEMORY = True
PERSISTENT_WORKERS = True

USE_BFLOAT16 = True
AUTOCAST_DTYPE = torch.bfloat16 if USE_BFLOAT16 else torch.float16

MAX_FLOW = 50.0
AUX_WEIGHT = 0.0
EXTRA_EPOCHS = 10


USE_COMPILE = True
COMPILE_BACKEND = "aot_eager"  # niente Triton/Inductor da problemi di compatibilità su RTX 5090

# Qui volevamo implementare una schedulazione di lambda variabile per l'addestramento, ma ci abbiamo ripensato per semplicità della rete.

LAMBDA_SCHEDULE = {
    6:  1024.0,
    7:  1024.0,
    8:  1024.0,
    9:  1024.0,
    10: 1024.0,
    11: 1024.0,
    12: 1024.0,
    13: 1024.0,
    14: 1024.0,
    15: 1024.0,
}

DEBUG_RECON_RANGE_EVERY = 200

VIMEO_PARTS = [
    "", "vimeo_part2", "vimeo-part3", "vimeo_part4",
    "vimeo_part5", "vimeo_part6", "vimeo-part7",
    "vimeo-part8", "vimeo_part9",
]
TRAIN_LISTS = [
    "vimeo_settuplet_1/sep_trainlist.txt", "vimeo_settuplet_2/sep_trainlist.txt",
    "sep_trainlist.txt", "sep_trainlist.txt", "sep_trainlist.txt",
    "sep_trainlist.txt", "sep_trainlist.txt", "sep_trainlist.txt", "sep_trainlist.txt",
]
TEST_LISTS = [
    "vimeo_settuplet_1/sep_testlist.txt", "vimeo_settuplet_2/sep_testlist.txt",
    "sep_testlist.txt", "sep_testlist.txt", "sep_testlist.txt",
    "sep_testlist.txt", "sep_testlist.txt", "sep_testlist.txt", "sep_testlist.txt",
]

OUTPUT_DIR = "workspace/"
os.makedirs(OUTPUT_DIR, exist_ok=True)

RESUME_PATH = os.path.join("joint_epoch_6.pth")


def seed_everything(seed=1234):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


class RateDistortionLossMSE(nn.Module):
    def __init__(self, lmbda=1024.0):
        super().__init__()
        self.lmbda = float(lmbda)

    def forward(self, recon, target, likelihoods_list):
        recon_f = recon.float()
        target_f = target.float()

        mse = F.mse_loss(recon_f, target_f)

        n, _, h, w = target_f.size()
        num_pixels = n * h * w

        rate = 0.0
        for likelihoods in likelihoods_list:
            if likelihoods is None:
                continue
            for key in ["y", "z"]:
                if key in likelihoods:
                    ll = torch.log(likelihoods[key] + 1e-10)
                    rate += ll.sum() / (-math.log(2) * num_pixels)

        loss = rate + self.lmbda * mse
        return loss, mse, rate, mse, torch.tensor(0.0, device=recon.device)


def unwrap(model: nn.Module) -> nn.Module:
    return model._orig_mod if hasattr(model, "_orig_mod") else model


def safe_update_entropy(model: nn.Module, force: bool = True):
    m = unwrap(model)
    if hasattr(m, "update"):
        m.update(force=force)


def safe_state_dict(model: nn.Module):
    return unwrap(model).state_dict()


def normalize_flow(flow, max_flow):
    return flow.clamp(-max_flow, max_flow) / max_flow


def flow_warp(x, flow):
    B, C, H, W = x.size()

    xx = torch.arange(0, W, device=x.device, dtype=x.dtype).view(1, -1).repeat(H, 1)
    yy = torch.arange(0, H, device=x.device, dtype=x.dtype).view(-1, 1).repeat(1, W)
    xx = xx.view(1, 1, H, W).repeat(B, 1, 1, 1)
    yy = yy.view(1, 1, H, W).repeat(B, 1, 1, 1)
    grid = torch.cat((xx, yy), 1)

    vgrid = grid + flow
    vgrid[:, 0] = 2.0 * vgrid[:, 0] / max(W - 1, 1) - 1.0
    vgrid[:, 1] = 2.0 * vgrid[:, 1] / max(H - 1, 1) - 1.0
    vgrid = vgrid.permute(0, 2, 3, 1)

    return F.grid_sample(x, vgrid, align_corners=True, mode='bilinear', padding_mode='border')


def compute_adaptive_mask(residual, lambda_param=1.5, epsilon=1e-6):
    norm_per_pixel = torch.sqrt(torch.sum(residual ** 2, dim=1, keepdim=True) + epsilon)
    H, W = residual.shape[2], residual.shape[3]
    mu = torch.sum(norm_per_pixel, dim=(2, 3), keepdim=True) / (H * W)
    mask = torch.tanh(lambda_param * norm_per_pixel / (mu + epsilon))
    return mask


def total_aux_loss(models):
    aux = 0.0
    for k in ['mvae', 'rvae']:
        m = unwrap(models[k])
        if hasattr(m, "aux_loss"):
            aux = aux + m.aux_loss()
    return aux


def freeze_entropy_modules(scale_hyperprior_model: nn.Module):
    model = unwrap(scale_hyperprior_model)

    frozen = 0
    for attr in ["entropy_bottleneck", "gaussian_conditional"]:
        if hasattr(model, attr):
            mod = getattr(model, attr)
            for p in mod.parameters():
                if p.requires_grad:
                    p.requires_grad = False
                    frozen += p.numel()
    return frozen


def resume_from_checkpoint(path, models, optimizer=None, scheduler=None):
    if not os.path.exists(path):
        raise FileNotFoundError(f"Resume checkpoint non trovato: {path}")

    print(f"Loading checkpoint from {path}...")
    ckpt = torch.load(path, map_location=DEVICE, weights_only=False)

    def load_weights(model, state_dict):
        try:
            model.load_state_dict(state_dict, strict=True)
        except RuntimeError:
            new_state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
            model.load_state_dict(new_state_dict, strict=False)

    load_weights(models['mvae'], ckpt['mvae'])
    load_weights(models['mrefine'], ckpt['mrefine'])
    load_weights(models['rvae'], ckpt['rvae'])
    load_weights(models['rrefine'], ckpt['rrefine'])
    load_weights(models['adaptive'], ckpt['adaptive'])

    if optimizer is not None and 'optimizer' in ckpt:
        optimizer.load_state_dict(ckpt['optimizer'])
    if scheduler is not None and 'scheduler' in ckpt:
        scheduler.load_state_dict(ckpt['scheduler'])

    last_epoch = int(ckpt.get('epoch', 0))
    start_epoch = last_epoch + 1

    print(f"  Last epoch: {last_epoch}  →  Start epoch: {start_epoch}")
    return start_epoch


def train_joint_epoch(models, loader, optimizer, criterion, raft_transforms, scaler, epoch):
    for name, model in models.items():
        if name == 'raft':
            model.eval()
        else:
            model.train()

    running_total = 0.0
    running_rate = 0.0
    running_lamdist = 0.0
    running_mse = 0.0
    num_batches = 0

    pbar = tqdm(loader, desc=f"Epoch {epoch} - RTX 5090 Joint Training", ncols=120)
    optimizer.zero_grad(set_to_none=True)

    for batch_idx, (frame_curr, frame_prev, history) in enumerate(pbar):
        frame_curr = frame_curr.to(DEVICE, non_blocking=True).clamp(0, 1)
        frame_prev = frame_prev.to(DEVICE, non_blocking=True).clamp(0, 1)
        history = history.to(DEVICE, non_blocking=True)

        # 1. Optical Flow (RAFT)
        with torch.no_grad():
            img1, img2 = raft_transforms(frame_prev, frame_curr)
            flow_gt = models['raft'](img1, img2)[-1]

        # 2. Forward + loss
        with autocast("cuda", dtype=AUTOCAST_DTYPE):
            flow_norm = normalize_flow(flow_gt, MAX_FLOW)

            mvae_out = models['mvae'](flow_norm)
            flow_refined_norm = models['mrefine'](mvae_out["x_hat"], history)
            warped = flow_warp(frame_prev, flow_refined_norm * MAX_FLOW)

            residual_gt = frame_curr - warped
            rvae_out = models['rvae'](residual_gt)
            residual_refined = models['rrefine'](rvae_out["x_hat"])

            recon_intermediate = (warped + residual_refined).clamp(0, 1)
            mask = compute_adaptive_mask(residual_refined)
            final_recon = models['adaptive'](recon_intermediate, warped, mask, history).clamp(0, 1)

            loss_rd, dist, rate, mse, _ = criterion(
                final_recon,
                frame_curr,
                [mvae_out.get("likelihoods", None), rvae_out.get("likelihoods", None)]
            )

            aux = total_aux_loss(models)
            loss_total = loss_rd + AUX_WEIGHT * aux
            loss_total = loss_total / GRADIENT_ACCUMULATION

        scaler.scale(loss_total).backward()

        if (batch_idx + 1) % GRADIENT_ACCUMULATION == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(
                [p for group in optimizer.param_groups for p in group['params']],
                max_norm=1.0
            )
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)

        contrib_rate = rate.item()
        contrib_lamdist = (criterion.lmbda * dist).item()

        running_total += loss_total.item() * GRADIENT_ACCUMULATION
        running_rate += contrib_rate
        running_lamdist += contrib_lamdist
        running_mse += mse.item()
        num_batches += 1

        pbar.set_postfix({
            "L": f"{(loss_total.item()*GRADIENT_ACCUMULATION):.4f}",
            "R": f"{contrib_rate:.3f}",
            "MSE": f"{mse.item():.1e}",
        })

    return {
        "loss": running_total / num_batches,
        "rate": running_rate / num_batches,
        "lamdist": running_lamdist / num_batches,
        "mse": running_mse / num_batches,
    }


@torch.no_grad()
def validate(models, loader, criterion, raft_transforms):
    for model in models.values():
        model.eval()

    running_total = 0.0
    running_rate = 0.0
    running_lamdist = 0.0
    running_mse = 0.0
    num_batches = 0

    pbar = tqdm(loader, desc="Validation", ncols=120)

    for (frame_curr, frame_prev, history) in pbar:
        frame_curr = frame_curr.to(DEVICE, non_blocking=True).clamp(0, 1)
        frame_prev = frame_prev.to(DEVICE, non_blocking=True).clamp(0, 1)
        history = history.to(DEVICE, non_blocking=True)

        img1, img2 = raft_transforms(frame_prev, frame_curr)
        flow_gt = models['raft'](img1, img2)[-1]

        with autocast("cuda", dtype=AUTOCAST_DTYPE):
            flow_norm = normalize_flow(flow_gt, MAX_FLOW)
            mvae_out = models['mvae'](flow_norm)
            flow_refined_norm = models['mrefine'](mvae_out["x_hat"], history)

            warped = flow_warp(frame_prev, flow_refined_norm * MAX_FLOW)
            residual_gt = frame_curr - warped
            rvae_out = models['rvae'](residual_gt)
            residual_refined = models['rrefine'](rvae_out["x_hat"])

            recon_intermediate = (warped + residual_refined).clamp(0, 1)
            mask = compute_adaptive_mask(residual_refined)
            final_recon = models['adaptive'](recon_intermediate, warped, mask, history).clamp(0, 1)

            loss_rd, dist, rate, mse, _ = criterion(
                final_recon,
                frame_curr,
                [mvae_out.get("likelihoods", None), rvae_out.get("likelihoods", None)]
            )

        contrib_rate = rate.item()
        contrib_lamdist = (criterion.lmbda * dist).item()

        running_total += loss_rd.item()
        running_rate += contrib_rate
        running_lamdist += contrib_lamdist
        running_mse += mse.item()
        num_batches += 1

        pbar.set_postfix({
            "L": f"{loss_rd.item():.4f}",
            "R": f"{contrib_rate:.3f}",
            "MSE": f"{mse.item():.1e}",
        })

    return {
        "loss": running_total / num_batches,
        "rate": running_rate / num_batches,
        "lamdist": running_lamdist / num_batches,
        "mse": running_mse / num_batches,
    }


def main_joint():
    seed_everything(1234)


    print(f"Device: {DEVICE} ({torch.cuda.get_device_name(0)})")
    vram_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"VRAM: {vram_gb:.1f} GB")

    eff_bs = BATCH_SIZE * GRADIENT_ACCUMULATION
    print(f"\nBatch size (per step): {BATCH_SIZE}")
    print(f"Gradient Accumulation: {GRADIENT_ACCUMULATION}")
    print(f" Effective Batch Size: {eff_bs} (Equivalente al config A100)")

    print(f"Workers: {NUM_WORKERS} / Prefetch: {PREFETCH_FACTOR}")
    print(f"Precision: {'BFloat16' if USE_BFLOAT16 else 'Float16'}")
    print(f"Torch Compile: {'ENABLED' if USE_COMPILE else 'DISABLED'} | backend={COMPILE_BACKEND}")

    print("\n Loading models...")

    mvae = ScaleHyperprior(N=192, M=192, in_channels=2).to(DEVICE)
    mrefine = MotionRefineNET(base=64, blocks=8).to(DEVICE)
    rvae = ScaleHyperprior(N=128, M=128, in_channels=3, out_channels=3).to(DEVICE)
    rrefine = ResRefiNET().to(DEVICE)
    adaptive = AdaptiveRefiNET(base=64, num_blocks=10).to(DEVICE)

    from torchvision.models.optical_flow import raft_small, Raft_Small_Weights
    raft = raft_small(weights=Raft_Small_Weights.DEFAULT, progress=False).to(DEVICE)
    raft.eval()
    for p in raft.parameters():
        p.requires_grad = False
    raft_transforms = Raft_Small_Weights.DEFAULT.transforms()

    models = {
        'mvae': mvae,
        'mrefine': mrefine,
        'rvae': rvae,
        'rrefine': rrefine,
        'adaptive': adaptive,
        'raft': raft
    }

    lr_scale = math.sqrt(eff_bs / 8.0)
    optimizer = optim.AdamW([
        {'params': mvae.parameters(),     'lr': 1e-6  * lr_scale, 'weight_decay': 1e-4},
        {'params': rvae.parameters(),     'lr': 1e-6  * lr_scale, 'weight_decay': 1e-4},
        {'params': mrefine.parameters(),  'lr': 5e-6  * lr_scale, 'weight_decay': 1e-4},
        {'params': rrefine.parameters(),  'lr': 5e-6  * lr_scale, 'weight_decay': 1e-4},
        {'params': adaptive.parameters(), 'lr': 2.5e-5 * lr_scale, 'weight_decay': 1e-4},
    ], betas=(0.9, 0.999), eps=1e-8)

    from torch.optim.lr_scheduler import CosineAnnealingLR
    scheduler = CosineAnnealingLR(optimizer, T_max=EXTRA_EPOCHS, eta_min=1e-6)

    criterion = RateDistortionLossMSE(lmbda=1024.0).to(DEVICE)
    scaler = GradScaler(enabled=True)

    start_epoch = resume_from_checkpoint(RESUME_PATH, models, optimizer=None, scheduler=None)

    torch.cuda.synchronize()

    safe_update_entropy(models['mvae'], force=True)
    safe_update_entropy(models['rvae'], force=True)

    frozen_m = freeze_entropy_modules(models['mvae'])
    frozen_r = freeze_entropy_modules(models['rvae'])

    if USE_COMPILE:
        print(f"\n Compiling custom models with torch.compile (backend={COMPILE_BACKEND})...")
        models['mrefine']  = torch.compile(unwrap(models['mrefine']),  backend=COMPILE_BACKEND)
        models['rrefine']  = torch.compile(unwrap(models['rrefine']),  backend=COMPILE_BACKEND)
        models['adaptive'] = torch.compile(unwrap(models['adaptive']), backend=COMPILE_BACKEND)
        print(" Custom models compiled.")

    print("\n Loading datasets.")
    train_ds = VimeoSeptupletDataset(VIMEO_PARTS, TRAIN_LISTS, transforms.ToTensor())
    val_ds = VimeoSeptupletDataset(VIMEO_PARTS, TEST_LISTS, transforms.ToTensor())

    train_loader = DataLoader(
        train_ds,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY,
        persistent_workers=PERSISTENT_WORKERS,
        prefetch_factor=PREFETCH_FACTOR,
        drop_last=True,
        multiprocessing_context='fork' if os.name != 'nt' else 'spawn',
    )

    val_loader = DataLoader(
        val_ds,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS // 2,
        pin_memory=PIN_MEMORY,
        persistent_workers=PERSISTENT_WORKERS,
        prefetch_factor=PREFETCH_FACTOR,
        multiprocessing_context='fork' if os.name != 'nt' else 'spawn',
    )

    end_epoch = start_epoch + EXTRA_EPOCHS - 1
    best_val = float('inf')

    for epoch in range(start_epoch, end_epoch + 1):
        if epoch in LAMBDA_SCHEDULE:
            criterion.lmbda = float(LAMBDA_SCHEDULE[epoch])
            print(f"\n Lambda updated to {criterion.lmbda:.0f}")

        print(f"Epoch {epoch}/{end_epoch} | λ={criterion.lmbda:.0f} | Eff.Batch={eff_bs}")

        train_metrics = train_joint_epoch(
            models, train_loader, optimizer, criterion,
            raft_transforms, scaler, epoch
        )
        print(f"\n Epoch {epoch} (TRAIN): Loss={train_metrics['loss']:.6f} | MSE={train_metrics['mse']:.6e}")

        # Sync + update entropy (eager)
        torch.cuda.synchronize()
        safe_update_entropy(models['mvae'], force=True)
        safe_update_entropy(models['rvae'], force=True)

        val_metrics = validate(models, val_loader, criterion, raft_transforms)
        print(f" Epoch {epoch} (VAL): Loss={val_metrics['loss']:.6f} | MSE={val_metrics['mse']:.6e}")

        scheduler.step()

        checkpoint_path = os.path.join(OUTPUT_DIR, f"joint_epoch_{epoch}.pth")
        torch.save({
            'epoch': epoch,
            'mvae': safe_state_dict(models['mvae']),
            'mrefine': safe_state_dict(models['mrefine']),
            'rvae': safe_state_dict(models['rvae']),
            'rrefine': safe_state_dict(models['rrefine']),
            'adaptive': safe_state_dict(models['adaptive']),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
        }, checkpoint_path)

        if val_metrics['loss'] < best_val:
            best_val = val_metrics['loss']
            best_path = os.path.join(OUTPUT_DIR, "joint_best.pth")
            torch.save({
                'epoch': epoch,
                'mvae': safe_state_dict(models['mvae']),
                'mrefine': safe_state_dict(models['mrefine']),
                'rvae': safe_state_dict(models['rvae']),
                'rrefine': safe_state_dict(models['rrefine']),
                'adaptive': safe_state_dict(models['adaptive']),
                'val_metrics': val_metrics,
            }, best_path)
            print(f" model saved Val loss: {best_val:.6f}")

    print("\n finish !")


if __name__ == "__main__":
    main_joint()
