In [2]:
import torch
import torch.nn as nn

class FinalLayer (nn.Module):
    """
    Final layer of the backbone
    """
    def __init__(self, hidden_size, patch_size, out_channels):
        super().__init__()
        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
        self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
    
    def forward(self, x, y):
        scale, shift = self.adaLN_modulation(y).chunk(2, dim=-1) # 2x (B, C)
        x = modulate(self.norm_final(x), shift, scale) # x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) -> (B, T, C)
        x = self.linear(x) # (B, T, C) - > (B, T, patch_size * patch_size * out_channels)
        return x

def modulate(x, shift, scale):
    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)

In [3]:
layer = FinalLayer(7, 2, 4)

In [4]:
layer(torch.randn(1, 3, 7), torch.randn(1, 7)).shape

torch.Size([1, 3, 16])

In [5]:
import torch
import torch.nn as nn

class activation(nn.Module):
    def __init__(self, activation_layer= nn.GELU):
        super().__init__()
        self.act = activation_layer
    
    def forward (self, x):
        return self.act(x)

x = torch.randn(4)
a = activation(activation_layer=nn.GELU(approximate="tanh"))
a(x)


tensor([-0.0447, -0.0520, -0.0628,  0.8552])

In [7]:
from torch import Tensor
import torch
import torch.nn as nn
import torch.nn.functional as F
class GatedMLP(nn.Module):
    def __init__(
            self,
            fan_in: int,
            fan_h: int = None,
            fan_out: int = None,
            act_layer = lambda:nn.GELU(approximate="tanh"),
            drop: float = 0.0,
            bias: bool = True,
    )-> None:
        super().__init__()
        fan_out = fan_out or fan_in # stores first truth value
        fan_h = fan_h or fan_in
        self.fc1 = nn.Linear(fan_in, 2*fan_h, bias=bias)
        self.fc2 = nn.Linear(fan_h, fan_out, bias=bias)
        self.act_layer = act_layer()
    
    def forward(self, x:Tensor)-> Tensor:
        x = self.fc1(x)
        x, scale = x.chunk(2, dim=-1)
        x = self.act_layer(x) * scale
        x = self.fc2(x)
        return x

# works! : when init calls approx_gelu(), it instantiates the GELU object
approx_gelu = lambda: nn.GELU(approximate="tanh")
m = GatedMLP(1, 2, act_layer=approx_gelu, drop=0, bias=False)
x = torch.randn(2, 1)
m(x)

tensor([[-4.7365e-05],
        [-1.8660e-03]], grad_fn=<MmBackward0>)

In [6]:
approx_gelu, nn.GELU

(<function __main__.<lambda>()>, torch.nn.modules.activation.GELU)

In [4]:
import torch
import torch.nn as nn

a = nn.Conv2d(3, 3, 4, 1, 0)

In [5]:
x = a.weight.data

In [6]:
x.shape

torch.Size([3, 3, 4, 4])

In [9]:
k = x.view (x.shape[0], -1)
k.shape, x.flatten(1, -1).shape

x.view(x.shape[0], -1).shape, x.flatten(1, -1).shape

(torch.Size([3, 48]), torch.Size([3, 48]))

In [20]:
from timm.models.vision_transformer import PatchEmbed
a = PatchEmbed(32, 2, 4, 64)

In [21]:
a.num_patches**0.5

16.0

In [22]:
a = nn.Linear(3,3)

In [23]:
a.bias._noinit = 1

In [24]:
not hasattr(a.bias, "_noinit")

False

In [25]:
for k,p in a.named_parameters():
    print (a)



Linear(in_features=3, out_features=3, bias=True)
Linear(in_features=3, out_features=3, bias=True)


In [30]:
if isinstance(a, nn.Linear):
    nn.init.constant_(a.weight, 16)

In [31]:
a.weight

Parameter containing:
tensor([[16., 16., 16.],
        [16., 16., 16.],
        [16., 16., 16.]], requires_grad=True)

In [41]:
import math
nn.init.kaiming_uniform_(a.weight, a=math.sqrt(5))

Parameter containing:
tensor([[-0.1373,  0.2918, -0.5081],
        [ 0.3046,  0.5191,  0.0401],
        [-0.1881, -0.4669, -0.4744]], requires_grad=True)

In [42]:
a.weight

Parameter containing:
tensor([[-0.1373,  0.2918, -0.5081],
        [ 0.3046,  0.5191,  0.0401],
        [-0.1881, -0.4669, -0.4744]], requires_grad=True)

In [43]:
with torch.no_grad():
    a.weight /= 2

