Skip to content

Commit

Permalink
reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
BenediktAlkin committed Dec 30, 2023
1 parent be29750 commit 86b39b7
Show file tree
Hide file tree
Showing 21 changed files with 41 additions and 25 deletions.
2 changes: 1 addition & 1 deletion kappamodules/attention/linear_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,4 +93,4 @@ def to_channel_last(self, x):

def to_channel_first(self, x, og_shape):
_, _, h, w, d = og_shape
return einops.rearrange(x, "bs (h w d) dim -> bs dim h w d", h=h, w=w, d=d)
return einops.rearrange(x, "bs (h w d) dim -> bs dim h w d", h=h, w=w, d=d)
3 changes: 2 additions & 1 deletion kappamodules/attention/perceiver_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
init_truncnormal_zero_bias,
)


class PerceiverAttention(nn.Module):
def __init__(self, dim, num_heads=8, bias=True, concat_query_to_kv=False, init_weights="truncnormal"):
super().__init__()
Expand Down Expand Up @@ -65,4 +66,4 @@ def forward(self, q, kv, attn_mask=None):


class PerceiverAttention1d(PerceiverAttention):
pass
pass
2 changes: 2 additions & 0 deletions kappamodules/init/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def init_norm_as_noaffine(m):
if m.weight is not None:
nn.init.constant_(m.weight, 1.)


# LEGACY remove
def init_norms_as_noaffine(m):
if isinstance(m, ALL_NORMS):
Expand Down Expand Up @@ -74,6 +75,7 @@ def init_norm_as_identity(m):
if m.weight is not None:
nn.init.constant_(m.weight, 0.)


# LEGACY remove
def init_norms_as_identity(m):
if isinstance(m, ALL_NORMS):
Expand Down
4 changes: 2 additions & 2 deletions kappamodules/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .identity import Identity
from .normalize import Normalize
from .paramless_batchnorm import ParamlessBatchNorm1d
from .residual import Residual
from .weight_norm_linear import WeightNormLinear
from .regular_grid_sincos_embed import RegularGridSincosEmbed
from .residual import Residual
from .rms_norm import RMSNorm
from .weight_norm_linear import WeightNormLinear
2 changes: 1 addition & 1 deletion kappamodules/modulation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .dit import Dit
from .film import Film
from .timestep_embed import TimestepEmbed
from .timestep_embed import TimestepEmbed
4 changes: 1 addition & 3 deletions kappamodules/modulation/film.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import math

from torch import nn

from kappamodules.init import init_xavier_uniform_merged_linear
from kappamodules.utils.shapes import to_ndim

from kappamodules.init import init_xavier_uniform_merged_linear

class Film(nn.Module):
def __init__(self, dim_cond, dim_out, init_weights="xavier_uniform"):
Expand Down
2 changes: 1 addition & 1 deletion kappamodules/transformer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .perceiver_block import PerceiverBlock
from .postnorm_block import PostnormBlock
from .postnorm_block import PostnormBlock
2 changes: 1 addition & 1 deletion kappamodules/unet/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .unet_denoising_diffusion import UnetDenoisingDiffusion
from .unet_denoising_diffusion import UnetDenoisingDiffusion
1 change: 0 additions & 1 deletion kappamodules/vit/vit_batchnorm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import einops
from torch import nn


Expand Down
1 change: 1 addition & 0 deletions kappamodules/vit/vit_pos_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def forward(self, x):
embed = self.embed
return x + embed


# LEGACY remove
class VitPosEmbedNd(VitPosEmbed):
pass
Expand Down
1 change: 1 addition & 0 deletions original_modules/mae_pos_embed.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np


def get_2d_sincos_pos_embed(embed_dim, grid_size):
if isinstance(grid_size, int):
grid_size = (grid_size, grid_size)
Expand Down
7 changes: 4 additions & 3 deletions original_modules/original_perceiver_attention.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import einops
import torch
from torch import nn
import einops


