In [None]:
import os, glob
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

In [None]:
class CharTokenizer:
    def __init__(self):
        chars = list("ABCDEFGHIJKLMNOPQRSTUVWXYZ '")  # LRS2 has capital letters and space
        self.stoi = {ch: i+1 for i, ch in enumerate(chars)}  # Reserve 0 for padding
        self.itos = {i+1: ch for ch, i in self.stoi.items()}
        self.pad_id = 0

    def encode(self, text):
        text = text.upper()
        return [self.stoi.get(c, self.pad_id) for c in text if c in self.stoi]

    def decode(self, indices):
        return ''.join([self.itos.get(i, '') for i in indices if i > 0])

In [None]:
import torch
from torch.utils.data import Dataset
import re

# class LRS2PTDataset(Dataset):
#     def __init__(self, pt_paths, tokenizer):
#         self.pt_paths = pt_paths
#         self.tokenizer = tokenizer

#     def __len__(self):
#         return len(self.pt_paths)

#     def __getitem__(self, idx):
#         try:
#             data = torch.load(self.pt_paths[idx])
#             video = data['video']
#             audio = data['audio']
#             match = re.search(r'Text:\s*(.*?)\n', data['label'])
#             if not match:
#                 raise ValueError("Invalid label format")
#             text = match.group(1).strip()
#             label = torch.tensor(self.tokenizer.encode(text), dtype=torch.long)
#             return video, audio, label
#         except Exception as e:
#             print(f"Skipping file {self.pt_paths[idx]} due to error: {e}")
#             return self.__getitem__((idx + 1) % len(self.pt_paths))
import torch
import re

class LRS2PTDataset(Dataset):
    def __init__(self, pt_paths, tokenizer):
        self.pt_paths = pt_paths
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.pt_paths)

    def __getitem__(self, idx):
        try:
            # Load .pt file
            data = torch.load(self.pt_paths[idx])
            video = data['video']  # (T_v, 1, 96, 96)
            audio = data['audio']
            # Extract text label
            match = re.search(r'Text:\s*(.*?)\n', data['label'])
            if not match:
                print(f"Skipping {self.pt_paths[idx]}: invalid label format")
                return self.__getitem__((idx + 1) % len(self.pt_paths))
            
            text = match.group(1).strip()
            # Skip samples with digits
            if any(c.isdigit() for c in text):
                #print(f"Skipping {self.pt_paths[idx]}: contains digits in text: {text}")
                #print()
                return self.__getitem__((idx + 1) % len(self.pt_paths))
            
            # Encode text
            label = self.tokenizer.encode(text)
            video_length = video.shape[0]
            # Truncate label to video length
            label = label[:video_length]
            if len(label) == 0:
                print(f"Skipping {self.pt_paths[idx]}: empty label after encoding/truncation")
                return self.__getitem__((idx + 1) % len(self.pt_paths))
            
            # Validate tokens
            vocab_size = len(self.tokenizer.stoi) + 1  # 28 (1-27 + blank)
            if any(token >= vocab_size or token < 0 for token in label):
                print(f"Skipping {self.pt_paths[idx]}: invalid tokens in label: {label}")
                return self.__getitem__((idx + 1) % len(self.pt_paths))
            
            label = torch.tensor(label, dtype=torch.long)
            return video, audio, label
        
        except Exception as e:
            print(f"Skipping {self.pt_paths[idx]}: error loading file: {e}")
            return self.__getitem__((idx + 1) % len(self.pt_paths))

# def collate_fn(batch):
#     videos, audios, labels = zip(*batch)
#     max_len_v = max(v.shape[0] for v in videos)
#     max_len_a = max(a.shape[0] for a in audios)
    

#     def pad(x, target_len):
#         if x.shape[0] < target_len:
#             pad_size = (0, 0) * (x.dim() - 1) + (0, target_len - x.shape[0])
#             return torch.nn.functional.pad(x, pad_size)
#         return x
#     # for i in [pad(a, max_len_a) for a in audios]:
#     #     print(i.shape)

