In [1]:
import sys
sys.path.append('../')
import os
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.utils.data import DataLoader
import torchinfo
import torchvision
import matplotlib.pyplot as plt
import numpy as np
from transformers import Wav2Vec2Processor, Wav2Vec2Model
from src.AVDataset import AVDataset

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

if torch.cuda.is_available():
    print(os.environ['CUDA_VISIBLE_DEVICES'])
    for i in range(torch.cuda.device_count()):
       print(torch.cuda.get_device_properties(i).name)

  from .autonotebook import tqdm as notebook_tqdm


cuda:0
0,1,2,3
NVIDIA GeForce GTX 1080 Ti
NVIDIA GeForce GTX 1080 Ti
NVIDIA GeForce GTX 1080 Ti
NVIDIA GeForce GTX 1080 Ti


In [2]:
# Dataset
# =======
audio_root = r"../data/audios_denoised_16khz"
video_root = r"../data/dataset_2drt_video_only"
nSubs = [f"sub{str(i).zfill(3)}" for i in range(1, 2)]
keyword = "vcv"
dataset = AVDataset(audio_root=audio_root, 
                    video_root=video_root, 
                    subs=nSubs, 
                    filter_keyword=keyword, 
                    video_max_frames=None, # batch
                    audio_sampling_rate=16000,
                    frame_skip=1)

print("Number of pairs:", len(dataset))

dataloader = DataLoader(dataset, batch_size=2, shuffle=False, pin_memory=True, collate_fn=AVDataset.collate) # Batch / Collation

for i, (waveform, frames, audio_path, video_path) in enumerate(dataloader):
        print("Audio shape:", waveform.shape) 
        print("Video frames shape:", frames.shape)
        print("Audio file:", audio_path)
        print("Video file:", video_path)
        print("===========")
        if i > 1:
            break
    
"""
Audio shape: torch.Size([5, 619695])
Video frames shape: torch.Size([5, 3216, 1, 84, 84]) ... etc
"""

