In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.backends.cudnn as cudnn

import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets

In [None]:
PATH="datasets"

In [None]:
def imgs_to_SoP(imgs):
    shape = imgs.shape
    height = shape[-2]
    width = shape[-1]
    batch = shape[0]
    
    idx = torch.nonzero(torch.ones(height,width)).float()
    idx[:,0] = idx[:,0]/height - 0.5
    idx[:,1] = idx[:,1]/width - 0.5
    idx = idx.T
    
    flat_imgs = imgs.flatten(start_dim=len(shape)-2)
    SoP = torch.cat((flat_imgs,idx.repeat(batch,1,1)),dim=1)
    
    return SoP

In [None]:
cuda = True
batch_size = 64
train_dataset = datasets.CIFAR10(PATH, train=True, download=True,
    transform=transforms.Compose([
        # Set of transformations to apply, we have as input a PIL image (Python 
        #Image Library). Refer to the PyTorch documentation in the torchvision.transforms package
        transforms.ToTensor (), # Transform the PIL image to a torch.Tensor
        transforms.Normalize((0.491, 0.482, 0.447), (0.202, 0.199, 0.201)),
        #transforms.Lambda(imgs_to_SoP) #on l'utilise dans le reseau
        
        #transforms.RandomCrop(28),
        #transforms.Pad(2)
        #transforms.RandomHorizontalFlip(p=0.5)
    ]))

train_loader = torch.utils.data.DataLoader(train_dataset,
                        batch_size=batch_size, shuffle=True, pin_memory=cuda, num_workers=2)

Files already downloaded and verified


In [None]:
X,y = next(iter(train_loader))
print(X.shape)
print(imgs_to_SoP(X).shape)

torch.Size([64, 3, 32, 32])
torch.Size([64, 5, 1024])


In [None]:
class PointReducer(nn.Module):
    def __init__(self,in_chan,out_chan,kernel_size=2,stride = 2):
        super().__init__()
        self.conv2d = nn.Conv2d(in_chan,out_chan,kernel_size,stride)
    def forward(self,input):
        return self.conv2d(input)

In [None]:
class Model(nn.Module):
    def __init__(self,res):
        super().__init__()
        self.res = res #sequence of +(Point reducer, context cluster block)

    def imgs_to_SoP(self,imgs):
        """
        Transform a batch of images to a bacth of sets of points

        imgs : torch.Tensor([batch,chan,height,width])
        return : torch.Tensor([batch,chan+2,height*width])
        """
        shape = imgs.shape
        height = shape[-2]
        width = shape[-1]
        batch = shape[0]
        
        idx = torch.nonzero(torch.ones(height,width)).float()
        idx[:,0] = idx[:,0]/height - 0.5
        idx[:,1] = idx[:,1]/width - 0.5
        idx = idx.T
        
        flat_imgs = imgs.flatten(start_dim=len(shape)-2)
        SoP = torch.cat((flat_imgs,idx.repeat(batch,1,1)),dim=1)
        
        return SoP

    def forward(self,input):
        SoP = self.imgs_to_SoP(input)
        pred = self.res

        return pred