#     # print("For videos")
#     # for i in [pad(v, max_len_v) for v in videos]:
#     #     print(i.shape)
#     padded_videos = torch.stack([pad(v, max_len_v) for v in videos])  # (B, T, 1, 96, 96)
#     #padded_videos = padded_videos.permute(0, 2, 1, 3, 4)  # (B, 1, T, 96, 96)
#     padded_audios = torch.stack([pad(a, max_len_a) for a in audios])  # (B, T, 768)
#     padded_labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=0)
#     label_lengths = torch.tensor([len(l) for l in labels])
    
#     #print("returning")
#     #print(padded_videos.shape)
#     return padded_videos, padded_audios, padded_labels, label_lengths
def collate_fn(batch):
    videos, audios, labels = zip(*batch)
    max_len_v = max(v.shape[0] for v in videos)  # Maximum video length
    max_len_a = max(a.shape[0] for a in audios)  # Maximum audio length
    max_len_labels = max(len(l) for l in labels)  # Maximum label length

    # Pad videos to at least max_len_labels to satisfy CTC requirement
    target_video_len = max(max_len_v, max_len_labels)

    def pad(x, target_len):
        if x.shape[0] < target_len:
            pad_size = (0, 0) * (x.dim() - 1) + (0, target_len - x.shape[0])
            return torch.nn.functional.pad(x, pad_size)
        return x

    video_lengths = [v.shape[0] for v in videos]  # Original video lengths
    # Convert 1-channel to 3-channel if needed
    videos = [v.repeat(1, 3, 1, 1) if v.shape[1] == 1 else v for v in videos]  # (T, 3, 96, 96)
    # Pad videos to target_video_len
    padded_videos = torch.stack([pad(v, target_video_len) for v in videos])  # (B, T, 3, 96, 96)
    padded_audios = torch.stack([pad(a, max_len_a) for a in audios])  # (B, T, 768)
    padded_labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=0)
    label_lengths = torch.tensor([len(l) for l in labels])
    # Adjust video_lengths to reflect padding
    video_lengths = torch.tensor([min(vl, target_video_len) for vl in video_lengths], dtype=torch.long)

    return padded_videos, padded_audios, padded_labels, label_lengths

from torch.utils.data import Sampler
import random
#prompt of bucket sort already given
class BucketBatchSampler(Sampler):
    def __init__(self, lengths, batch_size, drop_last=False):
        self.batch_size = batch_size
        self.drop_last = drop_last

        # Sort indices by length
        self.sorted_indices = sorted(range(len(lengths)), key=lambda i: lengths[i])
        self.batches = [
            self.sorted_indices[i:i+batch_size]
            for i in range(0, len(self.sorted_indices), batch_size)
        ]
        random.shuffle(self.batches)

    def __iter__(self):
        for batch in self.batches:
            yield batch

    def __len__(self):
        return len(self.batches)

In [None]:
import torch
import torch.nn as nn
from torchvision.models import resnet18
from einops import rearrange

import torch
import torch.nn as nn
from torchvision.models import resnet18
from einops import rearrange
import torch
import torch.nn as nn
from torchvision.models import resnet34
from einops import rearrange

class TimeDistributed(nn.Module):
    def __init__(self, module):
        super().__init__()
        self.module = module

    def forward(self, x):
        # Input: (B, T, C, H, W)
        B, T, C, H, W = x.shape
        x = rearrange(x, 'b t c h w -> (b t) c h w')  # (B*T, C, H, W)
        x = self.module(x)  # Apply module to each frame
        x = rearrange(x, '(b t) c h w -> b t c h w', b=B, t=T)  # (B, T, C', H', W')
        return x