Number of pairs: 6
Audio shape: torch.Size([2, 619695])
Video frames shape: torch.Size([2, 3216, 1, 84, 84])
Audio file: ['../data/audios_denoised_16khz/sub001/sub001_2drt_01_vcv1_r1_video.wav', '../data/audios_denoised_16khz/sub001/sub001_2drt_02_vcv2_r2_video.wav']
Video file: ['../data/dataset_2drt_video_only/sub001/2drt/video/sub001_2drt_01_vcv1_r1_video.mp4', '../data/dataset_2drt_video_only/sub001/2drt/video/sub001_2drt_02_vcv2_r2_video.mp4']
Audio shape: torch.Size([2, 569911])
Video frames shape: torch.Size([2, 2960, 1, 84, 84])
Audio file: ['../data/audios_denoised_16khz/sub001/sub001_2drt_02_vcv2_r1_video.wav', '../data/audios_denoised_16khz/sub001/sub001_2drt_03_vcv3_r2_video.wav']
Video file: ['../data/dataset_2drt_video_only/sub001/2drt/video/sub001_2drt_02_vcv2_r1_video.mp4', '../data/dataset_2drt_video_only/sub001/2drt/video/sub001_2drt_03_vcv3_r2_video.mp4']
Audio shape: torch.Size([2, 564710])
Video frames shape: torch.Size([2, 2932, 1, 84, 84])
Audio file: ['../data/a

'\nAudio shape: torch.Size([5, 619695])\nVideo frames shape: torch.Size([5, 3216, 1, 84, 84]) ... etc\n'

In [17]:
# Simple Unet for image
# =====================

class Simple_UNet(nn.Module):
    def __init__(self, base_channels=32):
        super(Simple_UNet, self).__init__()
        
        # Encoder
        self.enc1 = nn.Sequential(
            nn.Conv2d(1, base_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(base_channels, base_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.enc2 = nn.Sequential(
            nn.MaxPool2d(2),
            nn.Conv2d(base_channels, base_channels*2, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(base_channels*2, base_channels*2, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        
        # Bridge
        self.bridge = nn.Sequential(
            nn.MaxPool2d(2),
            nn.Conv2d(base_channels*2, base_channels*4, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(base_channels*4, base_channels*4, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        
        # Decoder
        self.dec1 = nn.Sequential(
            nn.ConvTranspose2d(base_channels*4, base_channels*2, kernel_size=2, stride=2), # can be a bilinear interpolation too
            nn.Conv2d(base_channels*2, base_channels*2, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(base_channels*2, base_channels*2, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.dec2 = nn.Sequential(
            nn.ConvTranspose2d(base_channels*2, base_channels, kernel_size=2, stride=2),
            nn.Conv2d(base_channels, base_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(base_channels, base_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        
        # Output
        self.final = nn.Conv2d(base_channels, 1, kernel_size=1)
        
    def forward(self, x):
        # x shape: (batch_size, num_frames, channels, height, width)#
        batch_size, num_frames = x.shape[0], x.shape[1]
        
        # Reshape to process all frames at once -> (batch_size*num_frames, channels, height, width)
        x = x.reshape(-1, x.shape[2], x.shape[3], x.shape[4])
        
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        
        # Bridge
        b = self.bridge(e2)
        
        # Decoder
        d1 = self.dec1(b)
        d2 = self.dec2(d1)
        
        # Output
        out = self.final(d2)
        
        # Reshape back to include frames dimension
        out = out.view(batch_size, num_frames, -1, out.shape[2], out.shape[3])
        
        return out

In [18]:
# First AV denosiing model
# ========================

class AVModel(nn.Module):
    def __init__(self, base_channels=32):
        super().__init__()
        self.unet = Simple_UNet(base_channels=base_channels)
        
        # Projection layers for audio features
        self.audio_proj = nn.Sequential(
            nn.Linear(768, 256),  # 768 is wav2vec2 feature dim
            nn.ReLU(inplace=True),
            nn.Linear(256, 84*84)  # Match video frame size
        )
        
    def forward(self, audio_features, video_frames):
        # Process video through UNet
        video_output = self.unet(video_frames)  # [B, T, 1, H, W]
        
        # Project audio features to match video dimensions
        B, T, _ = audio_features.shape
        audio_proj = self.audio_proj(audio_features)  # [B, T, H*W]
        audio_proj = audio_proj.view(B, T, 1, 84, 84)  # Match video dimensions

        return video_output, audio_proj
    

class AVModel2(nn.Module):
    def __init__(self, base_channels=32):
        super().__init__()
        
        # Projection layers for audio features
        self.audio_proj = nn.Sequential(
            nn.Linear(768, 256),  # 768 is wav2vec2 feature dim
            nn.ReLU(inplace=True),
            nn.Linear(256, 84*84)  # Match video frame size
        )
        
    def forward(self, audio_features):
        
        # Project audio features to match video dimensions
        B, T, _ = audio_features.shape
        audio_proj = self.audio_proj(audio_features)  # [B, T, H*W]
        audio_proj = audio_proj.view(B, T, 1, 84, 84)  # Match video dimensions

        return audio_proj
    

class AVModel3(nn.Module):
    def __init__(self, base_channels=32):
        super().__init__()
        self.unet = Simple_UNet(base_channels=base_channels)
        
        # Projection layers for audio features
        self.audio_proj = nn.Sequential(
            nn.Linear(768, 256),  # 768 is wav2vec2 feature dim
            nn.ReLU(inplace=True),
            nn.Linear(256, 84*84)  # Match video frame size
        )
        
    def forward(self, audio_features, video_frames):
        # Process video through UNet
        window_size = 5
        T = video_frames.shape[1]
        overlap = 1
        stride = window_size - overlap
        video_outputs = []

        # Process frames in sliding windows
        for start_idx in range(0, T, stride):
            end_idx = min(start_idx + window_size, T)
            

            window_output = self.unet(video_frames[:, start_idx:end_idx])
            
            if start_idx == 0:
                video_outputs.append(window_output[:, :-overlap] if end_idx < T else window_output)
            else:
                video_outputs.append(window_output[:, overlap:-overlap] if end_idx < T else window_output[:, overlap:])
            
            # Free memory
            torch.cuda.empty_cache()

        # Combine all windows
        video_output = torch.cat(video_outputs, dim=1)
        
        # Project audio features to match video dimensions
        B, T, _ = audio_features.shape
        audio_proj = self.audio_proj(audio_features)  # [B, T, H*W]
        audio_proj = audio_proj.view(B, T, 1, 84, 84)  # Match video dimensions

        return video_output, audio_proj

In [20]:
# Training 1 DataParallel (discouraged)
# ========
torch.backends.cuda.max_memory_allocated = 0  # Reset memory stats
torch.cuda.empty_cache()  # Clear GPU cache
adev = torch.device('cuda:0')
mdev = torch.device('cuda:0')
m_ids = [1, 2, 0, 3]

wav2vec2_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")
wav2vec2 = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base").to(adev)
wav2vec2.eval()

model = AVModel3()
# model= nn.DataParallel(model, device_ids=m_ids) #, output_device=3)
model.to(mdev)
model.train()
    
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01) # more consistent regularization 
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

# Loss functions
mse_loss = nn.MSELoss()
cosine_loss = nn.CosineSimilarity(dim=-1)

def train_step(audio_features, video_frames):
    # optimizer.zero_grad()
    for param in model.parameters():
        param.grad = None
    
    # Forward pass
    video_output, audio_proj = model(audio_features, video_frames)
    
    # Compute losses
    # MSE between UNet output and original frames
    # print(video_output.shape)
    reconstruction_loss = 1 #mse_loss(video_output, video_frames)
    
    # Cosine similarity between audio projection and video features
    # Reshape tensors for cosine similarity
    # v_flat = video_output.view(video_output.shape[0], video_output.shape[1], -1)
    # a_flat = audio_proj.view(audio_proj.shape[0], audio_proj.shape[1], -1)
    alignment_loss = 1 #-cosine_loss(v_flat, a_flat).mean()
    
    # Combined loss
    total_loss = reconstruction_loss + 0.5 * alignment_loss
    
    # Backward pass
    # total_loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # mitigate exploding gradients
    optimizer.step()
    
    # return total_loss.item(), reconstruction_loss.item(), alignment_loss.item()
    return total_loss, reconstruction_loss, alignment_loss

# Training loop
n_epochs = 50
best_loss = float('inf')

for epoch in range(n_epochs):
    epoch_losses = []
    
    for batch_idx, (waveform, frames, _, _) in enumerate(dataloader):
        # Process audio through wav2vec2
        with torch.no_grad():
            inputs = wav2vec2_processor(waveform.numpy(), sampling_rate=16000, return_tensors="pt", padding=True)
            inputs = {k: v.to(adev) for k, v in inputs.items()}
            audio_features = wav2vec2(**inputs).last_hidden_state
        
        # Move video frames to device
        frames = frames.to(mdev)
        
        # Train step
        total_loss, rec_loss, align_loss = train_step(audio_features, frames)
        epoch_losses.append(total_loss)
        
        if batch_idx % 10 == 0:
            print(f"Epoch {epoch}, Batch {batch_idx}: Total Loss: {total_loss:.4f}, Rec Loss: {rec_loss:.4f}, Align Loss: {align_loss:.4f}")
    
    # Update learning rate
    avg_loss = sum(epoch_losses) / len(epoch_losses)
    scheduler.step(avg_loss)
    
    # Save best model
    # if avg_loss < best_loss:
    #     best_loss = avg_loss
    #     torch.save(model.state_dict(), 'best_av_model.pth')

OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacity of 10.90 GiB of which 11.25 MiB is free. Including non-PyTorch memory, this process has 10.89 GiB memory in use. Of the allocated memory 10.59 GiB is allocated by PyTorch, and 142.43 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [9]:
# Training 1 DataParallel (discouraged) ----------------
# ========
def process_video_in_chunks(unet, video_frames, chunk_size=100):
    """
    Process long video sequences in smaller chunks to avoid memory issues.
    Args:
        audio_features: Tensor [B, T, D]
        video_frames: Tensor [B, T, C, H, W]
        chunk_size: Number of frames to process at once
    """
    B, T = video_frames.shape[:2]
    outputs = []
    
    for start_idx in range(0, T, chunk_size):
        end_idx = min(start_idx + chunk_size, T)
        print(video_frames[:, start_idx:end_idx].shape)
        # Process chunk
        v_out = unet(video_frames[:, start_idx:end_idx])
        
        outputs.append(v_out)
        
        # Free memory
        torch.cuda.empty_cache()
    
    # Concatenate chunks along temporal dimension
    video_output = torch.cat(outputs, dim=1)
    
    return video_output

torch.backends.cuda.max_memory_allocated = 0  # Reset memory stats
torch.cuda.empty_cache()  # Clear GPU cache
adev = torch.device('cuda:0')
mdev = torch.device('cuda:1')
m_ids = [1, 0, 2, 3]

wav2vec2_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")
wav2vec2 = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base").to(adev)
wav2vec2.eval()

amod = AVModel2()
# model= nn.DataParallel(model, device_ids=m_ids) #, output_device=3)
amod.to(adev)

model = Simple_UNet(base_channels=32)
model = nn.DataParallel(model, device_ids=m_ids) #, output_device=3)
model.to(mdev)
model.train()
    
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01) # more consistent regularization 
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

# Loss functions
mse_loss = nn.MSELoss()
cosine_loss = nn.CosineSimilarity(dim=-1)

def train_step(audio_features, video_frames):
    # optimizer.zero_grad()
    for param in model.parameters():
        param.grad = None
    
    # Forward pass
    
    video_output = process_video_in_chunks(model, video_frames)
    audio_proj = amod(audio_features)
    
    # Compute losses
    # MSE between UNet output and original frames
    reconstruction_loss = mse_loss(video_output, video_frames)
    
    # Cosine similarity between audio projection and video features
    # Reshape tensors for cosine similarity
    v_flat = video_output.view(video_output.shape[0], video_output.shape[1], -1)
    a_flat = audio_proj.view(audio_proj.shape[0], audio_proj.shape[1], -1)
    alignment_loss = -cosine_loss(v_flat, a_flat).mean()
    
    # Combined loss
    total_loss = reconstruction_loss + 0.5 * alignment_loss
    
    # Backward pass
    total_loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # mitigate exploding gradients
    optimizer.step()
    
    return total_loss.item(), reconstruction_loss.item(), alignment_loss.item()

# Training loop
n_epochs = 50
best_loss = float('inf')

for epoch in range(n_epochs):
    epoch_losses = []
    
    for batch_idx, (waveform, frames, _, _) in enumerate(dataloader):
        # Process audio through wav2vec2
        with torch.no_grad():
            inputs = wav2vec2_processor(waveform.numpy(), sampling_rate=16000, return_tensors="pt", padding=True)
            inputs = {k: v.to(adev) for k, v in inputs.items()}
            audio_features = wav2vec2(**inputs).last_hidden_state
        
        # Move video frames to device
        frames = frames.to(mdev)
        
        # Train step
        total_loss, rec_loss, align_loss = train_step(audio_features, frames)
        epoch_losses.append(total_loss)
        
        if batch_idx % 10 == 0:
            print(f"Epoch {epoch}, Batch {batch_idx}: Total Loss: {total_loss:.4f}, Rec Loss: {rec_loss:.4f}, Align Loss: {align_loss:.4f}")
    
    # Update learning rate
    avg_loss = sum(epoch_losses) / len(epoch_losses)
    scheduler.step(avg_loss)
    
    # Save best model
    # if avg_loss < best_loss:
    #     best_loss = avg_loss
    #     torch.save(model.state_dict(), 'best_av_model.pth')

OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacity of 10.90 GiB of which 17.25 MiB is free. Including non-PyTorch memory, this process has 10.88 GiB memory in use. Of the allocated memory 10.63 GiB is allocated by PyTorch, and 52.49 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [16]:
# Training 2 (sub batch, reducing mem usage)
# ========

torch.backends.cuda.max_memory_allocated = 0  # Reset memory stats
torch.cuda.empty_cache()  # Clear GPU cache
adev = torch.device('cuda:0')
mdev = torch.device('cuda:0')
m_ids = [0,1,2,3]

wav2vec2_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")
wav2vec2 = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base").to(adev)
wav2vec2.eval()

model = AVModel()
model= nn.DataParallel(model, device_ids=m_ids) #, output_device=3)
model.to(mdev)
model.train()

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01) # more consistent regularization 
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

# Loss functions
mse_loss = nn.MSELoss()
cosine_loss = nn.CosineSimilarity(dim=-1)

# Modified training loop with memory optimizations
n_epochs = 50
best_loss = float('inf')
grad_accumulation_steps = 2  # Gradient accumulation to reduce memory usage

for epoch in range(n_epochs):
    epoch_losses = []
    # optimizer.zero_grad()
    for param in model.parameters():
        param.grad = None
    
    for batch_idx, (waveform, frames, _, _) in enumerate(dataloader):
        # Process in smaller chunks if needed
        batch_size = waveform.shape[0]
        sub_batch_size = 1  # Process 2 samples at a time
        
        total_loss = 0
        rec_loss = 0
        align_loss = 0
        
        for i in range(0, batch_size, sub_batch_size):
            end_idx = min(i + sub_batch_size, batch_size)
            sub_waveform = waveform[i:end_idx]
            sub_frames = frames[i:end_idx].to(mdev)
            
            # Process audio through wav2vec2
            with torch.no_grad():
                inputs = wav2vec2_processor(sub_waveform.numpy(), sampling_rate=16000, return_tensors="pt", padding=True)
                inputs = {k: v.to(adev) for k, v in inputs.items()}
                audio_features = wav2vec2(**inputs).last_hidden_state
            
            # Forward pass
            video_output, audio_proj = model(audio_features, sub_frames)
            
            # Compute losses
            batch_rec_loss = mse_loss(video_output, sub_frames)
            
            # Reshape tensors for cosine similarity
            v_flat = video_output.view(video_output.shape[0], video_output.shape[1], -1)
            a_flat = audio_proj.view(audio_proj.shape[0], audio_proj.shape[1], -1)
            batch_align_loss = -cosine_loss(v_flat, a_flat).mean()
            
            # Combined loss
            batch_total_loss = batch_rec_loss + 0.5 * batch_align_loss
            
            # Scale loss and backward pass
            scaled_loss = batch_total_loss / grad_accumulation_steps
            scaled_loss.backward()
            
            # Accumulate losses
            total_loss += batch_total_loss.item()
            rec_loss += batch_rec_loss.item()
            align_loss += batch_align_loss.item()
            
            # Clear memory
            del video_output, audio_proj, v_flat, a_flat
            torch.cuda.empty_cache()
        
        # Update weights after accumulating gradients
        if (batch_idx + 1) % grad_accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            optimizer.zero_grad()
        
        # Average losses over sub-batches
        total_loss /= (batch_size / sub_batch_size)
        rec_loss /= (batch_size / sub_batch_size)
        align_loss /= (batch_size / sub_batch_size)
        
        epoch_losses.append(total_loss)
        
        if batch_idx % 10 == 0:
            print(f"Epoch {epoch}, Batch {batch_idx}: Total Loss: {total_loss:.4f}, Rec Loss: {rec_loss:.4f}, Align Loss: {align_loss:.4f}")
    
    # Update learning rate
    avg_loss = sum(epoch_losses) / len(epoch_losses)
    scheduler.step(avg_loss)

OutOfMemoryError: CUDA out of memory. Tried to allocate 2.71 GiB. GPU 0 has a total capacity of 10.90 GiB of which 737.25 MiB is free. Including non-PyTorch memory, this process has 10.18 GiB memory in use. Of the allocated memory 9.63 GiB is allocated by PyTorch, and 284.03 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
def process_with_sliding_window(model, audio_features, video_frames, window_size=100, overlap=10):
    """
    Process video frames using sliding window approach to reduce memory usage.
    
    Args:
        model: AVModel instance
        audio_features: tensor of shape [B, T, 768]
        video_frames: tensor of shape [B, T, C, H, W]
        window_size: number of frames to process at once
        overlap: number of overlapping frames between windows
    """
    B, T, C, H, W = video_frames.shape
    stride = window_size - overlap
    
    # Initialize output tensors
    video_outputs = []
    audio_projs = []
    
    # Process each batch independently
    for b in range(B):
        batch_video_outputs = []
        batch_audio_projs = []
        
        # Process temporal chunks with sliding window
        for start_idx in range(0, T, stride):
            end_idx = min(start_idx + window_size, T)
            
            # Extract temporal chunk
            video_chunk = video_frames[b:b+1, start_idx:end_idx]
            audio_chunk = audio_features[b:b+1, start_idx:end_idx]
            
            # Process chunk through model
            with torch.amp.autocast("cuda"):  # Use mixed precision
                v_out, a_proj = model(audio_chunk, video_chunk)
            
            # Store results
            if start_idx == 0:
                # For first chunk, keep all frames
                batch_video_outputs.append(v_out[:, :-overlap] if end_idx < T else v_out)
                batch_audio_projs.append(a_proj[:, :-overlap] if end_idx < T else a_proj)
            else:
                # For subsequent chunks, remove overlapping frames from the start
                batch_video_outputs.append(v_out[:, overlap:-overlap] if end_idx < T else v_out[:, overlap:])
                batch_audio_projs.append(a_proj[:, overlap:-overlap] if end_idx < T else a_proj[:, overlap:])
        
        # Concatenate temporal chunks
        video_outputs.append(torch.cat(batch_video_outputs, dim=1))
        audio_projs.append(torch.cat(batch_audio_projs, dim=1))
        
        # Clear memory
        torch.cuda.empty_cache()
    
    # Combine all batches
    video_output = torch.cat(video_outputs, dim=0)
    audio_proj = torch.cat(audio_projs, dim=0)
    
    return video_output, audio_proj

# Modified training step
def train_step(model, audio_features, video_frames, optimizer, mse_loss, cosine_loss):
    optimizer.zero_grad()
    
    # Process with sliding window
    video_output, audio_proj = process_with_sliding_window(
        model, 
        audio_features, 
        video_frames,
        window_size=150,  # Adjust based on your GPU memory
        overlap=10
    )
    
    # Compute losses
    reconstruction_loss = mse_loss(video_output, video_frames)
    
    # Reshape tensors for cosine similarity
    v_flat = video_output.view(video_output.shape[0], video_output.shape[1], -1)
    a_flat = audio_proj.view(audio_proj.shape[0], audio_proj.shape[1], -1)
    alignment_loss = -cosine_loss(v_flat, a_flat).mean()
    
    # Combined loss
    total_loss = reconstruction_loss + 0.5 * alignment_loss
    
    # Backward pass
    total_loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()
    
    return total_loss.item(), reconstruction_loss.item(), alignment_loss.item()

In [27]:
torch.backends.cuda.matmul.allow_tf32 = False

In [28]:
torch.cuda.empty_cache()

In [39]:
print(torch.cuda.memory_summary(1))

|                  PyTorch CUDA memory summary, device ID 1                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 32           |        cudaMalloc retries: 32        |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |   1324 MiB |  10870 MiB | 134641 MiB | 133317 MiB |
|       from large pool |   1307 MiB |  10865 MiB | 134604 MiB | 133296 MiB |
|       from small pool |     16 MiB |     16 MiB |     37 MiB |     20 MiB |
|---------------------------------------------------------------------------|
| Active memory         |   1324 MiB |  10870 MiB | 134641 MiB | 133317 MiB |
|       from large pool |   1307 MiB |  10865 MiB | 134604 MiB | 133296 MiB |
|       from small pool |     16 MiB |     16 MiB |     37 MiB |     20 MiB |
|---------------------------------------------------------------