In [1]:
import os
import sys
import math
import fire
import json
from tqdm import tqdm
from math import floor, log2
from random import random
from shutil import rmtree
from functools import partial
import multiprocessing

import numpy as np
import torch
from torch import nn
from torch.utils import data
import torch.nn.functional as F

from torch.optim import Adam
from torch.autograd import grad as torch_grad

import torchvision
from torchvision import transforms

from linear_attention_transformer import ImageLinearAttention

from PIL import Image
from pathlib import Path

In [2]:
num_init_filters = 3

In [3]:
network_capacity = 16

In [5]:
num_layers = int(log2(128) - 1)

In [6]:
num_layers

6

In [12]:
filters = [network_capacity * (2 ** (i + 1)) for i in range(num_layers)][::-1]

In [13]:
filters

[1024, 512, 256, 128, 64, 32]

In [15]:
fmap_max = 512
set_fmap_max = partial(min, fmap_max)

In [18]:
filters = list(map(set_fmap_max, filters))

In [22]:
init_channels = filters[0]

In [20]:
*filters

SyntaxError: can't use starred expression here (<ipython-input-20-1c98ba60d54f>, line 4)

In [23]:
filters = [init_channels, *filters]

In [24]:
filters

[512, 512, 512, 256, 128, 64, 32]

In [25]:
filters[:-1]

[512, 512, 512, 256, 128, 64]

In [26]:
filters[1:]

[512, 512, 256, 128, 64, 32]

In [28]:
result = zip(filters[:-1], filters[1:])

In [29]:
result_set = set(result)

In [30]:
result_set

{(64, 32), (128, 64), (256, 128), (512, 256), (512, 512)}

In [31]:
import linear_attention_transformer

In [32]:
import torch
from torch import nn

class ImageLinearAttention(nn.Module):
    def __init__(self, chan, chan_out = None, kernel_size = 1, padding = 0, stride = 1, key_dim = 64, value_dim = 64, heads = 8, norm_queries = True):
        super().__init__()
        self.chan = chan
        chan_out = chan if chan_out is None else chan_out

        self.key_dim = key_dim
        self.value_dim = value_dim
        self.heads = heads

        self.norm_queries = norm_queries

        conv_kwargs = {'padding': padding, 'stride': stride}
        self.to_q = nn.Conv2d(chan, key_dim * heads, kernel_size, **conv_kwargs)
        self.to_k = nn.Conv2d(chan, key_dim * heads, kernel_size, **conv_kwargs)
        self.to_v = nn.Conv2d(chan, value_dim * heads, kernel_size, **conv_kwargs)

        out_conv_kwargs = {'padding': padding}
        self.to_out = nn.Conv2d(value_dim * heads, chan_out, kernel_size, **out_conv_kwargs)

    def forward(self, x, context = None):
        b, c, h, w, k_dim, heads = *x.shape, self.key_dim, self.heads

        q, k, v = (self.to_q(x), self.to_k(x), self.to_v(x))

        q, k, v = map(lambda t: t.reshape(b, heads, -1, h * w), (q, k, v))

        q, k = map(lambda x: x * (self.key_dim ** -0.25), (q, k))

        if context is not None:
            context = context.reshape(b, c, 1, -1)
            ck, cv = self.to_k(context), self.to_v(context)
            ck, cv = map(lambda t: t.reshape(b, heads, k_dim, -1), (ck, cv))
            k = torch.cat((k, ck), dim=3)
            v = torch.cat((v, cv), dim=3)

        k = k.softmax(dim=-1)

        if self.norm_queries:
            q = q.softmax(dim=-2)

        context = torch.einsum('bhdn,bhen->bhde', k, v)
        out = torch.einsum('bhdn,bhde->bhen', q, context)
        out = out.reshape(b, -1, h, w)
        out = self.to_out(out)
        return out

In [33]:
A = torch.randn(1,3,32,32)

In [35]:
module = ImageLinearAttention(3, norm_queries = True)

In [36]:
out = module(A)

In [37]:
out.size()

torch.Size([1, 3, 32, 32])

In [38]:
import retry

In [39]:
A = torch.randn(32,3,32,32)

In [40]:
A

tensor([[[[ 0.3704, -0.6113,  0.7308,  ...,  0.5591,  1.4293, -0.9304],
          [ 1.8500,  1.6959, -0.4120,  ...,  0.3360, -1.7933, -0.7164],
          [ 1.1696,  0.1751,  0.2366,  ...,  1.2030,  0.7360, -0.7904],
          ...,
          [ 0.2714, -0.9956,  0.1947,  ..., -0.2655,  0.4167,  0.2662],
          [ 1.2610,  1.4827,  0.7524,  ...,  0.3342,  1.1273, -0.6111],
          [-0.3727, -0.0563, -0.5229,  ..., -0.9922, -0.3278, -1.2602]],

         [[ 1.1477,  0.4592,  0.6177,  ..., -1.2424,  0.2933, -2.4751],
          [ 1.5440,  1.7549, -1.3749,  ...,  0.2297, -0.0814,  1.1810],
          [-1.1205, -1.9034, -1.6984,  ..., -0.5263,  0.4011,  1.1830],
          ...,
          [-1.7841,  0.2078, -0.7121,  ..., -0.6268,  1.3279,  0.3946],
          [-0.6363, -0.2628,  0.9047,  ..., -0.4399, -1.4968,  1.5452],
          [ 1.3672, -0.6764,  0.8107,  ..., -0.6315,  0.6567, -0.1027]],

         [[ 1.1908,  0.9489, -0.8995,  ..., -1.7106, -0.3303, -0.7173],
          [-0.2954,  0.2926,  

In [41]:
A.expand(32, 3, 32, 32)

tensor([[[[ 0.3704, -0.6113,  0.7308,  ...,  0.5591,  1.4293, -0.9304],
          [ 1.8500,  1.6959, -0.4120,  ...,  0.3360, -1.7933, -0.7164],
          [ 1.1696,  0.1751,  0.2366,  ...,  1.2030,  0.7360, -0.7904],
          ...,
          [ 0.2714, -0.9956,  0.1947,  ..., -0.2655,  0.4167,  0.2662],
          [ 1.2610,  1.4827,  0.7524,  ...,  0.3342,  1.1273, -0.6111],
          [-0.3727, -0.0563, -0.5229,  ..., -0.9922, -0.3278, -1.2602]],

         [[ 1.1477,  0.4592,  0.6177,  ..., -1.2424,  0.2933, -2.4751],
          [ 1.5440,  1.7549, -1.3749,  ...,  0.2297, -0.0814,  1.1810],
          [-1.1205, -1.9034, -1.6984,  ..., -0.5263,  0.4011,  1.1830],
          ...,
          [-1.7841,  0.2078, -0.7121,  ..., -0.6268,  1.3279,  0.3946],
          [-0.6363, -0.2628,  0.9047,  ..., -0.4399, -1.4968,  1.5452],
          [ 1.3672, -0.6764,  0.8107,  ..., -0.6315,  0.6567, -0.1027]],

         [[ 1.1908,  0.9489, -0.8995,  ..., -1.7106, -0.3303, -0.7173],
          [-0.2954,  0.2926,  