## Gaussian Pyramid Levels

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# !pip install online_triplet_loss

import os
import random

import cv2
import numpy as np
import matplotlib.pyplot as plt
from online_triplet_loss.losses import *
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import datasets, transforms
import torch.nn as nn
from torch.nn import functional as F
from torch import linalg
from scipy.spatial.distance import pdist
from sklearn.cluster import KMeans
from copy import deepcopy

from collections import defaultdict
from tqdm import tqdm



from torch.utils.data import DataLoader

import torch.optim as optim

from data import get_triplets

## Learning

In [3]:
class Backbone(nn.Module):
    def __init__(self, in_chan, out_dim, num_classes):
        super(Backbone, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels=in_chan, out_channels=16, kernel_size=3, stride=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu1 = nn.ReLU()
        
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1)
        self.bn2 = nn.BatchNorm2d(32)
        self.relu2 = nn.ReLU()
        
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1)
        self.bn3 = nn.BatchNorm2d(64)
        self.relu3 = nn.ReLU()
        
        self.conv4 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.bn4 = nn.BatchNorm2d(128)
        self.relu4 = nn.ReLU()
        self.maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv5 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1)
        self.bn5 = nn.BatchNorm2d(256)
        self.relu5 = nn.ReLU()
        self.maxpool5 = nn.MaxPool2d(kernel_size=2, stride=2)        
        
        self.pooling = nn.AdaptiveAvgPool2d((1, 1))
        
        self.fc = nn.Linear(in_features=256, out_features=out_dim)
        self.relu6 = nn.ReLU()
        
        self.classifier = nn.Linear(out_dim, num_classes)
        
    def forward(self, x):
        x = self.relu1(self.bn1(self.conv1(x)))
        x = self.relu2(self.bn2(self.conv2(x)))
        x = self.relu3(self.bn3(self.conv3(x)))

        x = self.relu4(self.bn4(self.conv4(x)))
        x = self.maxpool4(x)
        
        x = self.relu5(self.bn5(self.conv5(x)))
        x = self.maxpool5(x)

        x = self.pooling(x)
        embed = self.fc(x.squeeze())
        
        x = self.relu6(embed)
        x = self.classifier(x)
        
        return x, embed

In [4]:
class PetDataset(Dataset):
    def __init__(self, flist, transform, labels):
        
        self.flist = flist
        self.transform = transform
        self.labels = np.array(labels).astype("int64")
        assert len(flist) == len(labels)

    def __len__(self):
        return len(self.flist)

    def __getitem__(self, index):
        sample = self.flist[index]

        # read in the image, apply the standard transformation
        img = self.transform(Image.open(sample))

        return img, self.labels[index]

In [5]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [6]:
# use the augmented images
src = "pet/train/output"
flist_full = [os.path.join(src, f) for f in sorted(os.listdir(src))]

In [7]:
"""
Steps in every epoch:

1. Create random labels, create a dataset, dataloader with these labels
2. Do a forward & backward pass, using the feature vectors perform k-means clustering
3. Use cluster assignments as the labels and redefine the dataset and dataloader
4. Go to 2

"""

'\nSteps in every epoch:\n\n1. Create random labels, create a dataset, dataloader with these labels\n2. Do a forward & backward pass, using the feature vectors perform k-means clustering\n3. Use cluster assignments as the labels and redefine the dataset and dataloader\n4. Go to 2\n\n'

In [8]:
NUM_CLASSES = 20
NUM_CLUSTERS = NUM_CLASSES

In [10]:
# cluster the features
def cluster_features(features, num_clusters=5):
    
    cobj = KMeans(n_clusters=num_clusters)
    cobj.fit(features)    
    assignments = cobj.labels_
    
    return assignments 


def get_sizes(labels):
    
    clusters = defaultdict(int)
    for l in labels:
        clusters[l] += 1
    
    return clusters

In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
FEAT_DIM = 512

