In [1]:
import torch
from torch import nn

torch.cuda.empty_cache()

In [2]:
from ultralytics import YOLO

model = YOLO("best.pt")

In [3]:
import os

import pandas as pd
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms


class CustomYOLODataset(Dataset):
    def __init__(
        self, images_dir, sequence_length=5, skip=2, img_size=640, transform=None
    ):
        self.images_dir = images_dir
        self.img_files = sorted(
            [f for f in os.listdir(images_dir) if f.endswith((".jpg", ".png", ".jpeg"))]
        )
        self.img_size = img_size
        self.transform = (
            transform
            if transform
            else transforms.Compose(
                [
                    transforms.Resize((img_size, img_size)),
                    transforms.ToTensor(),
                ]
            )
        )
        self.df = pd.DataFrame(self.img_files, columns=["img_name"])
        # Split image file names into 'map', 'batch_id', 'img_id' columns
        self.df[["map", "batch_id", "img_id"]] = self.df["img_name"].str.extract(
            r"(.+?)_(\d+)_manual_frame(\d+)\.png"
        )
        self.df["img_id"] = self.df["img_id"].astype(int)
        self.df["batch_id"] = self.df["batch_id"].astype(int)
        self.df["id"] = self.df["batch_id"] * 10 + self.df["img_id"]
        self.sequences = []
        self.k = sequence_length
        self.skip = skip
        for _, group in self.df.groupby("map"):
            group_sorted = group.sort_values(by=["batch_id", "img_id"])
            img_ids = group_sorted["img_id"].values
            indices = group_sorted.index.values
            max_start = len(img_ids) - (self.k - 1) * skip
            for i in range(max_start):
                seq_indices = indices[i : i + self.k * skip : skip]
                if len(seq_indices) == self.k:
                    self.sequences.append(
                        group["img_name"][seq_indices].values.tolist()
                    )

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        img_names = self.sequences[idx]
        images = []
        for img_name in img_names:
            image = self.process_img(img_name)
            images.append(image)
        images = torch.stack(images)
        return images

    def process_img(self, img_name: str):
        img_path = os.path.join(self.images_dir, img_name)
        image = Image.open(img_path).convert("RGB")
        image = self.transform(image)
        return image

In [4]:
DATASET = "../datasets/DOOM_yolo_seg/"
dataset = CustomYOLODataset(images_dir=DATASET + "rgb")

In [5]:
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=8)

In [6]:
example = next(iter(dataloader))
len(example)

8

In [7]:
example.shape

torch.Size([8, 5, 3, 640, 640])

In [8]:
class YOLOFeatureExtractor(nn.Module):
    def __init__(self, yolo_model, layer_idx=17):
        super().__init__()
        self.model = yolo_model
        self.layer_idx = layer_idx
        self.features = None

        # Freeze YOLO model parameters
        for param in self.model.parameters():
            param.requires_grad = False

        # Register the hook
        self.hook = self.model.model.model[self.layer_idx].register_forward_hook(
            self._hook_fn
        )

    def _hook_fn(self, module, input, output):
        # Just store the output tensor
        self.features = output

    def forward(self, x):
        # Get batch shape dynamically
        batch_size, seq_len, channels, height, width = x.shape

        # Clear previous features
        self.features = None

        # Reshape for YOLO model
        x_reshaped = x.view(-1, channels, height, width)

        # Run model WITHOUT no_grad() so features can participate in gradient computation
        # The YOLO model is frozen via requires_grad=False, so it won't be trained
        self.model(x_reshaped, verbose=False)

        # Check features were extracted
        assert self.features is not None, "Features not extracted. Check the hook."

        # Reshape features back to sequence format
        features = self.features.view(batch_size, seq_len, -1)

        # Return features that can participate in gradient computation
        # The YOLO model is frozen, but the features can flow gradients to the LSTM
        return features

    def __del__(self):
        if hasattr(self, "hook"):
            self.hook.remove()


# Usage
feature_extractor = YOLOFeatureExtractor(model)
embeddings = feature_extractor(example)

In [9]:
batch_size, seq_len, emb_dim = embeddings.shape