In [44]:
a.weight

Parameter containing:
tensor([[-0.0687,  0.1459, -0.2540],
        [ 0.1523,  0.2596,  0.0200],
        [-0.0941, -0.2334, -0.2372]], requires_grad=True)

In [11]:
from CrossFusionMamba import CrossFusionMamba
from CrossAttentionFusion import CrossAttentionFusion
import torch
fuser_mamba = CrossFusionMamba(768, 20, 32)
fuser_att = CrossAttentionFusion(768)
a = torch.randn(1, 3, 384)
b = torch.randn(1, 3, 384)
fuser_att.to('cuda')
fuser_mamba.to('cuda')

b=b.to('cuda')
a=a.to('cuda')

att = fuser_att(a,b)
mamba = fuser_mamba(a,b)

print(f"Attention var:{att.var()}, Mamba var: {mamba.var()}")

Attention var:0.034645210951566696, Mamba var: 0.00042521519935689867


In [26]:
from mamba_ssm.modules.mamba_simple import CondMamba
x = CondMamba(d_model=100, d_cond=100).to('cuda')
a =torch.randn(2,10,100).to('cuda')

In [27]:
out = x(a,a[:,-1,:])

In [None]:
out.var()

tensor(0.0015, device='cuda:0', grad_fn=<VarBackward0>)

: 

In [2]:
import torch
from CrossAttentionFusion import CrossAttentionFusion
from mamba_ssm.modules.mamba_simple import CondMamba
x = CondMamba(d_model=64, d_cond=100).to('cuda')
c = CrossAttentionFusion(64)
a =torch.randn(2,10,100).to('cuda')

In [3]:
a = sum(p.numel()for p in x.parameters())
b = sum(p.numel()for p in c.parameters())
a,b

(45568, 10304)

In [2]:
import torch
import torch.nn as nn
from mamba_ssm.modules.mamba_simple import CondMamba, ArceeCondMamba

data = torch.randn(2, 10, 4).to('cuda')
x = ArceeCondMamba(d_model=4, d_cond=4).to('cuda')
y = ArceeCondMamba(d_model=4, d_cond=4).to('cuda')

x1 = CondMamba(d_model=4, d_cond=4).to('cuda')


activation = x(data, cond_emb=data[:,-1,:])
activation = y(activation, cond_emb=activation[:,-1,:])
loss = activation
loss.mean().backward()

for name, param in x.named_parameters():
    if param.grad is None:
        print(f"[NO GRAD] {name}")

print (f"ArceeCond parameters are :{sum(p.numel() for p in x1.parameters())} and base cond parameters{sum(p.numel() for p in x.parameters())}")


YAY
BEFORENEW PASS

YAY
AFTER NEW PASS

YAY
BEFORENEW PASS

YAY
AFTER NEW PASS
ArceeCond parameters are :592 and base cond parameters656


In [7]:
import torch

x = torch.randn(2, 4, requires_grad=True)
y = x[..., None].expand(-1, -1, 10)
z = y.sum()

# Reset gradients properly
if x.grad is not None:
    x.grad.zero_()

z.backward()
print(x.grad)

tensor([[10., 10., 10., 10.],
        [10., 10., 10., 10.]])


In [None]:
import torch
from torch import nn

class Test(nn.Module):
    def __init__(self):
        super().__init__()
        self.cond_proj = nn.Linear(4, 8)

    def forward(self, cond_emb):
        #cond_emb = self.cond_proj(cond_emb)[..., None].expand(-1, -1, 10)
        cond_emb = self.cond_proj(cond_emb)[..., None].repeat(1, 1, 10)
        return cond_emb.sum()

model = Test().cuda()
x = torch.randn(2, 4, device="cuda", requires_grad=True)
loss = model(x)
loss.backward()

print(model.cond_proj.weight.grad is not None)

True


In [None]:
import torch
from torch.autograd import gradcheck
import torch.nn.functional as F
B, D_cond, D_inner, L = 2, 3, 4, 5
cond_emb = torch.randn(B, D_cond, dtype=torch.double, requires_grad=True)
cond_proj_weight = torch.randn(D_inner, D_cond, dtype=torch.double, requires_grad=True)
cond_proj_bias = torch.randn(D_inner, dtype=torch.double, requires_grad=True)

def f(cond_emb):
    init_states = F.linear(cond_emb, cond_proj_weight, cond_proj_bias)
    return init_states.unsqueeze(-1).expand(-1, -1, L).sum()  # simple loss

gradcheck(f, (cond_emb,))

True