In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
from torchvision import transforms,datasets
import torch
import os
import numpy as np
import matplotlib.pyplot as plt
from skimage.io import imread
from skimage.transform import rotate
from PIL import Image
from torch.utils.data import Dataset, DataLoader, SequentialSampler
import torch.optim as optim
import torch.nn as nn
from torchvision import models


destination_folder = "/content/gdrive/My Drive/Vision_BCN/"
dir_blue_shirts = "/content/gdrive/My Drive/Vision_BCN/blue/"
dir_no_blue_shirts = "/content/gdrive/My Drive/Vision_BCN/no_blue/"


class CustomGuilleDataSet(Dataset):
    def __init__(self, dir_1, dir_2, transform):
        self.dir_1 = dir_1
        self.dir_2 = dir_2
        self.transform = transform
        self.blue = os.listdir(dir_1)
        self.no_blue = os.listdir(dir_2)      

    def __len__(self):
        return (len(self.blue) + len(self.no_blue))

    def __getitem__(self, idx):
        #Load Blue and No_blue images one at a time to make up good trios.
        if idx < len(self.blue): img_loc = os.path.join(self.dir_1, self.blue[idx]) if idx % 2 == 0 else os.path.join(self.dir_2, self.no_blue[idx])
        else : img_loc = os.path.join(self.dir_2, self.no_blue[idx-len(self.blue)]) if idx % 2 == 1 else os.path.join(self.dir_1, self.blue[idx-len(self.blue)])
        label = 0 if idx % 2 == 0 else 1  #label 0 == Blue. label 1 == No_blue
        image = Image.open(img_loc).convert("RGB")
        tensor_image = self.transform(image)
        return tensor_image,label



trsfm = transforms.Compose(   
    [transforms.Resize((500,500)),
     transforms.RandomVerticalFlip(),
     transforms.RandomRotation(45),
     transforms.ToTensor(),
     ])

'''
#Uncomment for evaluation
trsfm = transforms.Compose(   
    [transforms.Resize((500,500)),
     transforms.ToTensor(),
     ])
'''

#Initialize custom class
data = CustomGuilleDataSet(dir_1 = dir_blue_shirts, dir_2 = dir_no_blue_shirts ,transform=trsfm )

#Parameters
batch_size = 20
validation_split = 0.15
dataset_size = len(data)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))

#Indices
train_indices, val_indices = indices[split:], indices[:split]

#Samplers
train_sampler = SequentialSampler(train_indices)
valid_sampler = SequentialSampler(val_indices)

#DataLoaders
train_loader = DataLoader(data, batch_size=batch_size, sampler=train_sampler)
validation_loader = DataLoader(data, batch_size=batch_size, sampler=valid_sampler)

#Check
#for idx, (img,label) in enumerate(train_loader):
#  print(img.shape)
#  print(label)



In [None]:
#73000 parameters to train. With dropout = 0.5 so, roughly 35000. We have
original_model = models.alexnet(pretrained=True)

class CustomGuilleNet(nn.Module):
  def __init__(self):
      super(CustomGuilleNet, self).__init__()
      self.features = nn.Sequential(
          *list(original_model.features.children())
      )
      self.fc1 = nn.Linear(9216, 16)
      self.dropout = nn.Dropout(0.5) #Since we don't have a lot of data and quite a lot of weights because of the two fc layers
      self.pool = nn.MaxPool2d(kernel_size=3,stride=2)

  def forward(self, x):
      x = self.features(x)
      x = self.pool(x)
      x = x.view(-1,256*6*6)
      x = self.dropout(x)
      x = self.fc1(x)            
      return x

  def get_embedding(self,x):
    return self.forward(x)

model = CustomGuilleNet()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

for i, param in enumerate(model.parameters()):
  if i < 10:
    param.requires_grad = False

print(count_parameters(model))

Downloading: "https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth" to /root/.cache/torch/hub/checkpoints/alexnet-owt-4df8aa71.pth


HBox(children=(FloatProgress(value=0.0, max=244418560.0), HTML(value='')))


147472


In [None]:
from itertools import combinations
import torch.nn.functional as F

class AllTripletSelector():
    """
    Returns all possible triplets
    """

    def __init__(self):
        super(AllTripletSelector, self).__init__()

    def get_triplets(self, embeddings, labels):
        labels = labels.cpu().data.numpy()
        triplets = []
        for label in set(labels):
            label_mask = (labels == label)
            label_indices = np.where(label_mask)[0]
            if len(label_indices) < 2:
                continue
            negative_indices = np.where(np.logical_not(label_mask))[0]
            anchor_positives = list(combinations(label_indices, 2))  # All anchor-positive pairs

            # Add all negatives for all positive pairs
            temp_triplets = [[anchor_positive[0], anchor_positive[1], neg_ind] for anchor_positive in anchor_positives
                             for neg_ind in negative_indices]
            triplets += temp_triplets

        return torch.LongTensor(np.array(triplets))
    
    

class OnlineTripletLoss(nn.Module):
    """
    Online Triplets loss
    Takes a batch of embeddings and corresponding labels.
    Triplets are generated using triplet_selector object that take embeddings and targets and return indices of
    triplets
    """

    def __init__(self, margin, triplet_selector):
        super(OnlineTripletLoss, self).__init__()
        self.margin = margin
        self.triplet_selector = triplet_selector

    def forward(self, embeddings, target):

        triplets = self.triplet_selector.get_triplets(embeddings, target)
        #print('tri_shape: ',triplets.shape) #con batch 20 --> (900,3) novecientos trios
        if embeddings.is_cuda:
            triplets = triplets.cuda()

        ap_distances = (embeddings[triplets[:, 0]] - embeddings[triplets[:, 1]]).pow(2).sum(1)  # .pow(.5)
        an_distances = (embeddings[triplets[:, 0]] - embeddings[triplets[:, 2]]).pow(2).sum(1)  # .pow(.5)
        losses = F.relu(ap_distances - an_distances + self.margin)

        return losses.mean()