# declare the network
model = Backbone(in_chan=3, out_dim=FEAT_DIM, num_classes=NUM_CLASSES)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.005)
# print(model)

model = model.to(device)
model.train()

CLUSTER_EVERY = 500
BATCH_SIZE = 16
flist = random.sample(flist_full, CLUSTER_EVERY*BATCH_SIZE)

num_samples = len(flist)

# create random labels
labels = np.random.randint(low=0, high=20, size=num_samples)

# create the dataset and dataloader
NUM_WORKERS = 4
train_dataset = PetDataset(flist, transform, labels)
train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

running_loss = 0.0
num_epochs = 20
for epoch in range(num_epochs):
    
    # get embeddings
    embeds = None
    for idx, (x, y) in enumerate(tqdm(train_loader)):
        
        images_ = x.to(device)
        labels_ = y.to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs, batch_embeds = model(images_)
        
        # Accumulate the embeddings
        batch_embeds = batch_embeds.clone().detach().cpu().numpy()
        if embeds is None:
            embeds = batch_embeds.copy()
        else:
            embeds = np.concatenate([embeds, batch_embeds], axis=0)
            
        loss = criterion(outputs, labels_)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        
        if (idx+1) % CLUSTER_EVERY == 0:
            break

    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / CLUSTER_EVERY:.4f}')
    running_loss = 0.0
        
    # perform the clustering
    labels = cluster_features(embeds, num_clusters=NUM_CLUSTERS)  
    
    cluster_log = get_sizes(labels)
    
    print(f"Size of clusters: {cluster_log}")
    
    # re-define the dataset
    train_dataset = PetDataset(flist, transform, labels)
    train_loader = DataLoader(train_dataset, batch_size = 16, shuffle=True, num_workers=NUM_WORKERS)

100%|█████████▉| 499/500 [01:21<00:00,  6.13it/s]


Epoch [1/20], Loss: 3.0030
Size of clusters: defaultdict(<class 'int'>, {11: 146, 17: 161, 19: 15, 13: 379, 16: 455, 2: 1877, 1: 48, 5: 202, 4: 104, 6: 1729, 15: 1124, 3: 48, 7: 101, 14: 209, 8: 349, 0: 682, 10: 169, 9: 130, 18: 37, 12: 35})


100%|█████████▉| 499/500 [01:20<00:00,  6.21it/s]


Epoch [2/20], Loss: 2.4623
Size of clusters: defaultdict(<class 'int'>, {13: 1065, 5: 765, 8: 171, 6: 1091, 14: 242, 7: 59, 1: 108, 2: 455, 10: 51, 3: 29, 17: 86, 12: 784, 15: 25, 16: 324, 18: 70, 9: 116, 11: 153, 4: 601, 19: 687, 0: 1118})


100%|█████████▉| 499/500 [01:20<00:00,  6.18it/s]


Epoch [3/20], Loss: 2.5926
Size of clusters: defaultdict(<class 'int'>, {17: 96, 10: 21, 2: 213, 0: 66, 15: 81, 7: 404, 3: 59, 14: 415, 18: 171, 8: 189, 6: 426, 16: 7, 4: 447, 19: 552, 12: 1329, 1: 1455, 11: 826, 13: 499, 9: 198, 5: 546})


100%|█████████▉| 499/500 [01:20<00:00,  6.20it/s]


Epoch [4/20], Loss: 2.6187
Size of clusters: defaultdict(<class 'int'>, {4: 295, 14: 572, 6: 741, 16: 737, 18: 409, 7: 148, 12: 834, 19: 304, 5: 12, 2: 676, 11: 632, 0: 898, 8: 487, 17: 212, 10: 70, 1: 154, 9: 365, 3: 177, 13: 127, 15: 150})


100%|█████████▉| 499/500 [01:20<00:00,  6.18it/s]