In [10]:
class ModernEmbeddingLSTM(nn.Module):
    def __init__(
        self,
        embedding_dim,
        hidden_dim=256,
        num_layers=2,
        dropout=0.1,
        bidirectional=False,
        use_attention=True,
    ):
        super().__init__()

        # Calculate directions
        self.directions = 2 if bidirectional else 1
        self.use_attention = use_attention

        # LSTM layer with dropout between layers
        self.lstm = nn.LSTM(
            input_size=embedding_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=bidirectional,
        )

        self.layer_norm = nn.LayerNorm(hidden_dim * self.directions)
        # Attention mechanism
        if use_attention:
            self.attention = nn.Sequential(
                nn.Linear(hidden_dim * self.directions, hidden_dim),
                nn.Tanh(),
                nn.Linear(hidden_dim, 1),
            )
        # Output projection to match embedding dimension
        self.proj = nn.Sequential(
            nn.Linear(hidden_dim * self.directions, hidden_dim * 2),
            nn.GELU(),  # Modern activation function
            nn.Dropout(dropout),
            nn.Linear(hidden_dim * 2, embedding_dim),
        )
        # Initialize weights for better convergence
        self._init_weights()

    def _init_weights(self):
        for name, param in self.named_parameters():
            if "weight" in name:
                if "lstm" in name:
                    nn.init.orthogonal_(param)  # Orthogonal initialization for RNNs
                elif len(param.shape) >= 2:
                    nn.init.xavier_uniform_(
                        param
                    )  # Xavier for linear layers (2D+ weights)
                else:
                    nn.init.normal_(
                        param, mean=0.0, std=0.02
                    )  # Normal init for 1D weights
            elif "bias" in name:
                nn.init.zeros_(param)  # Initialize biases to zero

    def forward(self, x, return_attention=False):
        """
        Args:
            x: Input tensor of shape [batch_size, seq_len, embedding_dim]
            return_attention: Whether to return attention weights

        Returns:
            Predicted next embedding of shape [batch_size, embedding_dim]
            Optionally, attention weights if return_attention=True
        """
        # x shape: [B, T, embedding_dim]
        batch_size, seq_len = x.shape[0], x.shape[1]

        # Run through LSTM
        lstm_out, (hidden, cell) = self.lstm(x)
        # lstm_out shape: [B, T, hidden_dim * directions]

        # Apply layer normalization
        lstm_out = self.layer_norm(lstm_out)

        if self.use_attention:
            # Calculate attention scores
            attention_scores = self.attention(lstm_out).squeeze(-1)  # [B, T]
            attention_weights = torch.softmax(attention_scores, dim=1)  # [B, T]

            # Apply attention to get context vector
            attention_weights = attention_weights.unsqueeze(-1)  # [B, T, 1]
            context = torch.sum(
                attention_weights * lstm_out, dim=1
            )  # [B, hidden_dim * directions]
        else:
            # Just use the last output
            context = lstm_out[:, -1]  # [B, hidden_dim * directions]

        # Project to embedding dimension
        next_emb = self.proj(context)  # [B, embedding_dim]

        if return_attention and self.use_attention:
            return next_emb, attention_weights.squeeze(-1)
        return next_emb

    def predict_sequence(self, initial_sequence, steps=5):
        """
        Generate a sequence of future embeddings

        Args:
            initial_sequence: Initial sequence of embeddings [B, T, embedding_dim]
            steps: Number of future steps to predict

        Returns:
            Sequence of predicted embeddings [B, steps, embedding_dim]
        """
        device = next(self.parameters()).device
        batch_size = initial_sequence.shape[0]
        emb_dim = initial_sequence.shape[2]

        # Start with the initial sequence
        current_seq = initial_sequence
        predictions = torch.zeros(batch_size, steps, emb_dim, device=device)

        # Autoregressive prediction
        for i in range(steps):
            # Predict next embedding
            next_emb = self(current_seq)
            predictions[:, i] = next_emb

            # Update sequence for next prediction (drop oldest, add newest)
            if i < steps - 1:  # No need to update for the last step
                current_seq = torch.cat(
                    [current_seq[:, 1:], next_emb.unsqueeze(1)], dim=1
                )

        return predictions

In [11]:
# Initialize the model
lstm_model = ModernEmbeddingLSTM(
    embedding_dim=emb_dim, hidden_dim=512, num_layers=2, dropout=0.1, use_attention=True
)
lstm_model.to("cuda" if torch.cuda.is_available() else "cpu")

ModernEmbeddingLSTM(
  (lstm): LSTM(102400, 512, num_layers=2, batch_first=True, dropout=0.1)
  (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (attention): Sequential(
    (0): Linear(in_features=512, out_features=512, bias=True)
    (1): Tanh()
    (2): Linear(in_features=512, out_features=1, bias=True)
  )
  (proj): Sequential(
    (0): Linear(in_features=512, out_features=1024, bias=True)
    (1): GELU(approximate='none')
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=1024, out_features=102400, bias=True)
  )
)

