In [260]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
from einops import rearrange

import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets

In [None]:
PATH="datasets"

In [None]:
def imgs_to_SoP(imgs):
    shape = imgs.shape
    height = shape[-2]
    width = shape[-1]
    batch = shape[0]
    
    idx = torch.nonzero(torch.ones(height,width)).float()
    idx[:,0] = idx[:,0]/height - 0.5
    idx[:,1] = idx[:,1]/width - 0.5
    idx = idx.T
    
    flat_imgs = imgs.flatten(start_dim=len(shape)-2) #laisser en b c h w ?
    SoP = torch.cat((flat_imgs,idx.repeat(batch,1,1)),dim=1)
    
    return SoP

In [None]:
cuda = True
batch_size = 64
train_dataset = datasets.CIFAR10(PATH, train=True, download=True,
    transform=transforms.Compose([
        # Set of transformations to apply, we have as input a PIL image (Python 
        #Image Library). Refer to the PyTorch documentation in the torchvision.transforms package
        transforms.ToTensor (), # Transform the PIL image to a torch.Tensor
        transforms.Normalize((0.491, 0.482, 0.447), (0.202, 0.199, 0.201)),
        #transforms.Lambda(imgs_to_SoP) #on l'utilise dans le reseau
        
        #transforms.RandomCrop(28),
        #transforms.Pad(2)
        #transforms.RandomHorizontalFlip(p=0.5)
    ]))

train_loader = torch.utils.data.DataLoader(train_dataset,
                        batch_size=batch_size, shuffle=True, pin_memory=cuda, num_workers=2)

Files already downloaded and verified


In [None]:
X,y = next(iter(train_loader))
print(X.shape)
print(imgs_to_SoP(X).shape)

torch.Size([64, 3, 32, 32])
torch.Size([64, 5, 1024])


In [None]:
class PointReducer(nn.Module):
    def __init__(self,in_chan,out_chan,kernel_size=2,stride = 2):
        super().__init__()
        self.conv2d = nn.Conv2d(in_chan,out_chan,kernel_size,stride)
    def forward(self,input):
        return self.conv2d(input)

In [53]:
def pairwise_cosine_sim(x1, x2):
    x1_norm = F.normalize(x1, dim=-1) 
    x2_norm = F.normalize(x2, dim=-1)
    
    sim = x1_norm @ x2_norm.transpose(-2,-1) #to compute for each pair in the batch
    
    return sim

In [None]:
class Cluster(nn.Module):
    def __init__(self,in_channels,out_channels,heads,head_dim,proposal_w=2,proposal_h=2,fold_w=2,fold_h=2):
        super().__init__()
        
        self.heads = heads
        self.head_dim = head_dim
        self.fold_w = fold_w
        self.fold_h = fold_h
        
        self.fc1 = nn.Conv2d(in_channels,heads*head_dim,kernel_size = 1)
            #heads*head_dim -> we place the heads in the batch later
        self.fcv = nn.Conv2d(in_channels,heads*head_dim,kernel_size = 1)
        self.fc2 = nn.Conv2d(heads*head_dim,out_channels,kernel_size = 1)
        self.center_proposal = nn.AdaptiveAvgPool2d((proposal_w,proposal_h))
        
        self.apha = nn.Parameter(torch.ones(1)) #initialisation bizarre
        self.beta = nn.Parameter(torch.zeros(1)) #initialisation bizarre
        
    def forward(self,x):
        """
        x : [b,c,w,h]
        """
        _,_,w,h = x.shape 
        
        val = self.fcv(x)
        x = self.fc1(x)
        
        # separating heads
        x = x.reshape(-1,self.head_dim,w,h) # b (e c) w h -> (b e) c w h
        val = val.reshape(-1,self.head_dim,w,h) # b (e c) w h -> (b e) c w h 
        
        #splitting patches 
        if self.fold_w>1 and self.fold_h>1:
            assert w%self.fold_w==0 and h%self.fold_h==0
            x = rearrange(x, "b c (f1 w) (f2 h) -> (b f1 f2) c w h", f1=self.fold_w, f2=self.fold_h)
            val = rearrange(val, "b c (f1 w) (f2 h) -> (b f1 f2) c w h", f1=self.fold_w, f2=self.fold_h)
            
        #computing cluster centers
        cluster_centers = self.center_proposal(x) #(b,c,Cw,Ch)
        center_values = earrange(self.centers_proposal(value) , 'b c w h -> b (w h) c') #(b,Cw*Ch,c) = (b,m,c)
        
        b,c,_,_ = centers.shape
        
        #computing similarity
        sim = torch.sigmoid(self.beta + self.alpha*pairwise_cos_sim(cluster_centers.reshape(b,c,-1).permute(0,2,1), x.reshape(b,c,-1).permute(0,2,1))) #(b,Cw*Ch,h*w) = (b,m,n)
        sim_max, sim_argmax = sim.max(dim = 1, keepdim = True)
        
        #assigning a cluster to each point
        mask = torch.zeros_like(sim) #(b,m,n)
        mask.scatter_(1, sim_argmax, 1.)
        sim= sim*mask
        
        #computing aggregated feature
        value = rearrange(value, 'b c w h -> b (w h) c')
        out = ((value.unsqueeze(dim=1)*sim.unsqueeze(dim=-1)).sum(dim=2) + center_values)/(mask.sum(dim=-1,keepdim=True)+ 1)
        
        out = (out.unsqueeze(dim=2)*sim.unsqueeze(dim=-1)).sum(dim=1)
        out = rearrange(out, "b (w h) c -> b c w h", w=w)
        
        #recovering splitted patches
        if self.fold_w>1 and self.fold_h>1:
            out = rearrange(out, "(b f1 f2) c w h -> b c (f1 w) (f2 h)", f1=self.fold_w, f2=self.fold_h)
        
        #regrouping heads
        out = rearrange(out, "(b e) c w h -> b (e c) w h", e=self.heads)
        out = self.fc2(out)
        return out

