In [None]:
import sys
sys.path.append('../')

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchinfo
import torchvision
import matplotlib.pyplot as plt
import numpy as np
from transformers import Wav2Vec2Processor, Wav2Vec2Model
from src.AVDataset import AVDataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

if torch.cuda.is_available():
    for i in range(torch.cuda.device_count()):
       print(torch.cuda.get_device_properties(i).name)

In [None]:
# ==============
# Dataset
# ==============
audio_root = "../data/audios_denoised_16khz"
video_root = "../data/dataset_2drt_video_only"
filter_keyword = "vcv" 
nSubs = [f"sub{str(i).zfill(3)}" for i in range(1, 2)]

dataset = AVDataset(audio_root, video_root, subs=nSubs, filter_keyword=filter_keyword, video_max_frames=None, audio_sampling_rate=16000)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

print(len(dataset))
for i, (waveform, frames, audio_path, video_path) in enumerate(dataloader):
        print("Audio shape:", waveform.shape) 
        print("Video frames shape:", frames.shape)
        print("Audio file:", audio_path)
        print("Video file:", video_path)
        print("===========")
        if i > 0:
            break

In [21]:
# =================================================================
# U-Net for the image denoising with the audio (and FiLM conditioning) -> Perez 2018 / Dey 2022
# =================================================================

# ----------------- FiLM Layer -----------------
class FiLM(nn.Module):
    def __init__(self, in_channels, embedding_dim):
        super(FiLM, self).__init__()
        self.fc = nn.Linear(embedding_dim, in_channels * 2) # out_vector: one for scaling and another one for shifting

    def forward(self, x, audio_embedding):
        gamma, beta = self.fc(audio_embedding).chunk(2, dim=1)  # chunk for splitting the tensor into 2 equal parts ao9ng feat dim
        gamma = gamma.view(x.size(0), -1, 1, 1, 1)  # scaling
        beta = beta.view(x.size(0), -1, 1, 1, 1) # shifting/bias
        return x * (1 + gamma) + beta  # stability when gamma=0 

# ----------------- Audio-Informed 3D U-Net -----------------
class AudioConditioned3DUNet(nn.Module):
    def __init__(self, audio_embedding_dim=768, mode='nearest'):
        super(AudioConditioned3DUNet, self).__init__()
        self.mode = mode

        # Audio Branch
        self.audio_encoder = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
        for param in self.audio_encoder.parameters():
            param.requires_grad = False  # Freeze Wav2Vec2 params since Im using it as a feature extractor

        # Encoder_____________________________________________
        self.enc1 = self._block(1, 32)
        self.enc2 = self._block(32, 64)
        self.enc3 = self._block(64, 128)
        self.enc4 = self._block(128, 256) # this time i try smaller filters

        # Bottleneck__________________________________________
        self.bottleneck = self._block(256, 512)

        # FiLM layers for conditioning -----------------------
        self.film1 = FiLM(32, audio_embedding_dim)
        self.film2 = FiLM(64, audio_embedding_dim)
        self.film3 = FiLM(128, audio_embedding_dim)
        self.film4 = FiLM(256, audio_embedding_dim)

        # Decoder_____________________________________________
        self.dec1 = self._block(512, 256) # To avoid checkerboard artifacts that transposed convolutions can introduce (bilinear interp.)
        self.dec2 = self._block(256, 128) # smoother, lets try :)
        self.dec3 = self._block(128, 64)
        self.dec4 = self._block(64, 32)

        # Final Output Layer__________________________________
        self.final_conv = nn.Conv3d(32, 1, kernel_size=1)

    def _block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
            # nn.ReLU(inplace=True)
            # nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1), # should I use 2 convs as in the original paper?
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, video, audio):
        """
        Args:
            video: Tensor of shape [B, frames, 1, H, W]
            audio: Tensor of shape [B, num_audio_samples]
        """
        video = video.permute(0, 2, 1, 3, 4) # [B, 1, frames, H, W] for conv3d

        # Audio Processing_____________________________________
        with torch.no_grad():
            audio_embedding = self.audio_encoder(audio).last_hidden_state  # [B, seq_length, 768]
            audio_embedding = torch.mean(audio_embedding, dim=1)  # [B, 768] --> mean pooling over the sequence length to aggregate all the embeedings

        # Video Encoder________________________________________
        enc1 = self.enc1(video)  # [B, 32, 2800, 84, 84]
        enc2 = self.enc2(nn.functional.max_pool3d(enc1, 2))  # [B, 64, 1400, 42, 42]
        enc3 = self.enc3(nn.functional.max_pool3d(enc2, 2))  # [B, 128, 700, 21, 21]
        enc4 = self.enc4(nn.functional.max_pool3d(enc3, 2))  # [B, 256, 350, 10, 10]

        # Bottleneck
        bottleneck = self.bottleneck(nn.functional.max_pool3d(enc4, 2))  # [B, 512, 175, 5, 5] also frames reduce

        # Decoder with FiLM conditioning_______________________
        dec1 = self.dec1(nn.functional.interpolate(bottleneck, size=enc4.shape[2:], mode=self.mode)) # upsample the img before convolution (default 'nearest')
        dec1 = self.film4(dec1, audio_embedding) + enc4 # skip connection

        dec2 = self.dec2(nn.functional.interpolate(dec1, size=enc3.shape[2:], mode=self.mode))
        dec2 = self.film3(dec2, audio_embedding) + enc3 # skip connection

        dec3 = self.dec3(nn.functional.interpolate(dec2, size=enc2.shape[2:], mode=self.mode))
        dec3 = self.film2(dec3, audio_embedding) + enc2 # skip connection

        dec4 = self.dec4(nn.functional.interpolate(dec3, size=enc1.shape[2:], mode=self.mode))
        dec4 = self.film1(dec4, audio_embedding) + enc1 # skip connection

        # Final Output Layer___________________________________
        out = self.final_conv(dec4)
        # if out.shape != x.shape:
        #     out = nn.functional.interpolate(out, size=(84, 84), mode='bilinear', align_corners=False)
        return out

