In [1]:
import torch
import torch.nn as nn
import torch.functional as F

In [2]:
from transformers import RobertaModel, RobertaTokenizer


In [3]:
model = RobertaModel.from_pretrained("roberta-base")
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [132]:
text = ["this is a sample text", "another text about lorem ipsum"]

In [136]:
ids = tokenizer(text, return_tensors="pt", padding=True, truncation=True)

In [137]:
ids

{'input_ids': tensor([[    0,  9226,    16,    10,  7728,  2788,     2,     1,     1,     1],
        [    0, 30303,  2788,    59, 36307,   119,  1437,  7418,   783,     2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In [138]:
res = model(**ids)

In [142]:
O = res.last_hidden_state

In [203]:
model.training

False

In [326]:
import abc
import torch
import torch.nn as nn
import einops
from torch import Tensor
from typing import Any, Optional


class Whitening2d(nn.Module):
    def __init__(self, 
        num_features,
        iterations=4, 
        use_running_stats_train=True,
        use_batch_whitening=False,
        use_only_running_stats_eval=False,
        track_running_stats: bool = True,
        momentum: Optional[float] = 0.1,
        affine: bool = True,
        device=None,
        dtype=None,
                ):
        super(Whitening2d, self).__init__()
        factory_kwargs = {"device": device, "dtype": dtype}
        self.num_features=num_features
        self.iterations=iterations
        self.use_batch_whitening=use_batch_whitening
        self.use_running_stats_train=use_running_stats_train
        self.use_only_running_stats_eval=use_only_running_stats_eval
        self.track_running_stats=track_running_stats
        self.momentum=momentum
        self.affine=affine

        if self.affine:
            self.weight=torch.nn.Parameter(torch.empty(num_features, **factory_kwargs))
            self.bias=torch.nn.Parameter(torch.empty(num_features, **factory_kwargs))
        else:
            self.register_parameter("weight", None)
            self.register_parameter("bias", None)
        
        if self.track_running_stats:
            self.register_buffer(
                "running_mean", torch.zeros(num_features, **factory_kwargs)
            )
            self.register_buffer(
                "running_covariance", torch.ones(num_features, num_features, **factory_kwargs)
            )
            self.register_buffer(
                "running_whitening", torch.ones(num_features, num_features, **factory_kwargs)
            )
            self.running_mean: Optional[Tensor]
            self.running_covariance: Optional[Tensor]
            self.running_whitening: Optional[Tensor]
        else:
            self.register_buffer("running_mean", None)
            self.register_buffer("running_covariance")
            self.running_whitening("running_whitening")
        self.reset_parameters()

        
    def reset_running_stats(self) -> None:
        if self.track_running_stats:
            self.running_mean.zero_()  
            self.running_covariance.fill_(1)
            self.running_whitening.fill_(1)

    def reset_parameters(self) -> None:
        self.reset_running_stats()
        if self.affine:
            torch.nn.init.ones_(self.weight)
            torch.nn.init.zeros_(self.bias)

    def update_running_statistic(self, running_statistic, value):
        cur = getattr(self, running_statistic,)
        setattr(self, running_statistic, 
                (1-self.momentum)*cur + self.momentum*value
                )

    def forward_train(self, x):
        
        batch_size, w_dim = x.size(0), x.size(-1)
        
        m_r = x.mean(1, keepdim=True)
        if self.use_running_stats_train:
            m = (1-self.momentum)*self.running_mean + self.momentum*m_r
        else:
            m = m_r
        
        xn = x - m

        eye, sigma_r = self.calc_eye_sigma(xn, w_dim=w_dim, batch_size=batch_size)
        if self.use_running_stats_train:
            sigma = (1-self.momentum)*self.running_covariance[None, :, :] + self.momentum*sigma_r
        else:
            sigma = sigma_r

        wh_matrix = self.whiten_matrix(sigma=sigma, eye=eye)

        if self.track_running_stats:
            self.update_running_statistic("running_mean", m_r.mean(dim=0))
            self.update_running_statistic("running_covariance", sigma_r.mean(dim=0))
            self.update_running_statistic("running_whitening", wh_matrix.mean(dim=0))

        decorrelated = torch.bmm(xn, wh_matrix)
        return decorrelated
    
    @torch.no_grad
    def forward_test(self, x):

        batch_size, w_dim = x.size(0), x.size(-1)

        if self.use_only_running_stats_eval:
            xn = x - self.running_mean
            decorrelated = torch.bmm(xn, 
                                     einops.repeat(self.running_whitening, "feats1 feats2 -> batch feats1 feats2", batch=batch_size), 
                                     )
            return decorrelated
        
        m = x.mean(1, keepdim=True)
        m = (1-self.momentum)*self.running_mean + self.momentum*m

        xn = x - m

        eye, sigma = self.calc_eye_sigma(xn, w_dim=w_dim, batch_size=batch_size)
        sigma = (1-self.momentum)*self.running_covariance[None, :, :] + self.momentum*sigma

        wh_matrix = self.whiten_matrix(sigma=sigma, eye=eye)
        decorrelated = torch.bmm(xn, wh_matrix)
        return decorrelated

    def forward(self, x):

        if self.training:
            x = self.forward_train(x=x)
            return x
        x = self.forward_test(x=x)
        return x

    @abc.abstractmethod
    def whiten_matrix(self, sigma, eye):
        pass

    def calc_eye_sigma(self, xn, w_dim, batch_size):
        eye = einops.repeat(torch.eye(w_dim).type(xn.type()), 
                "feats1 feats2 -> batch feats1 feats2", batch=batch_size).to(xn.device)
        if self.use_batch_whitening:
            batch_cov = einops.rearrange(xn, "batch sequence feats -> (batch sequence) feats")
            sigma = einops.einsum(batch_cov, batch_cov, 
                                  "batch_seq feats1, batch_seq feats2 -> feats1 feats2") / (batch_cov.shape[0] - 1)
            sigma = einops.repeat(sigma, "feats1 feats2 -> batch feats1 feats2", batch=batch_size)
        else:
            sigma = einops.einsum(xn, xn, 
                                  "batch seq feats1, batch seq feats2 -> batch feats1 feats2") / (xn.shape[1] - 1)
        return eye, sigma

    def extra_repr(self):
        return (
            "{num_features}, iterations={iterations}, momentum={momentum}, affine={affine}, "
            "track_running_stats={track_running_stats}, use_batch_whitening={use_batch_whitening}, "
            "use_running_stats_train={use_running_stats_train}, use_only_running_stats_eval={use_only_running_stats_eval}".format(**self.__dict__)
        )

class Whitening2dIterNorm(Whitening2d):

    def whiten_matrix(self, sigma, eye):
        trace = sigma.diagonal(offset=0, dim1=-2, dim2=-1).sum(-1)
        trace = trace.reshape(sigma.size(0), 1, 1)
        sigma_norm = sigma * trace.reciprocal()

        projection = eye
        for _ in range(self.iterations):
            projection = torch.baddbmm(projection, torch.matrix_power(projection, 3), sigma_norm, beta=1.5, alpha=-0.5)
        wm = projection.mul_(trace.reciprocal().sqrt())
        return wm
    






In [339]:
WH = Whitening2dIterNorm(num_features=75, iterations=4, use_batch_whitening=False, use_only_running_stats_eval=False, use_running_stats_train=False, device="cpu")





In [341]:
WH.train()

Whitening2dIterNorm(75, iterations=4, momentum=0.1, affine=True, track_running_stats=True, use_batch_whitening=False, use_running_stats_train=False, use_only_running_stats_eval=False)

In [349]:
X = WH(A)

In [348]:
WH.running_covariance

tensor([[0.9080, 0.8992, 0.8997,  ..., 0.9000, 0.9001, 0.9004],
        [0.8992, 0.9082, 0.9002,  ..., 0.9003, 0.8999, 0.8998],
        [0.8997, 0.9002, 0.9082,  ..., 0.8999, 0.9002, 0.9001],
        ...,
        [0.9000, 0.9003, 0.8999,  ..., 0.9082, 0.9004, 0.8995],
        [0.9001, 0.8999, 0.9002,  ..., 0.9004, 0.9084, 0.8999],
        [0.9004, 0.8998, 0.9001,  ..., 0.8995, 0.8999, 0.9083]])

In [346]:
WH.eval()

Whitening2dIterNorm(75, iterations=4, momentum=0.1, affine=True, track_running_stats=True, use_batch_whitening=False, use_running_stats_train=False, use_only_running_stats_eval=False)

In [350]:
A.shape

torch.Size([12, 100, 75])

In [351]:
X.shape, 

AttributeError: 'NoneType' object has no attribute 'shape'

In [211]:
(A.permute(0, 2, 1) @ A).mean(dim=0)

tensor([[33.7785, 23.8803, 25.4830,  ..., 24.8048, 24.9052, 25.8465],
        [23.8803, 31.8603, 24.9091,  ..., 24.0745, 23.5968, 24.2057],
        [25.4830, 24.9091, 33.8999,  ..., 24.6719, 24.9012, 25.5308],
        ...,
        [24.8048, 24.0745, 24.6719,  ..., 32.0714, 24.2643, 24.0076],
        [24.9052, 23.5968, 24.9012,  ..., 24.2643, 32.1149, 24.2961],
        [25.8465, 24.2057, 25.5308,  ..., 24.0076, 24.2961, 33.3672]])

In [212]:
(A.reshape(-1, 75).T @ A.reshape(-1, 75))/12

tensor([[33.7785, 23.8803, 25.4830,  ..., 24.8048, 24.9052, 25.8465],
        [23.8803, 31.8603, 24.9091,  ..., 24.0745, 23.5968, 24.2057],
        [25.4830, 24.9091, 33.8999,  ..., 24.6719, 24.9012, 25.5308],
        ...,
        [24.8048, 24.0745, 24.6719,  ..., 32.0714, 24.2643, 24.0076],
        [24.9052, 23.5968, 24.9012,  ..., 24.2643, 32.1149, 24.2961],
        [25.8465, 24.2057, 25.5308,  ..., 24.0076, 24.2961, 33.3672]])

In [107]:
import einops

A = torch.rand(12, 100, 75)

B = A - A.mean(dim=1, keepdim=True)

In [90]:
[TL(x, axis=0) for x in A]

[tensor(62.9139),
 tensor(62.9044),
 tensor(63.0548),
 tensor(63.3632),
 tensor(62.6881),
 tensor(63.1737),
 tensor(63.2783),
 tensor(62.7499),
 tensor(62.8380),
 tensor(62.8024),
 tensor(62.9698),
 tensor(63.0622)]

In [91]:
C = torch.bmm(B.permute(0, 2, 1), B) / 49

In [152]:
C.diagonal(offset=0, dim1=1, dim2=2).add(-1).pow(2).sum(dim=1)

tensor([62.9139, 62.9044, 63.0548, 63.3632, 62.6881, 63.1737, 63.2783, 62.7499,
        62.8380, 62.8024, 62.9698, 63.0622])

In [153]:
def trace_loss(output):
    output = output - output.mean(dim=1, keepdim=True)
    output = torch.bmm(output.permute(0, 2, 1), output)/(output.shape[1] - 1)
    tl = output.diagonal(offset=0, dim1=1, dim2=2).add(-1).pow(2).sum(dim=1).mean()
    return tl


In [94]:
[((1-torch.diag(x))**2).sum() for x in C]

[tensor(62.9139),
 tensor(62.9044),
 tensor(63.0548),
 tensor(63.3632),
 tensor(62.6881),
 tensor(63.1737),
 tensor(63.2783),
 tensor(62.7499),
 tensor(62.8380),
 tensor(62.8024),
 tensor(62.9698),
 tensor(63.0622)]

In [158]:
model

RobertaModel(
  (embeddings): RobertaEmbeddings(
    (word_embeddings): Embedding(50265, 768, padding_idx=1)
    (position_embeddings): Embedding(514, 768, padding_idx=1)
    (token_type_embeddings): Embedding(1, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): RobertaEncoder(
    (layer): ModuleList(
      (0-11): 12 x RobertaLayer(
        (attention): RobertaAttention(
          (self): RobertaSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): RobertaSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (dr

In [162]:
model.base_model.embeddings.LayerNorm

LayerNorm((768,), eps=1e-05, elementwise_affine=True)

In [164]:
model.base_model.encoder.layer[1].output

RobertaIntermediate(
  (dense): Linear(in_features=768, out_features=3072, bias=True)
  (intermediate_act_fn): GELUActivation()
)

In [166]:
import regex as re

In [172]:
for name, module in model.named_modules():
    if name == "embeddings" or re.search("encoder\.layer\.[0-9]+\.output$", name):
        print(name)

embeddings
encoder.layer.0.output
encoder.layer.1.output
encoder.layer.2.output
encoder.layer.3.output
encoder.layer.4.output
encoder.layer.5.output
encoder.layer.6.output
encoder.layer.7.output
encoder.layer.8.output
encoder.layer.9.output
encoder.layer.10.output
encoder.layer.11.output


In [187]:
class RobertaClassifier(torch.nn.Module):
    def __init__(self, n_classes, cls_dropout=0.1, use_trace_loss=True):
        super().__init__()
        
        self.roberta = RobertaModel.from_pretrained("roberta-base")
        self.layer_losses = 0  # List to store layer-wise losses
        self.eff_ranks = {}
        self._register_eff_rank_hooks()

        if use_trace_loss:
            self._register_trace_loss_hooks()


        self.classifier = nn.Sequential(
            nn.Linear(768, 768),
            nn.ReLU(),
            nn.Dropout(cls_dropout),
            nn.Linear(768, n_classes))
        
        if n_classes == 1:
            self.criterion = nn.MSELoss()
        else:
            self.criterion = nn.CrossEntropyLoss()

    def _register_trace_loss_hooks(self):
        """Register hooks to calculate and store layer-wise losses"""
        def get_loss_hook(layer_name):
            def hook(module, input, output):
                if isinstance(output, tuple):
                    output = output[0]  # Handle cases where output is a tuple
                # Calculate MSE loss between input and output
                self.layer_losses += 0 # fix later
                return None
            return hook

        # Register hooks for specific layers
        for name, module in self.roberta.named_modules():
            if isinstance(module, nn.Linear):  # You can modify this condition to target specific layers
                module.register_forward_hook(get_loss_hook(name))

    def _register_eff_rank_hooks(self):
        """Register hooks to calculate and store layer-wise losses"""
        def get_loss_hook(layer_name):
            def hook(module, input, output):
                if isinstance(output, tuple):
                    output = output[0].clone().detach()  # Handle cases where output is a tuple
                self.eff_ranks[layer_name] = (
                    torch.linalg.matrix_norm(output, ord="fro", dim=(-2, -1)) / torch.linalg.matrix_norm(output, ord=2, dim=(-2, -1))
                    ).mean()
                return None
            return hook

        # Register hooks for specific layers
        for name, module in self.roberta.named_modules():
            if name == "embeddings" or re.search("encoder\.layer\.[0-9]+\.output$", name):
                module.register_forward_hook(get_loss_hook(name))
    

    def forward(self, input_ids, attention_mask, labels=None, **batch):
        self.layer_losses = 0  # Clear previous losses
        self.eff_ranks = {}
        roberta_output = self.roberta(input_ids, attention_mask=attention_mask)
        pooler = roberta_output[0][:, 0]
        logits = self.classifier(pooler)
        if labels is not None:
            loss = self.criterion(logits.squeeze(), labels)
            return {"loss": loss, "logits": logits, "layer_losses": self.layer_losses}
        return {"logits": logits, "layer_losses": self.layer_losses}


In [188]:
model = RobertaClassifier(n_classes=2,)

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [189]:
model(**ids)

{'logits': tensor([[ 0.1355, -0.0559],
         [ 0.1264, -0.0002]], grad_fn=<AddmmBackward0>),
 'layer_losses': 0}

In [190]:
model.eff_ranks

{'embeddings': tensor(1.8343, grad_fn=<MeanBackward0>),
 'encoder.layer.0.output': tensor(1.3521, grad_fn=<MeanBackward0>),
 'encoder.layer.1.output': tensor(1.1625, grad_fn=<MeanBackward0>),
 'encoder.layer.2.output': tensor(1.1425, grad_fn=<MeanBackward0>),
 'encoder.layer.3.output': tensor(1.1349, grad_fn=<MeanBackward0>),
 'encoder.layer.4.output': tensor(1.1327, grad_fn=<MeanBackward0>),
 'encoder.layer.5.output': tensor(1.1328, grad_fn=<MeanBackward0>),
 'encoder.layer.6.output': tensor(1.1404, grad_fn=<MeanBackward0>),
 'encoder.layer.7.output': tensor(1.1391, grad_fn=<MeanBackward0>),
 'encoder.layer.8.output': tensor(1.1367, grad_fn=<MeanBackward0>),
 'encoder.layer.9.output': tensor(1.1119, grad_fn=<MeanBackward0>),
 'encoder.layer.10.output': tensor(1.1065, grad_fn=<MeanBackward0>),
 'encoder.layer.11.output': tensor(1.0817, grad_fn=<MeanBackward0>)}