<a href="https://colab.research.google.com/github/RuwaAbey/Computer_vision_based_group_activity_detection/blob/main/Modified_Z_Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
from einops import rearrange

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Global flags
dropout_flag = False
scale_norm = False
multi_matmul = False

def import_class(name):
    """Dynamically import a class from a string."""
    components = name.split('.')
    mod = __import__(components[0])
    for comp in components[1:]:
        mod = getattr(mod, comp)
    return mod

def conv_init(module):
    """Initialize convolutional layer with He-normal initialization."""
    n = module.out_channels
    for k in module.kernel_size:
        n = n * k
    module.weight.data.normal_(0, math.sqrt(2. / n))
    if module.bias is not None:
        nn.init.constant_(module.bias, 0)

def bn_init(bn, scale):
    """Initialize batch normalization layer."""
    nn.init.constant_(bn.weight, scale)
    nn.init.constant_(bn.bias, 0)

def conv_branch_init(conv, branches):
    """Initialize convolutional layer for GCN branches."""
    weight = conv.weight
    n = weight.size(0)
    k1 = weight.size(1)
    k2 = weight.size(2)
    nn.init.normal_(weight, 0, math.sqrt(2. / (n * k1 * k2 * branches)))
    nn.init.constant_(conv.bias, 0)

class Unit2D(nn.Module):
    """2D convolutional unit with batch norm and ReLU."""
    def __init__(self, D_in, D_out, kernel_size, stride=1, dim=2, dropout=0, bias=True):
        super(Unit2D, self).__init__()
        pad = int((kernel_size - 1) / 2)
        if dim == 2:
            self.conv = nn.Conv2d(D_in, D_out, kernel_size=(kernel_size, 1), padding=(pad, 0), stride=(stride, 1), bias=bias)
        elif dim == 3:
            self.conv = nn.Conv2d(D_in, D_out, kernel_size=(1, kernel_size), padding=(0, pad), stride=(1, stride), bias=bias)
        else:
            raise ValueError("dim must be 2 or 3")
        self.bn = nn.BatchNorm2d(D_out)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(dropout, inplace=True) if dropout > 0 else None
        conv_init(self.conv)
        bn_init(self.bn, 1)

    def forward(self, x):
        if self.dropout and dropout_flag:
            x = self.dropout(x)
        x = self.relu(self.bn(self.conv(x)))
        return x

class unit_tcn(nn.Module):
    """Temporal Convolutional Network unit."""
    def __init__(self, in_channels, out_channels, kernel_size=9, stride=1):
        super(unit_tcn, self).__init__()
        self.conv = Unit2D(in_channels, out_channels, kernel_size=kernel_size, stride=stride, dim=2, dropout=0)

    def forward(self, x):
        return self.conv(x)

class unit_tcn_m(nn.Module):
    """Multi-scale Temporal Convolutional Network unit."""
    def __init__(self, in_channels, out_channels, stride=1, kernel_size=[1, 3, 7]):
        super(unit_tcn_m, self).__init__()
        mid_channels = out_channels // 3
        pad1 = (kernel_size[0] - 1) // 2
        pad2 = (kernel_size[1] - 1) // 2
        pad3 = (kernel_size[2] - 1) // 2

        self.conv11 = nn.Conv2d(in_channels, in_channels, kernel_size=(1, 1))
        self.conv21 = nn.Conv2d(in_channels, in_channels, kernel_size=(1, 1))
        self.conv31 = nn.Conv2d(in_channels, in_channels, kernel_size=(1, 1))
        self.conv12 = nn.Conv2d(in_channels, mid_channels, kernel_size=(kernel_size[0], 1), padding=(pad1, 0), stride=(stride, 1))
        self.conv22 = nn.Conv2d(in_channels, mid_channels, kernel_size=(kernel_size[1], 1), padding=(pad2, 0), stride=(stride, 1))
        self.conv32 = nn.Conv2d(in_channels, mid_channels, kernel_size=(kernel_size[2], 1), padding=(pad3, 0), stride=(stride, 1))
        self.bn = nn.BatchNorm2d(out_channels)

        conv_init(self.conv11)
        conv_init(self.conv21)
        conv_init(self.conv31)
        conv_init(self.conv12)
        conv_init(self.conv22)
        conv_init(self.conv32)
        bn_init(self.bn, 1)

    def forward(self, x):
        x1 = self.conv12(self.conv11(x))
        x2 = self.conv22(self.conv21(x))
        x3 = self.conv32(self.conv31(x))
        x = torch.cat([x1, x2, x3], dim=1)
        return self.bn(x)