class VisualEncoder(nn.Module):
    def __init__(self, embed_dim=512):
        super().__init__()
        # Enhanced 3D CNN backbone with additional layers
        self.cnn3d = nn.Sequential(
            nn.Conv3d(3, 64, (3, 5, 5), (1, 2, 2), (1, 2, 2)),  # (B, 3, T, 96, 96) -> (B, 64, T, 24, 24)
            nn.BatchNorm3d(64),  # Stabilize activations
            nn.ReLU(),
            nn.Conv3d(64, 128, (3, 3, 3), padding=1),  # (B, 128, T, 24, 24)
            nn.BatchNorm3d(128),
            nn.ReLU(),
            nn.MaxPool3d((1, 2, 2)),  # (B, 128, T, 12, 12)
            nn.Conv3d(128, 256, (3, 3, 3), padding=1),  # (B, 256, T, 12, 12)
            nn.BatchNorm3d(256),
            nn.ReLU(),
            # Additional 3D CNN layer for deeper spatiotemporal features
            nn.Conv3d(256, 256, (3, 3, 3), padding=1),  # (B, 256, T, 12, 12)
            nn.BatchNorm3d(256),
            nn.ReLU(),
            nn.MaxPool3d((1, 2, 2)),  # (B, 256, T, 6, 6)
        )

        # TimeDistributed 2D CNN for per-frame feature enhancement
        self.timedistributed = TimeDistributed(
            nn.Sequential(
                nn.Conv2d(256, 128, kernel_size=3, padding=1),  # (B*T, 256, 6, 6) -> (B*T, 128, 6, 6)
                nn.BatchNorm2d(128),
                nn.ReLU(),
                nn.Conv2d(128, 128, kernel_size=3, padding=1),  # (B*T, 128, 6, 6)
                nn.BatchNorm2d(128),
                nn.ReLU(),
            )
        )

        # Upgrade to ResNet-34 for better feature extraction
        resnet = resnet34(pretrained=False)
        resnet.conv1 = nn.Conv2d(128, 64, kernel_size=7, stride=2, padding=3, bias=False)  # Adjusted for input channels
        self.resnet = nn.Sequential(*list(resnet.children())[:-2])  # Remove avgpool and fc
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, embed_dim)

        # Initialize weights to stabilize training
        self._initialize_weights()

    def _initialize_weights(self):
        # Initialize Conv3d and Conv2d layers with Kaiming normal
        for m in self.modules():
            if isinstance(m, (nn.Conv3d, nn.Conv2d)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm3d) or isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):  # Input: (B, T, 3, 96, 96)
        B, T = x.shape[0], x.shape[1]
        x = x.permute(0, 2, 1, 3, 4)  # (B, 3, T, 96, 96)
        x = self.cnn3d(x)  # (B, 256, T, 6, 6)
        x = rearrange(x, 'b c t h w -> b t c h w')  # (B, T, 256, 6, 6)
        x = self.timedistributed(x)  # (B, T, 128, 6, 6)
        x = rearrange(x, 'b t c h w -> (b t) c h w')  # (B*T, 128, 6, 6)
        x = self.resnet(x)  # (B*T, 512, H', W')
        x = self.pool(x).squeeze(-1).squeeze(-1)  # (B*T, 512)
        x = rearrange(x, '(b t) c -> b t c', b=B, t=T)  # (B, T, 512)
        return self.fc(x)  # (B, T, 512)

# class VisualEncoder(nn.Module):
#     def __init__(self, embed_dim=512):
#         super().__init__()
#         self.cnn3d = nn.Sequential(
#             nn.Conv3d(3, 64, (3, 5, 5), (1, 2, 2), (1, 2, 2)),  # (B, 3, T, 96, 96) -> (B, 64, T, 24, 24)
#             nn.ReLU(),
#             nn.Conv3d(64, 128, (3, 3, 3), padding=1),           # (B, 128, T, 24, 24)
#             nn.ReLU(),
#             nn.MaxPool3d((1, 2, 2)),                            # (B, 128, T, 12, 12)
#             nn.Conv3d(128, 256, (3, 3, 3), padding=1),
#             nn.ReLU(),
#         )
        
