# Pattern Grouping Model

## Architecture

Three components 

1. Resnet18 to extract visual features.
2. UNet to extract structural features (more on that later)
3. Positional attributes

These 3 features are combined into a 32D feature vector. The architecture is trained on triplets sampled from annotated graphics using max margin loss. **The three components are merged using a final linear layer**.

## What do I want to study here?

1. Why is there a significant gap in loss over train and validation triplets?
2. What are the individual components doing?


In [None]:
from vectorrvnn.interfaces import *
from vectorrvnn.trainutils import *
from vectorrvnn.data import *
from vectorrvnn.utils import *
from tqdm import tqdm
import torch
import matplotlib.pyplot as plt
import numpy as np
from functools import partial
from copy import deepcopy
from sklearn.manifold import TSNE

opts = Options().parse(testing=[
    '--batch_size', '32',
    '--checkpoints_dir', '../results',
    '--dataroot', '../data/All',
    '--embedding_size', '32', 
    '--load_ckpt', 'pattern_oneof/best_0-770.pth',                          
    '--modelcls', 'PatternGrouping',
    '--name', 'pattern_oneof',
    '--structure_embedding_size', '8',
    '--samplercls', 'DiscriminativeSampler',
    '--device', 'cuda:1'
])


In [None]:
model = buildModel(opts)
data = buildData(opts)

## What are the features?

In [None]:

def plotVisImage (im) : 
    nodeType = ['ref', 'plus', 'minus']
    print(nodeType)
    ims = []
    for ntype, im_ in zip(nodeType, im) :
        numpyIm = im_.detach().cpu().numpy()
        numpyIm = np.transpose(numpyIm, (1, 2, 0))
        ims.append(numpyIm)
    
    fig, ax = plt.subplots(1, 1, dpi=150)
    ax.imshow(np.concatenate(ims, 1))
    plt.show()
        
trainData, valData, trainDataLoader, valDataLoader = data

for trainBatch in trainDataLoader : 
    break
    
for valBatch in valDataLoader :
    break
    
tripletviscallback = TripletVisCallback()

trainTripletImage = tripletviscallback.visualized_image(trainBatch, dict(mask=None), False)
valTripletImage   = tripletviscallback.visualized_image(valBatch  , dict(mask=None), True)

print("Plotting train triplet")
plotVisImage(trainTripletImage)
print("Plotting val triplet")
plotVisImage(valTripletImage)

1. The resnet18 branch processes the last column, just the path/group
2. The UNet branch processes the middle column. (3, 256, 256) -- UNet --> (8, 256, 256). The output of the UNet is weighted averaged by the bitmap (the first column) to produce an 8D vector. This branch is supposed to capture the global context of the graphic from the perspective of the path/group.
3. The positional attribute is just the bounding box normalized by document's coordinates

## What are hard examples among validation triplets?

I'll accumulate all the validation triplets for which ||ref - minus|| < ||ref - plus|| and visually see if I can spot any pattern.

In [None]:
interface = TripletInterface(opts, model, trainData, valData)

rets = [] 
model.eval()
valDataLoader.reset()
with torch.no_grad() :
    for batch in tqdm(valDataLoader) : 
        rets.append(interface.forward(batch))
        
print('loss = ', avg(map(lambda r : r['loss'], rets)))
print('% hard triplets = ', avg(map(lambda r : r['hardpct'], rets)))

In [None]:
allValTriplets = aggregateDict(list(valDataLoader), torch.cat)
allValRets = aggregateDict(rets, torch.cat, keys=[('dplus',), ('dminus',)])

