In [None]:

# 1.	Generator: Combines a linear mapping from a latent space (z) with a transformer encoder to capture temporal relationships between video frames. Uses deconvolution layers to upscale embeddings into high-resolution video frames.
# 2.	Discriminator: Processes each frame with convolutional layers.
# o	Aggregates features across frames to evaluate both spatial quality and temporal consistency. Uses an adversarial training loop with a loss function (e.g., GAN loss) to train both the generator and discriminator.

# Future Goal:	Using a more sophisticated transformer architecture (e.g., attention masking for sequence positions). Adding perceptual or temporal loss for enhanced video realism. Incorporating motion dynamics models for better temporal coherence.



import torch
import torch.nn as nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer

class Generator(nn.Module):
    def __init__(self, img_channels, latent_dim, seq_len, embed_dim):
        super(Generator, self).__init__()
        self.seq_len = seq_len
        self.latent_dim = latent_dim
        self.embed_dim = embed_dim

        # Linear mapping for latent vector to sequence embedding
        self.fc = nn.Linear(latent_dim, seq_len * embed_dim)

        # Transformer Encoder for temporal consistency
        encoder_layer = TransformerEncoderLayer(d_model=embed_dim, nhead=8, dim_feedforward=512)
        self.transformer = TransformerEncoder(encoder_layer, num_layers=4)

        # Deconvolution layers for upscaling to high-resolution video frames
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(embed_dim, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, img_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        # Map latent vector to sequence embedding
        batch_size = z.size(0)
        sequence_embed = self.fc(z).view(batch_size, self.seq_len, self.embed_dim)

        # Apply transformer to ensure temporal consistency
        temporal_output = self.transformer(sequence_embed)

        # Reshape and pass through deconvolution layers for video frame generation
        frames = temporal_output.view(batch_size * self.seq_len, self.embed_dim, 1, 1)
        high_res_frames = self.deconv(frames).view(batch_size, self.seq_len, -1, 64, 64)  # Example: 64x64 resolution
        return high_res_frames

class Discriminator(nn.Module):
    def __init__(self, img_channels, seq_len):
        super(Discriminator, self).__init__()
        self.seq_len = seq_len

        # Convolution layers for frame-level feature extraction
        self.conv = nn.Sequential(
            nn.Conv2d(img_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2)
        )

        # Linear layer for temporal consistency and classification
        self.fc = nn.Sequential(
            nn.Linear(256 * 8 * 8 * seq_len, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1),
            nn.Sigmoid()
        )

    def forward(self, video_frames):
        batch_size = video_frames.size(0)
        seq_len = video_frames.size(1)

        # Extract features for each frame and concatenate
        features = []
        for t in range(seq_len):
            frame_features = self.conv(video_frames[:, t, :, :, :])
            features.append(frame_features.view(batch_size, -1))
        
        temporal_features = torch.cat(features, dim=1)
        output = self.fc(temporal_features)
        return output

# Hyperparameters
img_channels = 3
latent_dim = 128
seq_len = 16
embed_dim = 256

# Initialize models
generator = Generator(img_channels, latent_dim, seq_len, embed_dim)
discriminator = Discriminator(img_channels, seq_len)

# Example latent vector
z = torch.randn(4, latent_dim)  # Batch size of 4
generated_video = generator(z)  # Generated video output
print("Generated video shape:", generated_video.shape)  # [Batch, Seq, Channels, Height, Width]

