## Context Free Network

by outerskyb

### import

In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import ToTensor, Lambda, Compose
from torchvision.io import read_image
import random
import cv2


In [2]:
from tqdm import trange
import numpy as np
import itertools
from scipy.spatial.distance import cdist

classes = 100
selection = 'max'


P_hat = np.array(list(itertools.permutations(list(range(9)), 9)))
n = P_hat.shape[0]

for i in trange(classes):
    if i==0:
        j = np.random.randint(n)
        P = np.array(P_hat[j]).reshape([1,-1])
    else:
        P = np.concatenate([P,P_hat[j].reshape([1,-1])],axis=0)
    
    P_hat = np.delete(P_hat,j,axis=0)
    D = cdist(P,P_hat, metric='hamming').mean(axis=0).flatten()
    
    if selection=='max':
        j = D.argmax()
    else:
        m = int(D.shape[0]/2)
        S = D.argsort()
        j = S[np.random.randint(m-10,m+10)]
    
        
print(P)


100%|██████████| 100/100 [00:30<00:00,  3.25it/s]

[[5 0 7 4 8 2 3 6 1]
 [0 1 2 3 4 5 6 7 8]
 [1 2 0 5 3 4 7 8 6]
 [2 3 1 0 5 6 8 4 7]
 [3 4 5 6 7 8 0 1 2]
 [4 5 3 8 6 7 1 2 0]
 [6 7 8 1 0 3 2 5 4]
 [7 8 6 2 1 0 4 3 5]
 [8 6 4 7 2 1 5 0 3]
 [0 1 2 3 6 8 5 7 4]
 [1 0 3 6 2 4 7 8 5]
 [2 3 0 7 8 1 4 5 6]
 [3 2 4 1 5 6 8 0 7]
 [4 6 1 0 7 5 2 3 8]
 [5 7 6 8 4 2 0 1 3]
 [8 4 5 2 0 7 3 6 1]
 [6 8 7 5 3 0 1 4 2]
 [7 5 8 4 1 3 6 2 0]
 [0 1 3 2 4 6 7 5 8]
 [1 0 2 8 7 5 4 6 3]
 [2 3 1 6 0 7 8 4 5]
 [3 2 0 7 6 4 5 8 1]
 [4 6 5 1 8 0 3 7 2]
 [5 4 6 3 2 8 1 0 7]
 [6 7 8 4 5 1 2 3 0]
 [7 8 4 5 1 3 0 2 6]
 [8 5 7 0 3 2 6 1 4]
 [0 1 3 7 2 5 4 6 8]
 [1 0 4 2 3 7 8 5 6]
 [2 3 1 8 7 6 5 0 4]
 [3 4 7 0 6 1 2 8 5]
 [4 6 2 5 0 8 7 1 3]
 [8 7 0 6 5 4 1 3 2]
 [5 2 6 4 8 0 3 7 1]
 [6 5 8 1 4 3 0 2 7]
 [7 8 5 3 1 2 6 4 0]
 [0 1 2 4 6 3 8 7 5]
 [1 0 3 2 7 4 5 8 6]
 [2 3 1 0 5 8 7 6 4]
 [3 2 0 6 4 7 1 5 8]
 [4 5 6 8 0 1 3 2 7]
 [5 4 7 1 8 6 0 3 2]
 [6 7 8 3 1 5 2 4 0]
 [7 8 4 5 2 0 6 1 3]
 [8 6 5 7 3 2 4 0 1]
 [0 1 2 6 3 7 4 5 8]
 [1 0 3 5 2 6 7 8 4]
 [2 3 4 1 8 0




### cuda check

In [3]:

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")


Using cuda device


### dataloading

In [4]:
import os

IMAGE_DIR = './VOCdevkit/VOC2012/JPEGImages'

In [5]:
class Voc2012Dataset(Dataset):
    def __init__(self,img_dir,permutation_set,transform = None ):
        self.img_paths = os.listdir(img_dir)
        self.transform = transform
        self.permutation_set = permutation_set
        return

    def __len__(self):
        return len(self.img_paths)*69
    
    def __getitem__(self,idx):
        img_path = os.path.join(IMAGE_DIR, self.img_paths[random.randint(0,len(self.img_paths)-1)])
        mat = cv2.imread(img_path)
        mat = cv2.resize(mat,(225,225))
        images = list()
        for i in range(3):
            for j in range(3):
                rx = random.randint(0,10)
                ry = random.randint(0,10)
                roi = mat[i*75+ry:i*75+ry+64,j*75+rx:j*75+rx+64,:]
                images.append(torch.Tensor(cv2.resize(roi,(75,75))).permute(2,0,1).to(device))
        
        permutation = random.randint(0,len(self.permutation_set)-1)
        
        out = list()
        for i in range(0,9):
            out.append(images[self.permutation_set[permutation][i]].to(device))
        return out, torch.Tensor([permutation]).to(device)#torch.nn.functional.one_hot(torch.Tensor(permutation),num_classes=len(self.permutation_set-1)).to(device)


In [6]:
dataset = Voc2012Dataset(IMAGE_DIR,P)
dataloader = DataLoader(dataset,batch_size=128,shuffle=True)

### Model Defining

In [7]:

class CFN(nn.Module):
    def __init__(self):
        super(CFN,self).__init__()
        
        self.siames = nn.Sequential(
            nn.Conv2d(3,96,(11,11),2,0),
            nn.ReLU(),
            nn.MaxPool2d((3,3),2),
            nn.LocalResponseNorm(5,0.0001,0.75,2),
            nn.Conv2d(96,256,(5,5),1,2),
            nn.ReLU(),
            nn.MaxPool2d((3,3),2),
            nn.LocalResponseNorm(5,0.0001,0.75,2),
            nn.Conv2d(256,384,(3,3),1,1),
            nn.ReLU(),
            nn.Conv2d(384,384,(3,3),1,1),
            nn.ReLU(),
            nn.Conv2d(384,256,(3,3),1,1),
            nn.ReLU(),
            nn.MaxPool2d((3,3),2,0),
            nn.Flatten(),
            nn.Linear(3*3*256,512)
        )
        
        self.flatten = nn.Flatten()

        self.fc = nn.Sequential(
            nn.Linear(4608,4096),
            nn.ReLU(),
            nn.Linear(4096,100),
            nn.Softmax()
        )
        
        return
    
    def forward(self, x):
        y = torch.Tensor().to(device)
        for idx in range(0,9):
            y = torch.cat((y,self.siames(x[idx])),1)
        representation = self.flatten(y)
        out = self.fc(representation)
        return out

In [8]:
from tqdm import tqdm

In [9]:


model = CFN().to(device)
loss_fn = nn.CrossEntropyLoss()
learning_rate = 0.2
num_epoches = 100
optimizier = torch.optim.Adam(model.parameters(),lr=learning_rate)
size = len(dataloader.dataset)



In [15]:
for epoch in range(num_epoches):
    pgbar = tqdm(enumerate(dataloader))
    pgbar.set_description(f"Epoch {epoch}/{num_epoches}")
    for batch_idx, (images,y) in pgbar:
        
        pred = model(images)
        #y = nn.functional.one_hot(y,num_classes=100).to(device)
        loss = loss_fn(pred,y.type(torch.LongTensor).to(device).squeeze())
        
        optimizier.zero_grad()
        loss.backward()
        optimizier.step()
        if batch_idx % 10 == 0:
            loss, current = loss.item(), batch_idx 
            pgbar.set_postfix_str(f"epoch: {epoch} loss: {loss:>7f} batch: [{batch_idx:>5d}/{size//128:>5d}] ")
            

  input = module(input)
Epoch 0/100: : 496it [06:58,  1.23it/s, epoch: 0 loss: 4.614395 batch: [  490/ 9231] ]