In [1]:
# %pip list
# %pip install torchaudio

In [2]:
import cv2
import glob
import torch
import torchaudio
import numpy as np
import torch.nn as nn
from tqdm import tqdm
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
from video_loader import VideoDataset
from torch.utils.data import DataLoader
from torchvision.transforms.functional import to_pil_image
from torchvision.transforms import Compose, Resize, Normalize, ToTensor

  warn(


In [3]:
class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost, distance_dim):
        super(VectorQuantizer, self).__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.commitment_cost = commitment_cost
        self.distance_dim = distance_dim
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.embedding.weight.data.uniform_(-1.0 / num_embeddings, 1.0 / num_embeddings)

    def forward(self, inputs):
        # Flatten inputs except for the embedding dimension
        flat_inputs = inputs.view(-1, self.embedding_dim)
        
        # Calculate distances
        distances = (torch.sum(flat_inputs ** 2, dim=1, keepdim=True) 
                    + torch.sum(self.embedding.weight ** 2, dim=1)
                    - 2 * torch.matmul(flat_inputs, self.embedding.weight.t()))
        
        # Reshape distances back to the original shape with depth, height, and width dimensions
        distances = distances.view(*inputs.shape[:-1], -1)
        
        # Encoding
        encoding_indices = torch.argmin(distances, dim=-1)
        encodings = F.one_hot(encoding_indices, self.num_embeddings).float()
        
        # Use encodings to gather embeddings and compute loss
        quantized = torch.matmul(encodings, self.embedding.weight).view(inputs.shape)
        e_latent_loss = F.mse_loss(quantized.detach(), inputs)
        q_latent_loss = F.mse_loss(quantized, inputs.detach())
        loss = q_latent_loss + self.commitment_cost * e_latent_loss
        
        # Preserve gradients through the quantization
        quantized = inputs + (quantized - inputs).detach()
        return loss, quantized