#         # Feeding into ResNet
#         resnet = resnet18(pretrained=False)
#         resnet.conv1 = nn.Conv2d(256, 64, kernel_size=7, stride=2, padding=3, bias=False)
#         self.resnet = nn.Sequential(*list(resnet.children())[:-2])  # No avgpool/fc
#         self.pool = nn.AdaptiveAvgPool2d((1, 1))
#         self.fc = nn.Linear(512, embed_dim)

#     def forward(self, x):  # (B, T, 3, 96, 96)
#         B, T = x.shape[0], x.shape[1]
#         x = x.permute(0, 2, 1, 3, 4)       # (B, 3, T, 96, 96)
#         x = self.cnn3d(x)                  # (B, 256, T, H, W)
#         x = rearrange(x, 'b c t h w -> (b t) c h w')
#         x = self.resnet(x)                # (B*T, 512, H', W')
#         x = self.pool(x).squeeze(-1).squeeze(-1)  # (B*T, 512)
#         x = rearrange(x, '(b t) c -> b t c', b=B, t=T)
#         return self.fc(x)  # (B, T, 512)

In [None]:
# import torch.nn as nn

# class SyncVSR(nn.Module):
#     def __init__(self, visual_encoder, vocab_size, audio_dim=768):
#         super().__init__()
#         self.visual_encoder = visual_encoder
#         self.temporal = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=512, nhead=8), num_layers=4)
#         self.ctc_head = nn.Linear(512, vocab_size)
#         self.sync_head = nn.Linear(512, audio_dim)
#         # print("initialized")

#     def forward(self, video):
#         x = self.visual_encoder(video)          # (B, T, 512)
#         x = self.temporal(x.permute(1, 0, 2))   # (T, B, 512)
#         x = x.permute(1, 0, 2)                  # (B, T, 512)
#         ctc_logits = self.ctc_head(x)
#         audio_pred = self.sync_head(x)
#         return ctc_logits, audio_pred
# #"The problem is here we are not receiving (B, T, 512) but rather torch.Size([B, 1, T, 96, 96])"
class SyncVSR(nn.Module):
    def __init__(self, visual_encoder, vocab_size, audio_dim=768):
        super().__init__()
        self.visual_encoder = visual_encoder
        self.pos_encoding = nn.Parameter(torch.randn(500, 512))  # T_max = 500
        encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, dropout=0.1)
        self.temporal = nn.TransformerEncoder(encoder_layer, num_layers=6)
        self.norm = nn.LayerNorm(512)
        self.ctc_head = nn.Linear(512, vocab_size)
        self.sync_head = nn.Linear(512, audio_dim)

    def forward(self, video):  # (B, T, 3, 96, 96)
        x = self.visual_encoder(video)       # (B, T, 512)
        T = x.size(1)
        x = x + self.pos_encoding[:T]        # Add temporal pos encoding
        x = self.temporal(x.permute(1, 0, 2))  # (T, B, 512)
        x = self.norm(x).permute(1, 0, 2)      # (B, T, 512)
        ctc_logits = self.ctc_head(x)         # (B, T, vocab_size)
        audio_pred = self.sync_head(x)        # (B, T, audio_dim)
        return ctc_logits, audio_pred

In [None]:


device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load data
tokenizer = CharTokenizer()
pt_paths = glob.glob('/kaggle/input/sync-vsr-better/preprocessed/lrs2_subset_all2/*/*.pt')
dataset = LRS2PTDataset(pt_paths, tokenizer)

# Get video lengths (number of frames)
video_lengths = []
for path in dataset.pt_paths:
    try:
        data = torch.load(path, map_location="cpu")
        video_lengths.append(data['video'].shape[0])  # T (frames)
    except:
        video_lengths.append(0)  # Just to avoid crashing

sampler = BucketBatchSampler(video_lengths, batch_size=4)
loader = DataLoader(dataset, batch_sampler=sampler, collate_fn=collate_fn)

