In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
from einops import rearrange

import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets

In [None]:
#from timm.models.layers import DropPath

In [2]:
PATH="datasets"

In [288]:
cuda = True
batch_size = 64
train_dataset = datasets.CIFAR10(PATH, train=True, download=True,
    transform=transforms.Compose([
        # Set of transformations to apply, we have as input a PIL image (Python 
        #Image Library). Refer to the PyTorch documentation in the torchvision.transforms package
        transforms.ToTensor (), # Transform the PIL image to a torch.Tensor
        transforms.Normalize((0.491, 0.482, 0.447), (0.202, 0.199, 0.201)),
        #transforms.Lambda(imgs_to_SoP) #on l'utilise dans le reseau
        
        #transforms.RandomCrop(28),
        #transforms.Pad(2)
        #transforms.RandomHorizontalFlip(p=0.5)
    ]))

train_loader = torch.utils.data.DataLoader(train_dataset,
                        batch_size=batch_size, shuffle=True, pin_memory=cuda, num_workers=2)

Files already downloaded and verified


In [34]:
def imgs_to_SoP(imgs):
    shape = imgs.shape
    height = shape[-2]
    width = shape[-1]
    batch = shape[0]
    
    ind_w = torch.arange(width)/(width-1)-0.5
    ind_h = torch.arange(height)/(height-1)-0.5
    ind = torch.stack(torch.meshgrid(ind_w,ind_h,indexing = 'ij'),dim = -1).reshape(2,height,width)
    
    batch_ind = ind.repeat(batch,1,1,1)
    #flat_imgs = imgs.flatten(start_dim=len(shape)-2) #laisser en b c h w ?
    
    SoP = torch.cat((imgs,batch_ind),dim=1)
    
    return SoP

In [32]:
height = 64
width = 32
ind_w = torch.arange(width)/(width-1)-0.5
ind_h = torch.arange(height)/(height-1)-0.5
ind = torch.stack(torch.meshgrid(ind_w,ind_h,indexing = 'ij'),dim = -1).reshape(2,height,width)
ind.repeat(20,1,1,1).shape


torch.Size([20, 2, 64, 32])

In [36]:
#X,y = next(iter(train_loader))
X = torch.randn(32,3,224,122)
print(X.shape)
print(imgs_to_SoP(X).shape)

torch.Size([32, 3, 224, 122])
torch.Size([32, 5, 224, 122])


In [39]:
class PointReducer(nn.Module):
    def __init__(self,in_chan,out_chan,kernel_size=2,stride = 2):
        super().__init__()
        self.conv2d = nn.Conv2d(in_chan,out_chan,kernel_size,stride)
    def forward(self,input):
        return self.conv2d(input)

In [89]:
def pairwise_cosine_sim(x1, x2):
    x1_norm = F.normalize(x1, dim=-1) 
    x2_norm = F.normalize(x2, dim=-1)
    
    sim = x1_norm @ x2_norm.transpose(-2,-1) #to compute for each pair in the batch
    
    return sim

In [214]:
class Cluster(nn.Module):
    def __init__(self,in_channels,out_channels,heads,head_dim,proposal_w=2,proposal_h=2,fold_w=2,fold_h=2):
        super().__init__()
        
        self.heads = heads
        self.head_dim = head_dim
        self.fold_w = fold_w
        self.fold_h = fold_h
        
        self.fc1 = nn.Conv2d(in_channels,heads*head_dim,kernel_size = 1)
            #heads*head_dim -> we place the heads in the batch later
        self.fcv = nn.Conv2d(in_channels,heads*head_dim,kernel_size = 1)
        self.fc2 = nn.Conv2d(heads*head_dim,out_channels,kernel_size = 1)
        self.center_proposal = nn.AdaptiveAvgPool2d((proposal_w,proposal_h))
        
        self.alpha = nn.Parameter(torch.ones(1)) #initialisation bizarre
        self.beta = nn.Parameter(torch.zeros(1)) #initialisation bizarre
        
    def forward(self,x):
        """
        x : [b,c,h,w]
        """
        _,_,h,w = x.shape 
        
        val = self.fcv(x)
        x = self.fc1(x)

        # separating heads
        x = rearrange(x, "b (e c) h w -> (b e) c h w", e=self.heads)
        val = rearrange(val, "b (e c) h w -> (b e) c h w", e=self.heads) 
        
        #splitting patches 
        if self.fold_w>1 and self.fold_h>1:
            assert w%self.fold_w==0 and h%self.fold_h==0
            x = rearrange(x, "b c (f1 h) (f2 w) -> (b f1 f2) c h w", f1=self.fold_h, f2=self.fold_w)
            val = rearrange(val, "b c (f1 h) (f2 w) -> (b f1 f2) c h w", f1=self.fold_h, f2=self.fold_w)
        _,_,h,w = x.shape
        #computing cluster centers
        cluster_centers = self.center_proposal(x) #(b,c,Cw,Ch)
        center_values = rearrange(self.center_proposal(val) , 'b c h w -> b (h w) c') #(b,Cw*Ch,c) = (b,m,c)
        
        b,c,_,_ = cluster_centers.shape

        #computing similarity
        sim = torch.sigmoid(self.beta + self.alpha*pairwise_cosine_sim(cluster_centers.reshape(b,c,-1).permute(0,2,1), x.reshape(b,c,-1).permute(0,2,1))) #(b,Cw*Ch,h*w) = (b,m,n)
        sim_max, sim_argmax = sim.max(dim = 1, keepdim = True)
        
        #assigning a cluster to each point
        mask = torch.zeros_like(sim) #(b,m,n)
        mask.scatter_(1, sim_argmax, 1.)
        sim= sim*mask
        
        #computing aggregated feature
        val = rearrange(val, 'b c h w -> b (h w) c')
        out = ((val.unsqueeze(dim=1)*sim.unsqueeze(dim=-1)).sum(dim=2) + center_values)/(mask.sum(dim=-1,keepdim=True)+ 1)
        
        out = (out.unsqueeze(dim=2)*sim.unsqueeze(dim=-1)).sum(dim=1)
        out = rearrange(out, "b (h w) c -> b c h w", h=h)
        
        #recovering splitted patches
        if self.fold_w>1 and self.fold_h>1:
            out = rearrange(out, "(b f1 f2) c h w -> b c (f1 h) (f2 w)", f1=self.fold_h, f2=self.fold_w)
        
        #regrouping heads
        out = rearrange(out, "(b e) c h w -> b (e c) h w", e=self.heads)
        out = self.fc2(out)

        return out

