<a href="https://colab.research.google.com/github/Redd-hope/HumanPatterns/blob/main/StudentModelMaking_smm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
from transformers import ViTModel

class SelfAttention(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=4, batch_first=True)
        self.norm = nn.LayerNorm(hidden_dim)

    def forward(self, x):
        attn_output, _ = self.attn(x, x, x)
        return self.norm(x + attn_output)

class HybridDeepModel(nn.Module):
    def __init__(self, num_classes=100):
        super().__init__()
        # Vit output: [B, N+1, 768]
        self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")

        # CNN to compress to 256
        self.cnn1 = nn.Sequential(
            nn.Conv1d(768, 512, 3, padding=1),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Conv1d(512, 256, 3, padding=1),
            nn.BatchNorm1d(256),
            nn.ReLU()
        )

        # CNN again after first transitions
        self.cnn2 = nn.Sequential(
            nn.Conv1d(256, 256, 3, padding=1),
            nn.BatchNorm1d(256),
            nn.ReLU()
        )

        self.bi_lstm = nn.LSTM(256, 128, num_layers=2, batch_first=True, bidirectional=True)

        self.attn_256 = lambda: SelfAttention(256)

        # All attention layers
        self.attn1 = self.attn_256()
        self.attn2 = self.attn_256()
        self.attn3 = self.attn_256()
        self.attn4 = self.attn_256()
        self.attn5 = self.attn_256()
        self.attn6 = self.attn_256()

        self.classifier = nn.Linear(256, num_classes)

    def forward(self, x):
        # === ViT
        vit_out = self.vit(pixel_values=x).last_hidden_state[:, 1:, :]  # [B, N-1, 768]

        # === CNN in parallel
        cnn_in = vit_out.transpose(1, 2)
        cnn_out = self.cnn1(cnn_in).transpose(1, 2)  # [B, N-1, 256]

        # === Parallel Merge: (ViT || CNN) → Attention
        combined = cnn_out + vit_out[:, :, :256]  # reduce ViT dim to 256 for merge
        out = self.attn1(combined)

        # === CNN → Attention
        out = out.transpose(1, 2)
        out = self.cnn2(out).transpose(1, 2)
        out = self.attn2(out)

        # === BiLSTM → Attention
        out, _ = self.bi_lstm(out)
        out = self.attn3(out)

        # === (CNN || BiLSTM) → Attention
        cnn_branch = self.cnn2(out.transpose(1, 2)).transpose(1, 2)
        lstm_branch, _ = self.bi_lstm(out)
        out = cnn_branch + lstm_branch
        out = self.attn4(out)

        # === (ViT || BiLSTM) → Attention
        vit_compress = vit_out[:, :, :256]
        lstm_branch2, _ = self.bi_lstm(out)
        out = vit_compress + lstm_branch2
        out = self.attn5(out)

        # === Final BiLSTM → Attention
        out, _ = self.bi_lstm(out)
        out = self.attn6(out)

        # === Classification
        logits = self.classifier(out)  # [B, seq_len, num_classes]

        return logits.permute(1, 0, 2)  # CTC format: [T, B, C]