print("wow")
# Build model
model = SyncVSR(VisualEncoder(), vocab_size=len(tokenizer.stoi)+1).to(device)
print("Entered training")
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
lambda_sync = 1.0
counter = 0

In [None]:
import glob
import json
import os

# Directory for saving checkpoints
checkpoint_dir = '/kaggle/input/checkpoint-syncvsr/pytorch/default/1/checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)

# Initialize start epoch
start_epoch = 0

# Find all metadata files and get the latest
metadata_files = glob.glob(os.path.join(checkpoint_dir, 'metadata_epoch_*.json'))
if metadata_files:
    latest_metadata_file = max(metadata_files, key=os.path.getctime)
    
    # Extract epoch number from filename
    basename = os.path.basename(latest_metadata_file)
    epoch_num = int(basename.split('_')[-1].split('.')[0])

    # Build file paths
    model_path = os.path.join(checkpoint_dir, f'model_epoch_{epoch_num}.pt')
    optimizer_path = os.path.join(checkpoint_dir, f'optimizer_epoch_{epoch_num}.pt')

    # Load model and optimizer state dicts
    model.load_state_dict(torch.load(model_path, map_location=device))
    optimizer.load_state_dict(torch.load(optimizer_path, map_location=device))

    # Load metadata
    with open(latest_metadata_file, 'r') as f:
        metadata = json.load(f)
        start_epoch = metadata['epoch']
    
    print(f"✅ Loaded checkpoint: Epoch {start_epoch}")
else:
    print("⚠️ No checkpoint found. Starting fresh.")

