<a href="https://colab.research.google.com/github/BlackBoyZeus/Architecture/blob/main/VideoEmbedModel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, TensorDataset, random_split
from tqdm import tqdm

# Model Definitions (as provided previously)

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(MultiHeadSelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # Split embedding into self.heads pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        # Scaled dot-product attention
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.nn.functional.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )

        out = self.fc_out(out)
        return out

class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = MultiHeadSelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size)
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)
        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out

class VideoTransformer(nn.Module):
    def __init__(self, embed_size, heads, num_layers, forward_expansion, dropout):
        super(VideoTransformer, self).__init__()
        self.layers = nn.ModuleList(
            [
                TransformerBlock(
                    embed_size,
                    heads,
                    dropout=dropout,
                    forward_expansion=forward_expansion
                )
                for _ in range(num_layers)
            ]
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        for layer in self.layers:
            x = layer(x, x, x, mask)
        return self.dropout(x)

class VideoModel(nn.Module):
    def __init__(self, in_channels=3, num_classes=1, num_layers=2, forward_expansion=2, heads=4, dropout=0.5):
        super(VideoModel, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((224, 224))
        )

        self.transformer = VideoTransformer(
            embed_size=224*224*128,
            heads=heads,
            num_layers=num_layers,
            forward_expansion=forward_expansion,
            dropout=dropout
        )

        self.fc = nn.Linear(224*224*128, num_classes)

    def forward(self, x):
        x = self.cnn(x)
        x = x.view(x.shape[0], -1)
        x = self.transformer(x)
        return self.fc(x)

model = VideoModel()
print(model)

# 1. Data Preparation
video_path = '/content/segment.mp4'
cap = cv2.VideoCapture(video_path)
frames = []

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break
    frame = cv2.resize(frame, (224, 224))
    frames.append(frame)

cap.release()

frames = np.array(frames) / 255.0
frames = torch.tensor(frames).permute(0, 3, 1, 2).float()

train_frames, test_frames = train_test_split(frames, test_size=0.2, random_state=42)
train_frames, val_frames = train_test_split(train_frames, test_size=0.2, random_state=42)

# 2. Embedding Model

embedding_model = VideoModel(num_classes=224*224*128)
embedding_optimizer = optim.Adam(embedding_model.parameters(), lr=0.001)
embedding_criterion = nn.MSELoss()

# 3. Segmentation Model

segmentation_model = nn.Sequential(
    nn.Linear(224*224*128, 224*224*128),
    nn.ReLU(),
    nn.Linear(224*224*128, 224*224),
    nn.Sigmoid()
)
segmentation_optimizer = optim.Adam(segmentation_model.parameters(), lr=0.001)
segmentation_criterion = nn.BCELoss()

train_masks = torch.rand_like(train_frames[:, 0])
val_masks = torch.rand_like(val_frames[:, 0])

# 4. Decoder Model

decoder_model = nn.Sequential(
    nn.Linear(224*224*128, 224*224*128),
    nn.ReLU(),
    nn.Linear(224*224*128, 224*224*3)
)
decoder_optimizer = optim.Adam(decoder_model.parameters(), lr=0.001)
decoder_criterion = nn.MSELoss()

# Training Loops

epochs = 10

# Training Embedding Model
for epoch in range(epochs):
    embedding_model.train()
    for frame in tqdm(DataLoader(train_frames, batch_size=32, shuffle=True), desc="Embedding Epoch " + str(epoch+1)):
        embedding_optimizer.zero_grad()
        output = embedding_model(frame)
        loss = embedding_criterion(output, frame.view(frame.size(0), -1))
        loss.backward()
        embedding_optimizer.step()

# Training Segmentation Model
for epoch in range(epochs):
    segmentation_model.train()
    for frame, mask in tqdm(DataLoader(TensorDataset(train_frames, train_masks), batch_size=32, shuffle=True), desc="Segmentation Epoch " + str(epoch+1)):
        segmentation_optimizer.zero_grad()
        embedding = embedding_model(frame).detach()
        output = segmentation_model(embedding)
        loss = segmentation_criterion(output, mask.view(mask.size(0), -1))
        loss.backward()
        segmentation_optimizer.step()

# Training Decoder Model
for epoch in range(epochs):
    decoder_model.train()
    for frame in tqdm(DataLoader(train_frames, batch_size=32, shuffle=True), desc="Decoder Epoch " + str(epoch+1)):
        decoder_optimizer.zero_grad()
        embedding = embedding_model(frame).detach()
        output = decoder_model(embedding)
        loss = decoder_criterion(output, frame.view(frame.size(0), -1))
        loss.backward()
        decoder_optimizer.step()

# 5. Generate Output Video

output_frames = []
with torch.no_grad():
    for frame in tqdm(test_frames, desc="Generating Output Video"):
        embedding = embedding_model(frame.unsqueeze(0))
        mask = segmentation_model(embedding).view(224, 224).numpy()
        reconstructed_frame = decoder_model(embedding).view(3, 224, 224).numpy()
        output_frame = ((1-mask) * reconstructed_frame).astype(np.uint8)
        output_frames.append(output_frame)

fourcc = cv2.VideoWriter_fourcc(*'XVID')
out = cv2.VideoWriter('output_video.avi', fourcc, 20.0, (224, 224))
for frame in output_frames:
    out.write(frame)
out.release()