class OriginalPerceiverAttention(nn.Module):
"""
Expand All @@ -13,7 +14,7 @@ class OriginalPerceiverAttention(nn.Module):

def __init__(self, dim, dim_head=64, heads=8):
super().__init__()
self.scale = dim_head**-0.5
self.scale = dim_head ** -0.5
self.heads = heads
inner_dim = dim_head * heads

Expand Down Expand Up @@ -45,4 +46,4 @@ def forward(self, x, latents):

out = torch.einsum("... i j, ... j d -> ... i d", attn, v)
out = einops.rearrange(out, "b h tn d -> b tn (h d)", h=h)
return self.to_out(out)
return self.to_out(out)
5 changes: 3 additions & 2 deletions tests_unit/attention/test_perceiver_attention.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import unittest

import torch

from kappamodules.attention import PerceiverAttention1d
from original_modules.original_perceiver_attention import OriginalPerceiverAttention
import torch


class TestPerceiverAttention(unittest.TestCase):
def test_shape(self):
Expand Down Expand Up @@ -51,4 +53,3 @@ def test_equal_to_original(self):
self.assertEqual(q.shape, y_og.shape)
self.assertEqual(y_kc.shape, y_og.shape)
self.assertTrue(torch.allclose(y_kc, y_og))

5 changes: 4 additions & 1 deletion tests_unit/functional/test_pos_embed.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import unittest

import einops
import torch
import unittest

from kappamodules.functional.pos_embed import get_sincos_pos_embed_from_seqlens
from original_modules.mae_pos_embed import get_2d_sincos_pos_embed


class TestPosEmbed(unittest.TestCase):
def test_shapes(self):
self.assertEqual((10, 5), get_sincos_pos_embed_from_seqlens(seqlens=(10,), dim=5).shape)
Expand Down
4 changes: 3 additions & 1 deletion tests_unit/layers/test_continuous_sincos_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import einops
import torch
from kappamodules.layers import ContinuousSincosEmbed

from kappamodules.functional.pos_embed import get_sincos_pos_embed_from_seqlens
from kappamodules.layers import ContinuousSincosEmbed


class TestContinuousSincosEmbed(unittest.TestCase):
def test_shape(self):
Expand Down
3 changes: 3 additions & 0 deletions tests_unit/transformer/test_perceiver_block.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import unittest

import torch

from kappamodules.transformer import PerceiverBlock


class TestPerceiverBlock(unittest.TestCase):
def test_shape(self):
dim = 8
Expand Down
6 changes: 4 additions & 2 deletions tests_unit/unet/test_unet_denoising_diffusion.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import torch
import unittest

import torch

from kappamodules.unet import UnetDenoisingDiffusion


class TestUnetDenoisingDiffusion(unittest.TestCase):
def test_1d_uncond(self):
torch.manual_seed(9823)
Expand Down Expand Up @@ -32,4 +34,4 @@ def test_1d_cond(self):
cond = torch.randn(2, 4)
y = model(x, cond=cond)
self.assertEqual(x.shape, y.shape)
self.assertTrue(torch.isclose(y.mean(), torch.tensor(0.54554283618927)))
self.assertTrue(torch.isclose(y.mean(), torch.tensor(0.54554283618927)))
1 change: 1 addition & 0 deletions tests_unit/vit/test_dit_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from kappamodules.vit import DitBlock
from original_modules.original_dit_block import OriginalDitBlock


class TestDitBlock(unittest.TestCase):
def test_equal_to_original(self):
dim = 12
Expand Down
6 changes: 4 additions & 2 deletions tests_unit/vit/test_vit_batchnorm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import unittest

from kappamodules.vit import VitBatchNorm
import torch

from kappamodules.vit import VitBatchNorm


class TestVitBatchNorm(unittest.TestCase):
def test_0d(self):
dim = 4
Expand Down Expand Up @@ -38,4 +40,4 @@ def test_3d(self):
y = bn(x)
self.assertEqual(x.shape, y.shape)
self.assertTrue(torch.allclose(y.mean(dim=[0, 1, 2, 3]), torch.zeros(size=(dim,)), atol=1e-6))
self.assertTrue(torch.allclose(y.std(dim=[0, 1, 2, 3]), torch.ones(size=(dim,)), atol=1e-1))
self.assertTrue(torch.allclose(y.std(dim=[0, 1, 2, 3]), torch.ones(size=(dim,)), atol=1e-1))
3 changes: 2 additions & 1 deletion tests_unit/vit/test_vit_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@

from kappamodules.vit import VitBlock


class TestVitBlock(unittest.TestCase):
def test(self):
dim = 4
block = VitBlock(dim=dim, num_heads=2)
x = torch.randn(2, 6, dim, generator=torch.Generator().manual_seed(9834))
y = block(x)
self.assertEqual(x.shape, y.shape)
self.assertEqual(x.shape, y.shape)
2 changes: 0 additions & 2 deletions tests_unit/vit/test_vit_pos_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def test_3d_fixed(self):
def test_3d_learnable(self):
self._test3d(is_learnable=True)


def test_interpolate_2d(self):
seqlens = (8, 12)
dim = 64
Expand All @@ -63,4 +62,3 @@ def test_interpolate_2d(self):
x_half = torch.zeros(2, *[seqlen // 2 for seqlen in seqlens], dim)
y_half = pos_embed(x_half)
self.assertEqual(x_half.shape, y_half.shape)

0 comments on commit 86b39b7

Please sign in to comment.