In [None]:
class MLP(nn.Module):
    def __init__(self,in_channels,hidden_channels,out_channels,act,dropout=0):
        super().__init__()
        
        layers = [nn.Conv2d(in_channels,hidden_channels,kernel_size=1),act(),\
                  nn.Conv2d(hidden_channels,out_channels,kernel_size=1),act()]
        if dropout:
            layers.append(nn.Dropout(dropout))
        self.net = nn.Sequential(*layers)
        
    def forward(self,x):
        out = self.net(x)
        return out

In [None]:
#add droppath ???
class ClusterBlock(nn.Module):
    
    def __init__(self,in_channels,act,mlp_ratio=4,dropout=0,proposal_w=2,proposal_h=2,fold_w=2,fold_h=2,heads=4,head_dim=16):
        super().__init__()
        
        #(in_channels,out_channels,heads,head_dim,proposal_w=2,proposal_h=2,fold_w=2,fold_h=2))
        self.cluster = Cluster(in_channels,in_channels,heads,head_dim,proposal_w,proposal_h,fold_w,fold_h)
        hidden_dim = int(mlp_ratio*in_channels)
        self.mlp = MLP(in_channels,hidden_dim,in_channels,act,dropout=dropout)
        
    def forward(self,x):
        x = x + self.cluster(x) #skip conn
        out = x + self.mlp(x) #skip conn
        return out

In [None]:
class Model(nn.Module):
    def __init__(self,res):
        super().__init__()
        self.res = res #sequence of +(Point reducer, context cluster block)
        self.old_shape = None

    def imgs_to_SoP(self,imgs):
        """
        Transform a batch of images to a bacth of sets of points

        imgs : torch.Tensor([batch,chan,height,width])
        return : torch.Tensor([batch,chan+2,height*width])
        """
        shape = imgs.shape
        if self.old_shape != shape:
            self.old_shape = shape
            height = shape[-2]
            width = shape[-1]
            batch = shape[0]
            
            idx = torch.nonzero(torch.ones(height,width)).float()
            idx[:,0] = idx[:,0]/height - 0.5
            idx[:,1] = idx[:,1]/width - 0.5
            idx = idx.T
            
            self.batch_idx = idx.repeat(batch,1,1)
        
        flat_imgs = imgs.flatten(start_dim=len(shape)-2)
        SoP = torch.cat((flat_imgs,self.batch_idx,dim=1) #might need to clone
        
        return SoP

    def forward(self,input):
        SoP = self.imgs_to_SoP(input)
        pred = self.res

        return pred

In [261]:
X_test = torch.randn(32,5,32,32)

In [262]:
fct = nn.Conv2d(5,4*16,kernel_size = 1)

In [266]:
y_test = fct(X_test)
y_test = rearrange(y_test, "b (e c) w h -> (b e) c w h", e=4)
y_test.shape

torch.Size([128, 16, 32, 32])

In [268]:
h_test = rearrange(y_test, "b c (f1 w) (f2 h) -> (b f1 f2) c w h", f1=2, f2=2)

In [269]:
h_test.shape

torch.Size([512, 16, 16, 16])

In [270]:
pool = nn.AdaptiveAvgPool2d((2,2))

In [271]:
pool(h_test).shape

torch.Size([512, 16, 2, 2])

In [272]:
rearrange(pool(h_test) , 'b c w h -> b (w h) c').shape

torch.Size([512, 4, 16])