In [None]:
def save_checkpoint(save_path, model, valid_loss):

    if save_path == None:
        return
    
    state_dict = {'model_state_dict': model.state_dict(),
                  'valid_loss': valid_loss}
    
    torch.save(state_dict, save_path)
    print(f'Model saved to ==> {save_path}')

def load_checkpoint(load_path, model):
    
    if load_path==None:
        return
    
    state_dict = torch.load(load_path, map_location=device)
    print(f'Model loaded from <== {load_path}')
    
    model.load_state_dict(state_dict['model_state_dict'])
    return state_dict['valid_loss']


def save_metrics(save_path, train_loss_list, valid_loss_list, global_steps_list):

    if save_path == None:
        return
    
    state_dict = {'train_loss_list': train_loss_list,
                  'valid_loss_list': valid_loss_list,
                  'global_steps_list': global_steps_list}
    
    torch.save(state_dict, save_path)
    print(f'Model saved to ==> {save_path}')


def load_metrics(load_path):

    if load_path==None:
        return
    
    state_dict = torch.load(load_path, map_location=device)
    print(f'Model loaded from <== {load_path}')
    
    return state_dict['train_loss_list'], state_dict['valid_loss_list'], state_dict['global_steps_list']

In [None]:
def fit(model,
        loss_fn,
        optimizer,
        batch_size,
        device,
        train_loader = train_loader,
        valid_loader = validation_loader,
        num_epochs = 500,
        eval_every = len(train_loader),
        file_path = destination_folder,
        best_valid_loss = float("Inf")):
  
    running_loss = 0.0
    valid_running_loss = 0.0
    global_step = 0
    train_loss_list = []
    valid_loss_list = []
    global_steps_list = []
    
    print('Lets start train!')
    model.train()
    for epoch in range(num_epochs):
      for img,labels in train_loader:       
        img = img.to(device)
        labels = labels.to(device)
        outputs = model(img)
        loss = loss_fn(outputs,labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        global_step += 1

        if global_step % eval_every == 0:
            model.eval()
            with torch.no_grad():       
              # validation loop
              for (img, labels) in valid_loader:
                img = img.to(device)
                labels = labels.to(device)
                outputs = model(img)
                loss = loss_fn(outputs,labels)

                valid_running_loss += loss.item()
  

            # evaluation
            average_train_loss = running_loss / eval_every
            average_valid_loss = valid_running_loss / len(valid_loader)
            train_loss_list.append(average_train_loss)
            valid_loss_list.append(average_valid_loss)
            global_steps_list.append(global_step)

            # resetting running values
            running_loss = 0.0                
            valid_running_loss = 0.0
            model.train()

            # print progress
            print('Epoch [{}/{}], Step [{}/{}], Train Loss: {:.4f}, Valid Loss: {:.4f}'
                  .format(epoch+1, num_epochs, global_step, num_epochs*len(train_loader),
                          average_train_loss, average_valid_loss))
            
            # checkpoint
            if best_valid_loss > average_valid_loss:
                best_valid_loss = average_valid_loss
                save_checkpoint(file_path + '/' + 'model.pt', model, best_valid_loss)
                save_metrics(file_path + '/' + 'metrics.pt', train_loss_list, valid_loss_list, global_steps_list)             
              

    save_metrics(file_path + '/' + 'metrics.pt', train_loss_list, valid_loss_list, global_steps_list)
    print('Finished Training!')

In [None]:
margin = 1.
loss_fn = OnlineTripletLoss(margin, AllTripletSelector())
lr = 1e-3
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)

fit(model,loss_fn,optimizer,batch_size,device)

In [None]:
train_loss_list, valid_loss_list, global_steps_list = load_metrics(destination_folder + '/metrics.pt')
plt.plot(global_steps_list, train_loss_list, label='Train')
plt.plot(global_steps_list, valid_loss_list, label='Valid')
plt.xlabel('Global Steps')
plt.ylabel('Loss')
plt.legend()
plt.show() 

In [None]:
best_model = CustomGuilleNet().to(device)

load_checkpoint(destination_folder + '/model.pt', best_model)

In [None]:
def extract_embeddings(dataloader, model):
    with torch.no_grad():
        model.eval()
        embeddings = np.zeros((len(dataloader.dataset), 16))
        labels = np.zeros(len(dataloader.dataset))
        k = 0
        for images, target in dataloader:
            
            images = images.to(device)
            embeddings[k:k+len(images)] = model.get_embedding(images).data.cpu().numpy()
            labels[k:k+len(images)] = target.numpy()
            k += len(images)
    return embeddings, labels

trsfm = transforms.Compose(   
    [transforms.Resize((500,500)),
     transforms.ToTensor(),
     ])


In [None]:
embeddings, labels = extract_embeddings(train_loader, model)

In [None]:
from sklearn.cluster import KMeans

kmeans = KMeans(n_clusters=2, random_state=0).fit_predict(embeddings)
kmeans.shape

(694,)

In [None]:
count = 0
for i in range(len(kmeans)):
  if kmeans[i] == labels[i]: count +=1
print('Result: 'count/len(kmeans)) #0.5648 --> Slightly better than random --> Net didn't learn the patterns of the images it was trained on.

0.5648414985590778