In [1]:
from google.colab import drive

drive.mount('/content/drive/')

Mounted at /content/drive/


In [2]:
import os
os.chdir("drive/")
os.chdir('My Drive')
os.chdir('Experiment')

In [3]:
!pip install pytorch_lightning

Collecting pytorch_lightning
[?25l  Downloading https://files.pythonhosted.org/packages/07/0c/e2d52147ac12a77ee4e7fd7deb4b5f334cfb335af9133a0f2780c8bb9a2c/pytorch_lightning-1.2.10-py3-none-any.whl (841kB)
[K     |████████████████████████████████| 849kB 7.4MB/s 
[?25hCollecting future>=0.17.1
[?25l  Downloading https://files.pythonhosted.org/packages/45/0b/38b06fd9b92dc2b68d58b75f900e97884c45bedd2ff83203d933cf5851c9/future-0.18.2.tar.gz (829kB)
[K     |████████████████████████████████| 829kB 15.7MB/s 
[?25hCollecting PyYAML!=5.4.*,>=5.1
[?25l  Downloading https://files.pythonhosted.org/packages/64/c2/b80047c7ac2478f9501676c988a5411ed5572f35d1beff9cae07d321512c/PyYAML-5.3.1.tar.gz (269kB)
[K     |████████████████████████████████| 276kB 32.7MB/s 
Collecting fsspec[http]>=0.8.1
[?25l  Downloading https://files.pythonhosted.org/packages/e9/91/2ef649137816850fa4f4c97c6f2eabb1a79bf0aa2c8ed198e387e373455e/fsspec-2021.4.0-py3-none-any.whl (108kB)
[K     |██████████████████████████████

In [4]:
!pip install einops

Collecting einops
  Downloading https://files.pythonhosted.org/packages/5d/a0/9935e030634bf60ecd572c775f64ace82ceddf2f504a5fd3902438f07090/einops-0.3.0-py2.py3-none-any.whl
Installing collected packages: einops
Successfully installed einops-0.3.0


In [5]:
from math import sqrt
import torch
import torch.nn.functional as F
from torch import nn, einsum

from einops import rearrange, repeat
from einops.layers.torch import Rearrange



#------------------------------
# helpers
#------------------------------

def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d


#------------------------------
# class
#------------------------------

class GroupedFeedForward(nn.Module):
    def __init__(self, *, dim, groups, mult = 4):
        super().__init__()

        # levels * dim
        total_dim = dim * groups

        # a grouped feed forward network
        self.net = nn.Sequential(
            Rearrange('b n l d -> b (l d) n'),
            nn.Conv1d(total_dim, total_dim * mult, 1, groups = groups),
            nn.GELU(),
            nn.Conv1d(total_dim * mult, total_dim, 1, groups = groups),
            Rearrange('b (l d) n -> b n l d', l = groups)
        )

    def forward(self, levels):
        return self.net(levels)

class ConsensusAttention(nn.Module):
    def __init__(self, num_patches_side, attend_self = True, local_consensus_radius = 0):
        super().__init__()
        self.attend_self = attend_self
        self.local_consensus_radius = local_consensus_radius

        # constant
        self.TOKEN_ATTEND_SELF_VALUE = -5e-4

        if self.local_consensus_radius > 0:
            coors = torch.stack(torch.meshgrid(
                torch.arange(num_patches_side),
                torch.arange(num_patches_side)
            )).float()

            coors = rearrange(coors, 'c h w -> (h w) c')
            dist = torch.cdist(coors, coors)

            mask_non_local = dist > self.local_consensus_radius
            mask_non_local = rearrange(mask_non_local, 'i j -> () i j')
            
            self.register_buffer('non_local_mask', mask_non_local)


    def forward(self, levels):
        _, n, _, d, device = *levels.shape, levels.device
        q, k, v = levels, F.normalize(levels, dim = -1), levels

        sim = einsum('b i l d, b j l d -> b l i j', q, k) * (d ** -0.5)

        if not self.attend_self:
            self_mask = torch.eye(n, device = device, dtype = torch.bool)
            self_mask = rearrange(self_mask, 'i j -> () () i j')
            sim.masked_fill_(self_mask, self.TOKEN_ATTEND_SELF_VALUE)

        if self.local_consensus_radius > 0:
            max_neg_value = -torch.finfo(sim.dtype).max
            sim.masked_fill_(self.non_local_mask, max_neg_value)

        attn = sim.softmax(dim = -1)
        out = einsum('b l i j, b j l d -> b i l d', attn, levels)
        return out


#------------------------------
# main class
#------------------------------

class GLOM(nn.Module):
    def __init__(
        self,
        *,
        dim = 512,
        levels = 6,
        image_size = 224,
        patch_size = 14,
        consensus_self = False,
        local_consensus_radius = 0,
        img_channels=3
    ):
        super().__init__()
        # bottom level - incoming image, tokenize and add position
        num_patches_side = (image_size // patch_size)
        num_patches =  num_patches_side ** 2
        self.levels = levels

        self.image_to_tokens = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
            nn.Linear(patch_size ** 2 * img_channels, dim)
        )
        self.pos_emb = nn.Embedding(num_patches, dim)

        # initial embeddings for all levels of a column
        self.init_levels = nn.Parameter(torch.randn(levels, dim))

        # bottom-up and top-down
        self.bottom_up = GroupedFeedForward(dim = dim, groups = levels)
        self.top_down = GroupedFeedForward(dim = dim, groups = levels - 1)

        # consensus attention
        self.attention = ConsensusAttention(num_patches_side, attend_self = consensus_self, local_consensus_radius = local_consensus_radius)


    def forward(self, img, iters=None, levels=None, return_all=False):
        b, device = img.shape[0], img.device

        # need to have twice the number of levels of iterations in order for information to propagate up and back down. can be overridden
        iters = default(iters, self.levels * 2)

        tokens = self.image_to_tokens(img)
        n = tokens.shape[1]

        pos_embs = self.pos_emb(torch.arange(n, device = device))
        pos_embs = rearrange(pos_embs, 'n d -> () n () d')

        bottom_level = tokens
        bottom_level = rearrange(bottom_level, 'b n d -> b n () d')

        if not exists(levels):
            levels = repeat(self.init_levels, 'l d -> b n l d', b = b, n = n)

        hiddens = [levels]

        num_contributions = torch.empty(self.levels, device = device).fill_(4)
        num_contributions[-1] = 3  # top level does not get a top-down contribution, so have to account for this when doing the weighted mean

        for _ in range(iters):
            levels_with_input = torch.cat((bottom_level, levels), dim = -2)  # each iteration, attach original input at the most bottom level, to be bottomed-up

            bottom_up_out = self.bottom_up(levels_with_input[..., :-1, :])

            top_down_out = self.top_down(levels_with_input[..., 2:, :] + pos_embs) # positional embeddings given to top-down networks
            top_down_out = F.pad(top_down_out, (0, 0, 0, 1), value = 0.)

            consensus = self.attention(levels)

            levels_sum = torch.stack((levels, bottom_up_out, top_down_out, consensus)).sum(dim = 0) # hinton said to use the weighted mean of (1) bottom up (2) top down (3) previous level value {t - 1} (4) consensus value
            levels_mean = levels_sum / rearrange(num_contributions, 'l -> () () l ()')

            levels = levels_mean  # set for next iteration
            hiddens.append(levels)

        if return_all:
            return torch.stack(hiddens)  # return (time step, batch, num columns, levels, dimension)

        return levels


In [12]:
from math import sqrt
import torch
import torch.nn.functional as F
from torch import nn, einsum
from torch.autograd import Variable
import pytorch_lightning as pl

from einops import rearrange, repeat
from einops.layers.torch import Rearrange


#------------------------------
# main class
#------------------------------

class LightningGLOM(pl.LightningModule):
    def __init__(
        self,
        *,
        dim=512,
        levels=6,
        image_size=224,
        patch_size=14,
        consensus_self=False,
        local_consensus_radius=0,
        lr=1e-3,
        img_channels=3
    ):
        super().__init__()        
        self.lr, self.levels = lr, levels

        # a GLOM model
        self.glom = GLOM(
            dim=dim, 
            levels=self.levels,
            image_size=image_size, 
            patch_size=patch_size, 
            consensus_self=consensus_self, 
            local_consensus_radius=local_consensus_radius,
            img_channels=img_channels
        )

        # use MSE loss as a loss function
        self.loss_func = F.mse_loss

        # a network that converts the generated patches to images
        self.patches_to_images = nn.Sequential(
            nn.Linear(dim, patch_size ** 2 * img_channels),
            Rearrange('b (h w) (p1 p2 c) -> b c (h p1) (w p2)', p1 = 14, p2 = 14, h = (224 // 14))
        )

    
    def forward(self, img, iters=None, levels=None, return_all=False):
        all_levels = self.glom(img, iters=iters, levels=levels, return_all=return_all)
        return all_levels


    def calculate_loss(self, img):
        # change the shape of the tensor from 3D to 4D
        img = img.unsqueeze(0)
        # add random noise to images
        noised_img = img + torch.randn_like(img)

        # forward propagation
        all_levels = self(noised_img, return_all=True)


        # Reconstruct images from patches.

        # Get the top level embeddings after iteration n, where n is (# of levels + 1)
        # This is because the GLOM model needs to have twice the number of levels of iterations
        # in order for information to propagate up and back down.
        recon_img = all_levels[self.levels + 1, :, :, -1]

        # calculate loss
        loss = self.loss_func(img, recon_img.unsqueeze(0))

        return loss


    def training_step(self, batch, batch_idx):
        imgs, y = batch

        # if torch.cuda.is_available():
        #     loss = Variable(torch.zeros(1).cuda(), requires_grad=True)
        # else:
        #     loss = Variable(torch.zeros(1), requires_grad=True)\
        loss = None

        # iterate all images in the mini batch
        for img in imgs:
            if loss:
                loss += self.calculate_loss(img)
            else:
                loss = self.calculate_loss(img)
        tensorboard_logs = {'train_loss': loss}

        return {'loss': loss, 'log': tensorboard_logs}


    def validation_step(self, batch, batch_nb):
        imgs, y = batch

        # if torch.cuda.is_available():
        #     loss = Variable(torch.zeros(1).cuda(), requires_grad=True)
        # else:
        #     loss = Variable(torch.zeros(1), requires_grad=True)
        loss = None

        # iterate all images in the mini batch
        for img in imgs:
            if loss:
                loss += self.calculate_loss(img)
            else:
                loss = self.calculate_loss(img)
        tensorboard_logs = {'val_loss': loss}
        return {'loss': loss, 'val_loss': loss ,'log': tensorboard_logs}


    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

In [13]:
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import os
from pytorch_lightning.callbacks import ModelCheckpoint

In [None]:
print('load dataset')

dataset = MNIST(os.getcwd(), download=True, transform=transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
]))

print('split dataset')
train, val = random_split(dataset, [55000, 5000])

print('load GLOM model')
glom = LightningGLOM(
    dim=256,         # dimension
    levels=6,        # number of levels
    image_size=256,  # image size
    patch_size=16,   # patch size
    img_channels=1
)
print('model loaded!')

# saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath='.',
    filename='mnist-{epoch:02d}-{val_loss:.2f}',
    save_top_k=3,
    mode='min',
)

gpus = torch.cuda.device_count()
trainer = pl.Trainer(auto_scale_batch_size='power', gpus=gpus, deterministic=True, max_epochs=5)
print('start training')
trainer.fit(glom, DataLoader(train, batch_size=8, num_workers=2), DataLoader(val, batch_size=8, num_workers=2))

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type       | Params
-------------------------------------------------
0 | glom              | GLOM       | 5.9 M 
1 | patches_to_images | Sequential | 65.8 K
-------------------------------------------------
6.0 M     Trainable params
0         Non-trainable params
6.0 M     Total params
23.920    Total estimated model params size (MB)


load dataset
split dataset
load GLOM model
model loaded!
start training


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

Please use self.log(...) inside the lightningModule instead.
# log on a step or aggregate epoch metric to the logger and/or progress bar (inside LightningModule)
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…