Epoch [5/20], Loss: 2.7891
Size of clusters: defaultdict(<class 'int'>, {15: 124, 1: 73, 8: 68, 7: 130, 9: 162, 13: 31, 6: 167, 18: 406, 4: 360, 12: 93, 10: 813, 2: 1470, 17: 795, 16: 567, 3: 223, 0: 347, 11: 239, 5: 603, 19: 975, 14: 354})


100%|█████████▉| 499/500 [01:20<00:00,  6.20it/s]


Epoch [6/20], Loss: 2.6650
Size of clusters: defaultdict(<class 'int'>, {11: 49, 6: 74, 1: 78, 18: 92, 5: 34, 19: 87, 17: 40, 10: 109, 2: 239, 3: 21, 13: 67, 8: 257, 14: 282, 12: 191, 16: 673, 4: 732, 7: 267, 9: 906, 0: 2193, 15: 1609})


100%|█████████▉| 499/500 [01:20<00:00,  6.22it/s]


Epoch [7/20], Loss: 2.3402
Size of clusters: defaultdict(<class 'int'>, {12: 128, 1: 96, 18: 142, 8: 291, 13: 42, 2: 103, 10: 86, 9: 40, 5: 8, 14: 98, 16: 67, 3: 163, 17: 57, 6: 375, 15: 237, 0: 1688, 7: 1058, 19: 1456, 4: 1076, 11: 789})


100%|█████████▉| 499/500 [01:17<00:00,  6.46it/s]


Epoch [8/20], Loss: 2.4066
Size of clusters: defaultdict(<class 'int'>, {8: 98, 16: 63, 2: 63, 18: 80, 6: 96, 9: 96, 13: 144, 4: 144, 0: 160, 17: 192, 3: 656, 10: 464, 7: 704, 11: 976, 15: 624, 5: 1248, 19: 448, 14: 624, 1: 384, 12: 736})


100%|█████████▉| 499/500 [01:19<00:00,  6.27it/s]


Epoch [9/20], Loss: 2.7264
Size of clusters: defaultdict(<class 'int'>, {12: 112, 3: 80, 14: 80, 7: 64, 15: 96, 1: 128, 8: 160, 10: 208, 5: 208, 19: 176, 2: 336, 17: 192, 13: 304, 18: 704, 4: 576, 11: 704, 6: 1296, 16: 912, 0: 656, 9: 1008})


100%|█████████▉| 499/500 [01:19<00:00,  6.26it/s]


Epoch [10/20], Loss: 2.7586
Size of clusters: defaultdict(<class 'int'>, {5: 96, 14: 64, 1: 80, 10: 48, 16: 64, 4: 64, 18: 64, 8: 64, 12: 64, 13: 112, 0: 944, 19: 304, 6: 416, 11: 560, 3: 656, 17: 320, 2: 1312, 9: 1872, 7: 512, 15: 384})


100%|█████████▉| 499/500 [01:19<00:00,  6.28it/s]


Epoch [11/20], Loss: 2.5238
Size of clusters: defaultdict(<class 'int'>, {0: 96, 9: 256, 16: 688, 4: 944, 18: 416, 12: 944, 2: 272, 15: 320, 6: 384, 17: 112, 11: 128, 3: 192, 8: 208, 13: 128, 19: 288, 1: 624, 5: 880, 14: 592, 7: 224, 10: 304})


100%|█████████▉| 499/500 [01:19<00:00,  6.25it/s]


Epoch [12/20], Loss: 2.9217
Size of clusters: defaultdict(<class 'int'>, {11: 320, 8: 400, 1: 694, 7: 141, 4: 99, 14: 18, 13: 50, 2: 130, 18: 214, 17: 105, 6: 431, 10: 851, 12: 277, 19: 251, 3: 134, 5: 2221, 9: 795, 16: 436, 15: 56, 0: 377})


100%|█████████▉| 499/500 [01:20<00:00,  6.21it/s]