class unit_gcn(nn.Module):
    """Graph Convolutional Network unit."""
    def __init__(self, in_channels, out_channels, A, coff_embedding=4, num_subset=3):
        super(unit_gcn, self).__init__()
        inter_channels = out_channels // coff_embedding
        self.inter_c = inter_channels
        self.PA = nn.Parameter(torch.from_numpy(A.astype(np.float32)))
        nn.init.constant_(self.PA, 1e-6)
        self.A = torch.from_numpy(A.astype(np.float32)).to(device)
        self.num_subset = num_subset

        self.conv_a = nn.ModuleList()
        self.conv_b = nn.ModuleList()
        self.conv_d = nn.ModuleList()
        for i in range(num_subset):
            self.conv_a.append(nn.Conv2d(in_channels, inter_channels, 1))
            self.conv_b.append(nn.Conv2d(in_channels, inter_channels, 1))
            self.conv_d.append(nn.Conv2d(in_channels, out_channels, 1))

        self.down = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1),
            nn.BatchNorm2d(out_channels)
        ) if in_channels != out_channels else lambda x: x

        self.bn = nn.BatchNorm2d(out_channels)
        self.soft = nn.Softmax(-2)
        self.relu = nn.ReLU(inplace=True)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                conv_init(m)
            elif isinstance(m, nn.BatchNorm2d):
                bn_init(m, 1)
        bn_init(self.bn, 1e-6)
        for i in range(num_subset):
            conv_branch_init(self.conv_d[i], num_subset)

    def forward(self, x):
        N, C, T, V = x.size()
        A = self.A.to(x.device) + self.PA

        y = None
        for i in range(self.num_subset):
            A1 = self.conv_a[i](x).permute(0, 3, 1, 2).contiguous().view(N, V, self.inter_c * T)
            A2 = self.conv_b[i](x).view(N, self.inter_c * T, V)
            A1 = self.soft(torch.matmul(A1, A2) / A1.size(-1))
            A1 = A1 + A[i]
            A2 = x.view(N, C * T, V)
            z = self.conv_d[i](torch.matmul(A2, A1).view(N, C, T, V))
            y = z + y if y is not None else z

        y = self.bn(y) + self.down(x)
        return self.relu(y)

class TCN_GCN_unit(nn.Module):
    """Combined TCN and GCN unit."""
    def __init__(self, in_channels, out_channels, A, stride=1, residual=True):
        super(TCN_GCN_unit, self).__init__()
        self.gcn1 = unit_gcn(in_channels, out_channels, A)
        self.tcn1 = unit_tcn_m(out_channels, out_channels, stride=stride)
        self.relu = nn.ReLU(inplace=True)
        self.residual = unit_tcn(in_channels, out_channels, kernel_size=1, stride=stride) if (residual and (in_channels != out_channels or stride != 1)) else (lambda x: x if residual else lambda x: 0)

    def forward(self, x):
        return self.relu(self.tcn1(self.gcn1(x)) + self.residual(x))

class ScaleNorm(nn.Module):
    """Scale normalization layer."""
    def __init__(self, scale, eps=1e-8):
        super(ScaleNorm, self).__init__()
        self.scale = scale
        self.eps = eps

    def forward(self, x):
        norm = self.scale / (x.norm(dim=-1, keepdim=True) + self.eps)
        return x * norm

