In [1]:
from __future__ import annotations

import math

import torch
import torch.nn as nn
from torch import Tensor, BoolTensor
from torch.nn import functional as F

In [13]:
from model.transformer import TransformerBlock, CausalAttention

In [5]:
from model.time2vec import SineActivation, CosineActivation

VALID_T2V_ACTIVATION = ["sin", "cos"]

In [9]:
import torch

# Suppose we have a tensor of shape (B, S, C)
B, S, C = 4, 5, 1  # Example dimensions
tensor = torch.randn(B, S, C)  # Create a random tensor

# Extract the slice [:, :, 0] and retain the singleton dimension
slice_tensor = tensor[:, :, 1:]

# Check the shape of the resulting tensor
print("Original shape:", tensor.shape)  # (B, S, C)
print("Slice shape:", slice_tensor.shape)  # (B, S, 1)

Original shape: torch.Size([4, 5, 1])
Slice shape: torch.Size([4, 5, 0])


In [45]:
class QCCT(nn.Module):
    """QUIC Congestion Control Transformer."""

    def __init__(
        self,
        n_features: int,
        hidden_size: int,
        n_heads: int,
        n_layers: int,
        expand_size: int,
        context_size: int,
        t2v_act: str = "sin",
        act: nn.Module = nn.GELU,
        attention: nn.Module = CausalAttention,
        drop: float = 0.1,
        bias: bool = True,
    ):
        super().__init__()

        # 1. Features:
        # 1.1 timestamp
        if t2v_act == "sin":
            self.t2v = SineActivation(1, hidden_size)
        elif t2v_act == "cos":
            self.t2v = CosineActivation(1, hidden_size)
        else:
            raise Exception(f"Unsupported activation:{t2v_act} for time2vec")
        # 1.2 other features
        self.o2v = nn.ModuleList(
            [
                nn.Linear(n_features - 1, expand_size, bias=bias),
                act(),
                nn.Linear(expand_size, hidden_size, bias=bias),
                nn.Dropout(drop),
            ]
        )
        # 1.3 feature dropout
        self.f_drop = nn.Dropout(drop)

        # 2. transformer blocks
        # initialize num_layers of transformer layers
        self.tfm_blocks = nn.ModuleList(
            [
                TransformerBlock(
                    hidden_size=hidden_size,
                    num_heads=n_heads,
                    context_size=context_size,
                    expand_size=expand_size,
                    attention=attention,
                    act=act,
                    bias=bias,
                    attn_drop=drop,
                    out_drop=drop,
                    ffn_drop=drop,
                )
                for _ in range(n_layers)
            ]
        )

        # 3. output
        self.final = nn.Linear(hidden_size, 1, bias=bias)

        # 4. init parameters
        self.apply(self._init_weights)

    def forward(self, x: Tensor):
        # [Input]: (B, S, C)
        # B: batch_size, S: n_events, C: n_features
        B, S, C = x.shape
        # Step 1: (B, S, C) -> (B, S, D)
        # B: batch_size, S: n_events, D: hidden_size

        # Step 1.1: timestamp
        # (B, S, 1)
        timestamp = x[:, :, 0].unsqueeze(-1)
        # (B, S, D)
        f_ts = self.t2v(timestamp)

        # Step 1.2: other features
        # (B, S, C-1)
        f_others = x[:, :, 1:]
        # (B, S, D)
        for layer in self.o2v:
            f_others = layer(f_others)

        # Step 1.3: Addition
        f_all = self.f_drop(f_ts + f_others)

        # Step 2: transformer blocks
        for block in self.tfm_blocks:
            f_all = block(f_all)

        # Step 3: next congestion control window
        return self.final(f_all)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            if module._get_name() == "fc2":
                # GPT-2 style FFN init
                torch.nn.init.normal_(
                    module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.num_layers)
                )
            else:
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)

In [46]:
n_features = 8
hidden_size = 64
n_heads = 4
n_layers = 4
expand_size = 128
context_size = 32

In [47]:
model = QCCT(
    n_features=n_features,
    hidden_size=hidden_size,
    n_heads=n_heads,
    n_layers=n_layers,
    expand_size=expand_size,
    context_size=context_size,
)
model

QCCT(
  (t2v): SineActivation()
  (o2v): ModuleList(
    (0): Linear(in_features=7, out_features=128, bias=True)
    (1): GELU(approximate='none')
    (2): Linear(in_features=128, out_features=64, bias=True)
    (3): Dropout(p=0.1, inplace=False)
  )
  (f_drop): Dropout(p=0.1, inplace=False)
  (tfm_blocks): ModuleList(
    (0-3): 4 x TransformerBlock(
      (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (attn): CausalAttention(
        (Wqkv): Linear(in_features=64, out_features=192, bias=True)
        (attn_drop): Dropout(p=0.1, inplace=False)
        (Wo): Linear(in_features=64, out_features=64, bias=True)
        (out_drop): Dropout(p=0.1, inplace=False)
      )
      (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ffn): FeedForward(
        (fc1): Linear(in_features=64, out_features=128, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=128, out_features=64, bias=True)
        (drop): Dropout(p=0.1, inplac

In [48]:
B, S, C = 1, 1, 8  # Example dimensions
tensor = torch.randn(B, S, C)  # Create a random tensor
tensor

tensor([[[ 0.6939, -0.1266,  0.1713,  1.2234,  0.2717,  0.8448, -1.5471,
          -0.5695]]])

In [49]:
model(tensor)

tensor([[[0.2523]]], grad_fn=<ViewBackward0>)

In [7]:
class GPTForCausalLM(GPT):
    def __init__(self, loss_fn: nn.Module = nn.CrossEntropyLoss(), **kwargs):
        super().__init__(**kwargs)
        self.loss_fn = loss_fn

    def forward(self, x: Tensor):
        # the labels are the next token, so shift the labels over one
        # & resize inputs to same length as labels by dropping last token
        inputs = x[:, :-1]
        labels = x[:, 1:]

        # logits are of shape batch, sequence length, vocab size (B, S, VS),
        # labels are of shape batch, vocab size (B, S)
        logits = super().forward(inputs)

        # flatten logits into (B*S, VS) and labels into (B*S) & calculate loss
        loss = self.loss_fn(logits.view(-1, logits.shape[-1]), labels.view(-1))

        # return both the logits and the loss
        return {"logits": logits, "loss": loss}