In [4]:
class VQVAE(nn.Module):
    def __init__(self, input_channels, hidden_channels, num_embeddings, embedding_dim, commitment_cost):
        super(VQVAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(input_channels, hidden_channels, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_channels, embedding_dim, kernel_size=4, stride=2, padding=1),
            nn.ReLU()
        )
        self.quantizer = VectorQuantizer(num_embeddings, embedding_dim, commitment_cost, distance_dim=2)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(embedding_dim, hidden_channels, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(hidden_channels, input_channels, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()  # Assuming input is normalized to [0, 1]
        )

    def forward(self, x):
        z = self.encoder(x)
        z = z.permute(0, 2, 3, 1).contiguous()
        loss, quantized = self.quantizer(z)
        quantized = quantized.permute(0, 3, 1, 2).contiguous()
        recon_x = self.decoder(quantized)
        return recon_x, loss

# Example model initialization
# vqvae = VQVAE(input_channels=3, hidden_channels=128, num_embeddings=512, embedding_dim=64, commitment_cost=0.25)

In [5]:
class VQVAE3D(nn.Module):
    def __init__(self, input_channels, hidden_channels, num_embeddings, embedding_dim, commitment_cost):
        super(VQVAE3D, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv3d(input_channels, hidden_channels, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv3d(hidden_channels, embedding_dim, kernel_size=4, stride=2, padding=1),
            nn.ReLU()
        )
        self.quantizer = VectorQuantizer(num_embeddings, embedding_dim, commitment_cost, distance_dim=3)
        self.decoder = nn.Sequential(
            nn.ConvTranspose3d(embedding_dim, hidden_channels, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose3d(hidden_channels, input_channels, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()  # Assuming input is normalized to [0, 1]
        )

    def forward(self, x):
        z = self.encoder(x)
        z = z.permute(0, 2, 3, 4, 1).contiguous()  # Adjust for 3D
        loss, quantized = self.quantizer(z)
        quantized = quantized.permute(0, 4, 1, 2, 3).contiguous()  # Adjust back after quantization
        recon_x = self.decoder(quantized)
        return recon_x, loss, z
    
    def training_step(self, x, optimizer):
        optimizer.zero_grad()
        recon_x, loss, _ = self(x) # Get loss and compute gradients to update the model
        loss.backward()
        optimizer.step()
        return loss.item()

# Example model initialization
# vqvae3d = VQVAE3D(input_channels=3, hidden_channels=128, num_embeddings=512, embedding_dim=64, commitment_cost=0.25)

In [6]:
class AudioDecoder(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, output_length):
        super(AudioDecoder, self).__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.output_length = output_length

        # Define the audio decoder architecture
        self.decoder = nn.Sequential(
            nn.Linear(65536, 512),  # Adjusted from 64 to 65536 to match the flattened input
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, self.output_length),
            nn.Tanh()  # Assuming the output audio is normalized between -1 and 1
        )

    
    def forward(self, encoded_video):
        # encoded_video shape: (batch_size, num_embeddings, embedding_dim)
        # Flatten the encoded video to match the decoder input
        batch_size = encoded_video.shape[0]
        encoded_video_flat = encoded_video.view(batch_size, -1)
        
        # Decode to audio waveform
        audio_output = self.decoder(encoded_video_flat)
        return audio_output

In [7]:
class VQVAE3DWithAudio(nn.Module):
    def __init__(self, input_channels, hidden_channels, num_embeddings, embedding_dim, commitment_cost, audio_output_length):
        super(VQVAE3DWithAudio, self).__init__()
        # VQ-VAE components
        self.vqvae3d = VQVAE3D(input_channels, hidden_channels, num_embeddings, embedding_dim, commitment_cost)
        
        # Audio Decoder component
        self.audio_decoder = AudioDecoder(num_embeddings, embedding_dim, audio_output_length)
    
    def forward(self, x):
        recon_x, loss, z = self.vqvae3d(x)
        
        # For simplicity, let's use the quantized output directly from the encoder as input to the audio decoder
        # A more sophisticated approach might involve selecting or processing specific parts of the encoding
        _, quantized = self.vqvae3d.quantizer(z)
        
        # Generate audio from the quantized video encoding
        audio_output = self.audio_decoder(quantized)
        
        return recon_x, audio_output, loss

In [8]:
# # Initialize the enhanced VQ-VAE model with audio decoding capability
# vqvae3d_with_audio = VQVAE3DWithAudio(
#     input_channels=3, 
#     hidden_channels=128, 
#     num_embeddings=512, 
#     embedding_dim=64, 
#     commitment_cost=0.25, 
#     audio_output_length=48000  # For 1 second of audio at 48kHz
# )

# # Example forward pass with a dummy video input
# # Assuming the input video is a 5D tensor: (batch_size, channels, depth, height, width)
# # For example, a batch of 4 videos, each with 3 color channels, 16 frames, and 64x64 resolution
# dummy_video_input = torch.randn(4, 3, 16, 64, 64)
# reconstructed_video, generated_audio, loss = vqvae3d_with_audio(dummy_video_input)

# # Here, `reconstructed_video` is the reconstructed video output,
# # `generated_audio` is the audio waveform generated from the video,
# # and `loss` is the loss from the VQ-VAE quantization process.

# generated_audio_int16 = (generated_audio * 32767).short()

# generated_audio_int16 = generated_audio_int16.unsqueeze(0)
# generated_audio_int16 = generated_audio_int16.squeeze()

# # Save the audio directly as a PyTorch tensor without converting to a numpy array
# torchaudio.save(f"generated_audio{len(glob.glob('*'))}.wav", generated_audio_int16, 48000)

In [9]:
def train_vqvae3d(model, data_loader, epochs=10, lr=2e-4, accumulation_steps=2, max_norm=1, plot_save_path=None, model_save_path='vqvae3d_model.pth'):
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    loss_history = []  # Store loss over epochs
    
    for epoch in range(epochs):
        total_loss = 0
        optimizer.zero_grad()  # Reset gradients; do it once at the start
        
        for batch_idx, data in enumerate(data_loader):
            data = data.to(device)
            recon_x, loss, _ = model(data)
            loss = loss / accumulation_steps  # Normalize our loss (if averaged)
            loss.backward()  # Accumulate gradients

            if (batch_idx + 1) % accumulation_steps == 0:  # Perform optimization every 'accumulation_steps'
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_norm)
                optimizer.step()
                optimizer.zero_grad()

            total_loss += loss.item()

            # Print the current iteration, overwrite the previous text
            print(f'\rEpoch {epoch}, Iteration {batch_idx + 1}/{len(data_loader)}, Loss: {total_loss}', end='', flush=True)
        
        epoch_loss = total_loss / len(data_loader)
        loss_history.append(epoch_loss)
        print(f"Epoch {epoch+1}, Loss: {epoch_loss}")
        
    if not os.path.isdir(os.path.dirname(model_save_path)):
        os.makedirs(os.path.dirname(model_save_path))
    torch.save(model.state_dict(), model_save_path)
    print(f"Model saved to {model_save_path}")
        
    plt.plot(range(1, epochs + 1), loss_history, marker='o')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss Over Epochs')
    if plot_save_path:
        plt.savefig(plot_save_path)  # Save the plot if a path is provided
    plt.show()

In [10]:
width, height = 256, 144
videos_dir = "/data/ai_club/team_13_2023-24/videos/train"

In [11]:
# Original ImageNet stats for RGB images
mean_rgb = [0.485, 0.456, 0.406]
std_rgb = [0.229, 0.224, 0.225]

# Approximate stats for grayscale
mean_gray = sum(mean_rgb) / len(mean_rgb)
std_gray = sum(std_rgb) / len(std_rgb)

# Define transformations
transformations = Compose([
    Resize((width, height)),  # Resize to a common size
    ToTensor(),
    Normalize(mean=mean_gray, std=std_gray),  # Imagenet stats
])

# Create VideoDataset
video_dataset = VideoDataset(videos_dir, transform=transformations)

# Check dataset
print(f"Sample 0 shape: {video_dataset[0].shape}")

Sample 0 shape: torch.Size([1, 300, 256, 144])


In [12]:
# Device training on
device = "cuda"

# Create the model
vqvae3d = VQVAE3D(input_channels=1, hidden_channels=128, num_embeddings=512, embedding_dim=64, commitment_cost=0.25).to(device)

# Create DataLoader
data_loader = DataLoader(video_dataset, batch_size=2, shuffle=True)

# Hyperparameters
epochs = 32
learning_rate = 2e-4
accumulation_steps = 2

train_vqvae3d(vqvae3d, data_loader, epochs=epochs, lr=learning_rate, accumulation_steps=accumulation_steps, plot_save_path="./loss.png")

Epoch 0, Iteration 2/200, Loss: 0.031908092088997364

KeyboardInterrupt: 