In [None]:
import json
for epoch in range(50):
    print("Entered training")
    model.train()
    print("training")
    total_loss = 0
    vocab_size = len(tokenizer.stoi)+1
    for batch in loader:
        video, audio, labels, label_lens = [x.to(device) for x in batch]
        #video = video.permute(0, 2, 1, 3, 4)  # (B, 1, T, H, W)
        #print("Videodm")
        # AT this point torch.Size([8, 107, 1, 96, 96]) becomes torch.Size([8, 1, 107, 96, 96])
        ctc_logits, audio_preds = model(video)  # (B, T, Vocab), (B, T, 768)
        input_lengths = torch.full((video.size(0),), ctc_logits.size(1), dtype=torch.long).to(device)
        min_len = min(audio_preds.size(1), audio.size(1))
        max_len = max(audio_preds.size(1), audio.size(1))
        if (input_lengths < label_lens).any():
            print(f"input_lengths: {input_lengths},\t label_lens: {label_lens}")
            print("💥 Skipping batch due to input < label length")
            print(f"video.size(0): {video.size(0)}")
            print(f"Video shape: {video.shape}")
            print(f"CTC logits shape: {ctc_logits.shape}")
            print(f"Input lengths: {input_lengths}")
            print(f"Label lengths: {label_lens}")
            print(f"Label shape: {labels.shape}")
            continue
        for i, (il, ll) in enumerate(zip(input_lengths, label_lens)):
            if il < ll:
                print(f"💥 Sample {i}: input length {il.item()} < label length {ll.item()}")
        if audio_preds.size(1) < audio.size(1):
            pad_size = audio.size(1) - audio_preds.size(1)
            audio_preds = F.pad(audio_preds, (0, 0, 0, pad_size))  # pad along time dimension
            #print(f"Padding audio_preds to length {audio.size(1)}")
        elif audio_preds.size(1) > audio.size(1):
            audio_preds = audio_preds[:, :audio.size(1), :]
            #print(f"Cropping audio_preds to length {audio.size(1)}")
        # Losses
        if torch.any(torch.isnan(ctc_logits)) or torch.any(torch.isinf(ctc_logits)):
            print("CTC logits contain NaN or Inf!")
        if (input_lengths < label_lens).any():
            print("Input lengths are shorter than label lengths! That's a CTC death sentence.")
        if torch.any(labels >= vocab_size):  # replace vocab_size with your actual size
            print("Label has index out of vocab bounds!")
            
        ctc = F.ctc_loss(ctc_logits.log_softmax(2).transpose(0, 1), labels, input_lengths, label_lens, blank=0)
        sync = F.mse_loss(audio_preds, audio)
        loss = ctc + lambda_sync * sync
        if torch.isnan(loss) or torch.isinf(loss):
            #print("🔥 Skipping batch with corrupted loss!")
            #print(f"loss = {ctc} + {lambda_sync} * {sync}")
            print()
            #print(f"Video shape: {video.shape}")
            #print(f"Audio shape: {audio.shape}")
            #print(f"Labels: {labels}")
           # print(f"Label lengths: {label_lens}")
            #print(f"Input lengths: {input_lengths}")
            #print(f"CTC logits stats: min={ctc_logits.min().item()}, max={ctc_logits.max().item()}")
            continue
        if counter%100 == 0:
            print(f"loss = {ctc} + {lambda_sync} * {sync}")
            #print("🔥 Skipping batch with corrupted loss!")
            print(f"loss = {ctc} + {lambda_sync} * {sync}")
            print(f"Video shape: {video.shape}")
            print(f"Audio shape: {audio.shape}")
            #print(f"Labels: {labels}")
            print(f"Label lengths: {label_lens}")
            print(f"Input lengths: {input_lengths}")
            print(f"CTC logits stats: min={ctc_logits.min().item()}, max={ctc_logits.max().item()}")
            
        counter+=1
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        optimizer.step()

        total_loss += loss.item()
    

    if (epoch + 1) % 10 == 0:
        model_path = os.path.join(checkpoint_dir, f'model_epoch_{epoch+1}.pt')
        optimizer_path = os.path.join(checkpoint_dir, f'optimizer_epoch_{epoch+1}.pt')
        metadata_path = os.path.join(checkpoint_dir, f'metadata_epoch_{epoch+1}.json')
    
        # Save model and optimizer separately
        torch.save(model.state_dict(), model_path)
        torch.save(optimizer.state_dict(), optimizer_path)
    
        # Save metadata in a readable JSON format
        metadata = {
            'epoch': epoch + 1,
            'loss': total_loss / len(loader)
        }
        with open(metadata_path, 'w') as f:
            json.dump(metadata, f, indent=4)
    
        print(f"Saved model: {model_path}")
        print(f"Saved optimizer: {optimizer_path}")
        print(f"Saved metadata: {metadata_path}")

        
    print(f"[Epoch {epoch+1}] Loss: {total_loss / len(loader):.4f}")

In [None]:
# import json
# import random
# import torch
# import torch.nn.functional as F
# from torch.cuda.amp import GradScaler, autocast
# from torch.optim.lr_scheduler import ReduceLROnPlateau

# def greedy_ctc_decode(logits, tokenizer):
#     """Greedy CTC decoding: select the most likely token at each time step."""
#     # logits: (T, vocab_size)
#     predicted_ids = torch.argmax(logits, dim=-1)  # (T,)
#     # Collapse repeated tokens and remove blanks (0)
#     output = []
#     prev_id = None
#     for id in predicted_ids:
#         if id != 0 and id != prev_id:  # Skip blanks and repeats
#             output.append(id.item())
#         prev_id = id
#     return tokenizer.decode(output)

# def run_inference(model, dataset, tokenizer, device, sample_idx=None):
#     """Run inference on a single sample and return predicted/ground truth text."""
#     model.eval()
#     with torch.no_grad():
#         # Select a random sample if no index is provided
#         idx = sample_idx if sample_idx is not None else random.randint(0, len(dataset) - 1)
#         try:
#             video, _, label = dataset[idx]  # Only need video and label for inference
#             video = video.unsqueeze(0).to(device)  # (1, T, 3, 96, 96)
#             ctc_logits, _ = model(video)  # (1, T, vocab_size)
#             ctc_logits = ctc_logits.squeeze(0)  # (T, vocab_size)
            