In [12]:
# Predict next embedding
with torch.inference_mode():
    next_embedding = lstm_model(embeddings)
    print(f"Next embedding shape: {next_embedding.shape}")  # [batch_size, emb_dim]

    # With attention weights
    next_embedding, attention = lstm_model(embeddings, return_attention=True)
    print(f"Attention weights shape: {attention.shape}")  # [batch_size, seq_len]

    # Generate a sequence of future embeddings
    future_sequence = lstm_model.predict_sequence(embeddings, steps=1)
    print(
        f"Future sequence shape: {future_sequence.shape}"
    )  # [batch_size, 10, emb_dim]

Next embedding shape: torch.Size([8, 102400])
Attention weights shape: torch.Size([8, 5])
Future sequence shape: torch.Size([8, 1, 102400])


In [13]:
# Create a train/validation split
from torch.utils.data import random_split

# Define split ratio
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=8)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=8)

In [14]:
def train_embedding_lstm(
    feature_extractor,
    lstm_model,
    dataloader,
    num_epochs=50,
    learning_rate=0.001,
    weight_decay=1e-5,
    device="cuda" if torch.cuda.is_available() else "cpu",
):
    # Move models to device
    feature_extractor.to(device)
    lstm_model.to(device)

    # Set up optimizer and loss function
    optimizer = torch.optim.AdamW(
        lstm_model.parameters(), lr=learning_rate, weight_decay=weight_decay
    )
    criterion = nn.MSELoss()
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    # Training history
    history = {"train_loss": []}
    best_loss = float("inf")
    history = {"train_loss": []}
    # Training loop
    for epoch in range(num_epochs):
        lstm_model.train()
        epoch_loss = 0.0
        for batch_idx, batch in enumerate(dataloader):
            batch = batch.to(device)
            # Extract embeddings from the entire sequence
            embeddings = feature_extractor(batch)  # [B, T, emb_dim]
            # No need for requires_grad_(True) - embeddings are frozen features
            embeddings = embeddings.to(device)

            # Use first N-1 embeddings to predict the Nth embedding_(True)
            input_embeddings = embeddings[
                :, :-1, :
            ]  # [B, T-1, emb_dim])  # Enable gradients for embeddings
            target_embedding = embeddings[
                :, -1, :
            ]  # [B, emb_dim]1 embeddings to predict the Nth embedding
            # Forward pass_dim]
            optimizer.zero_grad()
            predicted_embedding = lstm_model(input_embeddings)

            # Compute loss
            loss = criterion(predicted_embedding, target_embedding)

            # Backward pass and optimize
            loss.backward()
            torch.nn.utils.clip_grad_norm_(lstm_model.parameters(), max_norm=1.0)
            optimizer.step()
            # Update metricsoptimizer.step()
            epoch_loss += loss.item()

            if batch_idx % 10 == 0:
                print(
                    f"Epoch {epoch + 1}/{num_epochs}, Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.6f}"
                )

        # End of epoch
        avg_loss = epoch_loss / len(dataloader)
        history["train_loss"].append(avg_loss)

        # Update learning rate# End of epoch
        scheduler.step()
        history["train_loss"].append(avg_loss)
        print(
            f"Epoch {epoch + 1}/{num_epochs} completed. Avg Loss: {avg_loss:.6f}, LR: {scheduler.get_last_lr()[0]:.6f}"
        )

        # Save best model
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(
                {
                    "epoch": epoch,
                    "model_state_dict": lstm_model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "loss": best_loss,
                },
                "best_lstm_model.pt",
            )
            print(f"Model saved with loss: {best_loss:.6f}")

    return lstm_model, history

In [15]:
# Create the feature extractor with dynamic shape handling
feature_extractor = YOLOFeatureExtractor(model, layer_idx=17)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=8)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=8)

# Train the model
device = "cuda" if torch.cuda.is_available() else "cpu"
trained_model, history = train_embedding_lstm(
    feature_extractor=feature_extractor,
    lstm_model=lstm_model,
    dataloader=train_loader,
    num_epochs=50,
    learning_rate=0.001,
    device=device,
)

RuntimeError: Inference tensors cannot be saved for backward. To work around you can make a clone to get a normal tensor and use it in autograd.