In [215]:
class MLP(nn.Module):
    def __init__(self,in_channels,hidden_channels,out_channels,act,dropout=0):
        super().__init__()
        
        layers = [nn.Conv2d(in_channels,hidden_channels,kernel_size=1),act(),\
                  nn.Conv2d(hidden_channels,out_channels,kernel_size=1),act()]
        if dropout:
            layers.append(nn.Dropout(dropout))
        self.net = nn.Sequential(*layers)
        
    def forward(self,x):
        out = self.net(x)
        return out

In [223]:
#add droppath ???
class ClusterBlock(nn.Module):
    
    def __init__(self,in_channels,act=nn.GELU,mlp_ratio=4,dropout=0,proposal_w=2,proposal_h=2,fold_w=2,fold_h=2,heads=4,head_dim=16):
        super().__init__()
        
        #(in_channels,out_channels,heads,head_dim,proposal_w=2,proposal_h=2,fold_w=2,fold_h=2))
        self.cluster = Cluster(in_channels,in_channels,heads,head_dim,proposal_w,proposal_h,fold_w,fold_h)
        hidden_dim = int(mlp_ratio*in_channels)
        self.mlp = MLP(in_channels,hidden_dim,in_channels,act,dropout=dropout)
        
    def forward(self,x):
        x = x + self.cluster(x) #skip conn
        out = x + self.mlp(x) #skip conn
        return out

In [264]:
class BasicBlock(nn.Module):
    def __init__(self,in_channels,out_channels):
        super().__init__()
        self.point_red = PointReducer(in_channels,out_channels)
        self.cluster_b = ClusterBlock(out_channels)
    
    def forward(self,x):
        x = self.point_red(x)
        x = self.cluster_b(x)
        return x

In [285]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.res = None #sequence of +(Point reducer, context cluster block)
        self.old_shape = None

    def imgs_to_SoP(self,imgs):
        """
        Transform a batch of images to a bacth of sets of points

        imgs : torch.Tensor([batch,chan,height,width])
        return : torch.Tensor([batch,chan+2,height*width])
        """
        shape = imgs.shape
        if self.old_shape != shape:
            self.old_shape = shape
            height = shape[-2]
            width = shape[-1]
            batch = shape[0]
            
            ind_w = torch.arange(width)/(width-1)-0.5
            ind_h = torch.arange(height)/(height-1)-0.5
            ind = torch.stack(torch.meshgrid(ind_w,ind_h,indexing = 'ij'),dim = -1).reshape(2,height,width)

            self.batch_ind = ind.repeat(batch,1,1,1)
        
        #flat_imgs = imgs.flatten(start_dim=len(shape)-2)
        SoP = torch.cat((imgs,self.batch_ind),dim=1) #might need to clone
        
        return SoP

    def forward(self,input):
        SoP = self.imgs_to_SoP(input)

        return SoP

# Tests

In [265]:
pr1 = PointReducer(5,16)
cb1 = ClusterBlock(16)

pr2 = PointReducer(16,32)
cb2 = ClusterBlock(32)

pr3 = PointReducer(32,64)
cb3 = ClusterBlock(64)

pr4 = PointReducer(64,128)
cb4 = ClusterBlock(128)

clf = nn.Linear(128,10)

In [266]:
X = torch.randn(32,3,64,128)
X = imgs_to_SoP(X)
X.shape

torch.Size([32, 5, 64, 128])

In [267]:
X = pr1(X)
print(X.shape)
X = cb1(X)
print(X.shape)

torch.Size([32, 16, 32, 64])
torch.Size([32, 16, 32, 64])


In [268]:
X = pr2(X)
print(X.shape)
X = cb2(X)
print(X.shape)

torch.Size([32, 32, 16, 32])
torch.Size([32, 32, 16, 32])


In [269]:
X = pr3(X)
print(X.shape)
X = cb3(X)
print(X.shape)

torch.Size([32, 64, 8, 16])
torch.Size([32, 64, 8, 16])


In [270]:
X = pr4(X)
print(X.shape)
X = cb4(X)
print(X.shape)

torch.Size([32, 128, 4, 8])
torch.Size([32, 128, 4, 8])


In [271]:
X = torch.mean(X, dim = (2,3))
X.shape

torch.Size([32, 128])

In [272]:
y = clf(X)
y.shape

torch.Size([32, 10])

# Tests Basic Blocks

In [274]:
b1 = BasicBlock(5,16)
b2 = BasicBlock(16,32)
b3 = BasicBlock(32,64)
b4 = BasicBlock(64,128)
clf = nn.Linear(128,10)

In [275]:
X = torch.randn(32,3,64,128)
X = imgs_to_SoP(X)

X = b1(X)
X = b2(X)
X = b3(X)
X = b4(X)

X = torch.mean(X, dim = (2,3))

y = clf(X)
print(y.shape)

torch.Size([32, 10])