Epoch [13/20], Loss: 2.5435
Size of clusters: defaultdict(<class 'int'>, {9: 143, 15: 7, 10: 53, 7: 32, 1: 119, 6: 289, 19: 179, 13: 169, 11: 202, 5: 56, 2: 130, 16: 274, 0: 575, 17: 751, 4: 412, 18: 188, 12: 1323, 8: 1476, 14: 662, 3: 960})


100%|█████████▉| 499/500 [01:20<00:00,  6.21it/s]


Epoch [14/20], Loss: 2.6321
Size of clusters: defaultdict(<class 'int'>, {8: 128, 4: 96, 6: 112, 13: 96, 12: 256, 0: 464, 5: 240, 11: 240, 18: 336, 2: 448, 7: 816, 17: 240, 1: 304, 10: 624, 19: 368, 9: 320, 3: 577, 15: 1056, 14: 991, 16: 288})


100%|█████████▉| 499/500 [01:19<00:00,  6.24it/s]


Epoch [15/20], Loss: 2.8572
Size of clusters: defaultdict(<class 'int'>, {3: 240, 17: 176, 1: 208, 11: 273, 18: 108, 2: 106, 19: 122, 9: 92, 8: 102, 12: 143, 5: 124, 10: 116, 6: 182, 16: 190, 0: 1530, 7: 1679, 15: 980, 14: 640, 13: 676, 4: 313})


100%|█████████▉| 499/500 [01:19<00:00,  6.26it/s]


Epoch [16/20], Loss: 2.5388
Size of clusters: defaultdict(<class 'int'>, {15: 279, 7: 490, 4: 414, 6: 504, 13: 1338, 0: 1730, 11: 1427, 3: 681, 12: 862, 5: 15, 8: 47, 17: 50, 14: 1, 18: 28, 2: 23, 9: 17, 16: 16, 1: 20, 19: 29, 10: 29})


100%|█████████▉| 499/500 [01:20<00:00,  6.22it/s]


Epoch [17/20], Loss: 2.2266
Size of clusters: defaultdict(<class 'int'>, {4: 122, 12: 82, 17: 65, 2: 188, 6: 44, 18: 9, 9: 245, 1: 88, 8: 25, 7: 1571, 19: 24, 11: 676, 14: 76, 13: 185, 10: 345, 5: 502, 3: 83, 15: 255, 16: 972, 0: 2443})


100%|█████████▉| 499/500 [01:20<00:00,  6.23it/s]


Epoch [18/20], Loss: 2.2622
Size of clusters: defaultdict(<class 'int'>, {17: 20, 5: 106, 1: 104, 13: 70, 10: 20, 16: 193, 11: 43, 8: 168, 15: 5, 2: 89, 19: 193, 14: 78, 12: 306, 7: 34, 4: 9, 9: 823, 6: 276, 3: 368, 0: 903, 18: 4192})


100%|█████████▉| 499/500 [01:19<00:00,  6.24it/s]


Epoch [19/20], Loss: 1.9366
Size of clusters: defaultdict(<class 'int'>, {14: 112, 1: 89, 17: 18, 13: 82, 4: 67, 11: 24, 16: 65, 8: 24, 3: 69, 19: 53, 15: 99, 7: 6, 5: 153, 18: 187, 12: 63, 6: 439, 2: 147, 9: 2952, 10: 695, 0: 2656})


100%|█████████▉| 499/500 [01:18<00:00,  6.36it/s]


Epoch [20/20], Loss: 1.8510
Size of clusters: defaultdict(<class 'int'>, {9: 96, 1: 80, 19: 48, 11: 80, 3: 80, 18: 96, 6: 128, 12: 160, 16: 400, 0: 800, 10: 496, 4: 384, 15: 304, 5: 576, 14: 449, 8: 799, 2: 929, 13: 560, 7: 960, 17: 575})


In [12]:
torch.save(model.state_dict(), "pet_data_dc_e20.pth")