In [15]:
unet = AudioConditioned3DUNet(audio_embedding_dim=768, mode='nearest').to(device)

# unet(torch.randn(1, 2800, 1, 84, 84).to(device), torch.randn(1, 619695).to(device))

# print(torchinfo.summary(unet, input_size=((1, 84, 1, 84, 84), (1, 1000)), device=device))

In [24]:
# ===================
# Training unsuperv
# ===================
class PerceptualLoss(nn.Module):
    def __init__(self, vgg_model):
        super(PerceptualLoss, self).__init__()
        self.vgg_layers = vgg_model.features[:16].eval()  # Here I use early layers for texture & structure
        for param in self.vgg_layers.parameters():
            param.requires_grad = False  # freeze VGG weights

    def forward(self, x, y):
        return torch.nn.functional.l1_loss(self.vgg_layers(x), self.vgg_layers(y))
    
def train_model_u(model, dataloader, num_epochs=10, lr=1e-4, device='cuda'):
    # model= nn.DataParallel(model)
    model.to(device)
    
    # Optim / later include scheduler?
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    # Losses
    mse_loss = nn.MSELoss()
    perceptual_loss = PerceptualLoss(torchvision.models.vgg16(pretrained=True)).to(device)
    cosine_loss = nn.CosineEmbeddingLoss()

    # Train
    model.train()
    print("Training...")
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        
        for waveform, video_frames, audio_path, video_path in dataloader:
            video_frames = video_frames.to(device)  # [B, T, 1, 84, 84]
            waveform = waveform.to(device)          # [B, num_samples]
            
            optimizer.zero_grad()
            
            # fwd
            denoised_output = model(video_frames, waveform)
            
            # Loss 1: MSE loss (pixel-level reconstruction)
            loss_mse = mse_loss(denoised_output, video_frames)
            
            # Loss 2: Perceptual loss (structural preservation)
            denoised_resized = denoised_output.view(-1, 1, 84, 84).repeat(1, 3, 1, 1) # reshape for VGG-> [B*T, 1, 84, 84] -> [B*T, 3, 84, 84]
            video_resized = video_frames.view(-1, 1, 84, 84).repeat(1, 3, 1, 1)
            loss_perceptual = perceptual_loss(denoised_resized, video_resized) # maybe I should not do this, takes time!
            
            # Loss 3: Audio Consistency Loss
            # video_features = model.enc1(video_frames)  # Here I take the first encoder randomly
            # audio_features = model.audio_encoder(waveform).last_hidden_state # takes time
            # audio_features = torch.mean(audio_features, dim=1)  # Global audio embedding
            
            # video_features_flat = video_features.view(video_features.size(0), -1)
            # target_labels = torch.ones(video_features_flat.size(0)).to(device) #unsupervided
            loss_audio = 0 #cosine_loss(video_features_flat, audio_features, target_labels)

            # Total loss but idk how to weight the losses!
            total_loss = loss_mse + 0.01 * loss_perceptual + 0.05 * loss_audio
            total_loss.backward()
            optimizer.step()

            epoch_loss += total_loss.item()
        
        avg_loss = epoch_loss / len(dataloader)
        print(f"Epoch [{epoch+1}/{num_epochs}] - Avg. Loss: {avg_loss:.4f}")
    
    return model

In [25]:
trained_model = train_model_u(unet, dataloader, num_epochs=2, lr=1e-4, device=device) # I need the gpus :(

# torch.save(trained_model.state_dict(), "audio_conditioned_3dunet.pth")

OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 MiB. GPU 0 has a total capacity of 10.90 GiB of which 3.25 MiB is free. Including non-PyTorch memory, this process has 10.90 GiB memory in use. Of the allocated memory 10.59 GiB is allocated by PyTorch, and 52.45 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
trained_model.eval()
for waveform, video_frames, audio_path, video_path in dataloader:
    dns_out = trained_model(video_frames, waveform)
    print(dns_out.shape)
    plt.imshow(video_frames[0,5,0], cmap="gray")
    plt.show()
    plt.imshow(dns_out[0, 0, 0].detach().cpu(), cmap="gray")
    plt.show()
    break

In [14]:
\(°^°)/

SyntaxError: invalid character '°' (U+00B0) (2867040718.py, line 1)