#             # Decode prediction
#             pred_text = greedy_ctc_decode(ctc_logits, tokenizer)
#             # Ground truth
#             gt_text = tokenizer.decode(label.tolist())
#             return pred_text, gt_text, idx
#         except Exception as e:
#             print(f"Inference failed for sample {idx}: {e}")
#             return None, None, idx

# # Training setup (assumes prior code for dataset, model, etc.)
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
# scaler = GradScaler()
# scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)

# # Fixed sample for consistent inference (optional, set to None for random)
# fixed_inference_idx = None  # Change to an index (e.g., 0) for a specific sample

# for epoch in range(start_epoch, 50):
#     print(f"Epoch {epoch+1} started")
#     model.train()
#     total_loss = 0
#     vocab_size = len(tokenizer.stoi) + 1
#     for batch_idx, batch in enumerate(loader):
#         video, audio, labels, label_lens = [x.to(device) for x in batch]
#         input_lengths = torch.full((video.size(0),), video.size(1), dtype=torch.long).to(device)

#         optimizer.zero_grad(set_to_none=True)
#         with autocast():
#             ctc_logits, audio_preds = model(video)
#             if torch.any(torch.isnan(ctc_logits)) or torch.any(torch.isinf(ctc_logits)):
#                 print(f"Batch {batch_idx}: NaN/Inf in CTC logits")
#                 continue

#             # Adjust audio lengths
#             min_len = min(audio_preds.size(1), audio.size(1))
#             audio_preds = audio_preds[:, :min_len, :]
#             audio = audio[:, :min_len, :]

#             ctc = F.ctc_loss(ctc_logits.log_softmax(2).transpose(0, 1), labels, input_lengths, label_lens, blank=0)
#             sync = F.mse_loss(audio_preds, audio)
#             loss = ctc + lambda_sync * sync

#         if torch.isnan(loss) or torch.isinf(loss):
#             print(f"Batch {batch_idx}: NaN/Inf Loss: ctc={ctc.item()}, sync={sync.item()}")
#             continue

#         scaler.scale(loss).backward()
#         scaler.unscale_(optimizer)
#         torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
#         scaler.step(optimizer)
#         scaler.update()

#         total_loss += loss.item()
#         if batch_idx % 100 == 0:
#             print(f"Batch {batch_idx}: Loss = {loss.item():.4f} (CTC: {ctc.item():.4f}, Sync: {sync.item():.4f})")
#             print(f"Video shape: {video.shape}, Label lengths: {label_lens}, Input lengths: {input_lengths}")

#     avg_loss = total_loss / len(loader)
#     scheduler.step(avg_loss)

#     # Inference after each epoch
#     pred_text, gt_text, sample_idx = run_inference(model, dataset, tokenizer, device, fixed_inference_idx)
#     if pred_text is not None:
#         print(f"\n[Inference after Epoch {epoch+1}] Sample {sample_idx}")
#         print(f"Predicted: {pred_text}")
#         print(f"Ground Truth: {gt_text}\n")
#     else:
#         print(f"\n[Inference after Epoch {epoch+1}] Failed for sample {sample_idx}\n")

#     # Save checkpoints every 5 epochs
#     if (epoch + 1) % 10 == 0:
#         model_path = os.path.join(checkpoint_dir, f'model_epoch_{epoch+1}.pt')
#         optimizer_path = os.path.join(checkpoint_dir, f'optimizer_epoch_{epoch+1}.pt')
#         metadata_path = os.path.join(checkpoint_dir, f'metadata_epoch_{epoch+1}.json')
        
#         torch.save(model.state_dict(), model_path)
#         torch.save(optimizer.state_dict(), optimizer_path)
#         metadata = {'epoch': epoch + 1, 'loss': avg_loss}
#         with open(metadata_path, 'w') as f:
#             json.dump(metadata, f, indent=4)
        
#         print(f"Saved: {model_path}, {optimizer_path}, {metadata_path}")

#     print(f"[Epoch {epoch+1}] Loss: {avg_loss:.4f}")