class tcn_unit_attention(nn.Module):
    """Temporal attention unit with optional convolution."""
    def __init__(self, in_channels, out_channels, dv_factor=0.5, dk_factor=0.5, Nh=3,
                 kernel_size_temporal=9, stride=1, num_point=18, relative=True,
                 only_temporal_att=False, bn_flag=True, data_normalization=True,
                 skip_conn=True, drop_connect=True):
        super(tcn_unit_attention, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.dk = int(dk_factor * out_channels)
        self.dv = int(dv_factor * out_channels) if not only_temporal_att else out_channels
        self.Nh = Nh
        self.kernel_size_temporal = kernel_size_temporal
        self.stride = stride
        self.num_point = num_point
        self.relative = relative
        self.only_temporal_att = only_temporal_att
        self.bn_flag = bn_flag
        self.data_normalization = data_normalization
        self.skip_conn = skip_conn
        self.drop_connect = drop_connect

        self.bn = nn.BatchNorm2d(out_channels) if bn_flag else nn.Identity()
        self.relu = nn.ReLU(inplace=True)

        if data_normalization:
            self.data_bn = nn.BatchNorm1d(in_channels * num_point)

        if not only_temporal_att:
            self.tcn_conv = Unit2D(in_channels, out_channels - self.dv, kernel_size=kernel_size_temporal,
                                   stride=stride, dim=2, dropout=0.25 if dropout_flag else 0)

        self.qkv_conv = nn.Conv2d(in_channels, 2 * self.dk + self.dv, kernel_size=1, stride=1)
        self.attn_out = nn.Conv2d(self.dv, self.dv, kernel_size=1, stride=1)

        if relative:
            self.key_rel = nn.Parameter(torch.randn(2 * 10 - 1, self.dk // Nh))

        self.down = Unit2D(in_channels, out_channels, kernel_size=1, stride=stride, dim=2, dropout=0) if (skip_conn and (in_channels != out_channels or stride != 1)) else None

        assert self.Nh > 0, "Nh must be >= 1"
        assert self.dk % self.Nh == 0, f"dk ({self.dk}) must be divisible by Nh ({self.Nh})"
        assert self.dv % self.Nh == 0, f"dv ({self.dv}) must be divisible by Nh ({self.Nh})"

    def forward(self, x):
        N, C, T, V = x.size()
        x_sum = x

        if self.data_normalization:
            x = x.permute(0, 1, 3, 2).reshape(N, C * V, T)
            x = self.data_bn(x)
            x = x.reshape(N, C, V, T).permute(0, 1, 3, 2)

        x = x.permute(0, 3, 1, 2).reshape(-1, C, 1, T)

        if scale_norm:
            x = ScaleNorm(scale=C ** 0.5)(x)

        flat_q, flat_k, flat_v, q, k, v = self.compute_flat_qkv(x, self.dk, self.dv, self.Nh)
        B, Nh, _, T = flat_q.size()

        logits = torch.matmul(flat_q.transpose(2, 3), flat_k)

        if self.relative:
            rel_logits = self.relative_logits(q)
            logits = logits + rel_logits

        weights = F.softmax(logits, dim=-1)

        if self.drop_connect and self.training:
            mask = torch.bernoulli(0.5 * torch.ones(B * Nh * T, device=x.device))
            mask = mask.reshape(B, Nh, T).unsqueeze(2).expand(B, Nh, T, T)
            weights = weights * mask / (weights.sum(3, keepdim=True) + 1e-8)

        attn_out = torch.matmul(weights, flat_v.transpose(2, 3))
        attn_out = attn_out.reshape(N, Nh, self.dv // Nh, T, 1).permute(0, 1, 2, 4, 3)
        attn_out = self.combine_heads_2d(attn_out)
        attn_out = self.attn_out(attn_out)
        attn_out = attn_out.reshape(N, V, -1, T).permute(0, 2, 3, 1)

        if self.skip_conn:
            result = attn_out
            if not self.only_temporal_att:
                x_tcn = self.tcn_conv(x_sum)
                result = torch.cat((x_tcn, attn_out), dim=1)
            result = result + (x_sum if self.down is None else self.down(x_sum))
        else:
            result = attn_out

        result = self.bn(result)
        return self.relu(result)

    def compute_flat_qkv(self, x, dk, dv, Nh):
        qkv = self.qkv_conv(x)
        N, _, V1, T1 = qkv.size()
        q, k, v = torch.split(qkv, [dk, dk, dv], dim=1)
        q = self.split_heads_2d(q, Nh)
        k = self.split_heads_2d(k, Nh)
        v = self.split_heads_2d(v, Nh)
        dkh = dk // Nh
        q = q * (dkh ** -0.5)
        flat_q = q.reshape(N, Nh, dkh, V1 * T1)
        flat_k = k.reshape(N, Nh, dkh, V1 * T1)
        flat_v = v.reshape(N, Nh, dv // Nh, V1 * T1)
        return flat_q, flat_k, flat_v, q, k, v

    def split_heads_2d(self, x, Nh):
        B, channels, F, V = x.size()
        return x.reshape(B, Nh, channels // Nh, F, V)

    def combine_heads_2d(self, x):
        batch, Nh, dv, F, V = x.size()
        return x.reshape(batch, Nh * dv, F, V)

    def relative_logits(self, q):
        B, Nh, dk, V, T = q.size()
        q = q.permute(0, 1, 3, 4, 2).reshape(B, Nh, V * T, dk)
        rel_logits = torch.einsum('bhld,md->bhlm', q, self.key_rel)
        rel_logits = self.rel_to_abs(rel_logits)
        return rel_logits.reshape(B, Nh, T, T)

    def rel_to_abs(self, x):
        B, Nh, L, K = x.size()
        col_pad = torch.zeros(B, Nh, L, 1, device=x.device)
        x = torch.cat((x, col_pad), dim=3)
        flat_x = x.reshape(B, Nh, L * (K + 1))
        flat_pad = torch.zeros(B, Nh, L - 1, device=x.device)
        flat_x_padded = torch.cat((flat_x, flat_pad), dim=2)
        final_x = flat_x_padded.reshape(B, Nh, L + 1, K + 1)[:, :, :L, -(L + 1):-1]
        return final_x

class Residual(nn.Module):
    """Residual connection wrapper."""
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

class LayerNormalize(nn.Module):
    """Layer normalization wrapper."""
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class MLP_Block(nn.Module):
    """MLP block for Transformer."""
    def __init__(self, dim, hid_dim, dropout=0.1):
        super().__init__()
        self.nn1 = nn.Linear(dim, hid_dim)
        nn.init.xavier_uniform_(self.nn1.weight)
        nn.init.normal_(self.nn1.bias, std=1e-6)
        self.af1 = nn.ReLU()
        self.do1 = nn.Dropout(dropout)
        self.nn2 = nn.Linear(hid_dim, dim)
        nn.init.xavier_uniform_(self.nn2.weight)
        nn.init.normal_(self.nn2.bias, std=1e-6)
        self.do2 = nn.Dropout(dropout)

    def forward(self, x):
        x = self.nn1(x)
        x = self.af1(x)
        x = self.do1(x)
        x = self.nn2(x)
        x = self.do2(x)
        return x

class Attention(nn.Module):
    """Multi-head attention module."""
    def __init__(self, dim, out_dim, heads=3, dropout=0.1):
        super().__init__()
        self.heads = heads
        self.scale = dim ** -0.5
        self.to_qkv = nn.Linear(dim, dim * 3, bias=True)
        nn.init.xavier_uniform_(self.to_qkv.weight)
        nn.init.zeros_(self.to_qkv.bias)
        self.nn1 = nn.Linear(dim, out_dim)
        nn.init.xavier_uniform_(self.nn1.weight)
        nn.init.zeros_(self.nn1.bias)
        self.do1 = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x)
        q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv=3, h=h)
        dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale

        if mask is not None:
            assert mask.shape[-1] == dots.shape[-1], 'Mask has incorrect dimensions'
            dots = (dots + mask) * 0.5

        attn = dots.softmax(dim=-1)
        out = torch.einsum('bhij,bhjd->bhid', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.nn1(out)
        return self.do1(out)

class Transformer(nn.Module):
    """Transformer module with attention and MLP blocks."""
    def __init__(self, dim, depth, heads, mlp_dim, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Residual(Attention(dim, mlp_dim, heads=heads, dropout=dropout)) if dim == mlp_dim else Attention(dim, mlp_dim, heads=heads, dropout=dropout),
                Residual(LayerNormalize(mlp_dim, MLP_Block(mlp_dim, mlp_dim * 2, dropout=dropout)))
            ]))

    def forward(self, x, mask=None):
        for attention, mlp in self.layers:
            x = attention(x, mask=mask)
            x = mlp(x)
        return x

class TCN_STRANSF_unit(nn.Module):
    """Unit combining TCN, Transformer, and temporal attention in parallel."""
    def __init__(self, in_channels=3, out_channels=48, heads=3, stride=1, residual=True,
                 dropout=0.1, mask=None, mask_grad=True):
        super(TCN_STRANSF_unit, self).__init__()
        self.attn = tcn_unit_attention(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size_temporal=9,
            stride=stride,
            Nh=heads,
            dk_factor=0.5,
            dv_factor=0.5,
            relative=True,
            only_temporal_att=False,
            skip_conn=True,
            bn_flag=True,
            drop_connect=True
        )
        self.transf1 = Transformer(dim=in_channels, depth=1, heads=heads, mlp_dim=in_channels, dropout=dropout)
        self.tcn1 = unit_tcn_m(in_channels, out_channels, stride=stride)
        self.relu = nn.ReLU(inplace=True)
        self.out_channels = out_channels

        self.residual = unit_tcn(in_channels, out_channels, kernel_size=1, stride=stride) if (residual and (in_channels != out_channels or stride != 1)) else (lambda x: x if residual else lambda x: 0)

        self.mask = nn.Parameter(mask, requires_grad=mask_grad) if mask is not None else None

    def forward(self, x, mask=None):
        B, C, T, V = x.size()

        # Branch 1: Transformer + TCN
        tx = x.permute(0, 2, 3, 1).contiguous().view(B * T, V, C)
        tx = self.transf1(tx, self.mask if mask is None else mask)
        tx = tx.view(B, T, V, C).permute(0, 3, 1, 2).contiguous()
        tx = self.tcn1(tx)

        # Branch 2: Temporal attention
        tcn_tx = self.attn(x)

        # Combine outputs
        x = tx + tcn_tx
        x = x + self.residual(x)
        return self.relu(x)

class ZiT(nn.Module):
    """Skeleton feature extraction module."""
    def __init__(self, in_channels=3, num_person=5, num_point=18, num_head=6, graph=None, graph_args=dict()):
        super(ZiT, self).__init__()
        self.data_bn = nn.BatchNorm1d(num_person * in_channels * num_point)
        bn_init(self.data_bn, 1)
        self.heads = num_head

        if graph is None:
            raise ValueError("Graph is required")
        Graph = import_class(graph)
        self.graph = Graph(**graph_args)
        self.A = torch.from_numpy(self.graph.A[0].astype(np.float32))

        self.l1 = TCN_GCN_unit(in_channels, 48, self.graph.A, residual=False)
        self.l2 = TCN_STRANSF_unit(48, 48, heads=num_head, mask=self.A, mask_grad=False)
        self.l3 = TCN_STRANSF_unit(48, 48, heads=num_head, mask=self.A, mask_grad=False)
        self.l4 = TCN_STRANSF_unit(48, 96, heads=num_head, stride=2, mask=self.A, mask_grad=True)
        self.l5 = TCN_STRANSF_unit(96, 96, heads=num_head, mask=self.A, mask_grad=True)
        self.l6 = TCN_STRANSF_unit(96, 192, heads=num_head, stride=2, mask=self.A, mask_grad=True)
        self.l7 = TCN_STRANSF_unit(192, 192, heads=num_head, mask=self.A, mask_grad=True)

    def forward(self, x):
        N, C, T, V, M = x.size()
        x = x.permute(0, 4, 3, 1, 2).contiguous().view(N, M * V * C, T)
        x = self.data_bn(x)
        x = x.view(N, M, V, C, T).permute(0, 1, 3, 4, 2).contiguous().view(N * M, C, T, V)

        x = self.l1(x)
        x = self.l2(x)
        x = self.l3(x)
        x = self.l4(x)
        x = self.l5(x)
        x = self.l6(x)
        x = self.l7(x)

        B, C_, T_, V_ = x.size()
        x = x.view(N, M, C_, T_, V_).mean(4).permute(0, 2, 3, 1).contiguous()
        return x

class ZoT(nn.Module):
    """Group interaction and classification module."""
    def __init__(self, num_class=15, num_head=6):
        super(ZoT, self).__init__()
        self.heads = num_head
        self.conv1 = nn.Conv2d(192, num_head, kernel_size=(1, 1))
        self.conv2 = nn.Conv2d(192, num_head, kernel_size=(1, 1))
        conv_init(self.conv1)
        conv_init(self.conv2)

        self.l1 = TCN_STRANSF_unit(192, 276, heads=num_head)
        self.l2 = TCN_STRANSF_unit(276, 276, heads=num_head)

        self.fc = nn.Linear(276, num_class)
        nn.init.normal_(self.fc.weight, 0, math.sqrt(2. / num_class))

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(x)
        x1 = x1.unsqueeze(3)
        x2 = x2.unsqueeze(4)
        mask = x1 - x2
        N, C, T, M, M2 = mask.shape
        mask = mask.permute(0, 2, 1, 3, 4).contiguous().view(N * T, C, M, M2).detach()
        mask = mask.softmax(dim=-1)

        x = self.l1(x, mask)
        x = self.l2(x, mask)
        x = x.mean(3).mean(2)
        return self.fc(x)

class Model(nn.Module):
    """Full model combining ZiT and ZoT."""
    def __init__(self, num_class=15, in_channels=3, num_person=5, num_point=18, num_head=6, graph=None, graph_args=dict()):
        super(Model, self).__init__()
        self.body_transf = ZiT(in_channels=in_channels, num_person=num_person, num_point=num_point, num_head=num_head, graph=graph, graph_args=graph_args)
        self.group_transf = ZoT(num_class=num_class, num_head=num_head)

    def forward(self, x):
        x = self.body_transf(x)
        x = self.group_transf(x)
        return x