In [None]:
mask = allValRets['dminus'] <= allValRets['dplus']
print("# hard triplets: ", mask.sum())
for i in range((int(mask.sum()) // 5)) : 
    print(i)
    hardTriplet = tripletviscallback.visualized_image(
        allValTriplets, 
        dict(mask=mask), 
        True,
        i=i
    )
    plotVisImage(hardTriplet)

### What is the problem?

1. Data bugs: Look at images 28, 22, 17, 9, 10, 36. There is definitely something wrong with some of the groups. But okay, out of 42 * 3 = 126 nodes, around 6 nodes display this problem. In my view, this is not substantial.
2. Size: Often hard triplets contain paths which are quite small relative to canvas. Examples: 42, 39, 32, 35, 29, 28, 26, 25, 10. This can be solved by giving crops/zooms as features too. Why have I not done this already? Because of bugs in pathfinder. Pathfinder *only* accepts a graphic with a viewbox that coincides with the origin. If there is any path outside the viewbox, pathfinder crashes. The temporary solution is to make a big raster and then crop the required path. Will do that.
3. Sampling: This is the most important issue. Currently, I sample triplets such that ||REF - PLUS|| < ||REF - MINUS|| where the distance is the length of the path in the tree. But one can easily construct examples where ||PLUS - MINUS|| < ||REF - PLUS|| < ||REF - MINUS||. In such a case, we just confuse the network. I think a majority of the hard triplets above fit into this last case. Examples: 41, 40, 39, 38, 37, 35, 27, 24, 0, 1, 2. Again, this is easy to solve by sampling triplets under the criteria: ||REF - PLUS|| < ||PLUS - MINUS|| and ||REF - PLUS|| < ||REF - MINUS||

With that I think I have addressed the question of why the validation loss is greater than training loss. Now let's see what the different branches do!

## What are the individual components doing?

Here, I'll study:

1. What is the output of the UNet?
2. Have I correctly computed the bbox features?
3. What if I used just one of the components to perform clustering? 

In [None]:
# Here I'm visualizing the norm of the final 8 channel image
# Don't get much information.
with torch.no_grad() :
    for i in [10, 13, 14, 17] : 
        fig, axes = plt.subplots(1, 2, dpi=70)
        im1 = allValTriplets['ref']['im'][i].unsqueeze(0)
        unetout = model.unet(im1.to(opts.device)).squeeze()
        im1 = im1.squeeze().detach().cpu().numpy().transpose(1,2,0)
        im2 = (unetout * unetout).sum(0).detach().cpu().numpy()
        axes[0].imshow(im1)
        axes[1].imshow(im2)
        plt.show()


Really not sure what I learned from this visualization. Maybe the TSNE plots of the embedding will clear things up later on.

In [None]:
def drawBBox (bbox, ax) : 
    x, y, w, h = bbox.view(-1).tolist()
    X = x + w
    Y = y + h
    y = 1 - y
    Y = 1 - Y
    ax.set_aspect('equal')
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.plot([x, X], [y, y], color='blue')
    ax.plot([x, X], [Y, Y], color='blue')
    ax.plot([x, x], [y, Y], color='blue')
    ax.plot([X, X], [y, Y], color='blue')

for i in range(10) : 
    fig, axes = plt.subplots(1, 2, dpi=70)
    axes[0].imshow(allValTriplets['ref']['whole'][i].detach().cpu().numpy().transpose(1, 2, 0))
    drawBBox(allValTriplets['ref']['bbox'][i], axes[1])
    plt.show()


The top left and bottom right bounding box coordinates are processed by a linear layer. Obviously, such a layer can't include the interaction between the x and y dims, i.e. the area attributes. I think that I should pass those as well.

In [None]:
class UNetEmbedder(PatternGrouping) :
    
    def embedding (self, node, **kwargs) : 
        im = node['im']
        bitmap = node['bitmap']
        f2 = (self.unet(im) * bitmap).sum(dim=(2, 3)) / (bitmap.sum() + 1e-6)
        return f2
    
class ResnetEmbedder (PatternGrouping) :
    
    def embedding (self, node, **kwargs) : 
        whole = node['whole']
        f1 = self.conv(whole)
        return f1
        
class BBoxEmbedder (PatternGrouping) : 
    
    def embedding(self, node, **kwargs) :
        f3 = node['bbox']
        return f3
    
data2Vis = list(valData)[:20]

def visualizeEmbeddings (modelcls) : 

    embedder = deepcopy(model)
    embedder.__class__ = modelcls

    with torch.no_grad() :
        for v in data2Vis : 
            d = deepcopy(v)
            t = embedder.greedyTree(d)
            d.initTree(t)
            d.initGraphic(d.doc)
            pathSets = [d.nodes[n]['pathSet'] for n in d.nodes]
            features = [PatternGrouping.nodeFeatures(d, ps, opts) for ps in pathSets]
            tensorApply(
                features,
                lambda x : x.to(opts.device).unsqueeze(0)
            )
            embeddings = np.concatenate([embedder.embedding(f).detach().cpu().numpy() for f in features])
            fig, axes = plt.subplots(1, 1, dpi=100)
            treeAxisFromGraph(d, axes)
            plt.show()

### Hierachies for UNetEmbedder

If I look at the trees produced, I don't think there is any difference between what is produced here and what would be produced had I just given the bitmap of the graphic as input. However, in some examples, color seems to be taken into account and such as the Hat, the Wing Lady, the Wine Glasses

In [None]:
visualizeEmbeddings(UNetEmbedder)

### Hierachies for ResnetEmbedder

The ResNet embedder, on its own is quite good actually. Differentiation between stroke and fill isn't quite accurate. It can detect symmetry such as for the flag of South Korea.

In [None]:
visualizeEmbeddings(ResnetEmbedder)

### Hierachies and TSNE plots for BBoxEmbedder

Not bad actually. It's just the bounding box. So not going to say much!

In [None]:
visualizeEmbeddings(BBoxEmbedder)

### Complete model

In [None]:
visualizeEmbeddings(PatternGrouping)